diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py new file mode 100644 index 00000000..2b793672 --- /dev/null +++ b/llms/mlx_lm/sample_utils.py @@ -0,0 +1,39 @@ +import mlx.core as mx + + +def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: + """ + Apply top-p (nucleus) sampling to logits. + + Args: + logits: The logits from the model's output. + top_p: The cumulative probability threshold for top-p filtering. + temperature: Temperature parameter for softmax distribution reshaping. + Returns: + token selected based on the top-p criterion. + """ + if ( + logits.dtype == mx.bfloat16 + ): # workaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16 + logits = logits.astype(mx.float32) + + # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 + probs = mx.softmax(logits / temperature, axis=-1) + + # sort probs in ascending order + sorted_indices = mx.argsort(probs, axis=-1) + sorted_probs = probs[..., sorted_indices] + + cumulative_probs = mx.cumsum(sorted_probs, axis=-1) + + # select tokens with cumulative probs below threshold + top_probs = mx.where( + cumulative_probs > 1 - top_p, + sorted_probs, + mx.zeros_like(sorted_probs), + ) + + sorted_token = mx.random.categorical(mx.log(top_probs)) + token = sorted_indices.squeeze(0)[sorted_token] + + return token diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 7b0e2da7..03e0fbd3 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -17,6 +17,8 @@ from mlx.utils import tree_flatten from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer +from .sample_utils import top_p_sampling + # Local imports from .tuner.utils import apply_lora_layers from .tuner.utils import dequantize as dequantize_model @@ -144,23 +146,7 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: token = mx.argmax(logits, axis=-1) else: if top_p > 0 and top_p < 1.0: - if ( - logits.dtype == mx.bfloat16 - ): # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16 - logits = logits.astype(mx.float32) - probs = mx.softmax(logits / temp, axis=-1) - - sorted_probs = mx.sort(probs)[::-1] - sorted_indices = mx.argsort(probs)[::-1] - cumulative_probs = mx.cumsum(sorted_probs, axis=-1) - - top_probs = mx.where( - cumulative_probs > 1 - top_p, - sorted_probs, - mx.zeros_like(sorted_probs), - ) - sorted_token = mx.random.categorical(mx.log(top_probs)) - token = sorted_indices.squeeze(0)[sorted_token] + token = top_p_sampling(logits, top_p, temp) else: token = mx.random.categorical(logits * (1 / temp)) diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py new file mode 100644 index 00000000..f02560a6 --- /dev/null +++ b/llms/tests/test_sample_utils.py @@ -0,0 +1,37 @@ +import unittest +from unittest.mock import patch + +import mlx.core as mx +from mlx_lm.sample_utils import top_p_sampling + + +class TestLora(unittest.TestCase): + @patch("mlx.core.random.categorical") + def test_top_p_sampling(self, mock_categorical): + logits = mx.array([[1.0, 2.0, 3.0, 4.0]]) + top_p = 0.3 + temperature = 1.0 + expected_token = mx.array([3]) + mock_categorical.return_value = expected_token + + token = top_p_sampling(logits, top_p, temperature) + expected_top_probs = mx.array([[0.0, 0.0, 0.0, 0.643914]]) + self.assertTrue(mx.allclose(token, expected_token)) + args, _ = mock_categorical.call_args + self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs))) + + logits = mx.array([[1.0, 2.0, 3.0, 4.0]]) + top_p = 0.9 + temperature = 1.0 + expected_token = mx.array([3]) + mock_categorical.return_value = expected_token + + token = top_p_sampling(logits, top_p, temperature) + expected_top_probs = mx.array([[0.0, 0.0871443, 0.236883, 0.643914]]) + self.assertTrue(mx.allclose(token, expected_token)) + args, _ = mock_categorical.call_args + self.assertTrue(mx.allclose(args[0], mx.log(expected_top_probs))) + + +if __name__ == "__main__": + unittest.main()