首页 > 分享 > PyTorch LSTM,batch

PyTorch LSTM,batch

class LSTM(nn.Module): def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.output_size = output_size self.num_directions = 1 # 单向LSTM self.batch_size = batch_size self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True) self.linear = nn.Linear(self.hidden_size, self.output_size) def forward(self, input_seq): batch_size, seq_len = input_seq[0], input_seq[1] h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device) c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device) # output(batch_size, seq_len, num_directions * hidden_size) output, _ = self.lstm(input_seq, (h_0, c_0)) pred = self.linear(output) pred = pred[:, -1, :] return pred这些代码分别是什么意思

相关知识

如何搭建LSTM(pytorch版)
详解pytorch实现猫狗识别98%附代码
使用PyTorch实现鸟类音频检测卷积网络模型
Pytorch与深度学习自查手册4
pytorch单机多卡训练 logger日志记录和wandb可视化
CNN简单实战:PyTorch搭建CNN对猫狗图片进行分类
Pytorch使用cuda后,任务管理器GPU的利用率还是为0?
安装 pytorch
Pytorch 使用Pytorch Lightning DDP时记录日志的正确方法
使用PyTorch进行城市声音分类:PyTorch音频识别

网址: PyTorch LSTM,batch https://m.mcbbbk.com/newsview505065.html

所属分类:萌宠日常
上一篇: Q字体/他加禄文字体/TTC字体
下一篇: 为狂热的读者和书籍爱好者批发最好