import numpy as np
import matplotlib.pyplot as plt
# ==================== 三种相关性场景 ====================
scenarios = [
{'name': '低相关(ρ=0.3)', 'corr': 0.3},
{'name': '中度相关(ρ=0.7)', 'corr': 0.7},
{'name': '高相关(ρ=0.95)', 'corr': 0.95}
]
# ==================== 资产参数 ====================
sigma1 = 0.20 # 资产1年化波动率20%
sigma2 = 0.25 # 资产2年化波动率25%
w1 = 0.5 # 等权配置
w2 = 0.5
print('投资组合风险分析:')
print(f'资产1波动率: {sigma1:.2%}')
print(f'资产2波动率: {sigma2:.2%}')
print(f'等权重配置(w1={w1}, w2={w2})\n')
# ==================== 计算组合风险 ====================
results = []
for scenario in scenarios:
rho = scenario['corr']
# 组合方差: σ² = w₁²σ₁² + w₂²σ₂² + 2w₁w₂σ₁σ₂ρ
portfolio_var = (w1**2) * (sigma1**2) + (w2**2) * (sigma2**2) + 2 * w1 * w2 * sigma1 * sigma2 * rho
portfolio_std = np.sqrt(portfolio_var)
weighted_avg_std = w1 * sigma1 + w2 * sigma2
risk_reduction = (weighted_avg_std - portfolio_std) / weighted_avg_std
results.append({
'场景': scenario['name'],
'相关系数': rho,
'组合波动率': portfolio_std,
'加权平均波动率': weighted_avg_std,
'风险降低': risk_reduction
})
print(f"{scenario['name']}:")
print(f" 组合波动率: {portfolio_std:.4f} ({portfolio_std:.2%})")
print(f" 加权平均波动率: {weighted_avg_std:.4f} ({weighted_avg_std:.2%})")
print(f" 风险降低: {risk_reduction:.2%}\n")
# ==================== 可视化 ====================
fig, ax = plt.subplots(figsize=(10, 6))
scenarios_names = [r['场景'] for r in results]
portfolio_stds = [r['组合波动率'] for r in results]
weighted_stds = [r['加权平均波动率'] for r in results]
x = np.arange(len(scenarios_names))
width = 0.35
bars1 = ax.bar(x - width/2, weighted_stds, width, label='加权平均波动率', color='coral', alpha=0.7)
bars2 = ax.bar(x + width/2, portfolio_stds, width, label='组合波动率', color='steelblue', alpha=0.7)
ax.set_title('相关性对投资组合风险的影响', fontsize=14)
ax.set_ylabel('年化波动率', fontsize=12)
ax.set_xticks(x)
ax.set_xticklabels(scenarios_names)
ax.legend(fontsize=10)
ax.grid(axis='y', alpha=0.3)
for bars in [bars1, bars2]:
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.2%}', ha='center', va='bottom', fontsize=10)
plt.tight_layout()
plt.show()