Skip to content

Commit

Permalink
Implement Q8_0 quantization fully in PyTorch.
Browse files Browse the repository at this point in the history
This is equivalent to gguf.quantize_q8_0 but doesn't round-trip to
Numpy.
  • Loading branch information
heiner committed May 23, 2024
1 parent 4eaa704 commit e974a68
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions convert_grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,21 @@ def get_weights(fn):
assert len(arrays) in (1, 2)


def torch_roundf(t: torch.Tensor) -> torch.Tensor:
"""Round halfway cases away from zero like roundf(3). Cf. gguf/quants.py."""
a = abs(t)
floored = torch.floor(a)
b = floored + torch.floor(2 * (a - floored))
return torch.sign(t) * b


def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q8_0 in ggml.c (modulo rounding away from zero)
assert tensor.shape[1] % QK8_0 == 0
tensor = tensor.reshape(-1, QK8_0)
scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
tensor = (tensor / scale).round().clamp(min=-128, max=127).char()
iscale = torch.where(scale != 0.0, 1.0 / scale, 0.0)
tensor = torch_roundf(tensor * iscale).clamp(min=-128, max=127).char()
# add scale into each block
tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1)
return tensor
Expand Down Expand Up @@ -175,9 +184,7 @@ def maybe_quantize_tensor(tensor, ggml_type):
elif ggml_type == gguf.GGMLQuantizationType.F16:
return tensor.half()
elif ggml_type == gguf.GGMLQuantizationType.Q8_0:
if tensor.device.type == "meta":
return quantize_q8_0(tensor) # Cannot convert into numpy array.
return torch.from_numpy(gguf.quantize_q8_0(tensor.numpy()))
return quantize_q8_0(tensor)
elif ggml_type == gguf.GGMLQuantizationType.Q4_0:
return quantize_q4_0(tensor)
elif ggml_type == gguf.GGMLQuantizationType.Q4_1:
Expand Down

0 comments on commit e974a68

Please sign in to comment.