-
Notifications
You must be signed in to change notification settings - Fork 167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bf16 result mismatch for Conv2D op #2090
Comments
Can you give me the run module command that shows the error... That will help repro the error for me. For example the |
I am attaching the
|
Running conv2d with different precisions, keeping all the constants(weight and bias) same. |
I think this is not really a codegen issue. This is really a bf16 issue . Are we comfortable closing this, or do we need to do more here. |
contraction-flag-linear-fp32-iree_compile.log
inference_input.0.bin.txt
Not sure how pytorch handles the bf16 computation but it's close to =>(f32 computation then result demoted to bf16) This problem requires attention at how iree handles bf16 loads in cpu backends? Thoughts? I can create an onnx linear module with the same inputs and run it on onnx runtime to have one more reference output if it helps. Thanks! |
I think we should not close this unless we can conclude on handling bf16 in cpu. I mean how to verify the model is producing the correct outputs through onnx pipeline. |
Really I dont know what the compiler itself can decide here.. this always going to mismatch cause the reference is doing different things (and different references will do different things). What IREE is doing is basically "do what it is told to do". If the linalg op says the input type is bf16 and output type is bf16, it is actually doing the accumulation in f32. IMO that is actually being too smart. It should be doing the accumulation in bf16 as well (cause thats what it was told to do). So only AI I can think of is to actually make IREE less smart. |
Usually when I've seen these kinds of issues resolved before, it requires a much more careful drill down vs a high level this vs that. There is no abstract answer to these things at that precision: a rounding/truncation mode difference or 1ULP difference at any stage is enough to result in a 1% error for a datatype like this. If trying to get complete correspondence, then none of that can be ignored. You'll have to dig deeper, and likely if you are still looking at results as textual floating point, you'll miss the difference. |
One advantage to having an ONNX reference is that it is much easier to hack on at that level than PyTorch (i.e. you can build it, set a breakpoint or print specific values in a kernel, etc). |
It seems the mismatch is due to different rounding mechanisms used by pytorch and IREE. I ran few simple add/mul tests, and it's mostly 1-bit difference in the outputs. The pytorch simply truncates the last 16 bits after computing the result while IREE seems to be rounding it.
Their binary representation differs exactly in the last bit.
|
Great analysis, thanks! Indeed, the |
I don't know what code path is used in the code that you ran, but I checked this PyTorch And it does return the correct result,
prints
And not |
This implements the same incorrect rounding-towards-zero as we discussed above. Just dropping the bottom 16 bits like this fails to account for the possibiliy that their value might be >= 0x8000 requiring rounding upwards to the next representable value. (Side note: this also has undefined behavior in C++, as an |
@bjacob Thanks for the explanation! For multiplication example, iree output is more closer as pytorch is simply truncating the result. Can we add a functionality to explicitly mention the rounding mechanism we want in IREE, as we need pytorch results as our reference for the model outputs? And I did following for pytorch bf16 multiplication:
|
I also got a weird example, @bjacob
|
Wow, funny bug that you found here! It appears to be a parsing bug, in how
FYI @benvanik |
The parsing itself is correct, though - And yet, something is producing incorrect results only when the |
The bug reproduces whenever the specified |
And the other operand, which is hardcoded as a constant in the above testcase, also matters. Here is a testcase taking both operands as arguments:
With that, I find that for this to reproduce, that other operand needs to be bfloat16 |
This actually minimizes down to a testcase that performs no bfloat16 arithmetic and only a f32->bfloat16 truncf: #map = affine_map<(d0) -> (d0)>
module {
func.func @main_graph(%arg0: tensor<1xf32>) -> tensor<1xbf16> {
%0 = tensor.empty() : tensor<1xbf16>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xf32>) outs(%0 : tensor<1xbf16>) {
^bb0(%in0: f32, %out: bf16):
%3 = arith.truncf %in0 : f32 to bf16
linalg.yield %3 : bf16
} -> tensor<1xbf16>
return %1 : tensor<1xbf16>
}
}
|
@rsuderman this might be for you :-) What // -----// IR Dump After CSE (cse) //----- //
module {
func.func @main_graph_dispatch_0_generic() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<f32>
memref.assume_alignment %0, 64 : memref<f32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<i16>
memref.assume_alignment %1, 64 : memref<i16>
%2 = memref.load %0[] : memref<f32>
%3 = arith.truncf %2 : f32 to bf16
%4 = arith.bitcast %3 : bf16 to i16
memref.store %4, %1[] : memref<i16>
return
}
}
// -----// IR Dump After ConvertToLLVM (iree-convert-to-llvm) //----- //
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-unknown-eabi-elf"} {
llvm.func @main_graph_dispatch_0_generic(%arg0: !llvm.ptr {llvm.align = 16 : i64, llvm.noalias}, %arg1: !llvm.ptr {llvm.align = 16 : i64, llvm.noalias}, %arg2: !llvm.ptr {llvm.align = 16 : i64, llvm.noalias}) -> i32 {
%0 = llvm.mlir.constant(0 : i32) : i32
%1 = llvm.mlir.constant(16 : i32) : i32
%2 = llvm.mlir.constant(32768 : i32) : i32
%3 = llvm.mlir.constant(2130706432 : i32) : i32
%4 = llvm.mlir.constant(2139095040 : i32) : i32
%5 = llvm.mlir.constant(8388607 : i32) : i32
%6 = llvm.mlir.constant(31 : i32) : i32
%7 = llvm.mlir.constant(23 : i32) : i32
%8 = llvm.mlir.constant(63 : index) : i64
%9 = llvm.mlir.constant(0 : index) : i64
%10 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
%11 = llvm.extractvalue %10[10] : !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
%12 = llvm.load %11 : !llvm.ptr -> !llvm.ptr
%13 = llvm.ptrtoint %12 : !llvm.ptr to i64
%14 = llvm.and %13, %8 : i64
%15 = llvm.icmp "eq" %14, %9 : i64
"llvm.intr.assume"(%15) : (i1) -> ()
%16 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
%17 = llvm.extractvalue %16[10] : !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
%18 = llvm.getelementptr %17[1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
%19 = llvm.load %18 : !llvm.ptr -> !llvm.ptr
%20 = llvm.ptrtoint %19 : !llvm.ptr to i64
%21 = llvm.and %20, %8 : i64
%22 = llvm.icmp "eq" %21, %9 : i64
"llvm.intr.assume"(%22) : (i1) -> ()
%23 = llvm.load %12 : !llvm.ptr -> f32
%24 = llvm.bitcast %23 : f32 to i32
%25 = llvm.lshr %24, %6 : i32
%26 = llvm.sub %2, %25 : i32
%27 = llvm.and %24, %5 : i32
%28 = llvm.add %27, %26 : i32
%29 = llvm.lshr %28, %7 : i32
%30 = llvm.lshr %28, %29 : i32
%31 = llvm.and %24, %4 : i32
%32 = llvm.add %31, %28 : i32
%33 = llvm.and %32, %4 : i32
%34 = llvm.icmp "uge" %31, %3 : i32
%35 = llvm.select %34, %31, %33 : i1, i32
%36 = llvm.trunc %29 : i32 to i1
%37 = llvm.and %34, %36 : i1
%38 = llvm.select %37, %27, %30 : i1, i32
%39 = llvm.shl %25, %6 : i32
%40 = llvm.or %39, %35 : i32
%41 = llvm.or %40, %38 : i32
%42 = llvm.lshr %41, %1 : i32
%43 = llvm.trunc %42 : i32 to i16
llvm.store %43, %19 : i16, !llvm.ptr
llvm.return %0 : i32
}
} In the first part of the above log, our |
@rsuderman , here is what the equivalent f32->bf16 truncation code does in the runtime (actually it is generic in bit-widths, but it in particular does f32->bf16) specifically to fix-up in this specific case: https://github.com/openxla/iree/blob/01c4c57/runtime/src/iree/base/internal/math.h#L389-L390 |
@rsuderman , here is the much more concise and optimized way that the PyTorch runtime does it (I think that part was written by Marat and carried over from XNNPACK or some predecessor of it): In the above-linked runtime code, I didn't bother to implement this magic trick because I wanted genericity and didn't need to chase performance. But in the compiler lowering, it would make sense to do the concise and efficient thing. The link to IREE math.h in the previous comment has a comment explaining the magic trick here.
|
Great :/, I was pretty sure I had managed to implement the rounding behavior correctly but I did not have an aggressive test case to evaluate with. I assume this means there is an error in our |
It's ok, I think I have the patch ready soon. |
@Shukla-Gaurav , this seems to work. I'll fix up any unit test that fails and send that for review @rsuderman . |
Thanks a lot @bjacob for actively working on this. Will try the patch with other test cases/models as well. |
llvm/llvm-project#83180 is merged, so you'll get it in the next integrate or can cherry-pick it locally until then to verify it fixed your issue. |
[AMD Official Use Only - General]
Thank you!
From: Benoit Jacob ***@***.***>
Sent: Wednesday, February 28, 2024 10:57 AM
To: nod-ai/SHARK ***@***.***>
Cc: Deepak, Kumar ***@***.***>; Mention ***@***.***>
Subject: Re: [nod-ai/SHARK] bf16 result mismatch for Conv2D op (Issue #2090)
Caution: This message originated from an External Source. Use proper caution when opening attachments, clicking links, or responding.
llvm/llvm-project#83180<llvm/llvm-project#83180> is merged, so you'll get it in the next integrate or can cherry-pick it locally until then to verify it fixed your issue.
—
Reply to this email directly, view it on GitHub<#2090 (comment)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/A5OMX36M7WGVDGVTXACBMDDYV54YRAVCNFSM6AAAAABDMNYHESVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSNRZGY2DENBWHA>.
You are receiving this because you were mentioned.Message ID: ***@***.******@***.***>>
|
Running the above module through the IREE cpu backend generates incorrect results wrt the pytorch output.
The text was updated successfully, but these errors were encountered: