import numpy as np
import matplotlib.pyplot as plt
p = np.linspace(0.001, 0.999, 200)
gini = 2 * p * (1-p)
entropy = - (p * np.log2(p) + (1-p) * np.log2(1-p))
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(p, gini, label='Gini Impurity (基尼不纯度)', color='royalblue', linewidth=2.5)
ax.plot(p, entropy, label='Entropy (信息熵)', color='crimson', linewidth=2.5)
ax.set_title('基尼不纯度 vs. 信息熵 (二分类)', fontsize=16)
ax.set_xlabel('类别 1 的比例 (p)', fontsize=12)
ax.set_ylabel('不纯度', fontsize=12)
ax.axvline(0.5, color='grey', linestyle='--', lw=1)
ax.annotate('不纯度最高', xy=(0.5, 1.0), xytext=(0.55, 0.9),
arrowprops=dict(facecolor='black', shrink=0.05, width=1, headwidth=8),
fontsize=12)
ax.legend()
plt.show()