Skip to content

Commit

Permalink
radeon fix - compile with same launch params as instinct (#419)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajassani committed Jun 30, 2024
1 parent c809b2b commit 22c0c87
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 30 deletions.
6 changes: 0 additions & 6 deletions csrc/selective_scan/selective_scan_bwd_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -536,12 +536,6 @@ template<typename input_t, typename weight_t>
void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {

#ifndef USE_ROCM
#define warp_size 32
#else
#define warp_size ROCM_WARP_SIZE
#endif

#if warp_size == 32
if (params.seqlen <= 128) {
selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
} else if (params.seqlen <= 256) {
Expand Down
6 changes: 0 additions & 6 deletions csrc/selective_scan/selective_scan_fwd_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,6 @@ template<typename input_t, typename weight_t>
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {

#ifndef USE_ROCM
#define warp_size 32
#else
#define warp_size ROCM_WARP_SIZE
#endif

#if warp_size == 32
if (params.seqlen <= 128) {
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
} else if (params.seqlen <= 256) {
Expand Down
18 changes: 0 additions & 18 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,23 +199,6 @@ def append_nvcc_threads(nvcc_extra_args):

if HIP_BUILD:

try:
# set warp size based on gcn architecure
gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
# radeon
warp_size = 32
else:
# instinct
warp_size = 64
except AttributeError as e:
# fall back to crude method to set warp size
device_name = torch.cuda.get_device_properties(0).name
if 'instinct' in device_name.lower():
warp_size = 64
else:
warp_size = 32

extra_compile_args = {
"cxx": ["-O3", "-std=c++17"],
"nvcc": [
Expand All @@ -226,7 +209,6 @@ def append_nvcc_threads(nvcc_extra_args):
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-DCK_FMHA_FWD_FAST_EXP2=1",
"-fgpu-flush-denormals-to-zero",
f"-DROCM_WARP_SIZE={warp_size}"
]
+ cc_flag,
}
Expand Down

0 comments on commit 22c0c87

Please sign in to comment.