Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swapped naive dot product attention for flash attention #24

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 2 additions & 12 deletions torchscale/architecture/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,23 +339,13 @@ def forward(
):
assert src_tokens is not None or token_embeddings is not None

if encoder_padding_mask is None:
if src_tokens is not None:
encoder_padding_mask = torch.zeros_like(
src_tokens, device=src_tokens.device
).bool()
else:
encoder_padding_mask = torch.zeros(
[token_embeddings.size(0), token_embeddings.size(1)],
device=token_embeddings.device,
).bool()

if multiway_split_position is not None:
assert self.args.multiway
self.apply(set_split_position(multiway_split_position))

x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions)
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
if encoder_padding_mask is not None:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))

encoder_states = []

Expand Down
94 changes: 53 additions & 41 deletions torchscale/component/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under The MIT License [see LICENSE for details]

import math
from typing import Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -64,12 +65,12 @@ def reset_parameters(self):

def forward(
self,
query,
key,
value,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
incremental_state=None,
key_padding_mask=None,
attn_mask=None,
key_padding_mask: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
rel_pos=None,
):
bsz, tgt_len, embed_dim = query.size()
Expand All @@ -84,74 +85,85 @@ def forward(
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling

q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim)
k = k.reshape(bsz * self.num_heads, src_len, self.head_dim)
v = v.reshape(bsz * self.num_heads, src_len, self.head_dim)
q = q.reshape(bsz, self.num_heads, tgt_len, self.head_dim)
k = k.reshape(bsz, self.num_heads, src_len, self.head_dim)
v = v.reshape(bsz, self.num_heads, src_len, self.head_dim)

if incremental_state is not None:
if "prev_key" in incremental_state:
prev_key = incremental_state["prev_key"].view(
bsz * self.num_heads, -1, self.head_dim
bsz, self.num_heads, -1, self.head_dim
)
prev_value = incremental_state["prev_value"].view(
bsz * self.num_heads, -1, self.head_dim
bsz, self.num_heads, -1, self.head_dim
)
k = torch.cat([prev_key, k], dim=1)
v = torch.cat([prev_value, v], dim=1)
incremental_state["prev_key"] = k.view(
bsz, self.num_heads, -1, self.head_dim
)
incremental_state["prev_value"] = v.view(
bsz, self.num_heads, -1, self.head_dim
)
incremental_state["prev_key"] = k
incremental_state["prev_value"] = v
src_len = k.size(1)

if self.xpos is not None:
if incremental_state is not None:
offset = src_len - 1
else:
offset = 0
k, q = map(lambda t: t.view(bsz * self.num_heads, -1, self.head_dim), (k, q))
k = self.xpos(k, offset=0, downscale=True)
q = self.xpos(q, offset=offset, downscale=False)
k, q = map(lambda t: t.view(bsz, self.num_heads, -1, self.head_dim), (k, q))

attn_weights = torch.bmm(q, k.transpose(1, 2))

if attn_mask is not None:
attn_weights = torch.nan_to_num(attn_weights)
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
if attn_mask is not None and attn_mask.ndim != 4:
# Add batch and heads
attn_mask = attn_mask.reshape(1, 1, *attn_mask.shape).expand(bsz, self.num_heads, -1, -1)
# else:
# attn_mask = torch.zeros(1, tgt_len, src_len, dtype=torch.float32, device=k.device)

if key_padding_mask is not None:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# Achieve same result with an additive mask
key_padding_mask = torch.where(key_padding_mask, float("-inf"), 0.0)
# Add heads and dst_len
key_padding_mask = key_padding_mask.reshape(bsz, 1, 1, src_len).to(q.dtype).expand(-1, self.num_heads, tgt_len, -1)
if attn_mask is not None:
attn_mask = attn_mask + key_padding_mask
else:
attn_mask = key_padding_mask.expand(-1, self.num_heads, tgt_len, -1)

if rel_pos is not None:
rel_pos = rel_pos.view(attn_weights.size())
attn_weights = attn_weights + rel_pos
if attn_mask is not None:
attn_mask = attn_mask + rel_pos.view(attn_mask.size())
else:
attn_mask = rel_pos.reshape(bsz, self.num_heads, tgt_len, src_len)

attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
attn_weights
)
attn_probs = self.dropout_module(attn_weights)
if hasattr(F, "scaled_dot_product_attention"):
attn = F.scaled_dot_product_attention(
q, k, v, attn_mask, self.dropout_module.p
)
# attn: B,H,T,E (Batch, Heads, Tgt_Len, Dim)
# Permute to B,T,H,E, and then flatten to B,T,D
attn = attn.permute(0, 2, 1, 3).flatten(2)
attn_weights = None
else:
q *= self.scaling
q, k, v = map(lambda t: t.view(bsz * self.num_heads, -1, self.head_dim), (q, k, v))
attn_weights = torch.bmm(q, k.transpose(1, 2))

attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
attn_weights
)
attn_weights = attn_weights.view(
bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
attn_probs = self.dropout_module(attn_weights)

attn = torch.bmm(attn_probs, v)
attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1)
attn = torch.bmm(attn_probs, v)
attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1)

if self.inner_attn_ln is not None:
attn = self.inner_attn_ln(attn)

attn = self.out_proj(attn)
attn_weights = attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)

return attn, attn_weights