57 长短期记忆网络(LSTM)_by《李沐:动手学深度学习v2》pytorch版

系列文章目录


文章目录


长短期记忆网络(LSTM)

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。
解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM)它有许多与门控循环单元(GRU)一样的属性。有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些,却比门控循环单元早诞生了近20年。

门控记忆元

可以说,长短期记忆网络的设计灵感来自于计算机的逻辑门。
长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)。
有些文献认为记忆元是隐状态的一种特殊类型,它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。
为了控制记忆元,我们需要许多门。
其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。
另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。
我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理,这种设计的动机与门控循环单元相同,能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。让我们看看这在实践中是如何运作的。

输入门、忘记门和输出门

就如在门控循环单元中一样,当前时间步的输入和前一个时间步的隐状态作为数据送入长短期记忆网络的门中,如下图所示。它们由三个具有sigmoid激活函数的全连接层处理,以计算输入门、遗忘门和输出门的值。
因此,这三个门的值都在 ( 0 , 1 ) (0, 1) (0,1)的范围内。

在这里插入图片描述label:lstm_0

我们来细化一下长短期记忆网络的数学表达。
假设有 h h h个隐藏单元,批量大小为 n n n,输入数为 d d d
因此,输入为 X t ∈ R n × d \mathbf{X}_t \in \mathbb{R}^{n \times d} XtRn×d,前一时间步的隐状态为 H t − 1 ∈ R n × h \mathbf{H}_{t-1} \in \mathbb{R}^{n \times h} Ht1Rn×h
相应地,时间步 t t t的门被定义如下:
输入门是 I t ∈ R n × h \mathbf{I}_t \in \mathbb{R}^{n \times h} ItRn×h
遗忘门是 F t ∈ R n × h \mathbf{F}_t \in \mathbb{R}^{n \times h} FtRn×h
输出门是 O t ∈ R n × h \mathbf{O}_t \in \mathbb{R}^{n \times h} OtRn×h
它们的计算方法如下:

I t = σ ( X t W x i + H t − 1 W h i + b i ) , F t = σ ( X t W x f + H t − 1 W h f + b f ) , O t = σ ( X t W x o + H t − 1 W h o + b o ) , \begin{aligned} \mathbf{I}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\ \mathbf{F}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\ \mathbf{O}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o), \end{aligned} ItFtOt=σ(XtWxi+Ht1Whi+bi),=σ(XtWxf+Ht1Whf+bf),=σ(XtWxo+Ht1Who+bo),

其中 W x i , W x f , W x o ∈ R d × h \mathbf{W}_{xi}, \mathbf{W}_{xf}, \mathbf{W}_{xo} \in \mathbb{R}^{d \times h} Wxi,Wxf,WxoRd×h W h i , W h f , W h o ∈ R h × h \mathbf{W}_{hi}, \mathbf{W}_{hf}, \mathbf{W}_{ho} \in \mathbb{R}^{h \times h} Whi,Whf,WhoRh×h是权重参数, b i , b f , b o ∈ R 1 × h \mathbf{b}_i, \mathbf{b}_f, \mathbf{b}_o \in \mathbb{R}^{1 \times h} bi,bf,boR1×h是偏置参数。

候选记忆元 (相当于RNN中计算 H t H_t Ht)

由于还没有指定各种门的操作,所以先介绍候选记忆元(candidate memory cell) C ~ t ∈ R n × h \tilde{\mathbf{C}}_t \in \mathbb{R}^{n \times h} C~tRn×h
它的计算与上面描述的三个门的计算类似,但是使用 tanh ⁡ \tanh tanh函数作为激活函数,函数的值范围为 ( − 1 , 1 ) (-1, 1) (1,1)
下面导出在时间步 t t t处的方程:

C ~ t = tanh ( X t W x c + H t − 1 W h c + b c ) , \tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c), C~t=tanh(XtWxc+Ht1Whc+bc),

其中 W x c ∈ R d × h \mathbf{W}_{xc} \in \mathbb{R}^{d \times h} WxcRd×h W h c ∈ R h × h \mathbf{W}_{hc} \in \mathbb{R}^{h \times h} WhcRh×h是权重参数, b c ∈ R 1 × h \mathbf{b}_c \in \mathbb{R}^{1 \times h} bcR1×h是偏置参数。

候选记忆元的如下图 :numref:lstm_1所示。

在这里插入图片描述label:lstm_1

记忆元

在门控循环单元中,有一种机制来控制输入和遗忘(或跳过)。
类似地,在长短期记忆网络中,也有两个门用于这样的目的:
输入门 I t \mathbf{I}_t It控制采用多少来自 C ~ t \tilde{\mathbf{C}}_t C~t的新数据,而遗忘门 F t \mathbf{F}_t Ft控制保留多少过去的记忆元 C t − 1 ∈ R n × h \mathbf{C}_{t-1} \in \mathbb{R}^{n \times h} Ct1Rn×h的内容。
使用按元素乘法,得出:

C t = F t ⊙ C t − 1 + I t ⊙ C ~ t . \mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t. Ct=FtCt1+ItC~t.

如果遗忘门始终为 1 1 1且输入门始终为 0 0 0,则过去的记忆元 C t − 1 \mathbf{C}_{t-1} Ct1将随时间被保存并传递到当前时间步。
引入这种设计是为了缓解梯度消失问题,并更好地捕获序列中的长距离依赖关系。

这样我们就得到了计算记忆元的流程图,如 :numref:lstm_2

在这里插入图片描述label:lstm_2

隐状态

最后,我们需要定义如何计算隐状态 H t ∈ R n × h \mathbf{H}_t \in \mathbb{R}^{n \times h} HtRn×h,这就是输出门发挥作用的地方。
在长短期记忆网络中,它仅仅是记忆元的 tanh ⁡ \tanh tanh的门控版本。
这就确保了 H t \mathbf{H}_t Ht的值始终在区间 ( − 1 , 1 ) (-1, 1) (1,1)内:

H t = O t ⊙ tanh ⁡ ( C t ) . \mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t). Ht=Ottanh(Ct).

只要输出门接近 1 1 1,我们就能够有效地将所有记忆信息传递给预测部分,而对于输出门接近 0 0 0,我们只保留记忆元内的所有信息,而不需要更新隐状态(相当于重置隐状态)。

下图 :numref:lstm_3提供了数据流的图形化演示。

在这里插入图片描述label:lstm_3
在这里插入图片描述

从零开始实现

现在,我们从零开始实现长短期记忆网络。我们首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化模型参数

接下来,我们需要定义和初始化模型参数。
如前所述,超参数num_hiddens定义隐藏单元的数量。
我们按照标准差 0.01 0.01 0.01的高斯分布初始化权重,并将偏置项设为 0 0 0

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

定义模型

在[初始化函数]中,长短期记忆网络的隐状态需要返回一个额外的记忆元,单元的值为0,形状为(批量大小,隐藏单元数)。因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

[实际模型]的定义与我们前面讨论的一样:
提供三个门和一个额外的记忆元。
请注意,只有隐状态才会传递到输出层,而记忆元 C t \mathbf{C}_t Ct不直接参与输出计算。

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y) #Y的shape是(批量大小,词表长度)只有这里输出了批量大小的预测,之后才能用来计算损失
    return torch.cat(outputs, dim=0), (H, C)

训练和预测

让我们通过实例化RNN从零实现中引入的RNNModelScratch类来训练一个长短期记忆网络,就如我们在GRU中所做的一样。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 14.5, 27965.3 tokens/sec on cuda:0
time traveller te at at at at at at at at at at at at at at at a
traveller te at at at at at at at at at at at at at at at a

<Figure size 350x250 with 1 Axes>

在这里插入图片描述

简洁实现

使用高级API,我们可以直接实例化LSTM模型。
高级API封装了前文介绍的所有配置细节。
这段代码的运行速度要快得多,因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 11.2, 233619.5 tokens/sec on cuda:0
time traveller the the the the the the the the the the the the t
traveller the the the the the the the the the the the the t

<Figure size 350x250 with 1 Axes>

在这里插入图片描述
长短期记忆网络是典型的具有重要状态控制的隐变量自回归模型。
多年来已经提出了其许多变体,例如,多层、残差连接、不同类型的正则化。
然而,由于序列的长距离依赖性,训练长短期记忆网络和其他序列模型(例如门控循环单元)的成本是相当高的。
在后面的内容中,我们将讲述更高级的替代模型,如Transformer。

小结

  • 长短期记忆网络有三种类型的门:输入门、遗忘门和输出门。
  • 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。
  • 长短期记忆网络可以缓解梯度消失和梯度爆炸。

练习

  1. 调整和分析超参数对运行时间、困惑度和输出顺序的影响。
  2. 如何更改模型以生成适当的单词,而不是字符序列?
  3. 在给定隐藏层维度的情况下,比较门控循环单元、长短期记忆网络和常规循环神经网络的计算成本。要特别注意训练和推断成本。
  4. 既然候选记忆元通过使用 tanh ⁡ \tanh tanh函数来确保值范围在 ( − 1 , 1 ) (-1,1) (1,1)之间,那么为什么隐状态需要再次使用 tanh ⁡ \tanh tanh函数来确保输出值范围在 ( − 1 , 1 ) (-1,1) (1,1)之间呢?
  5. 实现一个能够基于时间序列进行预测而不是基于字符序列进行预测的长短期记忆网络模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/884623.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【d53】【Java】【力扣】24.两两交换链表中的节点

思路 定义一个指针cur, 先指向头节点&#xff0c; 1.判断后一个节点是否为空&#xff0c;不为空则交换值&#xff0c; 2.指针向后走两次 代码 /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode() {}*…

[数据集][目标检测]辣椒缺陷检测数据集VOC+YOLO格式695张5类别

重要说明&#xff1a;数据集图片里面都是一个辣椒&#xff0c;请仔细查看图片预览&#xff0c;确认符合要求下载 数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文…

Nacos 是阿里巴巴开源的一款动态服务发现、配置管理和服务管理平台,旨在帮助开发者更轻松地构建、部署和管理微服务应用。

Nacos 是阿里巴巴开源的一款动态服务发现、配置管理和服务管理平台&#xff0c;旨在帮助开发者更轻松地构建、部署和管理微服务应用。Nacos 提供了一系列的功能来支持服务注册与发现、配置管理、服务元数据管理、流量管理、服务健康检查等&#xff0c;是构建云原生应用和服务网…

SpringCloud 2023各依赖版本选择、核心功能与组件、创建项目(注意事项、依赖)

目录 1. 各依赖版本选择2. 核心功能与组件3. 创建项目3.1 注意事项3.2 依赖 1. 各依赖版本选择 SpringCloud: 2023.0.1SpringBoot: 3.2.4。参考Spring Cloud Train Reference Documentation选择版本 SpringCloud Alibaba: 2023.0.1.0*: 参考Spring Cloud Alibaba选择版本。同时…

【软考】高速缓存的组成

目录 1. 说明2. 组成 1. 说明 1.高速缓存用来存放当前最活跃的程序和数据。2.高速缓存位于CPU 与主存之间。3.容量般在几千字节到几兆字节之间。4.速度一般比主存快 5~10 倍&#xff0c;由快速半导体存储器构成。5.其内容是主存局部域的副本&#xff0c;对程序员来说是透明的。…

Java:选择排序

目录 直接选择排序 堆排序 基本思想&#xff1a; 每一次从待排序的数据元素中选出最小(或最大)的一个元素&#xff0c;存放在序列的起始位置&#xff0c;直到全部待排序的数据元素排完。 直接选择排序 思路1&#xff1a; 在元素集合array[i]--array[n-1]中选择关键码最大(小…

​fl studio21.2.3.4004中文版永久2024最新下载安装图文详细使用教程​

随着数字音乐制作的快速发展&#xff0c;越来越多的音乐制作软件涌现出来&#xff0c;而FL Studio无疑是其中的佼佼者。作为一款功能强大、易于上手的音乐制作软件&#xff0c;FL Studio V21中文版在继承了前代版本优秀基因的基础上&#xff0c;进一步提升了用户体验&#xff0…

什么是原生IP?

代理IP的各个类型称呼有很多&#xff0c;且它们在网络使用和隐私保护方面扮演着不同的角色。今天将探讨什么是原生IP以及原生IP和住宅IP之间的区别&#xff0c;帮助大家更好地理解这两者的概念和实际应用&#xff0c;并选择适合自己的IP类型。 一、什么是原生IP&#xff1f; 原…

FPGA-Vivado-IP核-逻辑分析仪(ILA)

ILA IP核 背景介绍 在用FPGA做工程项目时&#xff0c;当Verilog代码写好&#xff0c;我们需要对代码里面的一些关键信号进行上板验证查看。首先&#xff0c;我们可以把需要查看的这些关键信号引出来&#xff0c;接好线通过示波器进行实时监测&#xff0c;但这会用到大量的线材…

【英特尔IA-32架构软件开发者开发手册第3卷:系统编程指南】2001年版翻译,1-1

文件下载与邀请翻译者 学习英特尔开发手册&#xff0c;最好手里这个手册文件。原版是PDF文件。点击下方链接了解下载方法。 讲解下载英特尔开发手册的文章 翻译英特尔开发手册&#xff0c;会是一件耗时费力的工作。如果有愿意和我一起来做这件事的&#xff0c;那么&#xff…

.NET 工具库高效生成 PDF 文档

QuestPDF 是一个开源 .NET 库&#xff0c;用于生成 PDF 文档。使用了C# Fluent API方式可简化开发、减少错误并提高工作效率。利用它可以轻松生成 PDF 报告、发票、导出文件等。 QuestPDF 是一个革命性的开源 .NET 库&#xff0c;它彻底改变了我们生成 PDF 文档的方式。 Ques…

[Admin] Things Need to Know

List View Bulk Actions Highlight: To take bulk actions on all of the available records in a list, you click the bulk action button without selecting any records.

优雅使用 MapStruct 进行类复制

前言 在项目中&#xff0c;常常会遇到从数据库读取数据后不能直接返回给前端展示的情况&#xff0c;因为还需要对字段进行加工&#xff0c;比如去除时间戳记录、隐藏敏感数据等。传统的处理方式是创建一个新类&#xff0c;然后编写大量的 get/set 方法进行赋值&#xff0c;若字…

鸿蒙开发(NEXT/API 12)【硬件(传感器开发)】传感器服务

使用场景 Sensor Service Kit&#xff08;传感器服务&#xff09;使应用程序能够从传感器获取原始数据&#xff0c;并提供振感控制能力。 Sensor&#xff08;传感器&#xff09;模块是应用访问底层硬件传感器的一种设备抽象概念。开发者可根据传感器提供的相关接口订阅传感器…

The 2024 CCPC Online Contest (C I J三题思路)

写在前面 因为学弟已经问了几个题了&#xff0c;于是乎这场没有vp&#xff0c;准备直接开写了 题目 C. 种树&#xff08;树形dp&#xff09; 题解 只有两种情况&#xff0c; 一种是1-2-3&#xff0c;1是2的父亲&#xff0c;2是3的父亲 另一种是1-2-3&#xff0c;2同时是1…

【网络安全】-访问控制-burp(1~6)

文章目录 前言   1.Lab: Unprotected admin functionality  2.Lab: Unprotected admin functionality with unpredictable URL   3.Lab: User role controlled by request parameter   4.Lab:User role can be modified in user profile  5.Lab: User ID controlled by…

校园二手交易平台的小程序+ssm(lw+演示+源码+运行)

摘 要 随着社会的发展&#xff0c;社会的方方面面都在利用信息化时代的优势。互联网的优势和普及使得各种系统的开发成为必需。 本文以实际运用为开发背景&#xff0c;运用软件工程原理和开发方法&#xff0c;它主要是采用java语言技术和mysql数据库来完成对系统的设计。整个…

【JavaEE】——线程池大总结

阿华代码&#xff0c;不是逆风&#xff0c;就是我疯&#xff0c; 你们的点赞收藏是我前进最大的动力&#xff01;&#xff01;希望本文内容能够帮助到你&#xff01; 目录 引入&#xff1a;问题引入 一&#xff1a;解决方案 1&#xff1a;方案一——协程/纤程 &#xff08;1…

多输入多输出预测 | NGO-BP北方苍鹰算法优化BP神经网络多输入多输出预测(Matlab)

多输入多输出预测 | NGO-BP北方苍鹰算法优化BP神经网络多输入多输出预测&#xff08;Matlab&#xff09; 目录 多输入多输出预测 | NGO-BP北方苍鹰算法优化BP神经网络多输入多输出预测&#xff08;Matlab&#xff09;预测效果基本介绍程序设计往期精彩参考资料 预测效果 基本介…

计算机毕业设计 在线问诊系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…