New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FP8 rowwise scaling #125204
base: main
Are you sure you want to change the base?
FP8 rowwise scaling #125204
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125204
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 1 Unrelated FailureAs of commit 73b3a39 with merge base 11c2d12 (): FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
54a84cc
to
dac6a96
Compare
|
||
#include <c10/core/ScalarType.h> | ||
#include <cutlass/trace.h> | ||
// TODO we arent actually linking against cudaruntime, probably need to get this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removing this header include appears to work for me locally
#define BUILD_ROWWISE_FP8_KERNEL | ||
#endif | ||
|
||
CUresult CUDAAPI cuTensorMapEncodeTiled(CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, const cuuint64_t *globalStrides, const cuuint32_t *boxDim, const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when trying to mark this static
me/drisspg/meta/pytorch/aten/src/ATen/native/cuda/RowwiseScaledMM.cu:34:17: error: ‘CUresult cuTensorMapEncodeTiled(CUtensorMap*, CUtensorMapDataType, cuuint32_t, void*, const cuuint64_t*, const cuuint64_t*, const cuuint32_t*, const cuuint32_t*, CUtensorMapInterleave, CUtensorMapSwizzle, CUtensorMapL2promotion, CUtensorMapFloatOOBfill)’ was declared ‘extern’ and later ‘static’ [-fpermissive]
dac6a96
to
110261b
Compare
} | ||
|
||
namespace at::cuda::detail { | ||
void f8f8bf16_rowwise( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We recently open sourced this op in FBGEMM ( https://github.com/pytorch/FBGEMM/blob/39b655a5ad3933042fbec439d00894068f453932/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions.cu#L1110), originally from @jwfromm . Do you plan to also add the related quantize routine in PyTorch core (e.g., https://github.com/pytorch/FBGEMM/blob/39b655a5ad3933042fbec439d00894068f453932/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L919 ) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasnt currently planning on adding accompanying quantization ops, as we would likely rely on inductor to generate this casting code.
77638d7
to
7d9bc17
Compare
7d9bc17
to
73b3a39
Compare
Summary
This pull request introduces an fp8 row-scaling kernel as an optional implementation for
scaled_mm
. The kernel selection is based on the scaling tensors of the inputs. For inputsx
andy
of shape[M, K]
and[K, N]
respectively, the following conditions must be met:x
's scale should be a 1-dimensional tensor of lengthM
.y
's scale should be a 1-dimensional tensor of lengthN
.It's important to note that this kernel is not called "rowwise, columnwise" scaling because, although the scales for
y
are semantically along its columns, this implementation only supports the TN format. This means the scaling is along the faster-moving dimension, or the "row".The following two PRs were required to enable local builds:
Todo
We still do not build our Python wheels with this architecture.
@ptrblck @malfet, should we replace
sm_90
withsm_90a
?The NVRTC TMA shadowing feels wrong, but I a not sure the right way to spoof the symbol for this compilation unit:
https://github.com/pytorch/pytorch/pull/125204/files#r1586986954
ifdef
I tried to use :
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 && \ defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 900
to gate the building of the kernel. I was having a hell of a time with this.. so I am not really sure the right way to do this