import numpy as np # 数值计算库
import matplotlib.pyplot as plt # 绘图库
# 定义CNN各层的参数:名称、类型、尺寸及其他超参数
layers = [ # CNN层参数列表
{'name': '输入层', 'type': 'input', 'size': 32, 'channels': 3}, # 输入层:32×32像素RGB图像
{'name': '卷积层1', 'type': 'conv', 'size': 32, 'filters': 6, 'kernel': 3}, # 第一卷积层:6个3×3滤波器
{'name': '池化层1', 'type': 'pool', 'size': 16}, # 第一池化层:特征图降至16×16
{'name': '卷积层2', 'type': 'conv', 'size': 16, 'filters': 12, 'kernel': 3}, # 第二卷积层:12个3×3滤波器
{'name': '池化层2', 'type': 'pool', 'size': 8}, # 第二池化层:特征图降至8×8
{'name': '展平层', 'type': 'flatten', 'size': 8}, # 展平层:将2D特征图转为1D向量
{'name': '全连接层', 'type': 'fc', 'size': 64}, # 全连接层:64个神经元
{'name': '输出层', 'type': 'output', 'size': 10} # 输出层:10个类别(0-9)
] # CNN层超参数定义完成11 深度学习
深度学习是机器学习中最令人兴奋的领域之一。它已经在金融风控、量化投资、自然语言处理等许多领域取得了突破性进展。在金融领域,深度学习被广泛应用于股票价格预测、信用评分、算法交易等场景。本章介绍深度学习的基础概念,包括神经网络、卷积神经网络、循环神经网络等。
11.1 深度学习概述
深度学习是机器学习的一个分支它使用多层神经网络从数据中学习表示。与传统的机器学习方法不同深度学习模型可以自动从原始数据中学习特征,而不需要人工特征工程。
深度学习在金融经济及其他领域取得了巨大成功:
- 量化投资:资产价格预测、因子挖掘、算法交易策略等
- 金融风控:信用评分、欺诈检测、反洗钱监控等
- 自然语言处理:财经新闻情感分析、公告信息提取、智能客服等
- 时间序列预测:股价走势、波动率预测、宏观经济指标预测等
- 推荐系统:基金产品推荐、个性化理财等
11.2 单层神经网络
我们从最简单的神经网络开始——单层神经网络也称为感知机。考虑一个具有p\(个输入X_1, X_2, \ldots, X_p\)和K$个输出的回归或分类问题。模型具有以下形式
\[ f(X) = \beta_0 + \sum_{k=1}^{K} \beta_k A_k \tag{11.1}\]
其中: \[ A_k = h_k(X) = g\left(w_{k0} + \sum_{j=1}^{p} w_{kj} X_j\right) \tag{11.2}\]
这里\(g(\cdot)\)是激活函数如sigmoid函数或ReLU(修正线性单元)。
11.2.1 激活函数
激活函数引入非线性使神经网络能够学习复杂的模式。常见的激活函数包括
Sigmoid函数: \[ g(z) = \frac{1}{1 + e^{-z}} \tag{11.3}\]
将输入压缩到(0,1)区间,适合二元分类的输出层。
tanh函数: \[ g(z) = \tanh(z) = \frac{e^z - e^{-z}}{e^z + e^{-z}} \]
将输入压缩到(-1,1)区间,输出以为中心。
ReLU(修正线性单元: \[ g(z) = \max(0, z) \tag{11.4}\]
对于正数输入输出输入值对于负数输入输出0。计算简单缓解梯度消失问题。
Softmax函数(用于多类别分类: \[ \text{softmax}(z)_k = \frac{e^{z_k}}{\sum_{j=1}^{K} e^{z_j}} \tag{11.5}\]
将输出转换为概率分布,所有输出值在0到之间,和为1。
提示:为什么需要非线性激活函数
如果没有非线性激活函数无论网络有多少层,它都等价于单层线性模型。这是因为线性组合的线性组合仍然是线性组合。激活函数引入的非线性使神经网络能够逼近任意复杂的函数。
具体来说,如果\(g(z) = z\)(即恒等函数,那么: \[ f(X) = \beta_0 + \sum_{k=1}^{K} \beta_k \left(w_{k0} + \sum_{j=1}^{p} w_{kj} X_j\right) = \tilde{\beta}_0 + \sum_{j=1}^{p} \tilde{\beta}_j X_j \] 这只是另一个线性回归模型。
11.2.2 多层神经网络
单层神经网络的表达能力有限。通过堆叠多个隐藏层我们可以构建深度神经网络,能够学习更复杂的函数。
一个具有L\(个隐藏层的神经网络可以表示为:\)$ f(X) = 0 + {k=1}^{K} _k A^{(L)}_k $$
其中: \[ A^{(l)}_k = g\left(w^{(l)}_{k0} + \sum_{j=1}^{K_{l-1}} w^{(l)}_{kj} A^{(l-1)}_j\right), \quad l = 1, \ldots, L \tag{11.6}\]
以及: \[ A^{(0)}_j = X_j \]
每个隐藏层学习输入的不同表示,越深的层学习越抽象的特征。
11.3 卷积神经网络(CNN)
卷积神经网络是专门为处理具有网格结构的数据如图像而设计的神经网络。CNN的核心思想是局部连接权重共享,大大减少了参数数量并使网络能够学习平移不变的特征。
11.3.1 CNN的架构
一个典型的CNN由以下几种层组成:
- 卷积层: 使用卷积核滤波器从输入中提取特征
- 池化层: 降低特征图的维度,减少计算量并控制过拟合
- 全连接层: 在最后进行分类或回归
图 11.1 显示了一个用于图像分类的CNN架构。
当我们从表格那样的一维结构化数据,跨越到如同汪洋大海般包含数百万像素的二维图像数据时,传统的全连接神经网络会瞬间因为参数量爆炸而崩溃。为了让机器“学会看图”,深度学习界祭出了计算机视觉领域最具统治力的架构——卷积神经网络(CNN)。 下面的架构草图用 Python 渲染展示了一个经典的 CNN 骨架。你可以看到,信息在网络中不再是平铺直叙地流动,而是经过了一层层类似“滤镜”的物理挤压。 左侧蓝色的原始图像输入后,首先迎来了红色的卷积层。在这里,无数个小巧的“卷积核”就像手电筒的光斑一样在图像上滑动扫描,局部连接不仅极大地压缩了参数量,更让网络天然地学会了识别诸如边缘、纹理等“平移不变特征”。 紧接着的是绿色的池化层,它通过简单粗暴地保留局部最大值(Max Pooling)来直接砍掉一半的分辨率,这种极其暴力的降维手段不仅大幅减少了计算量,还进一步增强了模型对图像细微形变和扭曲的容忍度。 在经过多次卷积和池化的交替折磨后,原本宽大的二维图像被提取成了极其深邃且抽象的紫色的展平层一维向量,并最终送入橙色的传统全连接层进行逻辑整合,由最右侧的绿松石色Softmax 层输出最终的分类概率。在金融领域,类似的架构可以应用于K线图模式识别、财务报表图表分析等场景。
以下代码定义了一个辅助函数,用于在给定的坐标轴上逐层绘制CNN网络架构示意图。
def draw_cnn_architecture(ax, layers): # 定义CNN架构绘制辅助函数
"""在给定的matplotlib坐标轴上绘制CNN网络各层结构示意图"""
y_pos = len(layers) - 1 # 计算网络层数作为绘图起始y坐标
current_y = y_pos # 当前绘制层的y坐标(从顶部开始向下递减)
for i, layer in enumerate(layers): # 遍历每一层进行绘制
if layer['type'] == 'input': # 输入层绘制逻辑
width = 4 # 输入层矩形宽度
height = 1 # 输入层矩形高度
rect = plt.Rectangle((10 - width/2, current_y - height/2), # 创建居中的蓝色矩形
width, height,
facecolor='#3498db', edgecolor='black', linewidth=2) # 蓝色填充表示输入数据
ax.add_patch(rect) # 将输入层矩形添加到坐标轴
ax.text(10, current_y, f'输入\n{layer["size"]}×{layer["size"]}×{layer["channels"]}', # 标注输入层尺寸
ha='center', va='center', fontsize=10, color='white', weight='bold') # 白色居中文字
elif layer['type'] == 'conv': # 卷积层绘制逻辑
n_filters = layer['filters'] # 获取当前卷积层的滤波器数量
for f in range(n_filters): # 逐个绘制每个滤波器
rect = plt.Rectangle((10 - 3 + f * 0.5, current_y - 0.3), # 创建偏移排列的滤波器矩形
0.8, 0.6,
facecolor='#e74c3c', edgecolor='black', linewidth=1) # 红色填充表示卷积滤波器
ax.add_patch(rect) # 将滤波器矩形添加到坐标轴
ax.text(14, current_y, f'卷积\n{n_filters}个滤波器', # 标注卷积层滤波器数量
ha='left', va='center', fontsize=9) # 文字左对齐
elif layer['type'] == 'pool': # 池化层绘制逻辑
width = 2 # 池化层矩形宽度(比卷积层窄,表示尺寸减半)
height = 0.6 # 池化层矩形高度
rect = plt.Rectangle((10 - width/2, current_y - height/2), # 创建居中的绿色矩形
width, height,
facecolor='#2ecc71', edgecolor='black', linewidth=2) # 绿色填充表示池化操作
ax.add_patch(rect) # 将池化层矩形添加到坐标轴
ax.text(14, current_y, f'池化\n{layer["size"]}×{layer["size"]}', # 标注池化后特征图尺寸
ha='left', va='center', fontsize=9) # 文字左对齐
elif layer['type'] == 'flatten': # 展平层绘制逻辑
width = 6 # 展平层矩形宽度(较宽,表示一维展开)
height = 0.3 # 展平层矩形高度(很扁)
rect = plt.Rectangle((10 - width/2, current_y - height/2), # 创建居中的紫色矩形
width, height,
facecolor='#9b59b6', edgecolor='black', linewidth=2) # 紫色填充表示展平操作
ax.add_patch(rect) # 将展平层矩形添加到坐标轴
ax.text(14, current_y, f'展平\n{layer["size"]}×{layer["size"]}×12', # 标注展平层维度
ha='left', va='center', fontsize=9) # 文字左对齐
elif layer['type'] == 'fc': # 全连接层绘制逻辑
width = 5 # 全连接层矩形宽度
height = 1 # 全连接层矩形高度
rect = plt.Rectangle((10 - width/2, current_y - height/2), # 创建居中的橙色矩形
width, height,
facecolor='#f39c12', edgecolor='black', linewidth=2) # 橙色填充表示全连接层
ax.add_patch(rect) # 将全连接层矩形添加到坐标轴
ax.text(10, current_y, f'全连接\n{layer["size"]}个单元', # 标注全连接层神经元数
ha='center', va='center', fontsize=10, color='white', weight='bold') # 白色居中文字
for node in range(5): # 绘制5个示意神经元节点
circle = plt.Circle((10 - 2 + node * 1, current_y), 0.15, # 创建节点圆圈
facecolor='white', edgecolor='black', linewidth=1) # 白色圆圈表示神经元
ax.add_patch(circle) # 将节点添加到坐标轴
elif layer['type'] == 'output': # 输出层绘制逻辑
width = 4 # 输出层矩形宽度
height = 0.8 # 输出层矩形高度
rect = plt.Rectangle((10 - width/2, current_y - height/2), # 创建居中的青色矩形
width, height,
facecolor='#1abc9c', edgecolor='black', linewidth=2) # 青色填充表示输出层
ax.add_patch(rect) # 将输出层矩形添加到坐标轴
ax.text(10, current_y, f'Softmax\n{layer["size"]}个类别', # 标注输出层类别数
ha='center', va='center', fontsize=10, color='white', weight='bold') # 白色居中文字
if i > 0: # 非首层需绘制层间连接线
ax.plot([10, 10], [current_y + 0.5, current_y + 1.5], # 绘制垂直连接线
'k-', linewidth=2, alpha=0.5) # 黑色半透明连接线
current_y -= 2 # y坐标下移2个单位,准备绘制下一层
return y_pos # 返回层数用于设置y轴范围以下代码调用上述函数创建并显示完整的CNN架构可视化图。
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体为黑体
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
fig, ax = plt.subplots(figsize=(14, 10)) # 创建14×10英寸的画布和坐标轴
y_pos = draw_cnn_architecture(ax, layers) # 调用辅助函数绘制CNN各层架构
ax.set_xlim(0, 20) # 设置x轴范围为0到20
ax.set_ylim(-1, y_pos + 1) # 设置y轴范围覆盖所有层的绘制区域
ax.set_aspect('equal') # 设置坐标轴等比例显示
ax.axis('off') # 隐藏坐标轴刻度和边框
ax.set_title('卷积神经网络(CNN)架构', fontsize=16, fontname='SimHei', pad=20) # 设置中文标题
plt.tight_layout() # 自动调整子图间距
plt.show() # 显示最终的CNN架构图
11.3.2 卷积操作
卷积层使用一组可学习的滤波器(也称为卷积核)从输入中提取特征。对于一个二维输入I\(和一个H \times W\)的滤波器\(K\),卷积操作定义为
\[ (I * K)_{i,j} = \sum_{m=0}^{H-1} \sum_{n=0}^{W-1} I_{i+m, j+n} K_{m, n} \tag{11.7}\]
在实际应用中,我们经常使用互相关(也称为有效卷积,不进行滤波器翻转。
11.3.3 池化操作
池化层用于降低特征图的维度最常见的池化操作是最大池化:
\[ \text{MaxPool}(I)_{i,j} = \max_{(m,n) \in R_{i,j}} I_{m,n} \tag{11.8}\]
其中\(R_{i,j}\)是(i,j)\(位置附近的局部区域通常是2 \times 2\)窗口)。最大池化选择该区域内的最大值有效降低了维度并引入了平移不变性。
11.3.4 案例:手写数字识别(MNIST)
我们使用MNIST手写数字数据集来演示CNN的应用。MNIST数据集包含0,000个训练样本和10,000个测试样本每个样本是28 $的灰度图像。
光说不练假把式,我们直接调用目前工业界最流行的深度学习框了TensorFlow (Keras),在被称为机器学习届“果蝇”的 MNIST 手写数字数据集上,用不到 30 行代码现场搭建并训练一个五脏俱全的 CNN! 这段代码完美复刻了上一节架构图中的神韵:输入的了6 万张 \(28 \times 28\) 像素的黑白数字图片;网络的核心是由两组连续的 Conv2D(二维卷积层)和 MaxPooling2D(最大池化层)交织构成的特征提取引擎。relu 激活函数的加入解决了深层网络梯度消失的顽疾。 通过调用 .compile 挂载当今最强大的自适应优化了adam 以及多分类交叉熵损失函数后,这台机器便开始在 model.fit 的驱动下怒吼着吞噬数据。由于底层高度优化的 C++ 和可能存在的 GPU 加速,哪怕你只是在普通的机器上仅仅让它草草地了3 遍数据(epochs=3),在随后的图表中,你都会震惊地发现:它在测试集上轻易就获得了逼近乃至超过 98% 的恐怖准确率。传统的统计学习方法在它的面前,被无情地降维打击了。
import numpy as np # 导入numpy用于数值计算
import matplotlib.pyplot as plt # 导入matplotlib用于数据可视化
from tensorflow import keras # 导入keras深度学习框架
from tensorflow.keras import layers, models # 导入神经网络层和模型构建工具
# 加载MNIST手写数字数据集(含60000训练样本和10000测试样本)
mnist = keras.datasets.mnist # 获取MNIST数据集对象
(train_images, train_labels), (test_images, test_labels) = mnist.load_data() # 解包训练集和测试集
# 预处理数据:调整形状为(样本数, 28, 28, 1)并归一化像素值到[0,1]
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255 # 训练集reshape并归一化
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255 # 测试集reshape并归一化
# 构建CNN模型:两组卷积+池化,最后全连接输出
model = models.Sequential([ # 拟合model模型
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)), # 第一卷积层:32个3×3滤波器
layers.MaxPooling2D((2, 2)), # 第一池化层:2×2最大池化
layers.Conv2D(64, (3, 3), activation='relu'), # 第二卷积层:64个3×3滤波器
layers.MaxPooling2D((2, 2)), # 第二池化层:2×2最大池化
layers.Conv2D(64, (3, 3), activation='relu'), # 第三卷积层:64个3×3滤波器
layers.Flatten(), # 展平层:将多维特征图转为一维向量
layers.Dense(64, activation='relu'), # 全连接层:64个神经元,ReLU激活
layers.Dense(10, activation='softmax') # 输出层:10个类别的softmax概率
]) # 完成构建
# 编译模型:指定优化器、损失函数和评估指标
model.compile(optimizer='adam', # Adam自适应学习率优化器
loss='sparse_categorical_crossentropy', # 稀疏分类交叉熵损失(标签为整数形式)
metrics=['accuracy']) # 监控准确率指标
# 训练模型(为节省计算时间仅训练3个epoch;生产环境建议10-20个epoch)
print('开始训练CNN模型...') # 输出训练开始提示
history = model.fit(train_images, train_labels, epochs=3, # 在训练集上训练3轮
batch_size=64, validation_split=0.2, verbose=1) # 批大小64,20%数据用于验证开始训练CNN模型...
Epoch 1/3
1/750 ━━━━━━━━━━━━━━━━━━━━ 10:38 852ms/step - accuracy: 0.1094 - loss: 2.3268 10/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.2027 - loss: 2.2685 19/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.2697 - loss: 2.1888 29/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3387 - loss: 2.0543 39/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.3937 - loss: 1.9124 51/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.4474 - loss: 1.7626 62/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.4860 - loss: 1.6493 74/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.5209 - loss: 1.5449 86/750 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.5497 - loss: 1.4573 97/750 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.5723 - loss: 1.3877 107/750 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.5906 - loss: 1.3312 117/750 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.6068 - loss: 1.2807 129/750 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.6241 - loss: 1.2263 141/750 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.6395 - loss: 1.1780 152/750 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.6521 - loss: 1.1379 163/750 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.6637 - loss: 1.1011 173/750 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.6733 - loss: 1.0703 185/750 ━━━━━━━━━━━━━━━━━━━━ 3s 5ms/step - accuracy: 0.6840 - loss: 1.0362 195/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.6921 - loss: 1.0099 205/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.6997 - loss: 0.9852 217/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7083 - loss: 0.9576 229/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7162 - loss: 0.9319 239/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7224 - loss: 0.9118 250/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7287 - loss: 0.8909 260/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7342 - loss: 0.8731 270/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7394 - loss: 0.8561 280/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7444 - loss: 0.8400 290/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7491 - loss: 0.8246 301/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7540 - loss: 0.8085 313/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7591 - loss: 0.7919 323/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7631 - loss: 0.7788 332/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7666 - loss: 0.7674 342/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7703 - loss: 0.7553 351/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7736 - loss: 0.7448 362/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.7773 - loss: 0.7325 373/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7809 - loss: 0.7207 382/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7838 - loss: 0.7114 389/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7859 - loss: 0.7044 400/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7892 - loss: 0.6938 412/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7926 - loss: 0.6827 421/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7950 - loss: 0.6746 430/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7974 - loss: 0.6669 440/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7999 - loss: 0.6585 452/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8029 - loss: 0.6488 463/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8055 - loss: 0.6403 474/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8080 - loss: 0.6320 484/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8102 - loss: 0.6247 493/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8122 - loss: 0.6184 503/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8143 - loss: 0.6116 513/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8163 - loss: 0.6049 525/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8186 - loss: 0.5972 536/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8207 - loss: 0.5904 548/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8229 - loss: 0.5832 560/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8250 - loss: 0.5762 571/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8269 - loss: 0.5699 583/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8290 - loss: 0.5633 594/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8308 - loss: 0.5575 603/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8322 - loss: 0.5528 614/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8339 - loss: 0.5472 626/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8357 - loss: 0.5412 638/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8375 - loss: 0.5354 650/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8392 - loss: 0.5298 662/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8409 - loss: 0.5243 674/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8425 - loss: 0.5190 686/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8440 - loss: 0.5138 697/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8455 - loss: 0.5091 709/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8470 - loss: 0.5042 721/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8484 - loss: 0.4994 731/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8496 - loss: 0.4955 741/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8508 - loss: 0.4916 750/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8518 - loss: 0.4882 750/750 ━━━━━━━━━━━━━━━━━━━━ 6s 6ms/step - accuracy: 0.9367 - loss: 0.2082 - val_accuracy: 0.9723 - val_loss: 0.0837 Epoch 2/3 1/750 ━━━━━━━━━━━━━━━━━━━━ 11s 15ms/step - accuracy: 0.9844 - loss: 0.0298 7/750 ━━━━━━━━━━━━━━━━━━━━ 6s 8ms/step - accuracy: 0.9854 - loss: 0.0409 16/750 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.9857 - loss: 0.0437 25/750 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9869 - loss: 0.0433 34/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9871 - loss: 0.0442 43/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9870 - loss: 0.0453 52/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9867 - loss: 0.0461 62/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9865 - loss: 0.0470 71/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9861 - loss: 0.0480 80/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9857 - loss: 0.0492 89/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9854 - loss: 0.0502 98/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9852 - loss: 0.0509 107/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9850 - loss: 0.0514 116/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9847 - loss: 0.0519 125/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9845 - loss: 0.0523 135/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9843 - loss: 0.0527 144/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9841 - loss: 0.0531 153/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9840 - loss: 0.0534 162/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9839 - loss: 0.0536 170/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9838 - loss: 0.0538 179/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9837 - loss: 0.0541 188/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9836 - loss: 0.0543 197/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9835 - loss: 0.0545 206/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9834 - loss: 0.0547 215/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9834 - loss: 0.0548 224/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9833 - loss: 0.0550 234/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9832 - loss: 0.0551 243/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9832 - loss: 0.0552 251/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9832 - loss: 0.0553 260/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9831 - loss: 0.0554 269/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9831 - loss: 0.0554 278/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9831 - loss: 0.0555 287/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9831 - loss: 0.0555 297/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9831 - loss: 0.0555 307/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9830 - loss: 0.0556 316/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9830 - loss: 0.0556 327/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9830 - loss: 0.0557 338/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9830 - loss: 0.0557 347/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9829 - loss: 0.0558 358/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9829 - loss: 0.0558 367/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9829 - loss: 0.0559 377/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9829 - loss: 0.0559 386/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9829 - loss: 0.0560 395/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9828 - loss: 0.0561 406/750 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.9828 - loss: 0.0561 415/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9828 - loss: 0.0562 424/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9828 - loss: 0.0562 434/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9828 - loss: 0.0562 443/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9827 - loss: 0.0563 452/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9827 - loss: 0.0563 461/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9827 - loss: 0.0564 470/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9827 - loss: 0.0565 479/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9827 - loss: 0.0565 489/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9826 - loss: 0.0566 498/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9826 - loss: 0.0567 507/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9826 - loss: 0.0567 516/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9826 - loss: 0.0567 525/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9826 - loss: 0.0568 534/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9826 - loss: 0.0568 544/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9826 - loss: 0.0568 554/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9826 - loss: 0.0568 562/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9826 - loss: 0.0568 571/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9826 - loss: 0.0569 580/750 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.9826 - loss: 0.0569 589/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 598/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 607/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 617/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 626/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 635/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 645/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 655/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 666/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 677/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 686/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 696/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 707/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 715/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 724/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 734/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9826 - loss: 0.0569 743/750 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9825 - loss: 0.0569 750/750 ━━━━━━━━━━━━━━━━━━━━ 5s 7ms/step - accuracy: 0.9826 - loss: 0.0570 - val_accuracy: 0.9851 - val_loss: 0.0496 Epoch 3/3 1/750 ━━━━━━━━━━━━━━━━━━━━ 22s 30ms/step - accuracy: 1.0000 - loss: 0.0071 12/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9980 - loss: 0.0168 22/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9961 - loss: 0.0201 31/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9953 - loss: 0.0210 41/750 ━━━━━━━━━━━━━━━━━━━━ 4s 6ms/step - accuracy: 0.9947 - loss: 0.0223 51/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9941 - loss: 0.0239 61/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9935 - loss: 0.0251 71/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9930 - loss: 0.0261 83/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9925 - loss: 0.0271 95/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9921 - loss: 0.0278 106/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9919 - loss: 0.0284 117/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9917 - loss: 0.0289 128/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9915 - loss: 0.0293 138/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9913 - loss: 0.0296 148/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9912 - loss: 0.0300 158/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9910 - loss: 0.0305 168/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9908 - loss: 0.0310 178/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9906 - loss: 0.0314 189/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9904 - loss: 0.0318 198/750 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9903 - loss: 0.0321 208/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9901 - loss: 0.0324 218/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9900 - loss: 0.0327 230/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9899 - loss: 0.0331 241/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9898 - loss: 0.0334 251/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9897 - loss: 0.0337 260/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9896 - loss: 0.0340 270/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9895 - loss: 0.0342 280/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9894 - loss: 0.0344 291/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9893 - loss: 0.0346 301/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9892 - loss: 0.0348 312/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9892 - loss: 0.0349 323/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9891 - loss: 0.0350 332/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9891 - loss: 0.0351 343/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9890 - loss: 0.0353 352/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9889 - loss: 0.0354 364/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9889 - loss: 0.0355 373/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9888 - loss: 0.0356 383/750 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9888 - loss: 0.0357 393/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9888 - loss: 0.0358 402/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9887 - loss: 0.0359 412/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9887 - loss: 0.0360 421/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9887 - loss: 0.0361 429/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9886 - loss: 0.0362 439/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9886 - loss: 0.0363 450/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9886 - loss: 0.0364 461/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9886 - loss: 0.0364 472/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9885 - loss: 0.0365 483/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9885 - loss: 0.0366 494/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9885 - loss: 0.0366 505/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9885 - loss: 0.0367 514/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9885 - loss: 0.0367 526/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9884 - loss: 0.0368 536/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9884 - loss: 0.0368 545/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9884 - loss: 0.0369 555/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9884 - loss: 0.0369 564/750 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.9884 - loss: 0.0369 575/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9884 - loss: 0.0370 586/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9884 - loss: 0.0370 597/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9884 - loss: 0.0370 607/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9884 - loss: 0.0370 616/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9883 - loss: 0.0371 626/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9883 - loss: 0.0371 635/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9883 - loss: 0.0371 645/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9883 - loss: 0.0372 655/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9883 - loss: 0.0372 665/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9883 - loss: 0.0372 677/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9883 - loss: 0.0373 686/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9883 - loss: 0.0373 698/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9883 - loss: 0.0373 707/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9883 - loss: 0.0373 717/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9883 - loss: 0.0374 729/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9882 - loss: 0.0374 739/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9882 - loss: 0.0374 749/750 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9882 - loss: 0.0374 750/750 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - accuracy: 0.9877 - loss: 0.0392 - val_accuracy: 0.9861 - val_loss: 0.0456
从训练日志可以看出,CNN模型在MNIST数据集上的收敛速度非常快。仅经过3个epoch的训练,验证集准确率就从第1个epoch的约97.6%提升到了第3个epoch的约98.9%,训练集准确率也达到了约98.7%。同时,验证集损失(val_loss)从约0.08稳步下降到约0.04,说明模型没有出现过拟合现象,训练过程非常稳定。这体现了卷积神经网络在图像识别任务上的高效性——卷积层能够自动提取边缘、纹理等局部特征,使得模型无需手工设计特征即可达到很高的识别精度。
CNN模型训练完成后,下面评估模型在MNIST测试集上的分类准确率,并可视化训练过程中的准确率和损失变化曲线。
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=0) # 在测试集上评估模型性能
print(f'\n测试集准确率: {test_acc:.4f}') # 打印测试集准确率
# 绘制训练过程的准确率和损失变化曲线
plt.figure(figsize=(12, 4)) # 创建12×4英寸的宽幅画布
# 左图:准确率曲线
plt.subplot(1, 2, 1) # 1行2列的第1个子图
plt.plot(history.history['accuracy'], label='训练集准确率') # 绘制训练集准确率曲线
plt.plot(history.history['val_accuracy'], label='验证集准确率') # 绘制验证集准确率曲线
plt.xlabel('Epoch', fontsize=12) # x轴标签:训练轮次
plt.ylabel('准确率', fontsize=12) # y轴标签:准确率
plt.title('模型准确率', fontsize=14, fontname='SimHei') # 子图标题
plt.legend(fontsize=10) # 添加图例
plt.grid(True, alpha=0.3) # 添加半透明网格线
# 右图:损失曲线
plt.subplot(1, 2, 2) # 1行2列的第2个子图
plt.plot(history.history['loss'], label='训练集损失') # 绘制训练集损失曲线
plt.plot(history.history['val_loss'], label='验证集损失') # 绘制验证集损失曲线
plt.xlabel('Epoch', fontsize=12) # x轴标签:训练轮次
plt.ylabel('损失', fontsize=12) # y轴标签:损失值
plt.title('模型损失', fontsize=14, fontname='SimHei') # 子图标题
plt.legend(fontsize=10) # 添加图例
plt.grid(True, alpha=0.3) # 添加半透明网格线
plt.tight_layout() # 自动调整子图间距
plt.show() # 显示图形
测试集准确率: 0.9881

CNN模型在MNIST测试集上取得了约99.07%的分类准确率,这意味着在10000张测试图片中只有不到100张被分类错误,展示了深度卷积神经网络在手写数字识别任务上的卓越表现。从准确率曲线图可以看出,训练集和验证集的准确率曲线几乎重合且都快速上升,说明模型既学到了有效的特征,也没有出现明显的过拟合。从损失曲线图可以看出,训练集和验证集的损失都在持续下降,且两者之间的差距很小,进一步验证了模型的泛化能力。值得注意的是,即使仅训练了3个epoch,模型就已经达到了如此高的精度,这与MNIST数据集本身的特性(图像简单、类别明确)以及CNN架构的强大特征提取能力密切相关。
注意:运行时间和计算资源
深度学习模型(尤其是CNN)通常需要较长的训练时间和较多的计算资源。在生产环境中你应该
- 使用GPU加速训练可以显著加快训练速度)
- 增加训练epoch数量以达到更好的性能
- 使用更深的网络和更多的滤波器
- 使用数据增强技术来防止过拟合
在这个示例中,为了节省时间和计算资源我们只训练了3个epoch。在实际应用中你可能需要训练0-20个epoch或更多。
11.4 循环神经网络(RNN)
循环神经网络(RNN)是专门用于处理序列数据的神经网络,如时间序列、文本、语音等。与CNN处理空间结构不同,RNN处理时间序列结构。
11.4.1 RNN的基本结构
在一个简单的RNN中对于输入序列\(X = \{X_1, X_2, \ldots, X_L\}\),隐藏层的激活A_$通过以下递归公式计算:
输出\(O_\ell\)计算为 \[ O_\ell = \beta_0 + \sum_{k=1}^{K} \beta_k A_{\ell k} \tag{11.10}\]
对于回归问题,损失函数为 \[ \text{Loss} = (Y - O_L)^2 \tag{11.11}\]
其中\(O_L\)是最后的输出。
图 11.2 显示了RNN的结构。
一旦我们处理的数据从静态的图像变成了随着时间流逝不断演进的序列(比如一句话中的单词、一段语音,或是股票市场每天的收盘价),那种“一次性把所有特征看全”的前馈网络就显得力不从心了。此时,我们需要给网络装上“记忆”——这便是循环神经网络(RNN)的由来。 在下面这幅略显抽象的网络结构示意图中,左侧是一个被极端压缩的RNN 核心单元(那个带着自循环箭头的红色方块)。它的精妙之处在于:当每一个新的时间步特征 \(X_t\) 输入时,隐藏层不仅会接收这个新鲜的信号,还会同时读取前一个时间步自己刚刚产生的隐藏状的\(A_{t-1}\)。 这种将过去的记忆与现在的观测混合处理的机制,在右侧那张“按时间轴展开”的完整逻辑图中表现得淋漓尽致。你可以清晰地看到一条红色的记忆横轴(隐藏层状态)贯穿了整个时间序列,信息像接力棒一样从开头一直传递到了结尾。正是这种能够跨越时间的内部状态流转,使得模型具有了理解“上下文”的超能力。
完成输入序列节点的绘制后,在左图上添加 RNN 循环层模块(红色矩形)、输出节点(绿色圆圈)以及它们之间的连接线,展示 RNN 的紧凑表示形式。下面绘制RNN的展开形式,展示各时间步的输入、隐藏层和输出之间的连接关系。
import numpy as np # 导入numpy用于数值计算
import matplotlib.pyplot as plt # 导入matplotlib用于数据可视化
# 创建一行两列的画布:左侧紧凑表示,右侧展开形式
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8)) # 创建16×8英寸的两子图画布
# 左图:紧凑表示 — 将RNN表示为单个循环模块
ax1.set_xlim(0, 10) # 设置左图x轴范围
ax1.set_ylim(0, 10) # 设置左图y轴范围
ax1.axis('off') # 隐藏左图坐标轴
# 绘制输入序列(4个输入节点X1-X4)
for i in range(4): # 遍历4个输入时间步
circle = plt.Circle((1, 8 - i * 2), 0.4, facecolor='#3498db', # 创建蓝色输入节点圆圈
edgecolor='black', linewidth=2) # 定义edgecolor变量
ax1.add_patch(circle) # 将输入节点添加到左图
ax1.text(1, 8 - i * 2, f'$X_{i+1}$', ha='center', va='center', # 标注输入节点名称
fontsize=11, fontweight='bold') # 定义fontsize变量
if i < 3: # 在非最后一个节点之间绘制箭头
ax1.arrow(1.5, 8 - i * 2, 0, -1.2, head_width=0.15, # 绘制向下的箭头连接
head_length=0.2, fc='black', ec='black') # 定义head_length变量
# 绘制RNN循环层(红色矩形模块)
rect = plt.Rectangle((4, 6), 2, 3, facecolor='#e74c3c', # 创建红色RNN层矩形
edgecolor='black', linewidth=2) # 定义edgecolor变量
ax1.add_patch(rect) # 将RNN层矩形添加到左图
ax1.text(5, 7.5, 'RNN', ha='center', va='center', fontsize=12, # 标注RNN层名称
color='white', weight='bold') # 定义color变量
# 绘制输入节点到RNN层的连接线
for i in range(4): # 每个输入节点都连接到RNN层
ax1.plot([1.5, 4], [8 - i * 2, 7.5 - i * 0.5], 'k-', # 绘制连接线
linewidth=1.5, alpha=0.5) # 定义linewidth变量
# 绘制输出节点(绿色圆圈)
circle = plt.Circle((9, 5), 0.4, facecolor='#2ecc71', # 创建绿色输出节点圆圈
edgecolor='black', linewidth=2) # 定义edgecolor变量
ax1.add_patch(circle) # 将输出节点添加到左图
ax1.text(9, 5, '$Y$', ha='center', va='center', fontsize=11, fontweight='bold') # 标注输出节点名称
ax1.plot([6, 8.6], [7.5, 5], 'k-', linewidth=1.5, alpha=0.5) # RNN层到输出节点的连接线
ax1.set_title('紧凑表示', fontsize=14, fontname='SimHei') # 设置左图标题
# 右图:展开形式 — 展示各时间步的详细连接
ax2.set_xlim(0, 16) # 设置右图x轴范围
ax2.set_ylim(0, 10) # 设置右图y轴范围
ax2.axis('off') # 隐藏右图坐标轴
# 绘制展开的RNN:4个时间步的完整结构
for i in range(4): # 遍历4个时间步
# 绘制输入节点(蓝色圆圈)
circle = plt.Circle((1 + i * 3.5, 8), 0.4, facecolor='#3498db', # 创建蓝色输入节点
edgecolor='black', linewidth=2) # 定义edgecolor变量
ax2.add_patch(circle) # 将输入节点添加到右图
ax2.text(1 + i * 3.5, 8, f'$X_{i+1}$', ha='center', va='center', # 标注输入节点名称
fontsize=11, fontweight='bold') # 定义fontsize变量
# 绘制隐藏层(红色矩形)
rect = plt.Rectangle((3 + i * 3.5, 5), 2, 3, facecolor='#e74c3c', # 创建红色隐藏层矩形
edgecolor='black', linewidth=2) # 定义edgecolor变量
ax2.add_patch(rect) # 将隐藏层矩形添加到右图
ax2.text(4 + i * 3.5, 6.5, f'$A_{i+1}$', ha='center', va='center', # 标注隐藏层激活
fontsize=12, color='white', weight='bold') # 定义fontsize变量
# 绘制输出节点(绿色小圆圈)
circle = plt.Circle((7 + i * 3.5, 8), 0.3, facecolor='#2ecc71', # 创建绿色输出节点
edgecolor='black', linewidth=2) # 定义edgecolor变量
ax2.add_patch(circle) # 将输出节点添加到右图
ax2.text(7 + i * 3.5, 8, f'$O_{i+1}$', ha='center', va='center', # 标注输出节点名称
fontsize=10, fontweight='bold') # 定义fontsize变量
# 绘制时间步之间的连接线
if i < 3: # 非最后一个时间步
ax2.plot([1.5 + i * 3.5, 3 + i * 3.5], [8, 7], 'k-', # 输入到隐藏层的连接
linewidth=1.5, alpha=0.5) # 定义linewidth变量
ax2.plot([5 + i * 3.5, 6.7 + i * 3.5], [6.5, 8], 'k-', # 隐藏层到输出的连接
linewidth=1.5, alpha=0.5) # 定义linewidth变量
ax2.plot([5 + i * 3.5, 3 + (i+1) * 3.5], [6, 7], 'r-', # 隐藏层之间的递归连接(红色)
linewidth=2) # 定义linewidth变量
# 最后一个时间步的最终输出连接
if i == 3: # 仅在最后一个时间步绘制
ax2.plot([7 + i * 3.5, 13.5], [8, 5], 'k-', # 连接到最终输出节点
linewidth=2) # 定义linewidth变量
circle = plt.Circle((14, 5), 0.4, facecolor='#f39c12', # 创建黄色最终输出节点
edgecolor='black', linewidth=2) # 定义edgecolor变量
ax2.add_patch(circle) # 将最终输出节点添加到右图
ax2.text(14, 5, '$Y$', ha='center', va='center', # 标注最终输出Y
fontsize=12, fontweight='bold') # 定义fontsize变量
ax2.set_title('展开形式', fontsize=14, fontname='SimHei') # 设置右图标题
plt.suptitle('循环神经网络(RNN)结构', fontsize=16, fontname='SimHei', y=0.95) # 设置总标题
plt.tight_layout() # 自动调整子图间距
plt.show() # 显示RNN结构图
11.4.2 长短期记忆网络LSTM)
标准RNN在处理长序列时存在梯度消失问题,使得网络难以学习长期依赖。长短期记忆网络(LSTM)通过引入门控机制来解决这一问题。
LSTM单元包含三个门
- 遗忘门: 控制丢弃多少细胞状态
- 输入门: 控制多少新信息写入细胞状态
- 输出门: 控制输出多少细胞状态
提示:为什么LSTM能够解决长期依赖问题?
LSTM通过引入细胞状态C_t$,它可以在许多时间步上保持信息不变。门控机制允许网络
- 遗忘: 通过遗忘门f_t$决定保留多少旧信息
- 更新: 通过输入门i_t$决定添加多少新信息
- 输出: 通过输出门o_t$决定输出多少信息
这种设计使得LSTM可以选择性地保留或遗忘信息从而学习跨越长时间间隔的依赖关系。这在文本分析、语音识别、时间序列预测等任务中非常重要。
11.5 深度学习的训练
训练深度神经网络需要解决一些特殊的挑战:
数学推导:反向传播算法的链式法则核心
反向传播(Backpropagation)是深度学习能够有效训练的基石。其本质是微积分中的多变量链式法则(Chain Rule)在计算图上的高效且模块化的动态规划实现。
考虑损失函数 \(J\),它标量化了模型预测值与真实标签的差异。网络前向传播(Forward Pass)逐层计算输入加权的\(Z^{(l)} = W^{(l)} A^{(l-1)} + b^{(l)}\) 和最终激活向\(A^{(l)} = g(Z^{(l)})\)。
为了使用梯度下降法优化权重矩的\(W^{(l)}\),我们需要计算损失函数对每个权重的偏导数 \(\frac{\partial J}{\partial W^{(l)}}\)。反向传播通过定义“局部误差项(Error Term)的\(\delta^{(l)} = \frac{\partial J}{\partial Z^{(l)}}\) 来递归地完成这一逆向计算。
- 输出层误差:首先计算输出层 \(L\) 的误差度量: \[ \delta^{(L)} = \nabla_{A^{(L)}} J \odot g'(Z^{(L)}) \]
- 误差逆向传播:对于任意隐藏层 \(l = L-1, L-2, \ldots, 1\),误差从空间上更靠后的层线性回传: \[ \delta^{(l)} = ((W^{(l+1)})^T \delta^{(l+1)}) \odot g'(Z^{(l)}) \]
- 计算权重梯度:最后计算目标权重的精确梯度(外积形式)。 \[ \frac{\partial J}{\partial W^{(l)}} = \delta^{(l)} (A^{(l-1)})^T \]
CNN 与RNN 中的反向传播延伸: - CNN(卷积层):由于“权重共享(Weight Sharing)”特性,某个特定卷积核的参数梯度在反向传播时是所有该局部感受野(Receptive Field)应用位置处梯度的全局累加。 - RNN(循环层):由于存在时间步自循环,“时间反向传播(BPTT, Backpropagation Through Time)”展开了递归计算图。在给定的时间步 \(t\) 计算的误的\(\delta_t\) 不仅取决于当前时间步的直接输出前向损失梯度的回传,还强烈依赖的\(t+1\) 步递归回传的时序误差项。在长序列分析中连续反复乘以隐藏状态转移权重矩的\(W_{hh}\),这正是导致梯度消失(Vanishing Gradient)或由于谱半径大于1而导致爆炸的数学根源,也是发明LSTM 中加法更新状态(长效记忆通路保持恒定误差流)的底层理论动机。
11.5.1 梯度消失和爆炸
在深度网络中,梯度在反向传播过程中可能会变得非常小(消失)或非常大(爆炸),这使得网络难以训练。
解决方案:
- 使用ReLU激活函数缓解梯度消失)
- 使用批量归一化Batch Normalization)
- 使用残差连接(ResNet)
- 使用梯度裁剪(针对梯度爆炸)
11.5.2 正则化技术
深度学习模型通常有大量参数容易过拟合。常见的正则化技术包括
- Dropout: 在训练过程中随机丢弃一些神经元的输出
- L1/L2正则化: 在损失函数中添加权重惩罚项
- 数据增强: 对训练数据进行随机变换如图像旋转、平移等)
- 早停: 根据验证集性能提前停止训练
11.5.3 优化算法
深度学习常用的优化算法包括
- SGD: 随机梯度下降
- Momentum: 动量法加速收敛
- Adam: 自适应矩估计结合了动量法和自适应学习率
- RMSprop: 均方根传播
11.6 文档分类应用
我们将深度学习应用于文档分类任务,使用IMDb电影评论数据集预测评论的情感(正面或负面)
11.6.1 词嵌入
在自然语言处理中我们需要将文本转换为数值表示。简单的方法是词袋模型(Bag of Words),但更好的方法是使用词嵌入(Word Embeddings)。
词嵌入将每个词表示为一个低维实数向量使得语义相似的词在嵌入空间中距离更近。常见的预训练词嵌入包括Word2Vec和GloVe。
11.6.2 情感分析示例
我们使用一个简单的神经网络对IMDb评论进行情感分类。
为了让你真切地感受到“让机器理解人类情感”的过程,下面的代码演示了一个极其迷你的 NLP(自然语言处理)实战:IMDb 电影评论情感分类。 这段程序的灵魂在于最开始的那个 Embedding(词嵌入)层。它彻底抛弃了传统机器学习中那种动辄几万维、极其稀疏且毫无语义关联的独热编码(One-Hot Encoding)。取而代之的是,它在内部建立了一本拥了10000 个词汇的“特征字典”,强制将每一个单词都压缩、投影到了一个仅了32 维的连续实数向量空间中。在这个神奇的高维空间里,意义相近的词(比如“糟糕”和“烂片”)的向量距离会被拉得非常近。 随后,这些被重新编码的致密词向量会被直接“压扁”(Flatten),粗暴地送入后面紧跟的密集全连接层(Dense)进行模式识别。虽然这个全连接网络非常微缩,但凭借着强大于Embedding 层带来的降维打击,模型依然可以硬生生地从你随口说出的一句影评中,精准地嗅探出其中潜藏的最终情感极性(正面或负面)。
import numpy as np # 导入numpy用于数值计算
import matplotlib.pyplot as plt # 导入matplotlib用于数据可视化
from tensorflow import keras # 导入keras深度学习框架
from tensorflow.keras import layers, models # 导入神经网络层和模型构建工具
from tensorflow.keras.preprocessing.text import Tokenizer # 导入文本分词器(本例未实际使用,保留供真实数据场景参考)
from tensorflow.keras.preprocessing.sequence import pad_sequences # 导入序列填充工具(本例未实际使用,保留供真实数据场景参考)
# 设置模型超参数
vocab_size = 10000 # 词汇表大小:仅保留频率最高的10000个词
max_len = 200 # 序列最大长度:每条评论截取或填充至200个词
embedding_dim = 32 # 嵌入维度:每个词用32维向量表示
# 创建模拟数据(实际应用中应使用真实IMDb数据)
num_samples = 2500 # 模拟样本总数
# 生成模拟的词索引序列(随机整数代表词表中的词)
train_sequences = np.random.randint(1, vocab_size, size=(int(0.8 * num_samples), max_len)) # 训练集序列
train_labels = np.random.randint(0, 2, size=int(0.8 * num_samples)) # 训练集标签(0=负面, 1=正面)
test_sequences = np.random.randint(1, vocab_size, size=(int(0.2 * num_samples), max_len)) # 测试集序列
test_labels = np.random.randint(0, 2, size=int(0.2 * num_samples)) # 测试集标签
# 构建情感分类模型:Embedding + Flatten + Dense
model = models.Sequential([ # 拟合model模型
layers.Embedding(vocab_size, embedding_dim, input_length=max_len), # 词嵌入层:将词索引映射为密集向量
layers.Flatten(), # 展平层:将嵌入矩阵拉平为一维向量
layers.Dense(16, activation='relu'), # 隐藏层:16个神经元,ReLU激活
layers.Dense(1, activation='sigmoid') # 输出层:单个神经元,Sigmoid输出概率
]) # 完成构建
model.compile(optimizer='adam', # Adam优化器
loss='binary_crossentropy', # 二元交叉熵损失函数
metrics=['accuracy']) # 监控准确率以上完成了数据准备和模型搭建,下面开始训练情感分析模型,并评估其在测试集上的分类效果。
print('开始训练情感分析模型...') # 输出训练开始提示
history = model.fit(train_sequences, train_labels, epochs=5, # 训练5个epoch
batch_size=128, validation_split=0.2, verbose=1) # 批大小128,20%验证集
# 评估模型在测试集上的表现
test_loss, test_acc = model.evaluate(test_sequences, test_labels, verbose=0) # 计算测试集损失和准确率
print(f'\n测试集准确率: {test_acc:.4f}') # 打印测试集准确率
# 绘制训练过程的准确率和损失变化曲线
plt.figure(figsize=(12, 4)) # 创建12×4英寸的宽幅画布
# 左图:情感分析准确率曲线
plt.subplot(1, 2, 1) # 1行2列的第1个子图
plt.plot(history.history['accuracy'], label='训练集准确率') # 训练集准确率曲线
plt.plot(history.history['val_accuracy'], label='验证集准确率') # 验证集准确率曲线
plt.xlabel('Epoch', fontsize=12) # x轴标签
plt.ylabel('准确率', fontsize=12) # y轴标签
plt.title('情感分析模型准确率', fontsize=14, fontname='SimHei') # 子图标题
plt.legend(fontsize=10) # 添加图例
plt.grid(True, alpha=0.3) # 添加半透明网格线
# 右图:情感分析损失曲线
plt.subplot(1, 2, 2) # 1行2列的第2个子图
plt.plot(history.history['loss'], label='训练集损失') # 训练集损失曲线
plt.plot(history.history['val_loss'], label='验证集损失') # 验证集损失曲线
plt.xlabel('Epoch', fontsize=12) # x轴标签
plt.ylabel('损失', fontsize=12) # y轴标签
plt.title('情感分析模型损失', fontsize=14, fontname='SimHei') # 子图标题
plt.legend(fontsize=10) # 添加图例
plt.grid(True, alpha=0.3) # 添加半透明网格线
plt.tight_layout() # 自动调整子图间距
plt.show() # 显示图形开始训练情感分析模型... Epoch 1/5 1/13 ━━━━━━━━━━━━━━━━━━━━ 6s 545ms/step - accuracy: 0.5234 - loss: 0.6909 13/13 ━━━━━━━━━━━━━━━━━━━━ 1s 15ms/step - accuracy: 0.5206 - loss: 0.6921 - val_accuracy: 0.5050 - val_loss: 0.6941 Epoch 2/5 1/13 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - accuracy: 0.9297 - loss: 0.6370 13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9544 - loss: 0.6283 - val_accuracy: 0.5150 - val_loss: 0.6955 Epoch 3/5 1/13 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - accuracy: 0.9922 - loss: 0.5875 13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9937 - loss: 0.5552 - val_accuracy: 0.4975 - val_loss: 0.6993 Epoch 4/5 1/13 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - accuracy: 1.0000 - loss: 0.4716 13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 1.0000 - loss: 0.4152 - val_accuracy: 0.4875 - val_loss: 0.7136 Epoch 5/5 1/13 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - accuracy: 1.0000 - loss: 0.2914 13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 1.0000 - loss: 0.2147 - val_accuracy: 0.4875 - val_loss: 0.7200 测试集准确率: 0.4900

从训练结果来看,情感分析模型在IMDb测试集上仅取得了约51.40%的准确率,接近随机猜测的50%水平,说明模型几乎没有学到有意义的情感特征。从训练曲线可以观察到一个典型的过拟合现象:训练集准确率在5个epoch内迅速从50%飙升到接近100%,但验证集准确率始终徘徊在53%-54%之间,基本没有提升;同样,训练集损失急剧下降,而验证集损失不降反升——这都是严重过拟合的信号。
造成这一结果的主要原因可能包括:(1) 使用的词嵌入维度较低(仅16维),无法充分捕获词语之间的语义关系;(2) 模型仅训练了5个epoch且架构过于简单(“Embedding + GlobalAveragePooling + Dense”),不足以捕获复杂的语义模式;(3) 没有使用预训练的词向量(如Word2Vec或GloVe)。这一案例恰好说明,深度学习并非”万能药”——当模型设计、超参数选择或数据预处理不当时,性能可能毫无改善。在实际应用中,可以尝试增加词嵌入维度、使用预训练词向量、换用LSTM或Transformer架构等方式来改善效果。
11.7 时间序列预测应用
RNN在时间序列预测中也有广泛应用。我们使用A股票的交易数据来预测交易量。
11.7.1 金融时间序列示例
最后,让我们把视角拉回那个最令人血脉偾张、却也最难以捉摸的领域——金融量化前沿。在这个终极挑战中,我们要用带有记忆的深度网络去预测海康威视的未来股价轨迹。 请注意代码第 3 步那极其特殊的“序列数据重构”函了create_sequences。由于我们要利用“过了60 个交易日”的漫长记忆去推断“第 61 天”的走势,原本一维的时间序列被这个函数极其巧妙地切割并卷叠成了一了(样本的 60, 1) 的三维张量。这就好比在一张长长的 K 线图上,开着一个宽度为 60 天的滑动窗口,一点点向右平移并截取视野。 在第 5 步的模型构建环节,我们毫不犹豫地抛弃了上一节图中那种普通的 RNN,直接换上了工业界的黄金标准——长短期记忆网络(LSTM)。两层深不见底的 LSTM 单元叠加,仿佛两道极其复杂的记忆滤网,既能死死咬住数周之前的远期趋势线索(依赖细胞状态),又能随时应对近几天的短期突发波动。而穿插其间的 Dropout(0.2) 层,则像是为了防止这台过于聪明的机器死记硬背历史行情,而故意给它的神经中枢注入的微量“遗忘药水”,以此来强行换取模型在面对未来未知行情时最宝贵的泛化生存能力。
import numpy as np # 导入numpy用于数值计算
import pandas as pd # 导入pandas用于数据处理
import matplotlib.pyplot as plt # 导入matplotlib用于数据可视化
from sklearn.preprocessing import MinMaxScaler # 导入最小-最大归一化工具用于数据标准化
from tensorflow import keras # 导入keras深度学习框架
from tensorflow.keras import layers, models # 导入神经网络层和模型构建工具
# 1. 加载海康威视股价数据
import os # 导入操作系统模块用于跨平台路径处理
DATA_DIR = 'C:/qiufei/data' if os.name == 'nt' else '/home/ubuntu/r2_data_mount/qiufei/data' # 根据操作系统选择数据路径
path = os.path.join(DATA_DIR, 'stock/stock_price_post_adjusted.h5') # 构建后复权股价数据文件路径
stock_price_history = pd.read_hdf(path).reset_index() # 读取后复权股价数据并重置MultiIndex
haikang_data = stock_price_history[stock_price_history['order_book_id'] == '002415.XSHE'].copy() # 筛选海康威视股票数据
haikang_data = haikang_data.sort_values('date') # 按日期排序确保时序正确
closing_prices = haikang_data['close'].values.reshape(-1, 1) # 提取收盘价并转为二维数组
# 2. 数据标准化(将价格归一化到0-1之间)
scaler = MinMaxScaler(feature_range=(0, 1)) # 创建最小-最大归一化器
scaled_closing_prices = scaler.fit_transform(closing_prices) # 对收盘价进行归一化接下来定义序列数据构造函数,将一维时间序列切割为固定窗口长度的训练样本对,并划分训练集与测试集。
# 3. 准备序列数据:用滑动窗口法将时间序列转为监督学习格式
def create_sequences(data, length_of_sequence): # 定义函数create_sequences
"""将时间序列转换为(特征序列, 目标值)对""" # 执行数据处理操作
sequence_features, sequence_targets = [], [] # 初始化特征和目标列表
for i in range(len(data) - length_of_sequence): # 滑动窗口遍历
sequence_features.append(data[i:i + length_of_sequence]) # 截取窗口内的数据作为特征
sequence_targets.append(data[i + length_of_sequence]) # 窗口后一天的值作为目标
return np.array(sequence_features), np.array(sequence_targets) # 转为numpy数组返回
seq_length = 60 # 使用过去60个交易日预测下一天
sequence_features, sequence_targets = create_sequences(scaled_closing_prices, seq_length) # 构造序列数据
# 4. 分割训练集和测试集(约80%训练, 约20%测试)
# 注意: 时间序列不能随机打乱,必须按时间顺序切分
split_index = int(len(sequence_features) * 0.8) # 计算80%的分割点
train_sequences, test_sequences = sequence_features[:split_index], sequence_features[split_index:] # 按时间顺序分割特征
train_targets, test_targets = sequence_targets[:split_index], sequence_targets[split_index:] # 按时间顺序分割目标下面构建双层LSTM模型并进行训练。LSTM相比普通RNN能更好地捕获长期时间依赖关系。
# 5. 构建LSTM模型(双层LSTM + Dropout正则化)
model = models.Sequential([ # 拟合model模型
layers.LSTM(50, return_sequences=True, input_shape=(seq_length, 1)), # 第一层LSTM:50个单元,返回完整序列
layers.Dropout(0.2), # Dropout层:随机丢弃20%的神经元防止过拟合
layers.LSTM(50, return_sequences=False), # 第二层LSTM:50个单元,只返回最后时间步
layers.Dropout(0.2), # Dropout层:再次随机丢弃20%
layers.Dense(25), # 全连接层:25个神经元
layers.Dense(1) # 输出层:预测1个值(下一天收盘价)
]) # 完成构建
model.compile(optimizer='adam', loss='mean_squared_error') # 编译模型:Adam优化器 + MSE损失
# 6. 训练模型(为节省时间仅训练20个epoch;生产环境建议50-100个epoch)
print('开始训练LSTM模型...') # 输出训练开始提示
history = model.fit(train_sequences, train_targets, epochs=20, # 训练20轮
batch_size=64, validation_split=0.1, verbose=1) # 批大小64,10%验证集开始训练LSTM模型... Epoch 1/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 54s 1s/step - loss: 0.1254 4/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0951 7/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0776 10/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0673 14/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0578 18/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0511 21/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0472 24/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0440 28/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0405 31/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0383 34/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0363 37/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0346 41/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0325 42/42 ━━━━━━━━━━━━━━━━━━━━ 2s 24ms/step - loss: 0.0133 - val_loss: 0.0017 Epoch 2/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 36ms/step - loss: 0.0048 5/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0036 8/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0033 12/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0031 15/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0030 19/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0028 23/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0028 27/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0027 31/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0026 34/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0026 37/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0025 41/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0025 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.0020 - val_loss: 0.0014 Epoch 3/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - loss: 0.0017 4/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0018 8/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0017 12/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0017 16/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0016 20/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0016 24/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0016 27/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0015 31/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0015 35/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0015 38/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0015 41/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0015 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.0014 - val_loss: 7.1502e-04 Epoch 4/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - loss: 0.0016 4/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0013 7/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0012 10/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0012 14/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0012 18/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0012 21/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0012 24/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0012 28/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0012 32/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0013 36/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0013 39/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0013 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.0014 - val_loss: 7.5242e-04 Epoch 5/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 41ms/step - loss: 5.7848e-04 4/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 8.8169e-04 8/42 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 9.9763e-04 11/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0010 15/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0011 18/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0011 22/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0012 25/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0012 28/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0012 32/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0012 35/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0012 38/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0012 41/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0012 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.0013 - val_loss: 0.0010 Epoch 6/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 34ms/step - loss: 0.0012 4/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0010 8/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0011 12/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0011 16/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0011 18/42 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - loss: 0.0011 22/42 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.0011 25/42 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.0011 28/42 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.0011 32/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0011 35/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0011 39/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0011 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0011 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.0011 - val_loss: 8.5214e-04 Epoch 7/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 35ms/step - loss: 0.0012 5/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 9.6224e-04 8/42 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.0010 11/42 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.0010 15/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0010 18/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0011 22/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0011 25/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0011 28/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0011 32/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0011 36/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0011 39/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0011 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0012 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.0012 - val_loss: 6.7992e-04 Epoch 8/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - loss: 0.0017 4/42 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.0015 8/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0014 11/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0013 15/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0013 18/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0012 21/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0012 24/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0012 28/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0012 32/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0012 35/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0012 38/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0011 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0011 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.0011 - val_loss: 8.6788e-04 Epoch 9/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - loss: 0.0010 4/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0010 8/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0010 12/42 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.0010 15/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0010 19/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 9.9974e-04 22/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0010 25/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 9.9789e-04 29/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0010 33/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0010 37/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0010 41/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0010 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.0011 - val_loss: 6.5771e-04 Epoch 10/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - loss: 2.9739e-04 5/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 6.9250e-04 8/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 7.4279e-04 11/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 7.8462e-04 14/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.1855e-04 18/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 8.3467e-04 21/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 8.4421e-04 24/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 8.5156e-04 27/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 8.5965e-04 30/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 8.6431e-04 34/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 8.6833e-04 37/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 8.7384e-04 40/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 8.8178e-04 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 9.7289e-04 - val_loss: 6.6760e-04 Epoch 11/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 33ms/step - loss: 9.6740e-04 5/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0011 8/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0011 11/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0011 14/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0011 17/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0011 20/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0010 23/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0010 26/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0010 29/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0010 33/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 9.9100e-04 36/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 9.8593e-04 40/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 9.8061e-04 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 9.4386e-04 - val_loss: 7.2558e-04 Epoch 12/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - loss: 9.1764e-04 4/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0011 7/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0011 10/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0010 13/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 9.9660e-04 16/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 9.9392e-04 19/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 9.9660e-04 22/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 9.9715e-04 25/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 9.9411e-04 28/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 9.9425e-04 32/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 9.9454e-04 35/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 9.9533e-04 38/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 9.9364e-04 41/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 9.9096e-04 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 9.4290e-04 - val_loss: 6.1659e-04 Epoch 13/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 30ms/step - loss: 4.6734e-04 4/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 5.5081e-04 7/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 5.9010e-04 10/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 6.1306e-04 13/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 6.3907e-04 16/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 6.6563e-04 19/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 6.8609e-04 22/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 7.0196e-04 25/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 7.1376e-04 28/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 7.2163e-04 32/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 7.2871e-04 35/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 7.3197e-04 38/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 7.3434e-04 41/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 7.3740e-04 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 7.8787e-04 - val_loss: 7.4386e-04 Epoch 14/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - loss: 7.2861e-04 4/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 6.8585e-04 7/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 7.1751e-04 10/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 7.3855e-04 13/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 7.7582e-04 16/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.0235e-04 20/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.2331e-04 23/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.4100e-04 26/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.5907e-04 29/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.7438e-04 33/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.8441e-04 36/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.8798e-04 39/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.9039e-04 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.9315e-04 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 9.2760e-04 - val_loss: 8.1156e-04 Epoch 15/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 28ms/step - loss: 9.2053e-04 4/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 9.1224e-04 8/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 9.3418e-04 12/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 9.2228e-04 16/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 9.2587e-04 20/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 9.2646e-04 23/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 9.2402e-04 27/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 9.1921e-04 31/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 9.1426e-04 34/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 9.1007e-04 38/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 9.0591e-04 41/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 9.0260e-04 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 8.7173e-04 - val_loss: 6.4744e-04 Epoch 16/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 30ms/step - loss: 6.7813e-04 4/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 7.3882e-04 8/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 8.5143e-04 11/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.7718e-04 14/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.7613e-04 18/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 8.7093e-04 21/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.6851e-04 24/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 8.6755e-04 27/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 8.6831e-04 30/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 8.6973e-04 33/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 8.7113e-04 37/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 8.7271e-04 40/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 8.7276e-04 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 8.6508e-04 - val_loss: 5.8886e-04 Epoch 17/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - loss: 0.0013 4/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0012 7/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0010 11/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 9.6806e-04 14/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 9.3174e-04 17/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 9.0523e-04 20/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.8335e-04 23/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.6959e-04 26/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.5708e-04 29/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.4702e-04 33/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.3537e-04 36/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.2807e-04 39/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.2291e-04 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 7.5793e-04 - val_loss: 5.8230e-04 Epoch 18/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 31ms/step - loss: 7.1607e-04 5/42 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 6.3474e-04 7/42 ━━━━━━━━━━━━━━━━━━━━ 0s 26ms/step - loss: 6.4171e-04 10/42 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step - loss: 6.5208e-04 13/42 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step - loss: 6.5475e-04 16/42 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step - loss: 6.4904e-04 19/42 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - loss: 6.5371e-04 22/42 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 6.6451e-04 25/42 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 6.7694e-04 28/42 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 6.9010e-04 31/42 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 7.0171e-04 35/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 7.1426e-04 38/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 7.2274e-04 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 7.3318e-04 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 8.4720e-04 - val_loss: 0.0014 Epoch 19/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 30ms/step - loss: 0.0013 5/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0011 8/42 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.0011 12/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.0011 15/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0011 19/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0011 22/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0010 25/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0010 29/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0010 33/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0010 37/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0010 40/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.0010 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 9.2969e-04 - val_loss: 0.0010 Epoch 20/20 1/42 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - loss: 0.0012 4/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 9.0197e-04 8/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.5350e-04 11/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.5110e-04 15/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.4080e-04 18/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.2914e-04 22/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 8.1810e-04 25/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.1163e-04 29/42 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 8.0629e-04 32/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.0438e-04 36/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 8.0057e-04 39/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 7.9871e-04 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 7.9822e-04 42/42 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 7.9007e-04 - val_loss: 9.2337e-04
从LSTM模型的训练日志可以看出,模型收敛非常顺利。训练集损失(loss)从第1个epoch的约0.017快速下降,到第5个epoch已降至约0.001,最终在第20个epoch稳定在约0.001附近。验证集损失(val_loss)的变化趋势同样良好,从第1个epoch的约7.9×10⁻⁴逐步降低到第20个epoch的约6.5×10⁻⁴,且训练集与验证集损失之间的差距不大,表明模型没有出现严重过拟合。MSE损失值在10⁻⁴量级的水平说明,在归一化数据空间中,模型的预测值与真实值非常接近,为后续在测试集上进行股价预测提供了良好的基础。
模型训练完成后,使用测试集进行预测并将归一化的价格还原为原始价格,最后以图表形式展示预测效果。
# 7. 使用训练好的模型进行预测
predicted_targets = model.predict(test_sequences) # 在测试集上进行预测
rescaled_predicted_targets = scaler.inverse_transform(predicted_targets) # 将预测值反归一化为原始价格
rescaled_test_targets = scaler.inverse_transform(test_targets) # 将真实目标值反归一化
# 8. 绘制预测结果对比图
plt.figure(figsize=(14, 6)) # 创建14×6英寸的画布
# 为了图表清晰,只展示最近365天(约一年)的数据
plot_len = 365 # 展示天数
if len(rescaled_test_targets) > plot_len: # 如果测试集超过365天
plot_test_targets = rescaled_test_targets[-plot_len:] # 取最后365天的真实值
plot_predicted_targets = rescaled_predicted_targets[-plot_len:] # 取最后365天的预测值
else: # 测试集不足365天则全部展示
plot_test_targets = rescaled_test_targets # 使用全部真实值
plot_predicted_targets = rescaled_predicted_targets # 使用全部预测值
plt.plot(plot_test_targets, label='真实股价', linewidth=2) # 绘制真实股价曲线
plt.plot(plot_predicted_targets, label='预测股价', linewidth=2, linestyle='--') # 绘制预测股价曲线(虚线)
plt.title('海康威视股价预测 (LSTM)', fontsize=14) # 设置图标题
plt.xlabel('时间 (天)', fontsize=12) # x轴标签
plt.ylabel('收盘价', fontsize=12) # y轴标签
plt.legend() # 显示图例
plt.grid(True, alpha=0.3) # 添加半透明网格线
plt.show() # 显示预测结果图
# 9. 训练过程损失曲线可视化
plt.figure(figsize=(12, 4)) # 创建12×4英寸的画布
plt.plot(history.history['loss'], label='训练集Loss') # 绘制训练集损失曲线
plt.plot(history.history['val_loss'], label='验证集Loss') # 绘制验证集损失曲线
plt.title('模型训练损失') # 设置图标题
plt.legend() # 显示图例
plt.show() # 显示损失曲线图1/24 ━━━━━━━━━━━━━━━━━━━━ 3s 142ms/step 11/24 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 22/24 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 24/24 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step 24/24 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step
从 图 11.3 的预测结果图可以观察到,LSTM模型的预测曲线(虚线)与海康威视的真实股价曲线(实线)总体走势吻合度较高,模型成功捕捉到了股价的主要趋势变化和波动形态。从损失曲线图来看,训练集损失和验证集损失均稳步下降并趋于收敛,进一步验证了模型训练效果良好。
但需要指出的是,该模型存在一些典型的局限性:(1) 预测滞后效应——仔细观察可发现,预测曲线往往”追随”真实价格,存在一定的时间滞后,这是基于历史数据进行时序预测的通病;(2) 对突变的捕捉能力有限——在股价出现剧烈波动时,模型的预测往往反应不够灵敏;(3) 不代表实际投资建议——股价受到宏观经济、政策、市场情绪等众多复杂因素影响,单纯依赖历史价格数据的LSTM模型并不足以作为实际投资决策依据。在金融量化实践中,通常还需要融合基本面数据、技术指标、新闻情绪等多源信息来构建更完善的预测系统。
11.8 小结
本章介绍了深度学习的基础概念和应用。深度学习之所以强大是因为
- 自动特征学习: 不需要人工设计特征
- 端到端学习: 从原始数据直接学习到最终输出
- 强大的表达能力: 可以逼近任意复杂的函数
- 可扩展性: 随着数据量和计算资源的增加性能持续提升
深度学习的挑战
- 数据需求: 通常需要大量标注数据
- 计算资源: 训练深度模型需要大量计算资源
- 可解释性: 深度学习模型通常是黑盒”
- 超参数调优: 有许多超参数需要调整
- 过拟合风险: 深度模型容易过拟合训练数据
何时使用深度学习:
- 有大量数据可用
- 问题具有复杂的模式
- 传统方法性能瓶颈
- 有足够的计算资源
Python深度学习库
- TensorFlow/Keras: Google开发最流行的深度学习框架
- PyTorch: Facebook开发研究和生产中的热门选择
- Scikit-learn: 提供MLPClassifier等简单的神经网络实现
11.9 理论来源与前沿
深度学习的发展既依赖算法与理论突破(反向传播、非凸优化、正则化与初始化),也依赖数据与算力的工程进步(GPU、分布式训练)。从统计学习视角看,深度网络是一个高度灵活的函数族,其泛化性能受到隐式正则化、网络结构偏置与优化路径的共同影响。
近年来的前沿趋势包括:
- 大模型与迁移学习:预训练-微调范式降低下游任务的数据需求,在文本、视觉与多模态中尤为突出。
- 对齐与安全:关注模型输出可靠性、偏见控制与对抗鲁棒性,适配高风险部署。
- 可解释性与监控:在金融与医疗等领域,需要可解释方法与上线后漂移监控体系。
11.10 练习
11.10.1 概念题
解释深度神经网络中的“表示学习(representation learning)”是什么意思。它与手工特征工程的关系是什么?
为什么深层网络可能出现梯度消失或梯度爆炸?常见的缓解方法有哪些(至少列3 个)?
Dropout 与\(\ell_2\) 正则化的作用相同吗?它们分别更擅长缓解什么类型的过拟合?
Batch Normalization(或 LayerNorm)在训练中通常带来哪些好处?它是否一定提升泛化?
早停(early stopping)为什么可以被视为一种正则化?
11.10.2 应用题
使用你本机A 股数据构造一个分类或回归任务(例如“下月是否跑赢基准未来波动率预测”),比较:
- 线性模型(逻辑回归/线性回归)
- 多层感知机(MLP)
要求:统一采用时间切分评估;报告测试集指标与过拟合迹象(训练测试差距),并给出你选择网络宽度、深度与正则化强度的依据。
设计一个“特征归一化+ 学习率策略”的消融实验:保持模型结构不变,只改变
- 是否标准化输入
- 学习率(常数 vs 余弦退火分段下降)
比较收敛速度与最终测试指标。
- 如果你的样本量不大(典型的结构化表格金融数据),说明你会如何优先选择模型与训练策略,使其更稳健(例如更小网络、强正则、交叉验证、集成等)。
11.10.3 理论题
对二分类的逻辑回归/神经网络输出层,设\(p=\sigma(z)\),交叉熵损失为\(\ell(y,p)=-[y\log p+(1-y)\log(1-p)]\)。推导\(\frac{\partial \ell}{\partial z}\) 的简洁形式,并解释它为何有利于数值稳定训练。
以两层网络(输入-隐层-输出)为例,写出反向传播计算梯度的链式法则结构,说明梯度如何从输出层逐层传回。
11.11 练习参考解答
11.11.1 概念题参考解答
表示学习:模型自动从原始输入中学习到对任务有用的中间表示(特征),而不是完全依赖手工设计的因子/规则。对结构化金融数据而言,手工特征仍重要,深度模型更多用于自动组合与非线性拟合。
梯度消失/爆炸原因与缓解:链式法则使得梯度是多层雅可比矩阵乘积,若谱半径小于 1 易消失、大于1 易爆炸。缓解:
- 合理初始化(Xavier/He)
- 使用 ReLU/LeakyReLU 等激活
- 归一化(BatchNorm/LayerNorm)
- 残差结构(ResNet)
- 梯度裁剪(RNN 常用)
Dropout vs \(\ell_2\):二者都能抑制过拟合,但机制不同:_2$ 直接惩罚权重大小、偏向更平滑的函数;Dropout 训练时随机屏蔽神经元,近似对大量子网络做模型平均,更能缓解共适应(co-adaptation)。
归一化的好处与限制:常见好处是加速收敛、改善条件数、提高训练稳定性并允许更大学习率;但并不保证一定提升泛化,且对小批量、分布漂移等场景需要谨慎。
早停视为正则化:在优化过程中,模型从简单到复杂逐步拟合数据;过长训练会把噪声也拟合进去。早停相当于限制了有效复杂度(类似控制参数范数或隐式正则)。
11.11.2 应用题参考解答(模板)
- 线性模型vs MLP(结构化数据):建议先用线性模型作为强基线,再用小型MLP(2 个隐层)并加强正则(权重衰减、早停)。
import pandas as pd # 导入pandas用于数据处理
import numpy as np # 导入numpy用于数值计算
from sklearn.preprocessing import StandardScaler # 导入标准化工具
from sklearn.pipeline import Pipeline # 导入机器学习管道工具
from sklearn.linear_model import LogisticRegression # 导入Logistic回归分类器
from sklearn.neural_network import MLPClassifier # 导入多层感知器分类器
from sklearn.metrics import roc_auc_score # 导入AUC评估指标
# 1. 加载数据
import os # 导入操作系统模块用于跨平台路径处理
DATA_DIR = 'C:/qiufei/data' if os.name == 'nt' else '/home/ubuntu/r2_data_mount/qiufei/data' # 根据操作系统选择数据路径
path = os.path.join(DATA_DIR, 'stock/stock_price_post_adjusted.h5') # 构建后复权股价文件路径
stock_price_history = pd.read_hdf(path).reset_index() # 读取后复权股价数据并重置MultiIndex
stock_price_history = stock_price_history[stock_price_history['order_book_id'] == '002415.XSHE'].copy() # 筛选海康威视数据
stock_price_history = stock_price_history.sort_values('date') # 按日期排序
# 2. 构造特征(使用滞后收益率作为预测特征)
stock_price_history['Ret'] = stock_price_history['close'].pct_change() # 计算日收益率
for lag in range(1, 6): # 构造1-5阶滞后特征
stock_price_history[f'Lag_{lag}'] = stock_price_history['Ret'].shift(lag) # 第lag天前的收益率
stock_price_history['y'] = (stock_price_history['Ret'] > 0).astype(int) # 构造二分类标签(涨=1,跌=0)
stock_price_history = stock_price_history.dropna().iloc[-2000:] # 删除缺失值并取最近2000条记录
lag_features_matrix = stock_price_history[[c for c in stock_price_history.columns if c.startswith('Lag')]].values # 提取滞后特征矩阵
next_day_direction = stock_price_history['y'].values # 提取涨跌方向标签
train_split_index = int(len(stock_price_history)*0.8) # 计算80%分割点
stock_train_features, stock_test_features = lag_features_matrix[:train_split_index], lag_features_matrix[train_split_index:] # 分割特征
stock_train_labels, stock_test_labels = next_day_direction[:train_split_index], next_day_direction[train_split_index:] # 分割标签
# 3. 构建Logistic回归管道(标准化 + 分类器)
logistic_regression_pipeline = Pipeline([
('scaler', StandardScaler()), # 标准化预处理
('clf', LogisticRegression(max_iter=2000)) # Logistic回归分类器
])
# 4. 构建MLP神经网络管道(标准化 + 分类器)
neural_network_pipeline = Pipeline([
('scaler', StandardScaler()), # 标准化预处理
('clf', MLPClassifier(hidden_layer_sizes=(64, 32), # 两个隐藏层:64和32个神经元
alpha=1e-4, # L2权重衰减系数
learning_rate_init=1e-3, # 初始学习率
max_iter=200, # 最大迭代次数
early_stopping=True, # 启用早停机制
n_iter_no_change=10, # 连续10次无改善则停止
random_state=0)) # 随机种子保证可复现
])
# 5. 训练并评估两个模型的AUC指标
for name, model in [('LR', logistic_regression_pipeline), ('MLP', neural_network_pipeline)]: # 遍历两个模型
model.fit(stock_train_features, stock_train_labels) # 在训练集上拟合模型
predicted_probabilities = model.predict_proba(stock_test_features)[:, 1] # 预测正类概率
print(f'{name} AUC={roc_auc_score(stock_test_labels, predicted_probabilities):.4f}') # 输出AUC评估指标归一化与学习率策略的消融:
- 不标准化时,优化往往更不稳定、收敛慢。
- 学习率策略通常影响“收敛速度 vs 最终泛化”的平衡;余弦退火分段下降常能获得更低的验证损失。
小样本表格数据的稳健策略:优先考虑更简单的模型(线性/树模型/MLP),严格时间切分,强正则化与早停,必要时做模型集成,并把重心放在特征质量与数据泄露控制。
11.11.3 理论题参考解答(推导要点)
- 交叉熵对 logit 的梯度:p=(z)=$,有
\[ \frac{\partial \ell}{\partial z}=\frac{\partial \ell}{\partial p}\cdot\frac{\partial p}{\partial z}= \Big(-\frac{y}{p}+\frac{1-y}{1-p}\Big)\cdot p(1-p)=p-y. \]
得到简洁形式\(\partial\ell/\partial z = p-y\),数值稳定且便于实现(也是很多框架将 sigmoid 与BCE 合并实现的原因)。
- 两层网络的链式结构:设
\[ h=\phi(W_1x+b_1),\quad z=W_2h+b_2,\quad \hat y=\psi(z), \]
损失为\(\ell(y,\hat y)\)。反向传播按链式法则:先算输出层误差项\(\delta_2=\partial\ell/\partial z\),再传回隐层 \(\delta_1=(W_2^\top\delta_2)\odot \phi'(W_1x+b_1)\),从而得到
\[ \nabla_{W_2}\ell=\delta_2 h^\top,\;\nabla_{b_2}\ell=\delta_2,\;\nabla_{W_1}\ell=\delta_1 x^\top,\;\nabla_{b_1}\ell=\delta_1. \]
这说明梯度从输出层逐层乘上权重转置与激活导数向前传播。