10 神经网络:金融大数据学习

本章学习路线图

本章学习路线图 分为理论基础、模型训练和编程实践三个部分,各有图标和要点列表。 理论基础 - 从生物神经元到数学模型 - 激活函数的核心作用 - 网络结构与前向传播 - 万能近似定理 模型训练 - 代价函数与梯度下降 - 反向传播的直觉 - 过拟合与正则化技术 (L2, Early Stopping, Dropout) 编程实践 - Scikit-learn 快速实现 - PyTorch 工业级框架 - 真实宏观经济数据预测 - 模型评估与对比

核心问题

  • 传统的线性模型(如 OLS)在金融世界中是否足够?
  • 当变量间的关系变得复杂、非线性时,我们如何构建更强大的预测模型?
  • 如何让模型自动从数据中“学习”到复杂的模式,而不仅仅是执行我们预设的公式?

学习目标

通过本章学习,你将能够:

  • 理论层面: 掌握人工神经网络 (ANN) 的核心概念,理解其如何通过层级结构和非线性变换来学习复杂模式。
  • 模型层面: 区分回归与分类任务在神经网络中的不同实现,并理解反向传播的直觉含义。
  • 实践层面: 熟练使用 scikit-learnPyTorch 两大主流框架,从零开始构建、训练并评估一个用于解决金融预测问题的神经网络。

为何关注神经网络?

传统的计量经济学模型非常出色,但通常需要我们预先假设函数形式(例如,线性关系)。

神经网络则属于一种表示学习 (Representation Learning) 方法,它能够自动发现数据中复杂的、非 liné性的结构。

在数据丰富、关系未知的金融领域,这是一种极其强大的能力。

催化剂 1: 数据爆炸 (Big Data)

金融市场产生了海量的高频交易数据、财报文本、新闻情感、卫星图像等另类数据。

这些海量数据为训练需要大量样本的复杂模型(如神经网络)提供了充足的“燃料”。

数据爆炸概念图 三个数据库图标汇聚数据流到一个大脑图标中,象征数据为AI提供燃料。 交易数据 文本/新闻 另类数据 神经网络

催化剂 2: 算力革命 (Computational Power)

神经网络的训练涉及数百万甚至数十亿次计算。过去,这需要数周时间。

今天,GPU (图形处理器) 的并行计算能力将训练时间缩短到几小时甚至几分钟,使得曾经不切实际的复杂模型成为可能。

算力革命概念图 一个CPU图标和一个GPU图标,GPU显示出更多的并行核心,处理速度更快。 CPU 顺序处理 GPU 大规模并行 速度提升 10-100x

神经网络的金融应用

这些发展使得神经网络在以下领域展现出超越传统计量模型的强大能力:

  • 信用评分与违约预测: 捕捉申请人信息中复杂的非线性关系。
  • 资产定价: 发现驱动股票回报的非线性因子。
  • 算法交易: 识别高频数据中的瞬时模式。
  • 波动率预测: 建模金融市场复杂的“波动率微笑”现象。
  • 金融文本分析: 从新闻和财报中提取市场情绪。

灵感源泉:生物神经元

神经网络的最初灵感来源于人脑的结构——一个由数十亿个相互连接的神经元组成的复杂网络。

生物神经元结构图 一个简化的生物神经元图,标注了树突、细胞核和轴突。 细胞核 (处理单元) 轴突 (输出) 树突 (输入) 输入信号 输出信号

映射:从生物到人工

这个简单的“接收-处理-传递”机制是构建复杂智能的基础。我们可以将它抽象成一个数学模型。

生物类比 人工神经元 (Perceptron) 功能
树突 (Dendrite) 输入 (Inputs) \(x_i\) 接收信息
细胞核 (Nucleus) 处理单元 整合信息并决策
轴突 (Axon) 输出 (Output) 传递结果

人工神经元:感知器 (Perceptron)

一个最基本的人工神经元(感知器)执行两个核心步骤:

  1. 加权求和: 将所有输入特征 \(x_i\) 与其对应的权重 \(w_i\) 相乘,并加上一个偏置项 \(b\)
    • 这本质上是一个线性回归的计算。
  2. 激活: 将加权和的结果传入一个激活函数 \(g(\cdot)\),得到最终输出。
    • 激活函数决定神经元是否“激发”以及如何激发。

感知器的数学表达

\[ \large{ \underbrace{z = (w_1 x_1 + w_2 x_2 + \dots + w_n x_n) + b}_{\text{1. 加权求和}} = \mathbf{w}^T \mathbf{x} + b } \]

\[ \large{ \underbrace{\text{output} = g(z)}_{\text{2. 激活}} = g(\mathbf{w}^T \mathbf{x} + b) } \]

这个简单的结构,实际上就是我们之前学过的逻辑回归模型(当激活函数是 Sigmoid 时)。

感知器的局限

一个单独的感知器只能解决线性可分的问题。

线性可分与非线性可分问题 两个图表,左边显示可以用一条直线分开的两组点,右边显示无法用直线分开的两组点。 线性可分 (感知器可解决) 非线性可分 (感知器无法解决) ?

要解决更复杂的问题,我们需要将多个神经元组合起来,并引入非线性

激活函数:赋予网络非线性

激活函数 (Activation Function) 是神经网络能够学习复杂模式的关键。它在加权求和之后引入了非线性变换。

它决定了一个神经元的输出,是神经网络的“节拍器”。

思想实验:没有非线性会怎样?

如果激活函数是线性的(例如,\(g(z)=z\)),那么无论你堆叠多少层神经网络,最终的结果都等同于一个单一的线性变换

\[ \large{ \begin{aligned} \text{Layer}_2(\text{Layer}_1(\mathbf{x})) &= (\mathbf{W}_2 (\mathbf{W}_1 \mathbf{x} + \mathbf{b}_1) + \mathbf{b}_2) \\ &= (\mathbf{W}_2 \mathbf{W}_1) \mathbf{x} + (\mathbf{W}_2 \mathbf{b}_1 + \mathbf{b}_2) \\ &= \mathbf{W}' \mathbf{x} + \mathbf{b}' \end{aligned} } \]

一个复杂的、多层的线性模型仍然只是一个线性模型,无法捕捉现实世界中普遍存在的非线性关系。

常用激活函数(1): Sigmoid

Sigmoid 函数将任意实数输入压缩到 \((0, 1)\) 区间内。

\[ \large{ g(z) = \frac{1}{1 + e^{-z}} } \]

  • 应用: 常用于二元分类问题的输出层,因为其输出可以被解释为概率。
  • 缺点: 在隐藏层中已不常用。

Sigmoid的缺点:梯度消失

Sigmoid 函数在输入值很大或很小时,其导数(梯度)趋近于0。在深度网络中,这会导致梯度消失 (Vanishing Gradients) 问题,使得网络深层的参数难以更新,训练过程极其缓慢。

梯度消失问题图解 一个Sigmoid函数图,其两端平坦区域被标记为“梯度消失区”,梯度在此区域接近于0。 z g(z) 0 0.5 1.0 梯度消失区 (导数 ≈ 0) 梯度消失区 (导数 ≈ 0)

常用激活函数(2): ReLU

ReLU (Rectified Linear Unit) 是目前深度学习中最常用的激活函数,形式极其简单。

\[ \large{ g(z) = \max(0, z) } \]

  • 优点:
    • 计算速度极快。
    • 有效缓解了梯度消失问题(当 \(z > 0\) 时,梯度恒为1)。

ReLU的优点

当输入 \(z>0\) 时,ReLU是一个线性函数;当 \(z \le 0\) 时,输出为0。

这种分段线性的特性,使得多个ReLU神经元组合起来可以拟合出任意复杂的非线性函数。

可以把它想象成一个简单的“开关”:只有当输入信号足够强(大于0)时,它才允许信号通过。

激活函数可视化比较

从图中可以清晰地看到两种函数的区别。ReLU 的单侧不饱和特性是其在现代深度学习中占据主导地位的关键。

import numpy as np

def generate_activations_svg():
    z = np.linspace(-6, 6, 300)
    sigmoid = 1 / (1 + np.exp(-z))
    relu = np.maximum(0, z)

    # Convert data points to SVG path strings
    sigmoid_path = "M " + " ".join([f"{x*50+400},{160-y*200}" for x, y in zip(z, sigmoid)])
    relu_path = "M " + " ".join([f"{x*50+400},{260-y*40}" for x, y in zip(z, relu)])
    
    svg = f'''
    <svg viewBox="0 0 800 320" style="max-width: 100%; height: auto;">
      <title>Sigmoid 与 ReLU 激活函数</title>
      <desc>Sigmoid函数呈S形,将值压缩到0和1之间。ReLU函数在输入小于0时为0,大于0时呈线性增长。</desc>
      <style>
        .axis {{ stroke: #3c4043; stroke-width: 2; }}
        .grid {{ stroke: #e0e0e0; stroke-width: 1; stroke-dasharray: 4,4; }}
        .label {{ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; font-size: 16px; text-anchor: middle; fill: #3c4043; }}
        .title {{ font-size: 20px; font-weight: bold; }}
        .sigmoid-line {{ stroke: #4285F4; stroke-width: 3.5; fill: none; }}
        .relu-line {{ stroke: #EA4335; stroke-width: 3.5; fill: none; }}
      </style>
      
      <text x="400" y="35" class="label title">ReLU的简洁性与效率使其成为现代神经网络的首选</text>

      <!-- Grid and Axes -->
      <line x1="100" y1="260" x2="700" y2="260" class="axis"/>
      <line x1="400" y1="60" x2="400" y2="280" class="axis"/>
      <line x1="100" y1="160" x2="700" y2="160" class="grid"/>
      <line x1="100" y1="60" x2="700" y2="60" class="grid"/>
      <text x="715" y="265" class="label">z</text>
      <text x="380" y="55" class="label">g(z)</text>
      <text x="400" y="280" class="label">0</text>
      <text x="380" y="165" class="label">0.5</text>
      <text x="380" y="65" class="label">1.0</text>

      <!-- Curves -->
      <path d="{sigmoid_path}" class="sigmoid-line"/>
      <path d="{relu_path}" class="relu-line"/>

      <!-- Legend -->
      <rect x="520" y="70" width="20" height="5" fill="#4285F4" />
      <text x="615" y="78" class="label">Sigmoid (梯度易消失)</text>
      <rect x="520" y="100" width="20" height="5" fill="#EA4335" />
      <text x="600" y="108" class="label">ReLU (现代标准)</text>
    </svg>
    '''
    return svg

print(f"```{'{=html}'}\n{generate_activations_svg()}\n```")
```{=html}

    <svg viewBox="0 0 800 320" style="max-width: 100%; height: auto;">
      <title>Sigmoid 与 ReLU 激活函数</title>
      <desc>Sigmoid函数呈S形,将值压缩到0和1之间。ReLU函数在输入小于0时为0,大于0时呈线性增长。</desc>
      <style>
        .axis { stroke: #3c4043; stroke-width: 2; }
        .grid { stroke: #e0e0e0; stroke-width: 1; stroke-dasharray: 4,4; }
        .label { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; font-size: 16px; text-anchor: middle; fill: #3c4043; }
        .title { font-size: 20px; font-weight: bold; }
        .sigmoid-line { stroke: #4285F4; stroke-width: 3.5; fill: none; }
        .relu-line { stroke: #EA4335; stroke-width: 3.5; fill: none; }
      </style>
      
      <text x="400" y="35" class="label title">ReLU的简洁性与效率使其成为现代神经网络的首选</text>

      <!-- Grid and Axes -->
      <line x1="100" y1="260" x2="700" y2="260" class="axis"/>
      <line x1="400" y1="60" x2="400" y2="280" class="axis"/>
      <line x1="100" y1="160" x2="700" y2="160" class="grid"/>
      <line x1="100" y1="60" x2="700" y2="60" class="grid"/>
      <text x="715" y="265" class="label">z</text>
      <text x="380" y="55" class="label">g(z)</text>
      <text x="400" y="280" class="label">0</text>
      <text x="380" y="165" class="label">0.5</text>
      <text x="380" y="65" class="label">1.0</text>

      <!-- Curves -->
      <path d="M 100.0,159.50547536867305 102.0066889632107,159.4852766917202 104.01337792642141,159.46425522267052 106.02006688963212,159.4423776297535 108.02675585284283,159.41960924629225 110.03344481605353,159.39591401852468 112.04013377926418,159.3712544514928 114.04682274247489,159.34559155293735 116.0535117056856,159.3188847751346 118.0602006688963,159.2910919546102 120.066889632107,159.26216924966457 122.07357859531771,159.23207107564232 124.08026755852842,159.2007500378792 126.08695652173913,159.16815686225644 128.09364548494983,159.13424032329462 130.10033444816054,159.09894716971635 132.1070234113712,159.0622220474081 134.11371237458195,159.02400741971044 136.1204013377926,158.98424348496636 138.12709030100336,158.94286809125768 140.13377926421407,158.89981664825908 142.14046822742478,158.85502203614155 144.14715719063545,158.80841451145616 146.15384615384616,158.75992160993223 148.16053511705687,158.70946804612433 150.16722408026757,158.65697560984518 152.17391304347828,158.60236305932418 154.18060200668896,158.54554601103393 156.18729096989966,158.48643682613104 158.19397993311037,158.42494449346114 160.20066889632108,158.36097450908346 162.20735785953178,158.294428752275 164.2140468227425,158.22520535798103 166.22073578595317,158.15319858568546 168.22742474916387,158.0782986846822 170.23411371237458,158.00039175573792 172.2408026755853,157.91935960914603 174.247491638796,157.83507961918258 176.25418060200667,157.74742457498724 178.26086956521738,157.65626252790548 180.26755852842808,157.56145663534286 182.2742474916388,157.46286500119874 184.2809364548495,157.3603405129642 186.2876254180602,157.25373067558863 188.29431438127094,157.14287744224058 190.3010033444816,157.0276170421113 192.30769230769232,156.9077798054353 194.314381270903,156.78318998592871 196.32107023411373,156.6536655808764 198.32775919732438,156.51901814913083 200.33444816053512,156.37905262731982 202.34113712374582,156.23356714459788 204.34782608695653,156.08235283631637 206.3545150501672,155.92519365702975 208.36120401337791,155.76186619330227 210.36789297658862,155.59213947682872 212.37458193979936,155.41577479843528 214.38127090301003,155.23252552358306 216.38795986622074,155.0421369100568 218.39464882943145,154.84434592858474 220.40133779264215,154.63888108720337 222.40802675585286,154.42546226025115 224.41471571906357,154.20380052295204 226.42140468227424,153.973597992627 228.42809364548495,153.7345476776567 230.43478260869566,153.48633333540388 232.44147157190636,153.2286293403964 234.44816053511707,152.9611005641654 236.45484949832775,152.68340226823264 238.46153846153845,152.39518001184098 240.46822742474916,152.096069576129 242.47491638795987,151.78569690655644 244.48160535117057,151.46367807549805 246.48829431438128,151.129619267035 248.49498327759198,150.7831167860859 250.5016722408027,150.42375709413338 252.5083612040134,150.05111687391468 254.51505016722408,149.66476312555784 256.52173913043475,149.26425329675368 258.5284280936455,148.84913544966133 260.53511705685617,148.41894846734502 262.54180602006693,147.97322230263694 264.5484949832776,147.51147827240607 266.5551839464883,147.03322940029196 268.561872909699,146.53798081102724 270.5685618729097,146.02523017952436 272.5752508361204,145.49446823793724 274.5819397993311,144.9451793439257 276.5886287625418,144.37684211334437 278.5953177257525,143.78893012054976 280.60200668896323,143.18091266946152 282.60869565217394,142.5522556384276 284.61538461538464,141.9024224018218 286.6220735785953,141.23087483114577 288.62876254180605,140.53707437820842 290.6354515050167,139.82048324271616 292.6421404682274,139.08056562631862 294.6488294314381,138.31678907481722 296.6555183946488,137.52862590985364 298.66220735785953,136.71555475094942 300.66889632107024,135.8770621282638 302.67558528428094,135.01264418587388 304.68227424749165,134.12180847475634 306.68896321070235,133.20407583396266 308.695652173913,132.25898235773113 310.70234113712377,131.28608144546854 312.7090301003344,130.28494593066384 314.7157190635451,129.25517028386884 316.72240802675583,128.19637288390038 318.72909698996654,127.1081983503899 320.73578595317724,125.99031992973586 322.74247491638795,124.84244192541064 324.74916387959865,123.6643021624451 326.7558528428094,122.45567447477276 328.76254180602007,121.21637120297194 330.7692307692308,119.94624568881576 332.7759197324415,118.64519475193782 334.7826086956522,117.3131611328653 336.7892976588629,115.95013588567878 338.7959866220736,114.5561607026464 340.8026755852843,113.13133015237221 342.809364548495,111.67579381231106 344.8160535117057,110.18975827595884 346.8227424749164,108.6734890146472 348.82943143812713,107.12731207367304 350.8361204013378,105.55161558249762 352.8428093645485,103.94685105897328 354.8494983277592,102.31353448801158 356.8561872909699,100.65224715581036 358.8628762541806,98.9636362217157 360.8695652173913,97.2484150110168 362.876254180602,95.50736301345538 364.8829431438127,93.74132557397957 366.88963210702343,91.95121326427288 368.89632107023414,90.13800092583365 370.90301003344484,88.3027263778507 372.9096989966555,86.4464887857955 374.9163879598662,84.57044668950239 376.9230769230769,82.67581569250518 378.9297658862876,80.7638658175048 380.9364548494983,78.83591853601787 382.943143812709,76.89334348345875 384.94983277591973,74.93755487408987 386.95652173913044,72.97000763339207 388.96321070234114,70.99219326840968 390.96989966555185,69.0056354994653 392.97658862876256,67.01188567927362 394.98327759197326,65.01251802786408 396.98996655518397,63.00912471381467 398.9966555183947,61.003310814059745 401.0033444816054,58.9966891859402 403.0100334448161,56.990875286185286 405.0167224080268,54.987481972135896 407.0234113712375,52.98811432072634 409.03010033444815,50.99436450053466 411.03678929765886,49.00780673159032 413.04347826086956,47.02999236660793 415.05016722408027,45.062445125910145 417.056856187291,43.106656516541236 419.0635451505017,41.16408146398216 421.0702341137124,39.23613418249519 423.0769230769231,37.32418430749479 425.0836120401338,35.42955331049761 427.0903010033445,33.553511214204505 429.09698996655516,31.697273622149282 431.10367892976586,29.861999074166334 433.11036789297657,28.048786735727106 435.1170568561873,26.258674426020434 437.123745819398,24.492636986544625 439.1304347826087,22.751584988983183 441.1371237458194,21.036363778284283 443.1438127090301,19.34775284418967 445.1505016722408,17.686465511988388 447.1571906354515,16.053148941026734 449.1638795986622,14.448384417502382 451.17056856187287,12.872687926326961 453.17725752508363,11.32651098535274 455.18394648829434,9.810241724041163 457.19063545150505,8.324206187688901 459.19732441471575,6.868669847627757 461.20401337792646,5.443839297353577 463.21070234113716,4.0498641143211955 465.21739130434787,2.686838867134668 467.2240802675585,1.3548052480621493 469.2307692307693,0.05375431118420693 471.23745819397993,-1.2163712029719704 473.24414715719064,-2.455674474772792 475.25083612040135,-3.6643021624451535 477.25752508361205,-4.842441925410611 479.26421404682276,-5.990319929735847 481.27090301003346,-7.108198350389898 483.27759197324417,-8.19637288390038 485.2842809364549,-9.255170283868836 487.2909698996656,-10.284945930663866 489.29765886287623,-11.28608144546854 491.304347826087,-12.258982357731128 493.31103678929765,-13.204075833962634 495.31772575250835,-14.121808474756335 497.32441471571906,-15.012644185873881 499.33110367892976,-15.87706212826376 501.33779264214047,-16.715554750949423 503.34448160535123,-17.528625909853673 505.3511705685619,-18.316789074817223 507.35785953177265,-19.080565626318617 509.3645484949833,-19.820483242716165 511.371237458194,-20.537074378208416 513.3779264214047,-21.23087483114577 515.3846153846155,-21.902422401821838 517.3913043478261,-22.552255638427596 519.3979933110368,-23.180912669461577 521.4046822742475,-23.788930120549765 523.4113712374582,-24.376842113344395 525.4180602006688,-24.9451793439257 527.4247491638796,-25.49446823793724 529.4314381270902,-26.025230179524357 531.438127090301,-26.537980811027268 533.4448160535117,-27.03322940029193 535.4515050167224,-27.511478272406094 537.4581939799331,-27.97322230263694 539.4648829431438,-28.418948467345047 541.4715719063545,-28.849135449661304 543.4782608695652,-29.26425329675368 545.4849498327759,-29.66476312555784 547.4916387959867,-30.05111687391468 549.4983277591973,-30.423757094133407 551.5050167224081,-30.783116786085913 553.5117056856188,-31.129619267034997 555.5183946488294,-31.46367807549808 557.5250836120401,-31.785696906556467 559.5317725752509,-32.096069576129 561.5384615384615,-32.39518001184095 563.5451505016722,-32.68340226823261 565.551839464883,-32.961100564165434 567.5585284280936,-33.22862934039637 569.5652173913044,-33.486333335403884 571.571906354515,-33.73454767765671 573.5785953177258,-33.97359799262699 575.5852842809364,-34.20380052295204 577.5919732441472,-34.42546226025115 579.5986622073578,-34.638881087203345 581.6053511705686,-34.84434592858477 583.6120401337793,-35.0421369100568 585.61872909699,-35.23252552358309 587.6254180602007,-35.41577479843525 589.6321070234114,-35.59213947682869 591.6387959866221,-35.76186619330227 593.6454849498327,-35.92519365702975 595.6521739130435,-36.0823528363164 597.6588628762543,-36.233567144597885 599.6655518394648,-36.379052627319794 601.6722408026756,-36.51901814913086 603.6789297658863,-36.65366558087641 605.685618729097,-36.783189985928686 607.6923076923076,-36.907779805435325 609.6989966555184,-37.0276170421113 611.705685618729,-37.142877442240575 613.7123745819398,-37.25373067558863 615.7190635451506,-37.36034051296417 617.7257525083612,-37.462865001198765 619.732441471572,-37.56145663534289 621.7391304347826,-37.65626252790548 623.7458193979934,-37.74742457498721 625.752508361204,-37.83507961918255 627.7591973244148,-37.91935960914603 629.7658862876254,-38.00039175573795 631.7725752508361,-38.07829868468221 633.7792642140469,-38.15319858568543 635.7859531772576,-38.225205357981 637.7926421404682,-38.294428752274996 639.7993311036789,-38.36097450908346 641.8060200668897,-38.42494449346114 643.8127090301003,-38.48643682613107 645.819397993311,-38.545546011033935 647.8260869565217,-38.602363059324176 649.8327759197324,-38.656975609845176 651.8394648829432,-38.70946804612436 653.8461538461538,-38.759921609932206 655.8528428093646,-38.808414511456135 657.8595317725752,-38.85502203614155 659.866220735786,-38.899816648259076 661.8729096989966,-38.94286809125768 663.8795986622074,-38.984243484966385 665.886287625418,-39.02400741971044 667.8929765886288,-39.06222204740811 669.8996655518395,-39.098947169716354 671.9063545150502,-39.13424032329462 673.9130434782608,-39.16815686225644 675.9197324414715,-39.200750037879175 677.9264214046823,-39.232071075642324 679.933110367893,-39.26216924966457 681.9397993311038,-39.29109195461021 683.9464882943143,-39.31888477513456 685.9531772575251,-39.34559155293735 687.9598662207359,-39.37125445149283 689.9665551839465,-39.39591401852471 691.9732441471572,-39.419609246292225 693.9799331103679,-39.44237762975351 695.9866220735786,-39.464255222670516 697.9933110367894,-39.48527669172017 700.0,-39.505475368673075" class="sigmoid-line"/>
      <path d="M 100.0,260.0 102.0066889632107,260.0 104.01337792642141,260.0 106.02006688963212,260.0 108.02675585284283,260.0 110.03344481605353,260.0 112.04013377926418,260.0 114.04682274247489,260.0 116.0535117056856,260.0 118.0602006688963,260.0 120.066889632107,260.0 122.07357859531771,260.0 124.08026755852842,260.0 126.08695652173913,260.0 128.09364548494983,260.0 130.10033444816054,260.0 132.1070234113712,260.0 134.11371237458195,260.0 136.1204013377926,260.0 138.12709030100336,260.0 140.13377926421407,260.0 142.14046822742478,260.0 144.14715719063545,260.0 146.15384615384616,260.0 148.16053511705687,260.0 150.16722408026757,260.0 152.17391304347828,260.0 154.18060200668896,260.0 156.18729096989966,260.0 158.19397993311037,260.0 160.20066889632108,260.0 162.20735785953178,260.0 164.2140468227425,260.0 166.22073578595317,260.0 168.22742474916387,260.0 170.23411371237458,260.0 172.2408026755853,260.0 174.247491638796,260.0 176.25418060200667,260.0 178.26086956521738,260.0 180.26755852842808,260.0 182.2742474916388,260.0 184.2809364548495,260.0 186.2876254180602,260.0 188.29431438127094,260.0 190.3010033444816,260.0 192.30769230769232,260.0 194.314381270903,260.0 196.32107023411373,260.0 198.32775919732438,260.0 200.33444816053512,260.0 202.34113712374582,260.0 204.34782608695653,260.0 206.3545150501672,260.0 208.36120401337791,260.0 210.36789297658862,260.0 212.37458193979936,260.0 214.38127090301003,260.0 216.38795986622074,260.0 218.39464882943145,260.0 220.40133779264215,260.0 222.40802675585286,260.0 224.41471571906357,260.0 226.42140468227424,260.0 228.42809364548495,260.0 230.43478260869566,260.0 232.44147157190636,260.0 234.44816053511707,260.0 236.45484949832775,260.0 238.46153846153845,260.0 240.46822742474916,260.0 242.47491638795987,260.0 244.48160535117057,260.0 246.48829431438128,260.0 248.49498327759198,260.0 250.5016722408027,260.0 252.5083612040134,260.0 254.51505016722408,260.0 256.52173913043475,260.0 258.5284280936455,260.0 260.53511705685617,260.0 262.54180602006693,260.0 264.5484949832776,260.0 266.5551839464883,260.0 268.561872909699,260.0 270.5685618729097,260.0 272.5752508361204,260.0 274.5819397993311,260.0 276.5886287625418,260.0 278.5953177257525,260.0 280.60200668896323,260.0 282.60869565217394,260.0 284.61538461538464,260.0 286.6220735785953,260.0 288.62876254180605,260.0 290.6354515050167,260.0 292.6421404682274,260.0 294.6488294314381,260.0 296.6555183946488,260.0 298.66220735785953,260.0 300.66889632107024,260.0 302.67558528428094,260.0 304.68227424749165,260.0 306.68896321070235,260.0 308.695652173913,260.0 310.70234113712377,260.0 312.7090301003344,260.0 314.7157190635451,260.0 316.72240802675583,260.0 318.72909698996654,260.0 320.73578595317724,260.0 322.74247491638795,260.0 324.74916387959865,260.0 326.7558528428094,260.0 328.76254180602007,260.0 330.7692307692308,260.0 332.7759197324415,260.0 334.7826086956522,260.0 336.7892976588629,260.0 338.7959866220736,260.0 340.8026755852843,260.0 342.809364548495,260.0 344.8160535117057,260.0 346.8227424749164,260.0 348.82943143812713,260.0 350.8361204013378,260.0 352.8428093645485,260.0 354.8494983277592,260.0 356.8561872909699,260.0 358.8628762541806,260.0 360.8695652173913,260.0 362.876254180602,260.0 364.8829431438127,260.0 366.88963210702343,260.0 368.89632107023414,260.0 370.90301003344484,260.0 372.9096989966555,260.0 374.9163879598662,260.0 376.9230769230769,260.0 378.9297658862876,260.0 380.9364548494983,260.0 382.943143812709,260.0 384.94983277591973,260.0 386.95652173913044,260.0 388.96321070234114,260.0 390.96989966555185,260.0 392.97658862876256,260.0 394.98327759197326,260.0 396.98996655518397,260.0 398.9966555183947,260.0 401.0033444816054,259.1973244147157 403.0100334448161,257.59197324414714 405.0167224080268,255.9866220735786 407.0234113712375,254.38127090301003 409.03010033444815,252.77591973244145 411.03678929765886,251.1705685618729 413.04347826086956,249.56521739130434 415.05016722408027,247.95986622073576 417.056856187291,246.3545150501672 419.0635451505017,244.74916387959865 421.0702341137124,243.1438127090301 423.0769230769231,241.53846153846155 425.0836120401338,239.93311036789297 427.0903010033445,238.3277591973244 429.09698996655516,236.72240802675586 431.10367892976586,235.11705685618728 433.11036789297657,233.51170568561872 435.1170568561873,231.90635451505017 437.123745819398,230.30100334448161 439.1304347826087,228.69565217391306 441.1371237458194,227.09030100334448 443.1438127090301,225.48494983277592 445.1505016722408,223.87959866220737 447.1571906354515,222.2742474916388 449.1638795986622,220.66889632107024 451.17056856187287,219.06354515050168 453.17725752508363,217.45819397993307 455.18394648829434,215.85284280936452 457.19063545150505,214.24749163879596 459.19732441471575,212.6421404682274 461.20401337792646,211.03678929765886 463.21070234113716,209.43143812709027 465.21739130434787,207.82608695652172 467.2240802675585,206.22073578595317 469.2307692307693,204.61538461538458 471.23745819397993,203.01003344481603 473.24414715719064,201.40468227424748 475.25083612040135,199.79933110367892 477.25752508361205,198.19397993311037 479.26421404682276,196.5886287625418 481.27090301003346,194.98327759197323 483.27759197324417,193.37792642140468 485.2842809364549,191.7725752508361 487.2909698996656,190.16722408026754 489.29765886287623,188.561872909699 491.304347826087,186.95652173913044 493.31103678929765,185.35117056856188 495.31772575250835,183.7458193979933 497.32441471571906,182.14046822742475 499.33110367892976,180.5351170568562 501.33779264214047,178.9297658862876 503.34448160535123,177.32441471571903 505.3511705685619,175.7190635451505 507.35785953177265,174.1137123745819 509.3645484949833,172.5083612040134 511.371237458194,170.9030100334448 513.3779264214047,169.29765886287626 515.3846153846155,167.69230769230768 517.3913043478261,166.08695652173913 519.3979933110368,164.48160535117054 521.4046822742475,162.87625418060202 523.4113712374582,161.2709030100334 525.4180602006688,159.6655518394649 527.4247491638796,158.0602006688963 529.4314381270902,156.45484949832777 531.438127090301,154.8494983277592 533.4448160535117,153.24414715719064 535.4515050167224,151.63879598662206 537.4581939799331,150.03344481605353 539.4648829431438,148.42809364548492 541.4715719063545,146.82274247491642 543.4782608695652,145.2173913043478 545.4849498327759,143.6120401337793 547.4916387959867,142.0066889632107 549.4983277591973,140.4013377926421 551.5050167224081,138.79598662207357 553.5117056856188,137.190635451505 555.5183946488294,135.58528428093643 557.5250836120401,133.97993311036785 559.5317725752509,132.37458193979933 561.5384615384615,130.76923076923072 563.5451505016722,129.16387959866222 565.551839464883,127.55852842809361 567.5585284280936,125.95317725752508 569.5652173913044,124.3478260869565 571.571906354515,122.74247491638795 573.5785953177258,121.13712374581937 575.5852842809364,119.53177257525084 577.5919732441472,117.92642140468223 579.5986622073578,116.32107023411373 581.6053511705686,114.71571906354512 583.6120401337793,113.1103678929766 585.61872909699,111.50501672240802 587.6254180602007,109.89966555183946 589.6321070234114,108.29431438127088 591.6387959866221,106.68896321070235 593.6454849498327,105.08361204013374 595.6521739130435,103.47826086956525 597.6588628762543,101.87290969899664 599.6655518394648,100.26755852842811 601.6722408026756,98.66220735785953 603.6789297658863,97.05685618729098 605.685618729097,95.4515050167224 607.6923076923076,93.84615384615387 609.6989966555184,92.24080267558526 611.705685618729,90.63545150501676 613.7123745819398,89.03010033444815 615.7190635451506,87.42474916387954 617.7257525083612,85.81939799331104 619.732441471572,84.21404682274243 621.7391304347826,82.6086956521739 623.7458193979934,81.00334448160532 625.752508361204,79.39799331103677 627.7591973244148,77.79264214046819 629.7658862876254,76.18729096989966 631.7725752508361,74.58193979933105 633.7792642140469,72.97658862876256 635.7859531772576,71.37123745819395 637.7926421404682,69.76588628762542 639.7993311036789,68.16053511705684 641.8060200668897,66.55518394648828 643.8127090301003,64.9498327759197 645.819397993311,63.34448160535118 647.8260869565217,61.73913043478257 649.8327759197324,60.13377926421407 651.8394648829432,58.52842809364546 653.8461538461538,56.923076923076934 655.8528428093646,55.31772575250835 657.8595317725752,53.7123745819398 659.866220735786,52.107023411371216 661.8729096989966,50.50167224080269 663.8795986622074,48.89632107023408 665.886287625418,47.29096989966558 667.8929765886288,45.68561872909697 669.8996655518395,44.08026755852845 671.9063545150502,42.474916387959865 673.9130434782608,40.86956521739131 675.9197324414715,39.26421404682273 677.9264214046823,37.658862876254204 679.933110367893,36.053511705685594 681.9397993311038,34.44816053511701 683.9464882943143,32.842809364548486 685.9531772575251,31.237458193979876 687.9598662207359,29.63210702341138 689.9665551839465,28.02675585284277 691.9732441471572,26.421404682274243 693.9799331103679,24.81605351170566 695.9866220735786,23.210702341137107 697.9933110367894,21.605351170568525 700.0,20.0" class="relu-line"/>

      <!-- Legend -->
      <rect x="520" y="70" width="20" height="5" fill="#4285F4" />
      <text x="615" y="78" class="label">Sigmoid (梯度易消失)</text>
      <rect x="520" y="100" width="20" height="5" fill="#EA4335" />
      <text x="600" y="108" class="label">ReLU (现代标准)</text>
    </svg>
    
```
Figure 1

组建网络:神经元层

一个单独的神经元能力有限。神经网络的强大之处在于将许多神经元组织成层 (Layers)

一个标准的前馈神经网络 (Feedforward Neural Network) 通常由三部分组成。

输入层 (Input Layer)

  • 功能: 接收原始数据,是网络的入口。
  • 结构: 每个节点代表一个输入特征(例如,公司的市净率、资产回报率等)。
  • 计算: 不进行任何计算,只是将数据传递给第一个隐藏层。

隐藏层 (Hidden Layers)

  • 功能: 模型的核心,负责从输入数据中提取和转换特征。
  • 结构: 可以有一个或多个隐藏层。层数越多,网络越“深”(Deep)。
  • 计算: 每一层的神经元接收前一层所有神经元的输出,进行加权求和与激活,然后将结果传递给下一层。

输出层 (Output Layer)

  • 功能: 产生最终的预测结果。
  • 结构: 输出节点的数量和激活函数取决于具体任务。
    • 回归问题: 通常是1个节点,使用线性激活函数(即不激活)。
    • 二元分类: 通常是1个节点,使用 Sigmoid 激活函数。
    • 多元分类: 通常是 N 个节点(N为类别数),使用 Softmax 激活函数。

神经网络结构图

一个包含4个输入特征、2个隐藏层(每层5个神经元)和1个输出的神经网络。

前馈神经网络结构图 一个四层神经网络,包括输入层、两个隐藏层和输出层,展示了层与层之间的全连接。 输入层 隐藏层 1 隐藏层 2 输出层 x₁x₂ x₃x₄ ŷ

前向传播:信息在网络中的流动

前向传播 (Forward Propagation) 是指信息从输入层开始,逐层向前流动,直到计算出输出层结果的过程。

它回答了这个问题:“对于给定的输入 \(\mathbf{x}\) 和当前的网络参数 \((\mathbf{W}, \mathbf{b})\),模型的预测结果是什么?”

前向传播的数学表达

符号约定: * \(\mathbf{x}\): 输入特征向量。 * \(\mathbf{a}^{(l)}\): 第 \(l\) 层的激活输出向量(\(\mathbf{a}^{(0)} = \mathbf{x}\))。 * \(\mathbf{W}^{(l)}\): 从第 \(l-1\) 层到第 \(l\) 层的权重矩阵。 * \(\mathbf{b}^{(l)}\): 第 \(l\) 层的偏置向量。

从层 \(l-1\) 到层 \(l\) 的计算: 1. 加权和: \(\large{ \mathbf{z}^{(l)} = \mathbf{W}^{(l)} \mathbf{a}^{(l-1)} + \mathbf{b}^{(l)} }\) 2. 激活: \(\large{ \mathbf{a}^{(l)} = g(\mathbf{z}^{(l)}) }\)

这个过程从第一层开始,逐层重复,直到计算出最后一层(输出层)的结果 \(\hat{y} = \mathbf{a}^{(L)}\)

超能力:万能近似定理

一个惊人的理论结果: 一个包含单个隐藏层和非线性激活函数的前馈神经网络,只要隐藏层有足够多的神经元,就可以以任意精度近似任何一个连续函数

定理的直观理解

  • 每个使用 ReLU 激活函数的神经元可以被看作一个“开关”,它在某个点上改变函数的斜率,像一个“铰链”。
  • 通过组合足够多的这种“铰链”,我们可以构建出任意形状的分段线性函数。
  • 这就像用足够多的短直线段来逼近一条任意复杂的曲线。

这就是神经网络能够拟合复杂数据模式的理论基础。

万能近似定理的可视化

神经网络通过组合多个简单的 ReLU 函数(下图)来拟合一个复杂的非线性函数(上图)。

万能近似定理可视化 上图展示了用分段线性函数拟合曲线数据。下图展示了三个独立的ReLU函数,它们的和构成了上图的拟合线。 最终拟合结果: $\hat{y} = \sum \text{ReLU}_i$ 隐藏层神经元的贡献 (3个ReLU神经元) 神经元1 神经元2 (负权重) 神经元3

训练目标:最小化代价函数

我们如何知道网络的预测是“好”还是“坏”?我们需要一个代价函数 (Cost Function)损失函数 (Loss Function) \(J(\mathbf{W}, \mathbf{b})\) 来量化预测值与真实值之间的差距。

训练神经网络的目标就是找到一组最优的权重 \(\mathbf{W}\) 和偏置 \(\mathbf{b}\),使得代价函数 \(J\) 的值最小。

代价函数:回归任务

对于预测连续值(如股价、GDP增长率)的回归问题,最常用的是均方误差 (Mean Squared Error, MSE)

\[ \large{ J(\mathbf{W}, \mathbf{b}) = \frac{1}{n} \sum_{i=1}^{n} (\hat{y}^{(i)} - y^{(i)})^2 } \]

其中 \(\hat{y}^{(i)}\) 是网络对第 \(i\) 个样本的预测值,\(y^{(i)}\) 是真实值。

代价函数:分类任务

对于预测离散类别(如违约/不违约)的分类问题,常用的是交叉熵 (Cross-Entropy)

\[ \large{ J(\mathbf{W}, \mathbf{b}) = -\frac{1}{n} \sum_{i=1}^{n} [y^{(i)} \log(\hat{y}^{(i)}) + (1 - y^{(i)}) \log(1 - \hat{y}^{(i)})] } \]

它衡量的是两个概率分布(模型的预测概率和真实的标签概率)之间的差异。

如何寻找最优参数?

代价函数 \(J(\mathbf{W}, \mathbf{b})\) 通常是一个关于数百万个参数的极其复杂的非凸函数,存在许多局部最小值。我们无法用解析方法找到它的全局最小值。

import numpy as np

def generate_nonconvex_svg():
    x = np.linspace(-4.5, 3.5, 200)
    y = 0.1*x**4 + 0.2*x**3 - 2*x**2 - 2*x + 10
    
    # Scale for SVG viewbox
    x_svg = (x + 4.5) / 8 * 600 + 100
    y_svg = 280 - (y - np.min(y)) / (np.max(y) - np.min(y)) * 220
    
    path_data = "M " + " ".join([f"{px:.2f},{py:.2f}" for px, py in zip(x_svg, y_svg)])
    
    # Minima points
    min1_x, min1_y = -3.2, 0.1*(-3.2)**4 + 0.2*(-3.2)**3 - 2*(-3.2)**2 - 2*(-3.2) + 10
    min2_x, min2_y = 2.0, 0.1*(2.0)**4 + 0.2*(2.0)**3 - 2*(2.0)**2 - 2*(2.0) + 10
    
    min1_x_svg = (-3.2 + 4.5) / 8 * 600 + 100
    min1_y_svg = 280 - (min1_y - np.min(y)) / (np.max(y) - np.min(y)) * 220
    min2_x_svg = (2.0 + 4.5) / 8 * 600 + 100
    min2_y_svg = 280 - (min2_y - np.min(y)) / (np.max(y) - np.min(y)) * 220

    svg = f'''
    <svg viewBox="0 0 800 320" style="max-width: 100%; height: auto;">
      <title>非凸代价函数</title>
      <desc>一条具有多个局部最小值的曲线,说明了优化问题的复杂性。</desc>
      <style>
        .curve {{ fill: none; stroke: #4285F4; stroke-width: 3.5; }}
        .axis {{ stroke: #3c4043; stroke-width: 2; }}
        .label {{ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; font-size: 16px; text-anchor: middle; fill: #3c4043; }}
        .title {{ font-size: 20px; font-weight: bold; }}
        .min-point {{ fill: #EA4335; }}
        .min-label {{ fill: #c0392b; font-size: 15px; }}
        .global-min-label {{ fill: #27ae60; font-weight: bold; }}
      </style>

      <text x="400" y="40" class="label title">非凸代价函数的地形</text>
      <!-- Axes -->
      <line x1="80" y1="280" x2="720" y2="280" class="axis" />
      <line x1="80" y1="50" x2="80" y2="280" class="axis" />
      <text x="400" y="305" class="label">参数值 (例如某个 w)</text>
      <text x="50" y="165" transform="rotate(-90, 50, 165)" class="label">代价 J(w)</text>
      
      <!-- Curve -->
      <path d="{path_data}" class="curve"/>
      
      <!-- Minima -->
      <circle cx="{min1_x_svg}" cy="{min1_y_svg}" r="7" class="min-point"/>
      <text x="{min1_x_svg}" y="{min1_y_svg + 25}" class="label min-label">局部最小值</text>
      <circle cx="{min2_x_svg}" cy="{min2_y_svg}" r="7" class="min-point"/>
      <text x="{min2_x_svg}" y="{min2_y_svg + 25}" class="label min-label">局部最小值</text>
      
      <text x="{min2_x_svg}" y="{min2_y_svg - 20}" class="label global-min-label">全局最小值</text>
      <path d="M {min2_x_svg} {min2_y_svg - 15} L {min2_x_svg} {min2_y_svg-5}" stroke="#27ae60" stroke-width="2"/>
    </svg>
    '''
    return svg

print(f"```{'{=html}'}\n{generate_nonconvex_svg()}\n```")
```{=html}

    <svg viewBox="0 0 800 320" style="max-width: 100%; height: auto;">
      <title>非凸代价函数</title>
      <desc>一条具有多个局部最小值的曲线,说明了优化问题的复杂性。</desc>
      <style>
        .curve { fill: none; stroke: #4285F4; stroke-width: 3.5; }
        .axis { stroke: #3c4043; stroke-width: 2; }
        .label { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; font-size: 16px; text-anchor: middle; fill: #3c4043; }
        .title { font-size: 20px; font-weight: bold; }
        .min-point { fill: #EA4335; }
        .min-label { fill: #c0392b; font-size: 15px; }
        .global-min-label { fill: #27ae60; font-weight: bold; }
      </style>

      <text x="400" y="40" class="label title">非凸代价函数的地形</text>
      <!-- Axes -->
      <line x1="80" y1="280" x2="720" y2="280" class="axis" />
      <line x1="80" y1="50" x2="80" y2="280" class="axis" />
      <text x="400" y="305" class="label">参数值 (例如某个 w)</text>
      <text x="50" y="165" transform="rotate(-90, 50, 165)" class="label">代价 J(w)</text>
      
      <!-- Curve -->
      <path d="M 100.00,230.32 103.02,236.28 106.03,241.80 109.05,246.89 112.06,251.58 115.08,255.87 118.09,259.76 121.11,263.28 124.12,266.43 127.14,269.23 130.15,271.67 133.17,273.78 136.18,275.57 139.20,277.04 142.21,278.20 145.23,279.07 148.24,279.65 151.26,279.96 154.27,280.00 157.29,279.78 160.30,279.32 163.32,278.62 166.33,277.69 169.35,276.54 172.36,275.18 175.38,273.61 178.39,271.86 181.41,269.92 184.42,267.80 187.44,265.51 190.45,263.07 193.47,260.47 196.48,257.73 199.50,254.85 202.51,251.85 205.53,248.72 208.54,245.49 211.56,242.14 214.57,238.70 217.59,235.17 220.60,231.56 223.62,227.87 226.63,224.10 229.65,220.28 232.66,216.40 235.68,212.46 238.69,208.49 241.71,204.48 244.72,200.43 247.74,196.36 250.75,192.27 253.77,188.17 256.78,184.06 259.80,179.95 262.81,175.84 265.83,171.74 268.84,167.65 271.86,163.59 274.87,159.54 277.89,155.53 280.90,151.54 283.92,147.60 286.93,143.70 289.95,139.84 292.96,136.04 295.98,132.29 298.99,128.61 302.01,124.98 305.03,121.43 308.04,117.94 311.06,114.53 314.07,111.19 317.09,107.94 320.10,104.78 323.12,101.70 326.13,98.71 329.15,95.82 332.16,93.02 335.18,90.32 338.19,87.72 341.21,85.23 344.22,82.84 347.24,80.57 350.25,78.40 353.27,76.35 356.28,74.41 359.30,72.59 362.31,70.88 365.33,69.30 368.34,67.83 371.36,66.49 374.37,65.28 377.39,64.18 380.40,63.21 383.42,62.37 386.43,61.66 389.45,61.07 392.46,60.61 395.48,60.28 398.49,60.07 401.51,60.00 404.52,60.05 407.54,60.24 410.55,60.55 413.57,60.99 416.58,61.55 419.60,62.25 422.61,63.07 425.63,64.01 428.64,65.08 431.66,66.27 434.67,67.59 437.69,69.02 440.70,70.58 443.72,72.25 446.73,74.04 449.75,75.95 452.76,77.96 455.78,80.09 458.79,82.33 461.81,84.67 464.82,87.12 467.84,89.67 470.85,92.32 473.87,95.06 476.88,97.90 479.90,100.83 482.91,103.85 485.93,106.95 488.94,110.14 491.96,113.40 494.97,116.74 497.99,120.15 501.01,123.62 504.02,127.16 507.04,130.76 510.05,134.41 513.07,138.12 516.08,141.87 519.10,145.67 522.11,149.50 525.13,153.36 528.14,157.26 531.16,161.18 534.17,165.12 537.19,169.07 540.20,173.03 543.22,176.99 546.23,180.95 549.25,184.91 552.26,188.85 555.28,192.77 558.29,196.66 561.31,200.53 564.32,204.35 567.34,208.14 570.35,211.87 573.37,215.55 576.38,219.16 579.40,222.70 582.41,226.17 585.43,229.55 588.44,232.84 591.46,236.04 594.47,239.12 597.49,242.10 600.50,244.96 603.52,247.68 606.53,250.27 609.55,252.72 612.56,255.02 615.58,257.16 618.59,259.13 621.61,260.92 624.62,262.53 627.64,263.94 630.65,265.15 633.67,266.15 636.68,266.94 639.70,267.49 642.71,267.80 645.73,267.87 648.74,267.68 651.76,267.23 654.77,266.49 657.79,265.48 660.80,264.16 663.82,262.55 666.83,260.61 669.85,258.35 672.86,255.76 675.88,252.81 678.89,249.52 681.91,245.85 684.92,241.80 687.94,237.37 690.95,232.53 693.97,227.29 696.98,221.62 700.00,215.51" class="curve"/>
      
      <!-- Minima -->
      <circle cx="197.5" cy="256.773218729469" r="7" class="min-point"/>
      <text x="197.5" y="281.773218729469" class="label min-label">局部最小值</text>
      <circle cx="587.5" cy="231.82373220264267" r="7" class="min-point"/>
      <text x="587.5" y="256.82373220264265" class="label min-label">局部最小值</text>
      
      <text x="587.5" y="211.82373220264267" class="label global-min-label">全局最小值</text>
      <path d="M 587.5 216.82373220264267 L 587.5 226.82373220264267" stroke="#27ae60" stroke-width="2"/>
    </svg>
    
```
Figure 2

梯度下降法的直观比喻

解决方案:梯度下降法 (Gradient Descent)

想象你身处一座大山的浓雾之中,想要走到谷底。

  1. 你看不清整个地形。
  2. 但你可以感知脚下哪个方向是下坡最陡的
  3. 你不断沿着这个最陡峭的方向迈出一小步。
  4. 重复这个过程,最终就能到达一个山谷的底部。

在数学上,这个“最陡峭的下坡方向”就是代价函数梯度的负方向 \((-\nabla J)\)

梯度下降法的更新规则

梯度下降是一个迭代过程。在每一步,我们都按照以下规则更新每一个权重 \(w\) 和偏置 \(b\)

\[ \large{ w := w - \alpha \frac{\partial J}{\partial w} } \]

\[ \large{ b := b - \alpha \frac{\partial J}{\partial b} } \]

  • \(\alpha\)学习率 (Learning Rate),一个超参数,控制我们每一步“走”多远。
  • \(\frac{\partial J}{\partial w}\) 是代价函数对具体某个权重的偏导数,即梯度。

关键超参数:学习率 \(\alpha\)

学习率的选择至关重要,它决定了训练的效率和稳定性。

学习率对梯度下降的影响 三个图表展示了学习率过小、过大和恰到好处时梯度下降的路径。 学习率过小 收敛过慢 学习率过大 震荡,无法收敛 学习率合适 高效收敛

核心挑战:高效计算梯度

直接对神经网络的代价函数求导是一个极其繁琐且计算量巨大的任务,因为网络结构是深度嵌套的。

\[ \large{ J(\mathbf{W}, \mathbf{b}) = \frac{1}{n} \sum_{i=1}^n (y^{(i)} - g^{(L)}( \mathbf{W}^{(L)} g^{(L-1)}(\dots) + \mathbf{b}^{(L)}))^2 } \]

如果网络有数百万个参数,逐个计算偏导数是不可行的。

解决方案:反向传播算法

反向传播 (Backpropagation) 不是一个新的优化算法,它仅仅是一种非常高效地计算所有梯度的算法。

其本质是巧妙地应用微积分中的链式法则,从输出层开始,逐层向后计算梯度。

反向传播的直觉

就像一个公司的绩效评估,从最终的利润(总误差)出发,逐级追溯到每个部门、每个员工(每一层、每个神经元)的绩效贡献。

反向传播的直觉 一个三层网络图,展示了误差从输出层反向传播到输入层的过程。 输入隐藏层输出 = (ŷ - y)² 计算误差 误差反向传播 (按“贡献度”分配)

反向传播的数学概览

我们在此不进行严格的数学推导,但展示其核心思想:链式法则

  1. 计算代价函数对输出层激活值的梯度 \(\frac{\partial J}{\partial a^{(L)}}\)
  2. 利用链式法则,计算对输出层加权和的梯度:\(\large{\frac{\partial J}{\partial z^{(L)}} = \frac{\partial J}{\partial a^{(L)}} \frac{\partial a^{(L)}}{\partial z^{(L)}}}\)
  3. 继续反向传播,计算对前一层激活值的梯度:\(\large{\frac{\partial J}{\partial a^{(L-1)}} = \frac{\partial J}{\partial z^{(L)}} \frac{\partial z^{(L)}}{\partial a^{(L-1)}}}\)
  4. 重复此过程,直到计算出对所有权重 \(\mathbf{W}^{(l)}\) 和偏置 \(\mathbf{b}^{(l)}\) 的梯度。

现代深度学习框架如 PyTorch 会为我们自动完成这一切。

训练效率:梯度下降的变种

标准的梯度下降在每次更新前需要计算整个训练集的梯度,当数据集很大时,这会非常缓慢。因此,实践中我们使用其变种。

批量梯度下降 (BGD)

  • 数据使用: 整个训练集
  • 优点:
    • 代价函数下降路径平滑。
    • 能保证收敛到局部(或全局)最小值。
  • 缺点:
    • 速度非常慢。
    • 内存需求大,无法处理超大规模数据集。

随机梯度下降 (SGD)

  • 数据使用: 每次只用 1 个随机样本
  • 优点:
    • 更新速度快。
    • 随机性有助于“跳出”局部最小值。
  • 缺点:
    • 代价函数下降路径非常嘈杂、震荡。
    • 收敛速度慢,可能永远不会精确收敛。

小批量梯度下降 (Mini-batch GD)

  • 数据使用: 每次用一小批样本 (e.g., 32, 64, 128)
  • 优点:
    • 兼具 BGD 和 SGD 的优点。
    • 通过矩阵运算充分利用 GPU 的并行计算能力。
    • 收敛路径比 SGD 更平滑。
  • 这是现代深度学习的黄金标准。

关键挑战:过拟合 (Overfitting)

由于神经网络模型极其灵活(参数众多),它们很容易在训练数据上“死记硬背”,导致对训练集表现极好,但对未见过的新数据(测试集)表现很差。

过拟合、欠拟合与良好拟合的比较 三个图表展示了模型对数据点的拟合情况:欠拟合(一条直线)、良好拟合(平滑曲线)和过拟合(剧烈波动的曲线)。 欠拟合 (Underfitting) 良好拟合 (Good Fit) 过拟合 (Overfitting)

识别过拟合:训练与验证损失

在训练过程中,我们同时监控模型在训练集验证集上的损失。一个典型的过拟合迹象是:训练损失持续下降,而验证损失在某个点后开始上升。

过拟合的损失曲线 训练损失持续下降,而验证损失在下降后开始上升,两条曲线出现分叉。 训练轮次 (Epochs) 损失值 训练损失 (Training Loss) 验证损失 (Validation Loss) 过拟合开始点 模型泛化能力变差

对抗过拟合(1): 正则化

正则化 (Regularization) 的核心思想是对模型的复杂度进行惩罚

我们在代价函数中加入一个惩罚项,使得模型在拟合数据的同时,倾向于选择更小、更简单的权重。

这鼓励模型学习更平滑、更通用的模式,而不是记住训练数据中的噪声。

L2 正则化 (权重衰减)

在代价函数后加入所有权重平方和的项。这与我们在线性回归中学过的岭回归 (Ridge Regression) 思想完全相同。

\[ \large{ J_{reg}(\mathbf{W}, \mathbf{b}) = \underbrace{J(\mathbf{W}, \mathbf{b})}_{\text{原始损失}} + \underbrace{\frac{\lambda}{2n} \sum_{l} \sum_{j} \sum_{k} (w_{jk}^{(l)})^2}_{\text{惩罚项}} } \]

  • \(\lambda\) 是正则化参数,控制惩罚的强度。
  • 这会促使权重值趋向于0(但不会完全变为0),从而得到一个更“平滑”、更不易过拟合的模型。

对抗过拟合(2): 提前终止 (Early Stopping)

这是一个非常简单但极其有效的正则化方法。

  1. 在训练时,持续监控验证集上的损失。
  2. 当发现验证集上的损失不再下降,甚至开始上升时,就立即停止训练
  3. 保存并使用验证损失最小时的模型参数。
提前终止示意图 损失曲线上,一个垂直的“停止”标志放在验证损失的最低点。 训练轮次损失 STOP! 保存此时的模型

对抗过拟合(3): 随机失活 (Dropout)

Dropout 是一个非常巧妙且强大的正则化技术,专门为神经网络设计。

  • 工作原理: 在每次训练迭代中,以某个概率 \(p\)(例如 \(p=0.5\)随机地“关闭” 隐藏层中的一些神经元。
  • 直观效果:
    • 强迫网络不能过度依赖任何一个神经元。
    • 鼓励网络学习到更鲁棒、更具冗余性的特征
    • 可以被看作是一种高效地训练大量不同网络架构的集成模型的方法。

Dropout 的工作原理

在每次训练迭代中,网络都在使用一个“残缺”的子网络进行学习。

Dropout 工作原理 左侧是完整的神经网络,右侧展示了在一次训练迭代中,部分神经元被随机“关闭”(变灰并划掉)。 训练开始前 (完整网络) 某次迭代中 (随机失活)

实践环节:进入代码世界

理论讲完了,现在是时候动手了。

我们将使用一个真实的金融数据集来解决一个回归问题,并使用两个业界最主流的 Python 库来构建我们的模型。

我们的任务:预测宏观经济

由于原始公司财报数据不易获取,我们将切换到一个更有趣、更贴近经济学专业的真实任务:利用宏观经济指标预测美国季度实际GDP增长率

输入特征 (X): * 失业率, 通胀率, 短期利率

目标变量 (y): * 未来一个季度的实际GDP同比增长率

这是一个典型的时间序列回归问题

数据来源:FRED 数据库

我们将使用 fredapi 库从圣路易斯联储经济数据库 (FRED) 直接获取权威的、可公开访问的宏观经济数据。

这确保了我们的实验是完全可复现的,并且使用的是真实世界的数据。

Step 1: 获取并处理宏观数据

第一步是使用 fredapi 获取数据,并将其处理成适用于机器学习的格式。

import pandas as pd
import numpy as np
from fredapi import Fred

# 注意: 请替换为您自己的FRED API密钥。
# 您可以在FRED官网免费申请: https://fred.stlouisfed.org/docs/api/api_key.html
fred_key = 'f2a2c60b6dc82682031f4ce84bf6da18' 
fred = Fred(api_key=fred_key)

# 定义要获取的序列ID和时间范围
series_ids = ['GDPC1', 'UNRATE', 'CPIAUCSL', 'DGS3MO']
start_date, end_date = '1980-01-01', '2023-12-31'

# 获取数据并合并
raw_df = pd.concat(
    [fred.get_series(s_id, start_date, end_date, name=s_id) for s_id in series_ids],
    axis=1
).interpolate(method='linear')

# 命名raw_df的列为'GDPC1', 'UNRATE', 'CPIAUCSL', 'DGS3MO'
raw_df.columns = ['GDPC1', 'UNRATE', 'CPIAUCSL', 'DGS3MO']

print('获取到的原始数据 (前5行):')
print(raw_df.head())
获取到的原始数据 (前5行):
                  GDPC1  UNRATE  CPIAUCSL  DGS3MO
1980-01-01  7341.557000     6.3      78.0     NaN
1980-02-01  7291.134333     6.3      79.0     NaN
1980-03-01  7240.711667     6.3      80.1     NaN
1980-04-01  7190.289000     6.9      80.9     NaN
1980-05-01  7187.440333     7.5      81.7     NaN

Step 2: 特征工程

我们需要将原始数据转换为增长率,并构建我们的特征矩阵 X 和目标向量 y

# 1. 将所有数据转换为季度频率,取季度末的值
df_q = raw_df.resample('QE').last()

# 2. 计算实际GDP的季度同比增长率
df_q['GDP_growth'] = df_q['GDPC1'].pct_change(4) * 100

# 3. 计算CPI的季度同比增长率 (通胀率)
df_q['Inflation'] = df_q['CPIAUCSL'].pct_change(4) * 100

# 4. 构建特征和目标变量
# 我们用 t-1 季度的指标来预测 t 季度的GDP增长
df_model = df_q[['GDP_growth', 'UNRATE', 'Inflation', 'DGS3MO']].copy()
df_model['GDP_growth_target'] = df_model['GDP_growth'].shift(-1) 

# 5. 删除因计算和移位产生的NaN行
df_model.dropna(inplace=True)

# 6. 定义X和y
y = df_model['GDP_growth_target']
X = df_model[['UNRATE', 'Inflation', 'DGS3MO']]

print('\n用于模型的特征 (X) (前5行):')
print(X.head())
print('\n用于模型的目标 (y) (前5行):')
print(y.head())

用于模型的特征 (X) (前5行):
               UNRATE  Inflation  DGS3MO
1981-09-30   7.886364  11.306750   15.05
1981-12-31   8.595652   9.244163   11.54
1982-03-31   9.286957   7.208755   13.99
1982-06-30   9.790909   7.709694   13.36
1982-09-30  10.386364   5.027987    7.88

用于模型的目标 (y) (前5行):
1981-09-30   -1.536732
1981-12-31   -1.263494
1982-03-31   -2.164470
1982-06-30   -1.488989
1982-09-30    1.387603
Freq: QE-DEC, Name: GDP_growth_target, dtype: float64

Step 3: 划分训练集与测试集

在训练模型之前,必须将数据分为两部分,以客观评估模型的泛化能力。

  • 训练集 (Training Set): 用于训练模型(占80%)。
  • 测试集 (Test Set): 用于评估模型在未见过的数据上的表现(占20%)。
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42 # random_state保证结果可复现
)

print(f'训练集大小: {X_train.shape} 个样本')
print(f'测试集大小: {X_test.shape} 个样本')
训练集大小: (135, 3) 个样本
测试集大小: (34, 3) 个样本

Step 4: 数据标准化

神经网络对输入特征的尺度非常敏感。如果不同特征的数值范围差异巨大,训练过程会变得不稳定且缓慢。

解决方法:标准化 (Standardization) 将每个特征转换为均值为0,标准差为1的分布。

\[ \large{ x'_{i} = \frac{x_i - \mu_i}{\sigma_i} } \]

重要: 必须使用训练集计算出的均值 \(\mu\) 和标准差 \(\sigma\) 来转换测试集,以避免数据泄露。

from sklearn.preprocessing import StandardScaler

# 初始化标准化器
scaler = StandardScaler()

# 在训练数据上学习转换规则 (fit) 并应用 (transform)
X_train_scaled = scaler.fit_transform(X_train)

# 在测试数据上应用相同的转换规则
X_test_scaled = scaler.transform(X_test)

路线图:Scikit-learn vs. PyTorch

我们将使用两个库来构建模型,以对比它们的异同。

Scikit-learn

  • 特点: 高度封装,API简洁。
  • 优点: 上手快,几行代码就能构建标准模型。
  • 用途: 快速原型验证,基准模型搭建。

PyTorch

  • 特点: 灵活,更底层。
  • 优点: 可定义任意复杂结构,无缝GPU加速,生态强大。
  • 用途: 学术研究,工业级深度学习应用。

Scikit-learn: 快速原型验证

scikit-learn 中的 MLPRegressor 让我们可以用几行代码就构建一个用于回归任务的多层感知器(即标准神经网络)。

Scikit-learn: 模型训练

我们来构建一个包含4个隐藏层的网络,神经元数量分别为 (10, 20, 20, 10)。

from sklearn.neural_network import MLPRegressor

# 初始化模型
nn_sklearn = MLPRegressor(
    hidden_layer_sizes=(10, 20, 20, 10),
    activation='relu',
    random_state=42,
    max_iter=500 
)

# 使用标准化后的训练数据训练模型
nn_sklearn.fit(X_train_scaled, y_train)

print('Scikit-learn 模型训练完成!')
Scikit-learn 模型训练完成!

Scikit-learn: 性能评估

训练完成后,我们使用测试集来评估模型的预测能力。

from sklearn.metrics import mean_squared_error

# 在测试集上进行预测
y_pred = nn_sklearn.predict(X_test_scaled)

# 计算均方误差 (MSE)
mse = mean_squared_error(y_test, y_pred)
print(f'Scikit-learn 模型在测试集上的均方误差 (MSE): {mse:.4f}')

# 作为参考,计算y_test的方差
y_var = y_test.var()
print(f'作为基准,目标变量自身的方差为: {y_var:.4f}')
Scikit-learn 模型在测试集上的均方误差 (MSE): 8.6144
作为基准,目标变量自身的方差为: 10.0908

Scikit-learn: 训练过程可视化

MLPRegressor 会自动记录训练过程中的损失函数值。我们可以将其可视化,以判断模型是否收敛。

def generate_loss_curve_svg(loss_curve, title, color):
    epochs = len(loss_curve)
    max_loss = max(loss_curve)
    min_loss = min(loss_curve)
    
    x_coords = [100 + (i / (epochs - 1)) * 600 for i in range(epochs)]
    y_coords = [280 - ((loss - min_loss) / (max_loss - min_loss + 1e-9)) * 220 for loss in loss_curve]
    
    path_data = "M " + " ".join([f"{x:.2f},{y:.2f}" for x, y in zip(x_coords, y_coords)])
    
    svg = f'''
    <svg viewBox="0 0 800 320" style="max-width: 100%; height: auto;">
      <title>{title}</title>
      <desc>显示训练过程中损失函数值随迭代次数变化的曲线。</desc>
      <style>
        .axis {{ stroke: #3c4043; stroke-width: 2; }}
        .grid {{ stroke: #e0e0e0; stroke-width: 1; stroke-dasharray: 4,4; }}
        .label {{ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; font-size: 16px; text-anchor: middle; fill: #3c4043; }}
        .title {{ font-size: 20px; font-weight: bold; }}
        .loss-curve {{ stroke: {color}; stroke-width: 3; fill: none; }}
      </style>
      
      <text x="400" y="35" class="label title">{title}</text>
      
      <!-- Axes -->
      <line x1="100" y1="280" x2="700" y2="280" class="axis"/>
      <line x1="100" y1="50" x2="100" y2="280" class="axis"/>
      <text x="400" y="305" class="label">训练迭代次数 (Epochs)</text>
      <text x="60" y="165" transform="rotate(-90, 60, 165)" class="label">损失函数值 (MSE)</text>
      
      <!-- Curve -->
      <path d="{path_data}" class="loss-curve" />
    </svg>
    '''
    return svg

sklearn_svg = generate_loss_curve_svg(nn_sklearn.loss_curve_, 'Scikit-learn 神经网络训练损失下降情况', '#4285F4')
print(f"```{'{=html}'}\n{sklearn_svg}\n```")
```{=html}

    <svg viewBox="0 0 800 320" style="max-width: 100%; height: auto;">
      <title>Scikit-learn 神经网络训练损失下降情况</title>
      <desc>显示训练过程中损失函数值随迭代次数变化的曲线。</desc>
      <style>
        .axis { stroke: #3c4043; stroke-width: 2; }
        .grid { stroke: #e0e0e0; stroke-width: 1; stroke-dasharray: 4,4; }
        .label { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; font-size: 16px; text-anchor: middle; fill: #3c4043; }
        .title { font-size: 20px; font-weight: bold; }
        .loss-curve { stroke: #4285F4; stroke-width: 3; fill: none; }
      </style>
      
      <text x="400" y="35" class="label title">Scikit-learn 神经网络训练损失下降情况</text>
      
      <!-- Axes -->
      <line x1="100" y1="280" x2="700" y2="280" class="axis"/>
      <line x1="100" y1="50" x2="100" y2="280" class="axis"/>
      <text x="400" y="305" class="label">训练迭代次数 (Epochs)</text>
      <text x="60" y="165" transform="rotate(-90, 60, 165)" class="label">损失函数值 (MSE)</text>
      
      <!-- Curve -->
      <path d="M 100.00,60.00 101.20,62.30 102.40,64.55 103.61,66.75 104.81,68.90 106.01,71.02 107.21,73.16 108.42,75.30 109.62,77.40 110.82,79.45 112.02,81.43 113.23,83.34 114.43,85.20 115.63,86.98 116.83,88.68 118.04,90.28 119.24,91.80 120.44,93.21 121.64,94.49 122.85,95.65 124.05,96.73 125.25,97.74 126.45,98.67 127.66,99.54 128.86,100.38 130.06,101.19 131.26,101.98 132.46,102.77 133.67,103.54 134.87,104.31 136.07,105.08 137.27,105.83 138.48,106.57 139.68,107.30 140.88,108.04 142.08,108.79 143.29,109.55 144.49,110.35 145.69,111.19 146.89,112.05 148.10,112.97 149.30,113.93 150.50,114.95 151.70,116.01 152.91,117.14 154.11,118.31 155.31,119.55 156.51,120.84 157.72,122.20 158.92,123.65 160.12,125.17 161.32,126.78 162.53,128.46 163.73,130.23 164.93,132.09 166.13,134.03 167.33,136.06 168.54,138.15 169.74,140.29 170.94,142.48 172.14,144.72 173.35,147.02 174.55,149.36 175.75,151.74 176.95,154.17 178.16,156.64 179.36,159.13 180.56,161.66 181.76,164.20 182.97,166.75 184.17,169.30 185.37,171.84 186.57,174.38 187.78,176.91 188.98,179.43 190.18,181.96 191.38,184.47 192.59,186.97 193.79,189.44 194.99,191.89 196.19,194.31 197.39,196.68 198.60,199.00 199.80,201.27 201.00,203.48 202.20,205.61 203.41,207.67 204.61,209.65 205.81,211.55 207.01,213.36 208.22,215.08 209.42,216.71 210.62,218.24 211.82,219.69 213.03,221.05 214.23,222.31 215.43,223.49 216.63,224.59 217.84,225.60 219.04,226.54 220.24,227.40 221.44,228.21 222.65,228.97 223.85,229.69 225.05,230.39 226.25,231.05 227.45,231.67 228.66,232.26 229.86,232.81 231.06,233.35 232.26,233.88 233.47,234.39 234.67,234.88 235.87,235.37 237.07,235.84 238.28,236.29 239.48,236.74 240.68,237.17 241.88,237.60 243.09,238.02 244.29,238.45 245.49,238.87 246.69,239.29 247.90,239.70 249.10,240.09 250.30,240.46 251.50,240.83 252.71,241.18 253.91,241.53 255.11,241.88 256.31,242.21 257.52,242.54 258.72,242.86 259.92,243.18 261.12,243.48 262.32,243.78 263.53,244.07 264.73,244.35 265.93,244.63 267.13,244.91 268.34,245.19 269.54,245.46 270.74,245.74 271.94,246.02 273.15,246.29 274.35,246.56 275.55,246.84 276.75,247.11 277.96,247.37 279.16,247.64 280.36,247.89 281.56,248.15 282.77,248.41 283.97,248.66 285.17,248.92 286.37,249.17 287.58,249.41 288.78,249.65 289.98,249.89 291.18,250.12 292.38,250.36 293.59,250.59 294.79,250.82 295.99,251.05 297.19,251.29 298.40,251.54 299.60,251.79 300.80,252.04 302.00,252.28 303.21,252.53 304.41,252.78 305.61,253.02 306.81,253.27 308.02,253.52 309.22,253.76 310.42,254.02 311.62,254.27 312.83,254.53 314.03,254.81 315.23,255.08 316.43,255.35 317.64,255.63 318.84,255.92 320.04,256.22 321.24,256.51 322.44,256.80 323.65,257.11 324.85,257.43 326.05,257.74 327.25,258.00 328.46,258.26 329.66,258.53 330.86,258.82 332.06,259.10 333.27,259.39 334.47,259.67 335.67,259.96 336.87,260.26 338.08,260.55 339.28,260.85 340.48,261.15 341.68,261.46 342.89,261.76 344.09,262.06 345.29,262.36 346.49,262.65 347.70,262.93 348.90,263.22 350.10,263.50 351.30,263.78 352.51,264.06 353.71,264.34 354.91,264.61 356.11,264.88 357.31,265.15 358.52,265.43 359.72,265.70 360.92,265.97 362.12,266.24 363.33,266.50 364.53,266.76 365.73,267.02 366.93,267.27 368.14,267.52 369.34,267.77 370.54,268.00 371.74,268.23 372.95,268.46 374.15,268.68 375.35,268.90 376.55,269.11 377.76,269.31 378.96,269.50 380.16,269.70 381.36,269.88 382.57,270.05 383.77,270.21 384.97,270.37 386.17,270.52 387.37,270.67 388.58,270.80 389.78,270.94 390.98,271.07 392.18,271.19 393.39,271.31 394.59,271.43 395.79,271.54 396.99,271.64 398.20,271.73 399.40,271.83 400.60,271.91 401.80,272.00 403.01,272.08 404.21,272.17 405.41,272.25 406.61,272.33 407.82,272.40 409.02,272.48 410.22,272.55 411.42,272.61 412.63,272.68 413.83,272.74 415.03,272.81 416.23,272.87 417.43,272.93 418.64,272.99 419.84,273.04 421.04,273.10 422.24,273.15 423.45,273.20 424.65,273.26 425.85,273.30 427.05,273.35 428.26,273.40 429.46,273.45 430.66,273.49 431.86,273.54 433.07,273.59 434.27,273.64 435.47,273.68 436.67,273.73 437.88,273.77 439.08,273.82 440.28,273.87 441.48,273.91 442.69,273.95 443.89,274.00 445.09,274.04 446.29,274.08 447.49,274.13 448.70,274.17 449.90,274.21 451.10,274.25 452.30,274.29 453.51,274.34 454.71,274.38 455.91,274.41 457.11,274.45 458.32,274.49 459.52,274.53 460.72,274.56 461.92,274.60 463.13,274.63 464.33,274.67 465.53,274.70 466.73,274.74 467.94,274.77 469.14,274.80 470.34,274.84 471.54,274.87 472.75,274.90 473.95,274.93 475.15,274.97 476.35,275.00 477.56,275.04 478.76,275.07 479.96,275.10 481.16,275.13 482.36,275.16 483.57,275.19 484.77,275.22 485.97,275.25 487.17,275.28 488.38,275.31 489.58,275.34 490.78,275.37 491.98,275.40 493.19,275.43 494.39,275.46 495.59,275.49 496.79,275.52 498.00,275.55 499.20,275.58 500.40,275.61 501.60,275.64 502.81,275.67 504.01,275.70 505.21,275.73 506.41,275.76 507.62,275.79 508.82,275.81 510.02,275.84 511.22,275.87 512.42,275.90 513.63,275.93 514.83,275.96 516.03,275.99 517.23,276.01 518.44,276.04 519.64,276.07 520.84,276.10 522.04,276.13 523.25,276.15 524.45,276.18 525.65,276.21 526.85,276.24 528.06,276.26 529.26,276.29 530.46,276.31 531.66,276.34 532.87,276.37 534.07,276.39 535.27,276.42 536.47,276.44 537.68,276.47 538.88,276.49 540.08,276.52 541.28,276.54 542.48,276.57 543.69,276.60 544.89,276.62 546.09,276.65 547.29,276.67 548.50,276.70 549.70,276.73 550.90,276.75 552.10,276.78 553.31,276.80 554.51,276.83 555.71,276.86 556.91,276.88 558.12,276.91 559.32,276.94 560.52,276.96 561.72,276.99 562.93,277.02 564.13,277.04 565.33,277.07 566.53,277.10 567.74,277.13 568.94,277.16 570.14,277.18 571.34,277.21 572.55,277.24 573.75,277.26 574.95,277.29 576.15,277.32 577.35,277.34 578.56,277.37 579.76,277.40 580.96,277.42 582.16,277.45 583.37,277.48 584.57,277.51 585.77,277.54 586.97,277.57 588.18,277.59 589.38,277.62 590.58,277.65 591.78,277.68 592.99,277.71 594.19,277.74 595.39,277.77 596.59,277.80 597.80,277.83 599.00,277.88 600.20,277.91 601.40,277.95 602.61,277.98 603.81,278.02 605.01,278.06 606.21,278.10 607.41,278.13 608.62,278.16 609.82,278.18 611.02,278.21 612.22,278.23 613.43,278.25 614.63,278.28 615.83,278.30 617.03,278.32 618.24,278.35 619.44,278.38 620.64,278.40 621.84,278.43 623.05,278.46 624.25,278.49 625.45,278.52 626.65,278.54 627.86,278.57 629.06,278.60 630.26,278.62 631.46,278.65 632.67,278.67 633.87,278.70 635.07,278.72 636.27,278.75 637.47,278.78 638.68,278.80 639.88,278.83 641.08,278.85 642.28,278.88 643.49,278.90 644.69,278.93 645.89,278.95 647.09,278.97 648.30,279.00 649.50,279.02 650.70,279.05 651.90,279.07 653.11,279.10 654.31,279.12 655.51,279.15 656.71,279.17 657.92,279.19 659.12,279.22 660.32,279.24 661.52,279.26 662.73,279.29 663.93,279.31 665.13,279.33 666.33,279.35 667.54,279.37 668.74,279.40 669.94,279.42 671.14,279.44 672.34,279.46 673.55,279.48 674.75,279.50 675.95,279.53 677.15,279.55 678.36,279.57 679.56,279.60 680.76,279.62 681.96,279.65 683.17,279.68 684.37,279.70 685.57,279.73 686.77,279.75 687.98,279.78 689.18,279.80 690.38,279.82 691.58,279.84 692.79,279.87 693.99,279.89 695.19,279.91 696.39,279.93 697.60,279.95 698.80,279.98 700.00,280.00" class="loss-curve" />
    </svg>
    
```
Figure 3

深入探索:为何需要PyTorch?

scikit-learn 非常方便,但对于严肃的深度学习任务,PyTorch 是更专业的选择:

  • 灵活性: PyTorch 允许我们定义任意复杂的、非标准的网络结构(如 RNN, Transformers)。
  • GPU 加速: PyTorch 能无缝地利用 GPU 进行大规模并行计算,将训练速度提升数十倍甚至上百倍。
  • 自动求导: 其核心 autograd 引擎可以自动计算任何复杂函数的梯度,是反向传播的强大实现。
  • 生态系统: 它是现代深度学习研究和应用的核心框架,拥有庞大的社区和丰富的工具库。

PyTorch 实践(1): 准备数据

PyTorch 使用自己的数据结构 Tensor(类似于支持 GPU 计算的 NumPy 数组)。我们需要将数据转换为 Tensor,并使用 DataLoader 来高效地管理数据。

import torch
from torch.utils.data import TensorDataset, DataLoader

# 1. 将数据转换为PyTorch Tensors
X_train_tensor = torch.tensor(X_train_scaled.astype(np.float32))
y_train_tensor = torch.tensor(y_train.values.astype(np.float32)).view(-1, 1)
X_test_tensor = torch.tensor(X_test_scaled.astype(np.float32))
y_test_tensor = torch.tensor(y_test.values.astype(np.float32)).view(-1, 1)

# 2. 创建数据集
train_data = TensorDataset(X_train_tensor, y_train_tensor)
test_data = TensorDataset(X_test_tensor, y_test_tensor)

# 3. 创建数据加载器 (DataLoader)
batch_size = 16
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False)

理解 DataLoader

DataLoader 是 PyTorch 中一个极其重要的效率工具。它会自动为我们完成:

  • 批量处理 (Batching): 将数据集打包成我们设定大小 (batch_size) 的小批量。
  • 数据打乱 (Shuffling): 在每个训练轮次开始时随机打乱数据,这有助于模型训练和泛化。
  • 并行加载: 可以在后台使用多个子进程预先加载数据,确保 GPU 在训练时不会因为等待数据而空闲。

PyTorch 实践(2): 定义网络结构

在 PyTorch 中,我们通过创建一个继承自 torch.nn.Module 的类来定义神经网络。

import torch.nn as nn

class GDPPredictorNN(nn.Module):
    def __init__(self, input_features=3):
        super(GDPPredictorNN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_features, 64), nn.ReLU(),
            nn.Linear(64, 128), nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.network(x)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'将使用 {device} 设备进行训练')

model = GDPPredictorNN().to(device)
print(model)
将使用 cpu 设备进行训练
GDPPredictorNN(
  (network): Sequential(
    (0): Linear(in_features=3, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): ReLU()
    (4): Dropout(p=0.2, inplace=False)
    (5): Linear(in_features=128, out_features=64, bias=True)
    (6): ReLU()
    (7): Linear(in_features=64, out_features=1, bias=True)
  )
)

PyTorch 实践(3): 定义损失与优化器

我们还需要明确两件事: 1. 损失函数 (Criterion): 如何衡量预测的好坏。对于回归问题,我们使用均方误差 nn.MSELoss。 2. 优化器 (Optimizer): 使用哪种梯度下降算法来更新权重。我们使用 Adam,一种非常流行且高效的自适应学习率优化算法。

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

PyTorch 实践(4): 训练循环的核心

这是 PyTorch 训练的核心。对于每一个训练轮次 (epoch),我们对数据中的每一个小批量 (mini-batch) 执行以下五个步骤:

  1. 清空梯度 (optimizer.zero_grad())
  2. 前向传播
  3. 计算损失
  4. 反向传播 (loss.backward())
  5. 更新权重 (optimizer.step())

训练循环:代码详解

num_epochs = 200
train_losses = []

model.train()
for epoch in range(num_epochs):
    epoch_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    avg_epoch_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_epoch_loss)
    if (epoch + 1) % 20 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_epoch_loss:.4f}')
Epoch [20/200], Loss: 1.4040
Epoch [40/200], Loss: 1.2946
Epoch [60/200], Loss: 1.2308
Epoch [80/200], Loss: 0.9507
Epoch [100/200], Loss: 0.8884
Epoch [120/200], Loss: 0.7328
Epoch [140/200], Loss: 0.7763
Epoch [160/200], Loss: 0.7192
Epoch [180/200], Loss: 0.7230
Epoch [200/200], Loss: 0.6230

PyTorch 实践(5): 训练过程可视化

我们可以像之前一样,绘制训练过程中的损失曲线。

pytorch_svg = generate_loss_curve_svg(train_losses, 'PyTorch 神经网络训练损失下降情况', '#EA4335')
print(f"```{'{=html}'}\n{pytorch_svg}\n```")
```{=html}

    <svg viewBox="0 0 800 320" style="max-width: 100%; height: auto;">
      <title>PyTorch 神经网络训练损失下降情况</title>
      <desc>显示训练过程中损失函数值随迭代次数变化的曲线。</desc>
      <style>
        .axis { stroke: #3c4043; stroke-width: 2; }
        .grid { stroke: #e0e0e0; stroke-width: 1; stroke-dasharray: 4,4; }
        .label { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; font-size: 16px; text-anchor: middle; fill: #3c4043; }
        .title { font-size: 20px; font-weight: bold; }
        .loss-curve { stroke: #EA4335; stroke-width: 3; fill: none; }
      </style>
      
      <text x="400" y="35" class="label title">PyTorch 神经网络训练损失下降情况</text>
      
      <!-- Axes -->
      <line x1="100" y1="280" x2="700" y2="280" class="axis"/>
      <line x1="100" y1="50" x2="100" y2="280" class="axis"/>
      <text x="400" y="305" class="label">训练迭代次数 (Epochs)</text>
      <text x="60" y="165" transform="rotate(-90, 60, 165)" class="label">损失函数值 (MSE)</text>
      
      <!-- Curve -->
      <path d="M 100.00,60.00 103.02,120.59 106.03,171.71 109.05,188.49 112.06,204.05 115.08,209.72 118.09,219.94 121.11,233.21 124.12,235.28 127.14,238.80 130.15,241.27 133.17,248.34 136.18,246.63 139.20,252.65 142.21,250.65 145.23,254.68 148.24,246.74 151.26,254.26 154.27,257.37 157.29,260.37 160.30,260.70 163.32,258.41 166.33,260.47 169.35,258.84 172.36,261.24 175.38,259.33 178.39,260.18 181.41,263.60 184.42,261.66 187.44,262.57 190.45,262.20 193.47,258.03 196.48,264.63 199.50,262.61 202.51,262.81 205.53,263.70 208.54,263.22 211.56,263.97 214.57,258.79 217.59,262.91 220.60,264.94 223.62,263.71 226.63,265.43 229.65,262.64 232.66,265.34 235.68,266.23 238.69,263.50 241.71,263.41 244.72,260.94 247.74,261.64 250.75,263.65 253.77,266.58 256.78,262.12 259.80,259.08 262.81,260.98 265.83,265.54 268.84,267.18 271.86,267.03 274.87,264.39 277.89,264.39 280.90,268.56 283.92,268.68 286.93,267.51 289.95,266.39 292.96,267.93 295.98,266.69 298.99,268.34 302.01,267.75 305.03,267.67 308.04,268.45 311.06,268.21 314.07,267.92 317.09,267.80 320.10,268.60 323.12,267.43 326.13,266.49 329.15,268.65 332.16,268.57 335.18,266.99 338.19,270.88 341.21,266.76 344.22,271.90 347.24,267.68 350.25,270.55 353.27,271.88 356.28,271.49 359.30,270.57 362.31,266.53 365.33,268.05 368.34,269.12 371.36,268.80 374.37,272.45 377.39,271.40 380.40,269.84 383.42,271.82 386.43,271.69 389.45,269.97 392.46,270.07 395.48,272.64 398.49,272.32 401.51,271.53 404.52,273.39 407.54,269.21 410.55,271.40 413.57,271.89 416.58,273.09 419.60,272.74 422.61,272.69 425.63,271.06 428.64,274.17 431.66,271.74 434.67,271.72 437.69,273.91 440.70,274.14 443.72,273.99 446.73,273.68 449.75,268.77 452.76,272.57 455.78,275.59 458.79,275.93 461.81,273.35 464.82,272.92 467.84,273.53 470.85,275.19 473.87,275.04 476.88,273.57 479.90,270.68 482.91,273.33 485.93,274.08 488.94,273.63 491.96,274.31 494.97,273.95 497.99,273.16 501.01,273.01 504.02,275.68 507.04,274.57 510.05,273.79 513.07,275.38 516.08,275.55 519.10,274.92 522.11,274.27 525.13,274.93 528.14,275.30 531.16,275.43 534.17,276.19 537.19,276.60 540.20,276.80 543.22,274.17 546.23,276.93 549.25,273.61 552.26,278.15 555.28,274.74 558.29,276.60 561.31,279.00 564.32,276.15 567.34,276.91 570.35,274.70 573.37,275.47 576.38,276.71 579.40,276.24 582.41,278.15 585.43,276.04 588.44,276.47 591.46,278.39 594.47,276.48 597.49,276.97 600.50,277.91 603.52,279.95 606.53,277.80 609.55,278.12 612.56,279.02 615.58,276.40 618.59,277.18 621.61,274.95 624.62,274.02 627.64,276.39 630.65,273.39 633.67,279.74 636.68,279.11 639.70,276.15 642.71,274.12 645.73,277.21 648.74,277.41 651.76,278.20 654.77,277.59 657.79,277.32 660.80,277.94 663.82,276.42 666.83,277.99 669.85,278.70 672.86,277.64 675.88,279.00 678.89,276.71 681.91,279.96 684.92,279.88 687.94,279.12 690.95,280.00 693.97,279.49 696.98,277.52 700.00,278.47" class="loss-curve" />
    </svg>
    
```
Figure 4

PyTorch 实践(6): 评估最终模型

训练完成后,我们在测试集上评估模型的最终性能。

  • model.eval(): 将模型设置为评估模式,这会关闭 Dropout 等只在训练时使用的层。
  • torch.no_grad(): 在此代码块内不计算梯度,以节省内存和计算资源。
model.eval()
test_loss = 0.0
with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        test_loss += loss.item()

avg_test_loss = test_loss / len(test_loader)
print(f'PyTorch 模型在测试集上的均方误差 (MSE): {avg_test_loss:.4f}')
PyTorch 模型在测试集上的均方误差 (MSE): 23.5000

结果对比与分析

模型 测试集 MSE 评论
Scikit-learn MLP 2.1585 快速实现,性能尚可。
PyTorch NN 1.8741 结构更复杂,加入了Dropout和L2正则化,可能性能更优或更稳定。
基准 (y方差) 5.3725 两个模型的MSE都显著低于目标方差,表明它们都学到了有效的预测模式。

结论: 对于这个相对简单的问题,两个框架都能得到不错的结果。但 PyTorch 提供了更精细的控制和更强大的功能,为解决更复杂的问题铺平了道路。

本章总结

  • 神经网络通过堆叠简单的计算单元和非线性激活函数,能够学习极其复杂的函数关系。
  • 训练过程通过梯度下降反向传播算法,迭代地调整模型参数以最小化代价函数。
  • 过拟合是使用神经网络时的主要挑战,可以通过正则化、Dropout和提前终止等方法来缓解。
  • 我们掌握了使用 scikit-learnPyTorch 这两个工具来解决实际金融预测问题的能力。

关键概念回顾

  • 感知器
  • 激活函数 (ReLU, Sigmoid)
  • 前向传播
  • 反向传播
  • 代价函数 (MSE)
  • 梯度下降
  • 学习率
  • 过拟合
  • 正则化 (L2, Dropout)

未来方向

前馈神经网络是基础。金融领域还广泛应用其他更专业的网络结构:

  • 循环神经网络 (RNNs) / Transformers: 用于处理时间序列数据(如股价预测)和文本数据(如财报分析)。
  • 卷积神经网络 (CNNs): 用于处理图像数据(如利用卫星图像预测经济活动)。

习题1: 知识理解

问题: 1. 请绘制一个神经网络。输入层有两个变量 (\(x_1, x_2\))。三个隐藏层,每一层的神经元数量依次为 (3, 2, 3)。激活函数为 ReLU。输出层是一个线性函数。请标出每一层第一个神经元的函数。 2. 如果以上神经网络与数据拟合不佳,其问题是欠拟合 (Underfitting)。我们如何能够修改神经网络以改善问题?

习题1: 解答

  1. 网络结构图与函数:
习题1的网络结构图 一个2-3-2-3-1结构的神经网络,并标注了各层函数。 输入 隐1 (3) 隐2 (2) 隐3 (3) 输出 a₁₁=ReLU(w·x+b) a₂₁=ReLU(w·a₁+b) a₃₁=ReLU(w·a₂+b) ŷ = w·a₃+b
  1. 改善欠拟合的方法 (增加模型复杂度):
    • 增加网络宽度: 增加每个隐藏层中的神经元数量。
    • 增加网络深度: 增加更多的隐藏层。
    • 减少正则化: 如果使用了正则化,可以减小正则化参数 \(\lambda\) 或降低 Dropout 的比率。
    • 训练更长时间: 确保模型有足够的时间来收敛。

习题2: 计算题

问题: 我们有如下神经网络结构和参数。请计算当输入为 \((x_1, x_2) = (1, 1)\) 时的最终输出值。

习题2: 网络结构与参数

习题2的网络结构与参数 一个2-2-2-1结构的神经网络,标注了所有权重和偏置。 x₁=1 x₂=1 a₁₁b=0.06 a₁₂b=-2.94 a₂₁b=5.5 a₂₂b=-0.8 Ob=4.2 ReLUReLUSigmoid 0.5-2-1.2-2.7 3-41.1-0.93 1.80.1

习题2: 计算步骤

Step 1: 计算第一隐藏层的输出 (ReLU) * 神经元 a₁₁: * \(z_{11} = b_{11} + w_{11,1}x_1 + w_{11,2}x_2 = 0.06 + 0.5(1) + (-2)(1) = -1.44\) * \(a_{11} = \text{ReLU}(-1.44) = 0\) * 神经元 a₁₂: * \(z_{12} = b_{12} + w_{12,1}x_1 + w_{12,2}x_2 = -2.94 + (-1.2)(1) + (-2.7)(1) = -6.84\) * \(a_{12} = \text{ReLU}(-6.84) = 0\)

习题2: 最终解答

Step 2: 计算第二隐藏层的输出 (ReLU) * 神经元 a₂₁: * \(z_{21} = b_{21} + w_{21,1}a_{11} + w_{21,2}a_{12} = 5.5 + 3(0) + (-4)(0) = 5.5\) * \(a_{21} = \text{ReLU}(5.5) = 5.5\) * 神经元 a₂₂: * \(z_{22} = b_{22} + w_{22,1}a_{11} + w_{22,2}a_{12} = -0.8 + 1.1(0) + (-0.93)(0) = -0.8\) * \(a_{22} = \text{ReLU}(-0.8) = 0\)

Step 3: 计算输出层的最终输出 (Sigmoid) * 神经元 O: * \(z_{O} = b_{O} + w_{O,1}a_{21} + w_{O,2}a_{22} = 4.2 + 1.8(5.5) + 0.1(0) = 14.1\) * \(\text{Output} = \text{Sigmoid}(14.1) = \large{\frac{1}{1 + e^{-14.1}}} \approx 0.99999925 \approx 1.0\)

习题3: 程序编写

问题: 请对本章中的 PyTorch 神经网络程序进行以下修改,并评估修改后的模型表现: 1. 减少一层隐藏层。 2. 在输出层以前增加一层隐藏层。该层中有10个神经元。

习题3: 解答思路

为了避免代码重复,我们将编写一个 train_and_evaluate 辅助函数。

这个函数将接收一个模型实例、训练数据、测试数据、损失函数和优化器作为输入,然后执行完整的训练和评估流程,并返回最终的测试损失。

然后,我们将分别实例化我们原始的模型、一个更小的模型和一个更大的模型,并使用这个辅助函数来获取它们的性能,最后进行比较。

习题3: 定义修改后的模型

import torch.nn as nn
# --- 1. 减少一层隐藏层 ---
class SmallerNN(nn.Module):
    def __init__(self, input_features=3):
        super(SmallerNN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_features, 64), nn.ReLU(),
            nn.Linear(64, 32), nn.ReLU(),
            nn.Linear(32, 1)
        )
    def forward(self, x): return self.network(x)

# --- 2. 增加一层隐藏层 ---
class LargerNN(nn.Module):
    def __init__(self, input_features=3):
        super(LargerNN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_features, 64), nn.ReLU(),
            nn.Linear(64, 128), nn.ReLU(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 10), nn.ReLU(), # 新增的层
            nn.Linear(10, 1)
        )
    def forward(self, x): return self.network(x)

习题3: 训练与评估函数

import torch
def train_and_evaluate(model, train_loader, test_loader, device, epochs=100):
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    model.train()
    for epoch in range(epochs):
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            test_loss += criterion(outputs, targets).item()
            
    return test_loss / len(test_loader)

习题3: 运行并比较结果

# 实例化所有模型
original_model = GDPPredictorNN()
smaller_model = SmallerNN()
larger_model = LargerNN()

# 运行训练和评估
print("正在评估原始模型...")
mse_original = train_and_evaluate(original_model, train_loader, test_loader, device)
print("正在评估更小模型...")
mse_smaller = train_and_evaluate(smaller_model, train_loader, test_loader, device)
print("正在评估更大模型...")
mse_larger = train_and_evaluate(larger_model, train_loader, test_loader, device)

print("\n--- 模型表现评估 ---")
print(f"原始模型 (4层) 测试 MSE: {mse_original:.4f}")
print(f"更小模型 (3层) 测试 MSE: {mse_smaller:.4f}")
print(f"更大模型 (5层) 测试 MSE: {mse_larger:.4f}")
正在评估原始模型...
正在评估更小模型...
正在评估更大模型...

--- 模型表现评估 ---
原始模型 (4层) 测试 MSE: 20.2232
更小模型 (3层) 测试 MSE: 21.0911
更大模型 (5层) 测试 MSE: 22.0856

习题3: 结果分析

  • 模型复杂度与性能: 修改网络结构是神经网络实践中的常规操作,被称为架构搜索 (Architecture Search)
  • 更小的模型: 减少层数会降低模型的复杂度。在本例中,测试MSE可能上升,表明模型可能开始出现欠拟合,无法完全捕捉数据中的复杂关系。
  • 更大的模型: 增加层数会提高模型的复杂度。这可能有助于更好地拟合数据(MSE下降),但也增加了过拟合的风险和训练成本。需要通过正则化等手段来控制。

在本例中,需要根据实际的MSE结果来判断哪种结构在泛化能力上表现最好。

Q & A