Skip to content

Commit

Permalink
Fix output shape of greedy sampler with batch > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 22, 2023
1 parent e611c9b commit 1e9b510
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion outlines/generate/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def greedy(logits: torch.DoubleTensor, samples: int, *_) -> torch.DoubleTensor:
"""
if samples == 1:
next_token_ids = torch.argmax(logits, dim=-1, keepdim=True).T
next_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
else:
next_token_ids = torch.topk(
logits, samples, dim=-1, largest=True, sorted=True
Expand Down
2 changes: 1 addition & 1 deletion tests/generate/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_greedy():

logits = torch.tensor([[10.0, 0.0, 3.0], [-math.inf, 2.0, 5.0]])
next_token_ids = greedy(logits, samples=1)
assert next_token_ids.equal(torch.tensor([[0, 2]]))
assert next_token_ids.equal(torch.tensor([[0], [2]]))

next_token_ids = greedy(logits, samples=2)
assert next_token_ids.equal(torch.tensor([[0, 2], [2, 1]]))
Expand Down

0 comments on commit 1e9b510

Please sign in to comment.