Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster attention calculation in 4-2.Seq2Seq? #75

Open
shouldsee opened this issue Jul 23, 2022 · 1 comment
Open

Faster attention calculation in 4-2.Seq2Seq? #75

shouldsee opened this issue Jul 23, 2022 · 1 comment

Comments

@shouldsee
Copy link

shouldsee commented Jul 23, 2022

Thanks for sharing! Just found out Attention.get_att_weight is calculating attention in a for-loop? this looks rather slow isn't it?

4-2.Seq2Seq(Attention)/Seq2Seq(Attention).ipynb

    def get_att_weight(self, dec_output, enc_outputs):  # get attention weight one 'dec_output' with 'enc_outputs'
        n_step = len(enc_outputs)
        attn_scores = torch.zeros(n_step)  # attn_scores : [n_step]

        for i in range(n_step):
            attn_scores[i] = self.get_att_score(dec_output, enc_outputs[i])

        # Normalize scores to weights in range 0 to 1
        return F.softmax(attn_scores).view(1, 1, -1)

    def get_att_score(self, dec_output, enc_output):  # enc_outputs [batch_size, num_directions(=1) * n_hidden]
        score = self.attn(enc_output)  # score : [batch_size, n_hidden]
        return torch.dot(dec_output.view(-1), score.view(-1))  # inner product make scalar value

Suggested parallel version

    def get_att_weight(self, dec_output, enc_outputs):  # get attention weight one 'dec_output' with 'enc_outputs'
        n_step = len(enc_outputs)
        attn_scores = torch.zeros(n_step,device=self.device)  # attn_scores : [n_step]

        enc_t = self.attn(enc_outputs)
        score = dec_output.transpose(1,0).bmm(enc_t.transpose(1,0).transpose(2,1))
        out1   = score.softmax(-1)
        return out1
@shouldsee shouldsee changed the title Faster attention calculation? Faster attention calculation in 4-2.Seq2Seq? Jul 23, 2022
@Ekundayo39283
Copy link

You can create a pull request to update the code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants