New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml : rewrite silu and softmax for cpu #7154
Conversation
Not deeply analysing the changes but these are the general observation if it would help other reviewers:
|
On AMD Ryzen 9 5950X and M2 Ultra Using the following command to benchmark: make -j tests && ./tests/test-backend-ops -o SOFT_MAX -b CPU perf |
I'm glad to hear that. Here's the avx2 and avx512 variations if you want to try them out: inline __m256 llamafile_expf_avx2(__m256 x) {
const __m256 r = _mm256_set1_ps(0x1.8p23f);
const __m256 z = MADD256(x, _mm256_set1_ps(0x1.715476p+0f), r);
const __m256 n = _mm256_sub_ps(z, r);
const __m256 b = NMADD256(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
NMADD256(n, _mm256_set1_ps(0x1.62e4p-1f), x));
const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
const __m256 k = _mm256_castsi256_ps(
_mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
const __m256i c = _mm256_castps_si256(
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
_mm256_set1_ps(126), _CMP_GT_OQ));
const __m256 u = _mm256_mul_ps(b, b);
const __m256 j = MADD256(MADD256(MADD256(_mm256_set1_ps(0x1.0e4020p-7f), b,
_mm256_set1_ps(0x1.573e2ep-5f)),
u,
MADD256(_mm256_set1_ps(0x1.555e66p-3f), b,
_mm256_set1_ps(0x1.fffdb6p-2f))),
u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
return MADD256(j, k, k);
const __m256i g = _mm256_and_si256(
_mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
_mm256_set1_epi32(0x82000000u));
const __m256 s1 =
_mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
const __m256i d = _mm256_castps_si256(
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
_mm256_set1_ps(192), _CMP_GT_OQ));
return _mm256_or_ps(
_mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
_mm256_andnot_ps(
_mm256_castsi256_ps(d),
_mm256_or_ps(
_mm256_and_ps(_mm256_castsi256_ps(c),
_mm256_mul_ps(MADD256(s2, j, s2), s1)),
_mm256_andnot_ps(_mm256_castsi256_ps(c), MADD256(k, j, k)))));
}
inline __m512 llamafile_expf_avx512(__m512 x) {
const __m512 r = _mm512_set1_ps(0x1.8p23f);
const __m512 z = MADD512(x, _mm512_set1_ps(0x1.715476p+0f), r);
const __m512 n = _mm512_sub_ps(z, r);
const __m512 b = NMADD512(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
NMADD512(n, _mm512_set1_ps(0x1.62e4p-1f), x));
const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
const __m512 k = _mm512_castsi512_ps(
_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
const __mmask16 c =
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
const __m512 u = _mm512_mul_ps(b, b);
const __m512 j = MADD512(MADD512(MADD512(_mm512_set1_ps(0x1.0e4020p-7f), b,
_mm512_set1_ps(0x1.573e2ep-5f)),
u,
MADD512(_mm512_set1_ps(0x1.555e66p-3f), b,
_mm512_set1_ps(0x1.fffdb6p-2f))),
u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
if (_mm512_kortestz(c, c))
return MADD512(j, k, k);
const __m512i g = _mm512_and_si512(
_mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
_mm512_set1_epi32(0x82000000u));
const __m512 s1 =
_mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
const __mmask16 d =
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
return _mm512_mask_blend_ps(
d,
_mm512_mask_blend_ps(c, MADD512(k, j, k),
_mm512_mul_ps(MADD512(s2, j, s2), s1)),
_mm512_mul_ps(s1, s1));
} Here's the numbers I got with the script I used for developing these functions:
|
@ggerganov Running your command, I'm noticing the advantage here increases from 1.5x to 1.9x if we include AVX2. On znver4 if we also include avx512 then that goes up to 2.1x. I'd expect that to go higher in the future, since znver4 only really implements the AVX512 ISA and uses 2 cycles for each vector operation. So I've gone ahead and included the code for you. |
This change upstreams llamafile's vectorized expf() functions. This lets us compute softmax and silu more accurately than the short[65536] lookup table that GGML previously used to make this operation go faster. We can support aarch64 and sse2+ with the worst case rounding error of 2ulp. It makes make -j8 tests && ./tests/test-backend-ops -o SOFT_MAX -b CPU perf go 1.5x faster for SSE2+FMA, 1.9x faster for AVX2+FMA and 2.1x on AVX512
With AVX512, you may want to use vscalefps. It overflows and underflows properly, letting you remove checks + blends. I have an implementation in Julia, e.g. a loop with 4x unrolling and interleaving. L304:
vmovups zmm15, zmmword ptr [r11 + 4*rax]
vmovups zmm14, zmmword ptr [r11 + 4*rax + 64]
vmovups zmm13, zmmword ptr [r11 + 4*rax + 128]
vmovups zmm12, zmmword ptr [r11 + 4*rax + 192]
vmovaps zmm16, zmm1
vfmadd213ps zmm16, zmm15, zmm0 # zmm16 = (zmm15 * zmm16) + zmm0
vmovaps zmm17, zmm1
vfmadd213ps zmm17, zmm14, zmm0 # zmm17 = (zmm14 * zmm17) + zmm0
vmovaps zmm18, zmm1
vfmadd213ps zmm18, zmm13, zmm0 # zmm18 = (zmm13 * zmm18) + zmm0
vmovaps zmm19, zmm1
vfmadd213ps zmm19, zmm12, zmm0 # zmm19 = (zmm12 * zmm19) + zmm0
vaddps zmm16, zmm16, zmm2
vaddps zmm17, zmm17, zmm2
vaddps zmm18, zmm18, zmm2
vaddps zmm19, zmm19, zmm2
vfmadd231ps zmm15, zmm16, zmm3 # zmm15 = (zmm16 * zmm3) + zmm15
vfmadd231ps zmm14, zmm17, zmm3 # zmm14 = (zmm17 * zmm3) + zmm14
vfmadd231ps zmm13, zmm18, zmm3 # zmm13 = (zmm18 * zmm3) + zmm13
vfmadd231ps zmm12, zmm19, zmm3 # zmm12 = (zmm19 * zmm3) + zmm12
vfmadd231ps zmm15, zmm16, zmm4 # zmm15 = (zmm16 * zmm4) + zmm15
vfmadd231ps zmm14, zmm17, zmm4 # zmm14 = (zmm17 * zmm4) + zmm14
vfmadd231ps zmm13, zmm18, zmm4 # zmm13 = (zmm18 * zmm4) + zmm13
vfmadd231ps zmm12, zmm19, zmm4 # zmm12 = (zmm19 * zmm4) + zmm12
vmovaps zmm20, zmm6
vfmadd213ps zmm20, zmm15, zmm5 # zmm20 = (zmm15 * zmm20) + zmm5
vmovaps zmm21, zmm6
vfmadd213ps zmm21, zmm14, zmm5 # zmm21 = (zmm14 * zmm21) + zmm5
vmovaps zmm22, zmm6
vfmadd213ps zmm22, zmm13, zmm5 # zmm22 = (zmm13 * zmm22) + zmm5
vmovaps zmm23, zmm6
vfmadd213ps zmm23, zmm12, zmm5 # zmm23 = (zmm12 * zmm23) + zmm5
vfmadd213ps zmm20, zmm15, zmm7 # zmm20 = (zmm15 * zmm20) + zmm7
vfmadd213ps zmm21, zmm14, zmm7 # zmm21 = (zmm14 * zmm21) + zmm7
vfmadd213ps zmm22, zmm13, zmm7 # zmm22 = (zmm13 * zmm22) + zmm7
vfmadd213ps zmm23, zmm12, zmm7 # zmm23 = (zmm12 * zmm23) + zmm7
vfmadd213ps zmm20, zmm15, zmm8 # zmm20 = (zmm15 * zmm20) + zmm8
vfmadd213ps zmm21, zmm14, zmm8 # zmm21 = (zmm14 * zmm21) + zmm8
vfmadd213ps zmm22, zmm13, zmm8 # zmm22 = (zmm13 * zmm22) + zmm8
vfmadd213ps zmm23, zmm12, zmm8 # zmm23 = (zmm12 * zmm23) + zmm8
vfmadd213ps zmm20, zmm15, zmm9 # zmm20 = (zmm15 * zmm20) + zmm9
vfmadd213ps zmm21, zmm14, zmm9 # zmm21 = (zmm14 * zmm21) + zmm9
vfmadd213ps zmm22, zmm13, zmm9 # zmm22 = (zmm13 * zmm22) + zmm9
vfmadd213ps zmm23, zmm12, zmm9 # zmm23 = (zmm12 * zmm23) + zmm9
vfmadd213ps zmm20, zmm15, zmm10 # zmm20 = (zmm15 * zmm20) + zmm10
vfmadd213ps zmm21, zmm14, zmm10 # zmm21 = (zmm14 * zmm21) + zmm10
vfmadd213ps zmm22, zmm13, zmm10 # zmm22 = (zmm13 * zmm22) + zmm10
vfmadd213ps zmm23, zmm12, zmm10 # zmm23 = (zmm12 * zmm23) + zmm10
vfmadd213ps zmm20, zmm15, zmm11 # zmm20 = (zmm15 * zmm20) + zmm11
vfmadd213ps zmm21, zmm14, zmm11 # zmm21 = (zmm14 * zmm21) + zmm11
vfmadd213ps zmm22, zmm13, zmm11 # zmm22 = (zmm13 * zmm22) + zmm11
vfmadd213ps zmm23, zmm12, zmm11 # zmm23 = (zmm12 * zmm23) + zmm11
vfmadd213ps zmm20, zmm15, zmm11 # zmm20 = (zmm15 * zmm20) + zmm11
vfmadd213ps zmm21, zmm14, zmm11 # zmm21 = (zmm14 * zmm21) + zmm11
vfmadd213ps zmm22, zmm13, zmm11 # zmm22 = (zmm13 * zmm22) + zmm11
vfmadd213ps zmm23, zmm12, zmm11 # zmm23 = (zmm12 * zmm23) + zmm11
vscalefps zmm12, zmm20, zmm16, {rn-sae}
vscalefps zmm13, zmm21, zmm17, {rn-sae}
vscalefps zmm14, zmm22, zmm18, {rn-sae}
vscalefps zmm15, zmm23, zmm19, {rn-sae}
vmovups zmmword ptr [r14 + 4*rax], zmm12
vmovups zmmword ptr [r14 + 4*rax + 64], zmm13
vmovups zmmword ptr [r14 + 4*rax + 128], zmm14
vmovups zmmword ptr [r14 + 4*rax + 192], zmm15
add rax, 64
cmp rax, r10
jl L304 These gave me a significant performance improvement. What hardware are you on? I'm using skylake-avx512/cascadelake with 2x fma units. Note that it doesn't use a lookup table. |
After this PR has been merged the server has been producing nondeterministic results when using >1 slots. Minimal example for reproduction: make clean && make server
./server -m models/opt/llama_2-7b-q4_0.gguf --parallel 2 --threads 1 In another shell: curl --request POST --url http://localhost:8080/completion --header "Content-Type: application/json" --data '{"prompt": "", "n_predict":10, "n_probs": 2, "temperature": -1}' | python3 -m json.tool The token probabilities for the last token cycle between two values with every |
@chriselrod Could you help me modify my avx512 intrinsics to use _mm512_scalef_ps (vscalefps) like your code? I'm currently talking to ARM Limited about getting these functions into Glibc, since our code goes faster. ARM-software/optimized-routines#69 |
This change upstreams llamafile's vectorized expf() functions. This lets us compute softmax and silu more accurately than the short[65536] lookup table that GGML previously used to make this operation go faster. We can support aarch64 and sse2+ with the worst case rounding error of 2 ulp. I wrote avx2 and avx512 implementations as well but they didn't offer much advantage compared to sse2+fma to be worth the code complexity.