Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : rewrite silu and softmax for cpu #7154

Merged
merged 1 commit into from May 17, 2024
Merged

Conversation

jart
Copy link
Contributor

@jart jart commented May 9, 2024

This change upstreams llamafile's vectorized expf() functions. This lets us compute softmax and silu more accurately than the short[65536] lookup table that GGML previously used to make this operation go faster. We can support aarch64 and sse2+ with the worst case rounding error of 2 ulp. I wrote avx2 and avx512 implementations as well but they didn't offer much advantage compared to sse2+fma to be worth the code complexity.

Copy link
Contributor

github-actions bot commented May 9, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 543 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8626.19ms p(95)=21696.44ms fails=, finish reason: stop=474 truncated=69
  • Prompt processing (pp): avg=94.59tk/s p(95)=412.43tk/s
  • Token generation (tg): avg=33.43tk/s p(95)=48.33tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=expf commit=d7359a389c236193edac1c8761e6ac98844654f3

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 543 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1715376005 --> 1715376631
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 676.15, 676.15, 676.15, 676.15, 676.15, 693.38, 693.38, 693.38, 693.38, 693.38, 686.03, 686.03, 686.03, 686.03, 686.03, 716.71, 716.71, 716.71, 716.71, 716.71, 787.67, 787.67, 787.67, 787.67, 787.67, 798.67, 798.67, 798.67, 798.67, 798.67, 798.41, 798.41, 798.41, 798.41, 798.41, 816.18, 816.18, 816.18, 816.18, 816.18, 816.66, 816.66, 816.66, 816.66, 816.66, 826.24, 826.24, 826.24, 826.24, 826.24, 827.91, 827.91, 827.91, 827.91, 827.91, 839.83, 839.83, 839.83, 839.83, 839.83, 845.37, 845.37, 845.37, 845.37, 845.37, 891.54, 891.54, 891.54, 891.54, 891.54, 896.52, 896.52, 896.52, 896.52, 896.52, 898.39, 898.39, 898.39, 898.39, 898.39, 896.16, 896.16, 896.16, 896.16, 896.16, 909.86, 909.86, 909.86, 909.86, 909.86, 901.74, 901.74, 901.74, 901.74, 901.74, 898.93, 898.93, 898.93, 898.93, 898.93, 900.17, 900.17, 900.17, 900.17, 900.17, 901.19, 901.19, 901.19, 901.19, 901.19, 901.37, 901.37, 901.37, 901.37, 901.37, 914.57, 914.57, 914.57, 914.57, 914.57, 913.27, 913.27, 913.27, 913.27, 913.27, 914.12, 914.12, 914.12, 914.12, 914.12, 884.7, 884.7, 884.7, 884.7, 884.7, 880.58, 880.58, 880.58, 880.58, 880.58, 874.62, 874.62, 874.62, 874.62, 874.62, 874.44, 874.44, 874.44, 874.44, 874.44, 878.93, 878.93, 878.93, 878.93, 878.93, 876.59, 876.59, 876.59, 876.59, 876.59, 879.89, 879.89, 879.89, 879.89, 879.89, 889.29, 889.29, 889.29, 889.29, 889.29, 896.06, 896.06, 896.06, 896.06, 896.06, 895.27, 895.27, 895.27, 895.27, 895.27, 898.07, 898.07, 898.07, 898.07, 898.07, 895.61, 895.61, 895.61, 895.61, 895.61, 898.03, 898.03, 898.03, 898.03, 898.03, 900.02, 900.02, 900.02, 900.02, 900.02, 903.55, 903.55, 903.55, 903.55, 903.55, 912.38, 912.38, 912.38, 912.38, 912.38, 913.02, 913.02, 913.02, 913.02, 913.02, 909.18, 909.18, 909.18, 909.18, 909.18, 908.34, 908.34, 908.34, 908.34, 908.34, 904.61, 904.61, 904.61, 904.61, 904.61, 904.91, 904.91, 904.91, 904.91, 904.91, 909.01, 909.01, 909.01, 909.01, 909.01, 908.42, 908.42, 908.42, 908.42, 908.42, 913.16, 913.16, 913.16, 913.16, 913.16, 912.15, 912.15, 912.15, 912.15, 912.15, 914.4, 914.4, 914.4, 914.4, 914.4, 917.57, 917.57, 917.57, 917.57, 917.57, 915.58, 915.58, 915.58, 915.58, 915.58, 920.75, 920.75, 920.75, 920.75, 920.75, 919.24, 919.24, 919.24, 919.24, 919.24, 920.07, 920.07, 920.07, 920.07, 920.07, 918.79, 918.79, 918.79, 918.79, 918.79, 917.24, 917.24, 917.24, 917.24, 917.24, 918.44, 918.44, 918.44, 918.44, 918.44, 918.61, 918.61, 918.61, 918.61]
                    
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 543 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1715376005 --> 1715376631
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 41.33, 41.33, 41.33, 41.33, 41.33, 35.68, 35.68, 35.68, 35.68, 35.68, 29.47, 29.47, 29.47, 29.47, 29.47, 28.84, 28.84, 28.84, 28.84, 28.84, 30.64, 30.64, 30.64, 30.64, 30.64, 31.13, 31.13, 31.13, 31.13, 31.13, 32.39, 32.39, 32.39, 32.39, 32.39, 33.65, 33.65, 33.65, 33.65, 33.65, 33.61, 33.61, 33.61, 33.61, 33.61, 33.73, 33.73, 33.73, 33.73, 33.73, 33.4, 33.4, 33.4, 33.4, 33.4, 33.78, 33.78, 33.78, 33.78, 33.78, 33.62, 33.62, 33.62, 33.62, 33.62, 32.91, 32.91, 32.91, 32.91, 32.91, 32.27, 32.27, 32.27, 32.27, 32.27, 32.39, 32.39, 32.39, 32.39, 32.39, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.07, 32.07, 32.07, 32.07, 32.07, 31.93, 31.93, 31.93, 31.93, 31.93, 31.67, 31.67, 31.67, 31.67, 31.67, 31.58, 31.58, 31.58, 31.58, 31.58, 31.79, 31.79, 31.79, 31.79, 31.79, 31.57, 31.57, 31.57, 31.57, 31.57, 31.78, 31.78, 31.78, 31.78, 31.78, 32.01, 32.01, 32.01, 32.01, 32.01, 32.02, 32.02, 32.02, 32.02, 32.02, 31.52, 31.52, 31.52, 31.52, 31.52, 31.35, 31.35, 31.35, 31.35, 31.35, 31.45, 31.45, 31.45, 31.45, 31.45, 31.65, 31.65, 31.65, 31.65, 31.65, 31.8, 31.8, 31.8, 31.8, 31.8, 32.01, 32.01, 32.01, 32.01, 32.01, 32.12, 32.12, 32.12, 32.12, 32.12, 32.05, 32.05, 32.05, 32.05, 32.05, 31.82, 31.82, 31.82, 31.82, 31.82, 31.67, 31.67, 31.67, 31.67, 31.67, 31.73, 31.73, 31.73, 31.73, 31.73, 31.87, 31.87, 31.87, 31.87, 31.87, 31.99, 31.99, 31.99, 31.99, 31.99, 32.1, 32.1, 32.1, 32.1, 32.1, 32.02, 32.02, 32.02, 32.02, 32.02, 31.97, 31.97, 31.97, 31.97, 31.97, 31.31, 31.31, 31.31, 31.31, 31.31, 30.76, 30.76, 30.76, 30.76, 30.76, 30.0, 30.0, 30.0, 30.0, 30.0, 29.71, 29.71, 29.71, 29.71, 29.71, 29.65, 29.65, 29.65, 29.65, 29.65, 29.82, 29.82, 29.82, 29.82, 29.82, 29.85, 29.85, 29.85, 29.85, 29.85, 29.95, 29.95, 29.95, 29.95, 29.95, 29.98, 29.98, 29.98, 29.98, 29.98, 30.01, 30.01, 30.01, 30.01, 30.01, 29.85, 29.85, 29.85, 29.85, 29.85, 29.78, 29.78, 29.78, 29.78, 29.78, 29.74, 29.74, 29.74, 29.74, 29.74, 29.88, 29.88, 29.88, 29.88, 29.88, 30.01, 30.01, 30.01, 30.01, 30.01, 30.1, 30.1, 30.1, 30.1, 30.1, 30.18, 30.18, 30.18, 30.18, 30.18, 30.28, 30.28, 30.28, 30.28]
                    

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 543 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1715376005 --> 1715376631
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.24, 0.24, 0.24, 0.24, 0.24, 0.38, 0.38, 0.38, 0.38, 0.38, 0.23, 0.23, 0.23, 0.23, 0.23, 0.12, 0.12, 0.12, 0.12, 0.12, 0.21, 0.21, 0.21, 0.21, 0.21, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.18, 0.18, 0.18, 0.18, 0.18, 0.22, 0.22, 0.22, 0.22, 0.22, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.26, 0.26, 0.26, 0.26, 0.26, 0.32, 0.32, 0.32, 0.32, 0.32, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.18, 0.18, 0.18, 0.18, 0.18, 0.3, 0.3, 0.3, 0.3, 0.3, 0.28, 0.28, 0.28, 0.28, 0.28, 0.32, 0.32, 0.32, 0.32, 0.32, 0.21, 0.21, 0.21, 0.21, 0.21, 0.17, 0.17, 0.17, 0.17, 0.17, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.2, 0.2, 0.2, 0.2, 0.2, 0.31, 0.31, 0.31, 0.31, 0.31, 0.23, 0.23, 0.23, 0.23, 0.23, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15, 0.15, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.17, 0.17, 0.17, 0.17, 0.17, 0.23, 0.23, 0.23, 0.23, 0.23, 0.21, 0.21, 0.21, 0.21, 0.21, 0.19, 0.19, 0.19, 0.19, 0.19, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.09, 0.09, 0.09, 0.09, 0.09, 0.25, 0.25, 0.25, 0.25, 0.25, 0.44, 0.44, 0.44, 0.44, 0.44, 0.54, 0.54, 0.54, 0.54, 0.54, 0.62, 0.62, 0.62, 0.62, 0.62, 0.6, 0.6, 0.6, 0.6, 0.6, 0.29, 0.29, 0.29, 0.29, 0.29, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.12, 0.12, 0.12, 0.12, 0.12, 0.17, 0.17, 0.17, 0.17, 0.17, 0.11, 0.11, 0.11, 0.11, 0.11, 0.17, 0.17, 0.17, 0.17, 0.17, 0.31, 0.31, 0.31, 0.31, 0.31, 0.23, 0.23, 0.23, 0.23, 0.23, 0.25, 0.25, 0.25, 0.25, 0.25, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.17, 0.17, 0.17, 0.17]
                    
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 543 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1715376005 --> 1715376631
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 1.0, 1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0]
                    

@mofosyne
Copy link
Collaborator

mofosyne commented May 9, 2024

Not deeply analysing the changes but these are the general observation if it would help other reviewers:

  • Commented out #define removed
  • Extracted 5 duplicated lines into ggml_vec_soft_max_f32()
  • Various functions relating to GGML_SILU_FP16 removed
  • ggml_v_expf() added
  • ggml_v_silu() added
  • ggml_vec_silu_f32() adjusted with preprocessor statement to adjust function based on SSE2 or __ARM_NEON flag
  • there are other changes... but these are the main things i noticed anyway...

@ggerganov
Copy link
Owner

On AMD Ryzen 9 5950X and M2 Ultra SOFT_MAX is about ~1.5x faster than master

Using the following command to benchmark:

make -j tests && ./tests/test-backend-ops -o SOFT_MAX -b CPU perf

@mofosyne mofosyne added refactoring Refactoring review complexity : high Generally require indepth knowledge of LLMs or GPUs labels May 9, 2024
@jart
Copy link
Contributor Author

jart commented May 9, 2024

I'm glad to hear that. Here's the avx2 and avx512 variations if you want to try them out:

inline __m256 llamafile_expf_avx2(__m256 x) {
  const __m256 r = _mm256_set1_ps(0x1.8p23f);
  const __m256 z = MADD256(x, _mm256_set1_ps(0x1.715476p+0f), r);
  const __m256 n = _mm256_sub_ps(z, r);
  const __m256 b = NMADD256(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
                            NMADD256(n, _mm256_set1_ps(0x1.62e4p-1f), x));
  const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
  const __m256 k = _mm256_castsi256_ps(
      _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
  const __m256i c = _mm256_castps_si256(
      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
                    _mm256_set1_ps(126), _CMP_GT_OQ));
  const __m256 u = _mm256_mul_ps(b, b);
  const __m256 j = MADD256(MADD256(MADD256(_mm256_set1_ps(0x1.0e4020p-7f), b,
                                           _mm256_set1_ps(0x1.573e2ep-5f)),
                                   u,
                                   MADD256(_mm256_set1_ps(0x1.555e66p-3f), b,
                                           _mm256_set1_ps(0x1.fffdb6p-2f))),
                           u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
  if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
    return MADD256(j, k, k);
  const __m256i g = _mm256_and_si256(
      _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
      _mm256_set1_epi32(0x82000000u));
  const __m256 s1 =
      _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
  const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
  const __m256i d = _mm256_castps_si256(
      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
                    _mm256_set1_ps(192), _CMP_GT_OQ));
  return _mm256_or_ps(
      _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
      _mm256_andnot_ps(
          _mm256_castsi256_ps(d),
          _mm256_or_ps(
              _mm256_and_ps(_mm256_castsi256_ps(c),
                            _mm256_mul_ps(MADD256(s2, j, s2), s1)),
              _mm256_andnot_ps(_mm256_castsi256_ps(c), MADD256(k, j, k)))));
}

inline __m512 llamafile_expf_avx512(__m512 x) {
  const __m512 r = _mm512_set1_ps(0x1.8p23f);
  const __m512 z = MADD512(x, _mm512_set1_ps(0x1.715476p+0f), r);
  const __m512 n = _mm512_sub_ps(z, r);
  const __m512 b = NMADD512(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
                            NMADD512(n, _mm512_set1_ps(0x1.62e4p-1f), x));
  const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
  const __m512 k = _mm512_castsi512_ps(
      _mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
  const __mmask16 c =
      _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
  const __m512 u = _mm512_mul_ps(b, b);
  const __m512 j = MADD512(MADD512(MADD512(_mm512_set1_ps(0x1.0e4020p-7f), b,
                                           _mm512_set1_ps(0x1.573e2ep-5f)),
                                   u,
                                   MADD512(_mm512_set1_ps(0x1.555e66p-3f), b,
                                           _mm512_set1_ps(0x1.fffdb6p-2f))),
                           u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
  if (_mm512_kortestz(c, c))
    return MADD512(j, k, k);
  const __m512i g = _mm512_and_si512(
      _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
      _mm512_set1_epi32(0x82000000u));
  const __m512 s1 =
      _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
  const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
  const __mmask16 d =
      _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
  return _mm512_mask_blend_ps(
      d,
      _mm512_mask_blend_ps(c, MADD512(k, j, k),
                           _mm512_mul_ps(MADD512(s2, j, s2), s1)),
      _mm512_mul_ps(s1, s1));
}

Here's the numbers I got with the script I used for developing these functions:

   2.98601 ns 2000x run_expf()
   1.35154 ns 2000x run_llamafile_expf_sse2()
   1.16659 ns 2000x run_llamafile_expf_avx2()
   1.18844 ns 2000x run_llamafile_expf_avx512()

//       input          exp    llamafile   bad
//       =====          ===    =========   ===
//           0            1            1     0
//          -0            1            1     0
//         nan          nan          nan     0
//        -nan         -nan         -nan     0
//         inf          inf          inf     0
//        -inf            0            0     0
//          87  6.07603e+37  6.07603e+37     1
//          88  1.65164e+38  1.65164e+38     0
//     88.7229          inf          inf     0
//          89          inf          inf     0
//         -87  1.64581e-38  1.64581e-38     1
//         -90  8.19401e-40  8.19401e-40     0
//         -95  5.52112e-42  5.52112e-42     0
//        -100  3.78351e-44  3.78351e-44     0
//        -104            0            0     0
//    0.660001      1.93479      1.93479     1
//   -0.324231     0.723083     0.723083     0
//   0.0205384      1.02075      1.02075     0
//   -0.224604     0.798833     0.798833     1
//   -0.339606     0.712051      0.71205     1
//    0.211472       1.2355       1.2355     0
//    0.238942       1.2699       1.2699     0
//    -0.78286     0.457097     0.457097     0
4294967296 numbers tested successfully

@jart
Copy link
Contributor Author

jart commented May 9, 2024

@ggerganov Running your command, I'm noticing the advantage here increases from 1.5x to 1.9x if we include AVX2. On znver4 if we also include avx512 then that goes up to 2.1x. I'd expect that to go higher in the future, since znver4 only really implements the AVX512 ISA and uses 2 cycles for each vector operation. So I've gone ahead and included the code for you.

This change upstreams llamafile's vectorized expf() functions. This lets
us compute softmax and silu more accurately than the short[65536] lookup
table that GGML previously used to make this operation go faster. We can
support aarch64 and sse2+ with the worst case rounding error of 2ulp. It
makes make -j8 tests && ./tests/test-backend-ops -o SOFT_MAX -b CPU perf
go 1.5x faster for SSE2+FMA, 1.9x faster for AVX2+FMA and 2.1x on AVX512
@chriselrod
Copy link

chriselrod commented May 16, 2024

With AVX512, you may want to use vscalefps.
It computes zmm0 = zmm1 * 2^{zmm2}, where all are floats.

It overflows and underflows properly, letting you remove checks + blends.

I have an implementation in Julia, e.g. a loop with 4x unrolling and interleaving.

L304:
	vmovups	zmm15, zmmword ptr [r11 + 4*rax]
	vmovups	zmm14, zmmword ptr [r11 + 4*rax + 64]
	vmovups	zmm13, zmmword ptr [r11 + 4*rax + 128]
	vmovups	zmm12, zmmword ptr [r11 + 4*rax + 192]
	vmovaps	zmm16, zmm1
	vfmadd213ps	zmm16, zmm15, zmm0      # zmm16 = (zmm15 * zmm16) + zmm0
	vmovaps	zmm17, zmm1
	vfmadd213ps	zmm17, zmm14, zmm0      # zmm17 = (zmm14 * zmm17) + zmm0
	vmovaps	zmm18, zmm1
	vfmadd213ps	zmm18, zmm13, zmm0      # zmm18 = (zmm13 * zmm18) + zmm0
	vmovaps	zmm19, zmm1
	vfmadd213ps	zmm19, zmm12, zmm0      # zmm19 = (zmm12 * zmm19) + zmm0
	vaddps	zmm16, zmm16, zmm2
	vaddps	zmm17, zmm17, zmm2
	vaddps	zmm18, zmm18, zmm2
	vaddps	zmm19, zmm19, zmm2
	vfmadd231ps	zmm15, zmm16, zmm3      # zmm15 = (zmm16 * zmm3) + zmm15
	vfmadd231ps	zmm14, zmm17, zmm3      # zmm14 = (zmm17 * zmm3) + zmm14
	vfmadd231ps	zmm13, zmm18, zmm3      # zmm13 = (zmm18 * zmm3) + zmm13
	vfmadd231ps	zmm12, zmm19, zmm3      # zmm12 = (zmm19 * zmm3) + zmm12
	vfmadd231ps	zmm15, zmm16, zmm4      # zmm15 = (zmm16 * zmm4) + zmm15
	vfmadd231ps	zmm14, zmm17, zmm4      # zmm14 = (zmm17 * zmm4) + zmm14
	vfmadd231ps	zmm13, zmm18, zmm4      # zmm13 = (zmm18 * zmm4) + zmm13
	vfmadd231ps	zmm12, zmm19, zmm4      # zmm12 = (zmm19 * zmm4) + zmm12
	vmovaps	zmm20, zmm6
	vfmadd213ps	zmm20, zmm15, zmm5      # zmm20 = (zmm15 * zmm20) + zmm5
	vmovaps	zmm21, zmm6
	vfmadd213ps	zmm21, zmm14, zmm5      # zmm21 = (zmm14 * zmm21) + zmm5
	vmovaps	zmm22, zmm6
	vfmadd213ps	zmm22, zmm13, zmm5      # zmm22 = (zmm13 * zmm22) + zmm5
	vmovaps	zmm23, zmm6
	vfmadd213ps	zmm23, zmm12, zmm5      # zmm23 = (zmm12 * zmm23) + zmm5
	vfmadd213ps	zmm20, zmm15, zmm7      # zmm20 = (zmm15 * zmm20) + zmm7
	vfmadd213ps	zmm21, zmm14, zmm7      # zmm21 = (zmm14 * zmm21) + zmm7
	vfmadd213ps	zmm22, zmm13, zmm7      # zmm22 = (zmm13 * zmm22) + zmm7
	vfmadd213ps	zmm23, zmm12, zmm7      # zmm23 = (zmm12 * zmm23) + zmm7
	vfmadd213ps	zmm20, zmm15, zmm8      # zmm20 = (zmm15 * zmm20) + zmm8
	vfmadd213ps	zmm21, zmm14, zmm8      # zmm21 = (zmm14 * zmm21) + zmm8
	vfmadd213ps	zmm22, zmm13, zmm8      # zmm22 = (zmm13 * zmm22) + zmm8
	vfmadd213ps	zmm23, zmm12, zmm8      # zmm23 = (zmm12 * zmm23) + zmm8
	vfmadd213ps	zmm20, zmm15, zmm9      # zmm20 = (zmm15 * zmm20) + zmm9
	vfmadd213ps	zmm21, zmm14, zmm9      # zmm21 = (zmm14 * zmm21) + zmm9
	vfmadd213ps	zmm22, zmm13, zmm9      # zmm22 = (zmm13 * zmm22) + zmm9
	vfmadd213ps	zmm23, zmm12, zmm9      # zmm23 = (zmm12 * zmm23) + zmm9
	vfmadd213ps	zmm20, zmm15, zmm10     # zmm20 = (zmm15 * zmm20) + zmm10
	vfmadd213ps	zmm21, zmm14, zmm10     # zmm21 = (zmm14 * zmm21) + zmm10
	vfmadd213ps	zmm22, zmm13, zmm10     # zmm22 = (zmm13 * zmm22) + zmm10
	vfmadd213ps	zmm23, zmm12, zmm10     # zmm23 = (zmm12 * zmm23) + zmm10
	vfmadd213ps	zmm20, zmm15, zmm11     # zmm20 = (zmm15 * zmm20) + zmm11
	vfmadd213ps	zmm21, zmm14, zmm11     # zmm21 = (zmm14 * zmm21) + zmm11
	vfmadd213ps	zmm22, zmm13, zmm11     # zmm22 = (zmm13 * zmm22) + zmm11
	vfmadd213ps	zmm23, zmm12, zmm11     # zmm23 = (zmm12 * zmm23) + zmm11
	vfmadd213ps	zmm20, zmm15, zmm11     # zmm20 = (zmm15 * zmm20) + zmm11
	vfmadd213ps	zmm21, zmm14, zmm11     # zmm21 = (zmm14 * zmm21) + zmm11
	vfmadd213ps	zmm22, zmm13, zmm11     # zmm22 = (zmm13 * zmm22) + zmm11
	vfmadd213ps	zmm23, zmm12, zmm11     # zmm23 = (zmm12 * zmm23) + zmm11
	vscalefps	zmm12, zmm20, zmm16, {rn-sae}
	vscalefps	zmm13, zmm21, zmm17, {rn-sae}
	vscalefps	zmm14, zmm22, zmm18, {rn-sae}
	vscalefps	zmm15, zmm23, zmm19, {rn-sae}
	vmovups	zmmword ptr [r14 + 4*rax], zmm12
	vmovups	zmmword ptr [r14 + 4*rax + 64], zmm13
	vmovups	zmmword ptr [r14 + 4*rax + 128], zmm14
	vmovups	zmmword ptr [r14 + 4*rax + 192], zmm15
	add	rax, 64
	cmp	rax, r10
	jl	L304

These gave me a significant performance improvement.
If my test is correct, I got a maximum error <1 ULP at x=47.483456f.

What hardware are you on? I'm using skylake-avx512/cascadelake with 2x fma units.
Zen4 or something like icelake-client/tigerlake likely won't benefit as much.

Note that it doesn't use a lookup table.
My Float64/double implementation uses a 16-element lookup table via vpermi2pd.
If we wanted, we could use a 32-element lookup table of floats via the same approach.
vpermi2pd is much faster than gather, the cost of course being that our table has to fit into two registers.

@ggerganov ggerganov merged commit 934266c into ggerganov:master May 17, 2024
64 checks passed
@ggerganov ggerganov removed the merging soon Will merge soon unless anyone objects label May 17, 2024
solos added a commit to solos/booster that referenced this pull request May 17, 2024
solos added a commit to solos/booster that referenced this pull request May 18, 2024
@JohannesGaessler
Copy link
Collaborator

After this PR has been merged the server has been producing nondeterministic results when using >1 slots. Minimal example for reproduction:

make clean && make server
./server -m models/opt/llama_2-7b-q4_0.gguf --parallel 2 --threads 1

In another shell:

curl --request POST --url http://localhost:8080/completion --header "Content-Type: application/json" --data '{"prompt": "", "n_predict":10, "n_probs": 2, "temperature": -1}' | python3 -m json.tool

The token probabilities for the last token cycle between two values with every curl call. When using 4 slots the token probabilities cycle between 4 possible values.

@jart
Copy link
Contributor Author

jart commented May 22, 2024

@chriselrod Could you help me modify my avx512 intrinsics to use _mm512_scalef_ps (vscalefps) like your code? I'm currently talking to ARM Limited about getting these functions into Glibc, since our code goes faster. ARM-software/optimized-routines#69

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
refactoring Refactoring review complexity : high Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants