Skip to content

Commit

Permalink
undefined symbol still
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed May 2, 2024
1 parent 4d41015 commit dac6a96
Show file tree
Hide file tree
Showing 7 changed files with 659 additions and 11 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ endif()

if(USE_CUDA AND NOT USE_ROCM)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
if($ENV{ATEN_STATIC_CUDA})
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
${CUDA_LIBRARIES}
Expand Down
34 changes: 34 additions & 0 deletions aten/src/ATen/cuda/detail/LazyNVRTC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,40 @@ CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *);
CUDA_STUB3(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int);
CUDA_STUB3(cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction);

CUresult CUDAAPI
cuTensorMapEncodeTiled(
CUtensorMap* tensorMap,
CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank,
void* globalAddress,
const cuuint64_t* globalDim,
const cuuint64_t* globalStrides,
const cuuint32_t* boxDim,
const cuuint32_t* elementStrides,
CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle,
CUtensorMapL2promotion l2Promotion,
CUtensorMapFloatOOBfill oobFill) {
auto fn = reinterpret_cast<decltype(&cuTensorMapEncodeTiled)>(
getCUDALibrary().sym(__func__));
if (!fn)
throw std::runtime_error("Can't get cuTensorMapEncodeTiled");
lazyNVRTC.cuTensorMapEncodeTiled = fn;
return fn(
tensorMap,
tensorDataType,
tensorRank,
globalAddress,
globalDim,
globalStrides,
boxDim,
elementStrides,
interleave,
swizzle,
l2Promotion,
oobFill);
}

// Irregularly shaped functions
CUresult CUDAAPI cuLaunchKernel(CUfunction f,
unsigned int gridDimX,
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ namespace at { namespace cuda {
_(cuLinkAddData) \
_(cuLinkComplete) \
_(cuFuncSetAttribute) \
_(cuFuncGetAttribute)
_(cuFuncGetAttribute) \
_(cuTensorMapEncodeTiled)

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
#define AT_FORALL_NVRTC(_) \
Expand Down
72 changes: 65 additions & 7 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <cstdint>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/core/NamedTensor.h>
Expand All @@ -9,6 +10,7 @@
#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/native/cuda/RowwiseScaledMM.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand Down Expand Up @@ -799,6 +801,52 @@ static bool _scaled_mm_allowed_device() {
#endif
}

namespace{

// Validates the scale tensors to scaled_mm
void validate_scales(
const c10::optional<at::Tensor>& scale_a,
const c10::optional<at::Tensor>& scale_b,
int64_t dim_m,
int64_t dim_n) {
TORCH_CHECK(
scale_a.has_value() == scale_b.has_value(),
"Both scale_a and scale_b must be present or absent");
// Scaler Tensor
if (scale_a.has_value()) {
TORCH_CHECK(scale_b.has_value(), "scale_b must be present if scale_a is!");

// Both Per-Tensor and Row-wise scaling expect fp32 tensors
TORCH_CHECK(
scale_a->scalar_type() == kFloat && scale_b->scalar_type() == kFloat,
"Both scale_a and scale_b must be float tensors");

// Check the singluar scale case for per-tensor scaling
if (scale_a->numel() == 1) {
TORCH_CHECK(
scale_b->numel() == 1,
"Per-tensor scaling is only supported for k >= 1");
return;
} else if (scale_a->numel() == dim_m) {
TORCH_CHECK(
scale_b->numel() == dim_n,
"Per-row scaling only supported for both matrices");
TORCH_CHECK(
scale_a->is_contiguous() && scale_b->is_contiguous(),
"scale_a and scale_b must be contiguous");
TORCH_CHECK(
scale_a->dim() == 1 && scale_b->dim() == 1,
"scale tensors must be scalars");

} else {
TORCH_CHECK(
false, "scale_a must be size ", dim_m, "but got ", scale_a->numel(), "and scale_b must be size ", dim_n, "but got ", scale_b->numel());
}
}
}

} // namespace

// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax
// Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default.
// If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed.
Expand All @@ -813,10 +861,10 @@ static bool _scaled_mm_allowed_device() {
// - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type
// - `scale_a`: a scalar tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type
// - `scale_b`: a scalar tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type
// - `scale_result`: a scalar tensor with the scale of the output, only set if the output is a float8 type
// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type
// - `use_fast_accum`: if true, enables fast float8 accumulation
// - `out`: a reference to the output tensor
// - `amax`: a reference to the amax tensor of the output, only needed if the output is a float8 type and will be updated inplace
// - `amax`: a reference to the amax tensor of the output, only mutated if the output is a float8 type and will be updated inplace

std::tuple<Tensor&, Tensor&>
_scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
Expand All @@ -835,10 +883,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_CHECK(!scale_a || (scale_a->numel() == 1 && scale_a->scalar_type() == kFloat),
"scale_a must be float scalar");
TORCH_CHECK(!scale_b || (scale_b->numel() == 1 && scale_b->scalar_type() == kFloat),
"scale_b must be a float scalar");
validate_scales(scale_a, scale_b, mat1.size(0), mat2.size(1));
TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
"scale_result must be a float scalar");
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
Expand Down Expand Up @@ -881,12 +926,25 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
{scale_result_, "scale_result", 7}};
checkAllSameGPU(__func__, targs);
}

// Validation checks have passed lets resize the output to actual size
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
at::native::resize_output(amax, {});

// We are doing row-wise scaling
if (scale_a.has_value() && scale_a->numel() != 1) {
at::cuda::detail::f8f8bf16_rowwise(
mat1,
mat2,
scale_a.value(),
scale_b.value(),
bias,
use_fast_accum,
out);
return {out, amax};
}

#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
cublasCommonArgs args(mat1, mat2, out);
const auto out_dtype_ = args.result->scalar_type();
Expand Down

0 comments on commit dac6a96

Please sign in to comment.