Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

triton generates unnecessary shared memory stores/loads #3491

Open
isuruf opened this issue Mar 28, 2024 · 8 comments · May be fixed by #3940
Open

triton generates unnecessary shared memory stores/loads #3491

isuruf opened this issue Mar 28, 2024 · 8 comments · May be fixed by #3940

Comments

@isuruf
Copy link

isuruf commented Mar 28, 2024

For the following triton kernels generated by pytorch, triton generated shared memory stores and loads in the LLVM IR and PTX just before the atomic add operation.

from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()


# kernel path: /tmp/torchinductor_isuruf/5e/c5ehw64oxeoeqqjnqn6v3gfy6z5ukksktwihp7jgzg6sujz5umto.py
# Source Nodes: [], Original ATen: []

triton_poi_fused_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[16777216], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '173ccbefad6764ffc6a32cfd80b0e0decca95dcaaab807475db0bd6fd7f94813'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 8750000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = 0.0
    tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream


# kernel path: /tmp/torchinductor_isuruf/qg/cqgmsmgdzivumf2gmksclwbmyrwpfpouuv3s5suqkeg4j4cdmpjr.py
# Source Nodes: [], Original ATen: []

triton_poi_fused_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[67108864], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_1', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'backend_hash': '173ccbefad6764ffc6a32cfd80b0e0decca95dcaaab807475db0bd6fd7f94813'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 35000000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 1000) % 1000
    x0 = xindex % 1000
    x3 = xindex
    x2 = (xindex // 1000000)
    tmp22 = tl.load(in_ptr0 + (x3), xmask)
    tmp0 = x1
    tmp1 = tmp0.to(tl.float32)
    tmp2 = 0.5
    tmp3 = tmp1 + tmp2
    tmp4 = tmp3 * tmp2
    tmp5 = tmp4 - tmp2
    tmp6 = tmp5.to(tl.int32)
    tmp7 = x0
    tmp8 = tmp7.to(tl.float32)
    tmp9 = tmp8 + tmp2
    tmp10 = tmp9 * tmp2
    tmp11 = tmp10 - tmp2
    tmp12 = tmp11.to(tl.int32)
    tmp13 = tmp6.to(tl.float32)
    tmp14 = tmp5 - tmp13
    tmp15 = 1.0
    tmp16 = tmp15 - tmp14
    tmp17 = tmp15 * tmp16
    tmp18 = tmp12.to(tl.float32)
    tmp19 = tmp11 - tmp18
    tmp20 = tmp15 - tmp19
    tmp21 = tmp17 * tmp20
    tmp23 = tmp21 * tmp22
    tmp24 = tl.full([1], 1, tl.int32)
    tmp25 = tmp12 + tmp24
    tmp26 = tl.full([1], 499, tl.int32)
    tmp27 = triton_helpers.minimum(tmp25, tmp26)
    tmp28 = tmp17 * tmp19
    tmp29 = tmp28 * tmp22
    tmp30 = tmp6 + tmp24
    tmp31 = triton_helpers.minimum(tmp30, tmp26)
    tmp32 = tmp15 * tmp14
    tmp33 = tmp32 * tmp20
    tmp34 = tmp33 * tmp22
    tmp35 = tmp32 * tmp19
    tmp36 = tmp35 * tmp22
    tl.atomic_add(out_ptr0 + (tmp12 + (500*tmp6) + (250000*x2)), tmp23, xmask)
    tl.atomic_add(out_ptr0 + (tmp27 + (500*tmp6) + (250000*x2)), tmp29, xmask)
    tl.atomic_add(out_ptr0 + (tmp12 + (500*tmp31) + (250000*x2)), tmp34, xmask)
    tl.atomic_add(out_ptr0 + (tmp27 + (500*tmp31) + (250000*x2)), tmp36, xmask)
''', device_str='cuda')


async_compile.wait(globals())
del async_compile

def call(args):
    args_1, = args
    args.clear()
    assert_size_stride(args_1, (7, 5, 1000, 1000), (5000000, 1000000, 1000, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((7, 5, 500, 500), (1250000, 250000, 500, 1), torch.float32)
        # Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_0.run(buf0, 8750000, grid=grid(8750000), stream=stream0)
        # Source Nodes: [], Original ATen: []
        triton_poi_fused_1.run(args_1, buf0, 35000000, grid=grid(35000000), stream=stream0)
        del args_1
    return (buf0, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    args_1 = rand_strided((7, 5, 1000, 1000), (5000000, 1000000, 1000, 1), device='cuda:0', dtype=torch.float32)
    fn = lambda: call([args_1])
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)

Shared memory loads/stores are unnecessary in this case. cc @peterbell10

@isuruf
Copy link
Author

isuruf commented Mar 28, 2024

Based on a suggestion from @peterbell10 I removed AtomicRMWOp at https://github.com/openai/triton/blob/0ba87e2ff35f703f84040400554702ee55476cdb/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp#L192 which resulted in the PTX not having any shared memory loads/stores. This resulted in the triton generated kernel to match the pytorch eager backend code whereas it was 50% slower previously with the shared stores and loads.

@isuruf
Copy link
Author

isuruf commented Mar 28, 2024

Is there a case where removing AtomicRMWOp as a layout anchor can result in incorrect code?

@manman-ren
Copy link
Collaborator

I don't think it will result in incorrect code, but I may be wrong. It can affect performance, so will likely need to go through benchmark suites to verify performance impact.
Which version of pytorch are you on? I tried to run your code, but failed.
AttributeError: type object 'torch._C.Generator' has no attribute 'graphsafe_set_state'

@isuruf
Copy link
Author

isuruf commented Mar 29, 2024

I'm using pytorch v2.3.0-rc6

@peterbell10
Copy link
Contributor

Which version of pytorch are you on? I tried to run your code, but failed.
AttributeError: type object 'torch._C.Generator' has no attribute 'graphsafe_set_state'

Given that graphsafe_set_state doesn't appear in the generated code, you probably just need to rebuild pytorch.

@manman-ren
Copy link
Collaborator

You are right. I thought I built it after the source pull.

@manman-ren
Copy link
Collaborator

I looked at this, but not sure what is the best solution :]
Instead, I noticed a few things which I will try to figure out why.
1> It is not clear to me why the atomic op has a different layout sizePerThread = [1] (sizePerThread = [4] for the load op).
2> why the atomic op is an anchor for remove-layout
3> With sizePerThread = [1] and sizePerThread = [4], at ptx level, the atomic op uses the same instruction 8 times atom.global.gpu.acq_rel.add.f32. For the first case, there are two different predicates, but for the latter, it has one predicate. So it looks like sizePerThread=[4] is more efficent?

@lezcano
Copy link
Contributor

lezcano commented Apr 3, 2024

cc @ThomasRaoux @Jokeren for visibility.

@isuruf isuruf linked a pull request May 17, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants