Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 Support #2054

Open
3 of 8 tasks
kuhar opened this issue Jan 3, 2024 · 10 comments
Open
3 of 8 tasks

FP8 Support #2054

kuhar opened this issue Jan 3, 2024 · 10 comments

Comments

@kuhar
Copy link
Member

kuhar commented Jan 3, 2024

This is an umbrella issue for allowing fp8 type(s) in shark, spanning all the required layers of the stack: Turbine, IREE, MLIR, LLVM, including backends of interest like ROCm.

Some initial research is required to scope this properly and divide into subtasks, but the main work items are roughly:

  • Confirm the exact level of support for fp8 type(s) in the amdgpu backend in llvm
  • Also check with the other target backends (cc @MaheshRavishankar)
    • Checked with the team and focus on gfx940 for now only
  • Intersect the above with support in llvm::APFloat
  • Make sure we can lower simple kernels in the gpu/llvm dialect all the way to amdgpu
  • Make sure Turbine / input conversion supports the relevant fp8 type(s)
  • Make sure iree tooling can consume fp8 weights / inputs
  • Make sure iree runtime works with fp8 types
  • Add e2e correctness tests
@kuhar
Copy link
Member Author

kuhar commented Jan 3, 2024

This is just an umbrella issue to get started. Feel free to modify / fill in the blanks / link sub-issues and related discussions.
cc: @antiagainst @MaheshRavishankar @qedawkins @raikonenfnu @hanhanW @bjacob

@kuhar
Copy link
Member Author

kuhar commented Jan 4, 2024

The gfx940 ISA supports 2 fp8 formats: fp8 and bf8. You can see both format supported with mfma, including operands of mixed formats: https://llvm.org/docs/AMDGPU/AMDGPUAsmGFX940.html#vop3.

FP8 mfma is plumbed through the amdgpu llvm backend: https://reviews.llvm.org/D129906, for example:

// CHECK-GFX940-LABEL: @test_mfma_f32_32x32x16_fp8_bf8
// CHECK-GFX940: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.fp8.bf8(i64 %a, i64 %b, <16 x float> %c, i32 0, i32 0, i32 0)
void test_mfma_f32_32x32x16_fp8_bf8(global v16f* out, long a, long b, v16f c)
{
  *out = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(a, b, c, 0, 0, 0);
}

The fp8 operands are packed as i64. The only other amdgcn intrinsic for fp8 types is cvt -- type conversions. https://github.com/llvm/llvm-project/blob/cd3942059eed7b7185f26bc583ac287a995db0d0/clang/include/clang/Basic/BuiltinsAMDGPU.def#L400-L407

FP8 is E4M3 (inference-focused) while BF8 is E5M2 (training-focused): https://www.amd.com/en/products/accelerators/instinct/mi300/mi300a.html.

The AMD CDNA 3 compute units support both variants of the FP8 data type as defined in the OCP 8-bit floating point specification.

OCP 8-bit Floating Point Specification (OFP8)

Related paper with an overview of fp8 types: FP8 FORMATS FOR DEEP LEARNING

Related blog post with overview of fp8 support for H100: https://lambdalabs.com/blog/nvidia-hopper-h100-and-fp8-support

@kuhar
Copy link
Member Author

kuhar commented Jan 4, 2024

FP8 support in LLVM/MLIR:

RFC from Sep '22 by @stellaraccident: https://discourse.llvm.org/t/rfc-add-apfloat-and-mlir-type-support-for-fp8-e5m2/65279.

Since then, the other types plumbed all the way through MLIR are:

  Float8E4M3FNType f8E4M3FNTy;
  Float8E5M2FNUZType f8E5M2FNUZTy;
  Float8E4M3FNUZType f8E4M3FNUZTy;
  Float8E4M3B11FNUZType f8E4M3B11FNUZTy;

(https://github.com/llvm/llvm-project/blob/6af713ae170c34f0561f19e594266ce2a2af343b/mlir/lib/IR/MLIRContext.cpp#L223C27-L227)

      .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
      .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
      .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
      .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
      .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })

(https://github.com/llvm/llvm-project/blob/6af713ae170c34f0561f19e594266ce2a2af343b/mlir/lib/IR/AsmPrinter.cpp#L2548C1-L2552C73)

func.func @float_attrs_pass() {
  "test.float_attrs"() {
    // CHECK: float_attr = 2.000000e+00 : f8E5M2
    float_attr = 2. : f8E5M2
  } : () -> ()
  "test.float_attrs"() {
    // CHECK: float_attr = 2.000000e+00 : f8E4M3FN
    float_attr = 2. : f8E4M3FN
  } : () -> ()
  "test.float_attrs"() {
    // CHECK: float_attr = 2.000000e+00 : f8E5M2FNUZ
    float_attr = 2. : f8E5M2FNUZ
  } : () -> ()
  "test.float_attrs"() {
    // CHECK: float_attr = 2.000000e+00 : f8E4M3FNUZ
    float_attr = 2. : f8E4M3FNUZ
  } : () -> ()
  "test.float_attrs"() {
    // CHECK: float_attr = 2.000000e+00 : f8E4M3B11FNUZ
    float_attr = 2. : f8E4M3B11FNUZ
  } : () -> ()
  "test.float_attrs

(https://github.com/llvm/llvm-project/blob/dd047c5b64944bae830b9fecf53f8d11ff41386e/mlir/test/IR/attribute.mlir#L38C1-L59C20)

static constexpr fltSemantics semFloat8E5M2 = {15, -14, 3, 8};
static constexpr fltSemantics semFloat8E5M2FNUZ = {
   15, -15, 3, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E4M3FN = {
   8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes};
static constexpr fltSemantics semFloat8E4M3FNUZ = {
   7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E4M3B11FNUZ = {
   4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};

(https://github.com/llvm/llvm-project/blob/dd047c5b64944bae830b9fecf53f8d11ff41386e/llvm/lib/Support/APFloat.cpp#L132-L140)

@kuhar
Copy link
Member Author

kuhar commented Jan 4, 2024

amgcn's fp8 maps to f8E4M3FNUZ while bf8 to f8E5M2NUZ.

@bjacob
Copy link
Contributor

bjacob commented Jan 4, 2024

FP8 is E4M3 (inference-focused) while BF8 is E5M2 (training-focused): https://www.amd.com/en/products/accelerators/instinct/mi300/mi300a.html.

If a model is trained with 2-bit mantissas (E5M2), how is the 3rd bit of mantissa in E4M3 going to be useful in inference?

@antiagainst
Copy link

Also https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html talks a bit about fp8 in NVIDIA GPUs, which is useful reference.

In general, fp8 right now are just used in a very ad-hoc way--with ISAs just do conversion and tensor/matrix core ops. For training we also have different fp8 scaling factors for different tensors and need model/framework level handling there, so also quite ad-hoc.

So as we've discussed in the meeting, getting a minimal matmul to excersise fp8 + tensor/matrix core in IREE/SHARK would be good start and foundation to everything else. We can then build other parts on top.

@antiagainst
Copy link

FP8 is E4M3 (inference-focused) while BF8 is E5M2 (training-focused): https://www.amd.com/en/products/accelerators/instinct/mi300/mi300a.html.

If a model is trained with 2-bit mantissas (E5M2), how is the 3rd bit of mantissa in E4M3 going to be useful in inference?

This is explained a bit in the NVIDIA doc as linked in my previous comment:

During training neural networks both of these types may be utilized. Typically forward activations and weights require more precision, so E4M3 datatype is best used during forward pass. In the backward pass, however, gradients flowing through the network typically are less susceptible to the loss of precision, but require higher dynamic range. Therefore they are best stored using E5M2 data format. H100 TensorCores provide support for any combination of these types as the inputs, enabling us to store each tensor using its preferred precision.

@qedawkins
Copy link
Contributor

Support in MLIR/LLVM/AMDGPU already seems quite promising, so as discussed this morning the plan is to show a very simple example using fp8 in IREE first, something like

module {
  func.func @matmul_static(%arg0: tensor<32x32xi8>, %arg1: tensor<32x32xi8>, %arg2: tensor<32x32xf32>) -> tensor<32x32xf32> {
    %0 = tensor.bitcast %arg0 : tensor<32x32xi8> to tensor<32x32xf8E4M3FNUZ>
    %1 = tensor.bitcast %arg1 : tensor<32x32xi8> to tensor<32x32xf8E4M3FNUZ>
    %2 = linalg.matmul ins(%0, %1 : tensor<32x32xf8E4M3FNUZ>, tensor<32x32xf8E4M3FNUZ>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32>
    return %2 : tensor<32x32xf32>
  }
}

or, to avoid the need to also handle mfma at the same time, just something as simple as

#map = affine_map<(d0) -> (d0)>
module {
  func.func @extend_i8(%arg0: tensor<32xi8>) -> tensor<32xf32> {
    %0 = tensor.bitcast %arg0 : tensor<32xi8> to tensor<32xf8E4M3FNUZ>
    %1 = tensor.empty() : tensor<32xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%0 : tensor<32xf8E4M3FNUZ>) outs(%1 : tensor<32xf32>) {
    ^bb0(%in: f8E4M3FNUZ, %out: f32):
      %3 = arith.extf %in : f8E4M3FNUZ to f32
      linalg.yield %3 : f32
    } -> tensor<32xf32>
    return %2 : tensor<32xf32>
  }
}

@kuhar
Copy link
Member Author

kuhar commented Jan 25, 2024

Explanation of the LLVM fp semantics naming convention:

F is for "finite" (no infinities), N for with special NaN encoding, UZ for unsigned zero.

source: https://github.com/jax-ml/ml_dtypes?tab=readme-ov-file#float8_e5m2fnuz

@MaheshRavishankar
Copy link

Looking through support in MLIR and lowering into NVVM/ROCDL, seems to be already there as well..

MFMA to ROCLD intrinsics :

Tensor core instructions lowering

  • WGMMA instruction support (link)
  • I didnt find support for conversion instructions. So thats strange.

So for the examples in this comment #2054 (comment) , the extension truncation should just pass through and compile on AMD. The mfma support, it would be great if we could just take a single matmul of the exact mfma shape and it would just lower to that operation. Like literally all tile sizes would be 1... it should vectorize to vector.contract, lower to amdgpu.mfma -> rocdl intrinsics...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants