diff --git a/python/test/unit/language/test_standard.py b/python/test/unit/language/test_standard.py index cdaf74ed892..5aa1a53b5d5 100644 --- a/python/test/unit/language/test_standard.py +++ b/python/test/unit/language/test_standard.py @@ -25,6 +25,7 @@ def test_maximum_minium(dtype, op): # --------------- +@pytest.mark.interpreter @pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) @pytest.mark.parametrize("descending", [False, True]) @pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32']) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 7d7296bc1c9..42b9fbe434b 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -535,6 +535,30 @@ def to_ir(self, builder: ir.builder): # pointer types pi32_t = pointer_type(int32) + +def get_int_dtype(bitwidth: int, signed: bool) -> dtype: + if bitwidth == 1: + return int1 + elif bitwidth == 8 and signed: + return int8 + elif bitwidth == 8 and not signed: + return uint8 + elif bitwidth == 16 and signed: + return int16 + elif bitwidth == 16 and not signed: + return uint16 + elif bitwidth == 32 and signed: + return int32 + elif bitwidth == 32 and not signed: + return uint32 + elif bitwidth == 64 and signed: + return int64 + elif bitwidth == 64 and not signed: + return uint64 + else: + raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}') + + # ----------------------- # constexpr # ----------------------- @@ -665,6 +689,9 @@ def __invert__(self): def __pow__(self, other): return constexpr(self.value**_constexpr_to_value(other)) + def __rpow__(self, other): + return constexpr(_constexpr_to_value(other)**self.value) + def __rshift__(self, other): return constexpr(self.value >> _constexpr_to_value(other)) diff --git a/python/triton/language/standard.py b/python/triton/language/standard.py index ecd11b6ba12..c3b63eafd60 100644 --- a/python/triton/language/standard.py +++ b/python/triton/language/standard.py @@ -331,7 +331,7 @@ def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr): left = core.reshape(left, x.shape) right = core.reshape(right, x.shape) # actual compare-and-swap - idtype = core.dtype(f'int{core.constexpr(x.dtype.primitive_bitwidth)}') + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) ileft = left.to(idtype, bitcast=True) iright = right.to(idtype, bitcast=True) ix = x.to(idtype, bitcast=True)