import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
import seaborn as sns
# --- 设置绘图风格 ---
sns.set_style("whitegrid")
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
# --- 1. 生成合成数据 ---
np.random.seed(0)
X = np.random.rand(50, 2) * 10
X[:, 1] = X[:, 1] + 0.5 * X[:, 0] + np.random.randn(50) * 0.5
# --- 2. 设定一个有缺失值的点 ---
point_with_missing = np.array([[5, np.nan]])
all_other_points = X
# --- 3. 基于非缺失特征寻找邻居 ---
k = 5
nn = NearestNeighbors(n_neighbors=k)
nn.fit(all_other_points[:, 0].reshape(-1, 1))
distances, indices = nn.kneighbors(point_with_missing[:, 0].reshape(-1, 1))
neighbors = all_other_points[indices.flatten()]
imputed_value = np.mean(neighbors[:, 1])
imputed_point = np.array([[5, imputed_value]])
# --- 4. 可视化 ---
plt.figure(figsize=(10, 7))
# 绘制所有点
plt.scatter(all_other_points[:, 0], all_other_points[:, 1], c='lightblue', label='完整数据点', s=60, alpha=0.8, edgecolors='black')
# 突出邻居
plt.scatter(neighbors[:, 0], neighbors[:, 1], c='orange', s=120, label=f'K={k}个最近邻', marker='*')
# 绘制填充后的点
plt.scatter(imputed_point[:, 0], imputed_point[:, 1], c='red', s=200, label='填充后的数据点', marker='X', edgecolors='black', linewidth=2)
# 连接到邻居
for neighbor in neighbors:
plt.plot([imputed_point[0,0], neighbor[0]], [imputed_point[0,1], neighbor[1]], 'k--', alpha=0.5)
plt.title('K-近邻 (KNN) 填充的直观解释', fontsize=16)
plt.xlabel('市净率 (已知)', fontsize=12)
plt.ylabel('资产回报率 (部分缺失)', fontsize=12)
plt.legend(fontsize=12)
plt.grid(True)
plt.show()