Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: bf16 kernels #1664

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft

perf: bf16 kernels #1664

wants to merge 17 commits into from

Conversation

rok
Copy link
Contributor

@rok rok commented Nov 27, 2023

See #1651

@rok rok linked an issue Nov 27, 2023 that may be closed by this pull request
2 tasks
rust/lance-linalg/build.rs Outdated Show resolved Hide resolved
rust/lance-linalg/build.rs Outdated Show resolved Hide resolved
@rok rok requested a review from wjones127 December 15, 2023 18:08
Copy link
Contributor

@wjones127 wjones127 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand the instructions involved, but it looks like you are frequently casting / reinterpreting bfloat16 as normal float16 and using float16 add / subtract instructions, which I'm worried isn't valid.

rust/lance-linalg/src/simd/bf16.c Outdated Show resolved Hide resolved
rust/lance-linalg/src/simd/bf16.c Outdated Show resolved Hide resolved
rust/lance-linalg/src/simd/bf16.c Outdated Show resolved Hide resolved
rust/lance-linalg/src/simd/bf16.c Outdated Show resolved Hide resolved
@rok rok force-pushed the 1615_bf16_kernels branch 2 times, most recently from 32447d1 to 88bd113 Compare December 19, 2023 18:48
@eddyxu
Copy link
Contributor

eddyxu commented Dec 19, 2023

For bf16, we have https://doc.rust-lang.org/core/arch/x86_64/fn._mm512_dpbf16_ps.html , should we just use rust code instead?

L2 = dot(x, x) + dot(y, y) - 2 * dot(x, y)
Cosine = 1 - dot(x, y)

@rok
Copy link
Contributor Author

rok commented Dec 19, 2023

For bf16, we have https://doc.rust-lang.org/core/arch/x86_64/fn._mm512_dpbf16_ps.html , should we just use rust code instead?

L2 = dot(x, x) + dot(y, y) - 2 * dot(x, y)

Cosine = 1 - dot(x, y)

This would indeed simplify things a lot, however I would compare performance to see if loading from memory (e.g. with _mm512_maskz_loadu_epi16) incurs a significant cost (rust option would do it 3 times vs C where we'd do it one time).

@eddyxu
Copy link
Contributor

eddyxu commented Dec 19, 2023

For bf16, we have https://doc.rust-lang.org/core/arch/x86_64/fn._mm512_dpbf16_ps.html , should we just use rust code instead?
L2 = dot(x, x) + dot(y, y) - 2 * dot(x, y)

Cosine = 1 - dot(x, y)

This would indeed simplify things a lot, however I would compare performance to see if loading from memory (e.g. with _mm512_maskz_loadu_epi16) incurs a significant cost (rust option would do it 3 times vs C where we'd do it one time).

Why we need to load 3 times in rust? Not sure i follow.

@rok
Copy link
Contributor Author

rok commented Dec 19, 2023

@eddyxu Oh sorry I misunderstood your question. I thought you were suggesting calling the dot kernel to calculate L2.
If we can use nightly rust functions (e.g. _mm512_dpbf16_ps) that would be preferable and I'll refactor.

@rok rok force-pushed the 1615_bf16_kernels branch 2 times, most recently from 9ba418f to b41de61 Compare December 21, 2023 01:08
@rok
Copy link
Contributor Author

rok commented Dec 21, 2023

@eddyxu I've looked at using _mm512_dpbf16_ps via rust but couldn't quite get it right today (e.g. _mm512_castsi512_ph is missing from). Please feel free to take over. I can also do some more work first week of January.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

perf: bf16 kernels
3 participants