Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2.5 rnn的实现代码貌似有问题 #102

Open
ZhaoQianfeng opened this issue Jul 31, 2020 · 3 comments
Open

2.5 rnn的实现代码貌似有问题 #102

ZhaoQianfeng opened this issue Jul 31, 2020 · 3 comments

Comments

@ZhaoQianfeng
Copy link
Contributor

ZhaoQianfeng commented Jul 31, 2020

原代码中的若干疑点:

  1. RNN中更新hidden应该是由这一次输入和上一次的hidden共同决定,参考pytorch文档中的公式:
    image
    而2.5节代码的step函数中将self.Whh.weight作为新的hidden返回,实际上是将上述公式中的Whh作为hidden返回,这应该是不对的。hidden应该对应2.5节out的表达式,而out变量没有存在的必要,RNN一般直接将更新后的hidden作为这一次的输出向量。

  2. 在底下的测试代码中,可以看得出input的第一维代表了序列长度,第二维代表输入向量长度,那么也就是默认了batch_size为1。RNN中hidden的形状是 (batch_size,hidden_size),而不是2.5节代码中的(seq_len,hidden_size)。也就是说,hidden的形状与输入序列的序列长度无关,只与batch_size和hidden_size有关。既然测试代码默认了批量大小为1,那么h_0形状是(1,hidden_size)或者是个一维向量。参考pytorch关于Input和hidden的定义:

image

  1. 在最后的for循环中,一直将h_0当作上一次的hidden输入,貌似是笔误?

我修改了一下代码,可供参考:

import torch

class RNN():
    def __init__(self,input_size,hidden_size):
        self.W1x_plus_b1 = torch.nn.Linear(input_size,hidden_size)
        self.W2h_plus_b2 = torch.nn.Linear(hidden_size,hidden_size)
    
    def __call__(self,x,hidden):
        return self.step(x,hidden)
    
    def step(self,x,hidden):
        hidden_new = torch.tanh(self.W1x_plus_b1(x) + self.W2h_plus_b2(hidden))
        return hidden_new

rnn = RNN(20,50)
input_seq = torch.randn(32, 20)
h_0 = torch.randn(50)
seq_len = input_seq.shape[0]

for i in range(seq_len):
    h_n = rnn(input_seq[i,:], h_0 if i == 0 else h_n)
    print('output %d: ' % i, h_n) 
@zergtant
Copy link
Owner

好的,我排查下,谢谢

@HamsterCoderSim
Copy link

HamsterCoderSim commented Sep 1, 2020

同意楼主的意见,因为每次更新的应该是作者原代码的out而不是hidden。
第2条意见中,第一个数据不受到前面数据的影响,可以直接为h_0 = 0吧 ,然后在输入第2个数据的时候在加上h_n

@njwm
Copy link

njwm commented Apr 29, 2021

楼主修改的程序中最后那个for循环语法错:第0次循环时h_n没有定义

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants