Skip to content

Commit

Permalink
[INTERPRETER] Support sort (#3426)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren committed Mar 21, 2024
1 parent 16d3f6e commit 6b72f57
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/test/unit/language/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_maximum_minium(dtype, op):
# ---------------


@pytest.mark.interpreter
@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]])
@pytest.mark.parametrize("descending", [False, True])
@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32'])
Expand Down
27 changes: 27 additions & 0 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,30 @@ def to_ir(self, builder: ir.builder):
# pointer types
pi32_t = pointer_type(int32)


def get_int_dtype(bitwidth: int, signed: bool) -> dtype:
if bitwidth == 1:
return int1
elif bitwidth == 8 and signed:
return int8
elif bitwidth == 8 and not signed:
return uint8
elif bitwidth == 16 and signed:
return int16
elif bitwidth == 16 and not signed:
return uint16
elif bitwidth == 32 and signed:
return int32
elif bitwidth == 32 and not signed:
return uint32
elif bitwidth == 64 and signed:
return int64
elif bitwidth == 64 and not signed:
return uint64
else:
raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}')


# -----------------------
# constexpr
# -----------------------
Expand Down Expand Up @@ -665,6 +689,9 @@ def __invert__(self):
def __pow__(self, other):
return constexpr(self.value**_constexpr_to_value(other))

def __rpow__(self, other):
return constexpr(_constexpr_to_value(other)**self.value)

def __rshift__(self, other):
return constexpr(self.value >> _constexpr_to_value(other))

Expand Down
2 changes: 1 addition & 1 deletion python/triton/language/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr):
left = core.reshape(left, x.shape)
right = core.reshape(right, x.shape)
# actual compare-and-swap
idtype = core.dtype(f'int{core.constexpr(x.dtype.primitive_bitwidth)}')
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
ileft = left.to(idtype, bitcast=True)
iright = right.to(idtype, bitcast=True)
ix = x.to(idtype, bitcast=True)
Expand Down

0 comments on commit 6b72f57

Please sign in to comment.