Skip to content

Commit

Permalink
Fixed phi-3 with Su Rotary Embedding (#499)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Jun 5, 2024
1 parent b2ea56e commit c71861a
Showing 1 changed file with 87 additions and 1 deletion.
88 changes: 87 additions & 1 deletion server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,11 @@ def static(cls, config, dim, base, device, dtype):
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
rope_scaling = rope_scaling.copy()
scaling_factor = rope_scaling["factor"]
rope_type = rope_scaling.pop("type")
if rope_type == "linear":
pass
elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(
dim=dim,
max_position_embeddings=config.max_position_embeddings,
Expand All @@ -442,6 +442,7 @@ def static(cls, config, dim, base, device, dtype):
scaling_factor=scaling_factor,
)
elif rope_type == "yarn":
scaling_factor = rope_scaling["factor"]
return YarnPositionRotaryEmbedding(
dim=dim,
max_position_embeddings=config.max_position_embeddings,
Expand All @@ -450,6 +451,48 @@ def static(cls, config, dim, base, device, dtype):
dtype=dtype,
**rope_scaling,
)
elif rope_type == "su":
short_factor = torch.tensor(
rope_scaling["short_factor"], dtype=torch.float32, device=device
)
short_inv_freq = 1.0 / (
short_factor
* base
** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
/ dim
)
)
long_factor = torch.tensor(
rope_scaling["long_factor"], dtype=torch.float32, device=device
)
long_inv_freq = 1.0 / (
long_factor
* base
** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
/ dim
)
)

original_max_position_embeddings = (
config.original_max_position_embeddings
)
max_position_embeddings = config.max_position_embeddings
if max_position_embeddings <= original_max_position_embeddings:
scaling_factor = 1.0
else:
scale = max_position_embeddings / original_max_position_embeddings
scaling_factor = math.sqrt(
1 + math.log(scale) / math.log(original_max_position_embeddings)
)

return SuRotaryEmbedding(
short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq,
scaling_factor=scaling_factor,
original_max_position_embeddings=original_max_position_embeddings,
)
else:
raise NotImplementedError(f"rope scaling type {rope_type} is not implemented or invalid")
return cls(inv_freq, scaling_factor, config.max_position_embeddings, device, dtype)
Expand Down Expand Up @@ -620,6 +663,49 @@ def yarn(self, device, scaling_factor):
self.mscale = float(
get_mscale(scaling_factor) * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation

class SuRotaryEmbedding(PositionRotaryEmbedding):
def __init__(
self,
short_inv_freq,
long_inv_freq,
scaling_factor,
original_max_position_embeddings,
):
super(PositionRotaryEmbedding, self).__init__()
self.short_inv_freq = short_inv_freq
self.long_inv_freq = long_inv_freq
self.scaling_factor = scaling_factor
self.original_max_position_embeddings = original_max_position_embeddings
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None

def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
if seqlen > self.original_max_position_embeddings:
inv_freq = self.long_inv_freq
else:
inv_freq = self.short_inv_freq
t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)

freqs = torch.outer(t, inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)

# Inverse dim formula to find dim based on number of rotations
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
Expand Down

0 comments on commit c71861a

Please sign in to comment.