Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about tensor.view operation in Bi-LSTM(Attention) #38

Open
iamxpy opened this issue Sep 29, 2019 · 1 comment
Open

Question about tensor.view operation in Bi-LSTM(Attention) #38

iamxpy opened this issue Sep 29, 2019 · 1 comment

Comments

@iamxpy
Copy link

iamxpy commented Sep 29, 2019

hidden = final_state.view(-1, n_hidden * 2, 1) # hidden : [batch_size, n_hidden * num_directions(=2), 1(=n_layer)]

Hi, this repo is awesome, but there might be something wrong in the code above. According to the comment above, this snippet intends to change a tensor from shape [num_layers(=1) * num_directions(=2), batch_size, n_hidden] to shape [batch_size, n_hidden * num_directions(=2), 1(=n_layer)], i.e. to concatenate the 2 hidden vector from different direction for every data example in a batch(By saying "data example", I mean a batch has batch_size examples). But I think the code above will mess up the data examples in a batch and lead to unexpected result.

For example, we can use IPython to check the effect of the snippet above.

# create a tensor with shape [num_layers(=1) * num_directions(=2), batch_size, n_hidden]                                                                                           
In [10]: a=torch.arange(2*3*5).reshape(2,3,5) 
                                                                       
In [11]: a                                                             
Out[11]:                                                               
tensor([[[ 0,  1,  2,  3,  4],                                         
         [ 5,  6,  7,  8,  9],                                         
         [10, 11, 12, 13, 14]],                                        
                                                                       
        [[15, 16, 17, 18, 19],                                         
         [20, 21, 22, 23, 24],                                         
         [25, 26, 27, 28, 29]]])                                       
                                                                       
In [12]: a.view(-1,10,1)                                               
Out[12]:                                                               
tensor([[[ 0],                                                         
         [ 1],                                                         
         [ 2],                                                         
         [ 3],                                                         
         [ 4],                                                         
         [ 5],                                                         
         [ 6],                                                         
         [ 7],                                                         
         [ 8],                                                         
         [ 9]],                                                        
                                                                       
        [[10],                                                         
         [11],                                                         
         [12],                                                         
         [13],                                                         
         [14],                                                         
         [15],                                                         
         [16],                                                         
         [17],                                                         
         [18],                                                         
         [19]],                                                        
                                                                       
        [[20],                                                         
         [21],                                                         
         [22],                                                         
         [23],                                                         
         [24],                                                         
         [25],                                                         
         [26],                                                         
         [27],                                                         
         [28],                                                         
         [29]]])                                                       
                                                                       
                         

As you can see, we create a tensor with batch_size=3 and n_hidden=5, e.g [ 0, 1, 2, 3, 4] and [15, 16, 17, 18, 19] belong to the same data example in the batch, but they are from different directions, so what we want is to concatenate them in the resulting tensor. But what the code really does is to concatenate [ 0, 1, 2, 3, 4] and [ 5, 6, 7, 8, 9], which are from different data examples in a batch.

I think it can be fixed by changing the line of code to hidden=torch.cat(final_state[0],final_state[1]],1).view(-1,10,1)

The effect of the new code can be shown as follows:

In [13]: torch.cat([a[0],a[1]],1).view(-1,10,1)
Out[13]:
tensor([[[ 0],
         [ 1],
         [ 2],
         [ 3],
         [ 4],
         [15],
         [16],
         [17],
         [18],
         [19]],

        [[ 5],
         [ 6],
         [ 7],
         [ 8],
         [ 9],
         [20],
         [21],
         [22],
         [23],
         [24]],

        [[10],
         [11],
         [12],
         [13],
         [14],
         [25],
         [26],
         [27],
         [28],
         [29]]])
@liuxiaoqun
Copy link

I think it need to change hidden = final_state.view(batch_size, -1, 1) to hidden = final_state.transpose(0,1).reshape(batch_size,-1,1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants