Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AMX support to speed up Faiss Inner-Product #535

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

mellonyou
Copy link

Use Intel AMX to speed up Inner-Product algorithm of knowhere::BruteForce::Search(), which can bring more than 10x performance boost.

Build parameter: use "-o with_dnnl=True/False" to control enable/disable AMX feature.
This feature will depends on libdnnl.so.3, you can install it by running scripts/install_deps.sh.

Runtime parameter: if you want use AMX feature, you need set ENV parameter "DNNL_ENABLE=1" at first, otherwise the AMX feature will not work.

@sre-ci-robot
Copy link
Collaborator

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by: mellonyou
To complete the pull request process, please assign zhengbuqian after the PR has been reviewed.
You can assign the PR to them by writing /assign @zhengbuqian in a comment when ready.

The full list of commands accepted by this bot can be found here.

Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

Copy link

mergify bot commented Apr 28, 2024

@mellonyou 🔍 Important: PR Classification Needed!

For efficient project management and a seamless review process, it's essential to classify your PR correctly. Here's how:

  1. If you're fixing a bug, label it as kind/bug.
  2. For small tweaks (less than 20 lines without altering any functionality), please use kind/improvement.
  3. Significant changes that don't modify existing functionalities should be tagged as kind/enhancement.
  4. Adjusting APIs or changing functionality? Go with kind/feature.

For any PR outside the kind/improvement category, ensure you link to the associated issue using the format: “issue: #”.

Thanks for your efforts and contribution to the community!.

@mellonyou
Copy link
Author

issue: #541

@mellonyou mellonyou marked this pull request as ready for review May 6, 2024 02:58
@mellonyou
Copy link
Author

I can't edit the labels, need any access permissions?

@liliu-z
Copy link
Collaborator

liliu-z commented May 6, 2024

/kind enhancement

Copy link

codecov bot commented May 6, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 71.59%. Comparing base (3c46f4c) to head (7b6f49a).
Report is 95 commits behind head on main.

Current head 7b6f49a differs from pull request most recent head b420761

Please upload reports for the commit b420761 to get more accurate results.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff            @@
##           main     #535       +/-   ##
=========================================
+ Coverage      0   71.59%   +71.59%     
=========================================
  Files         0       67       +67     
  Lines         0     4446     +4446     
=========================================
+ Hits          0     3183     +3183     
- Misses        0     1263     +1263     

see 67 files with indirect coverage changes

BaseData::getState().store(BASE_DATA_STATE::MODIFIED);
}

void execut(float** out_f32) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: execute?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it's a typo

Comment on lines 164 to 166
// inner memory bf16
bf16_md1 = dnnl::memory::desc({xrow, xcol}, dnnl::memory::data_type::bf16, dnnl::memory::format_tag::any);
bf16_md2 = dnnl::memory::desc({yrow, ycol}, dnnl::memory::data_type::bf16, dnnl::memory::format_tag::any);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob Q, why we use bf16 here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because AMX can native support for bf16/int8 compute, which can significantly improve performance, and we have done the test, it have little impact on accuracy.

BASE_DATA_STATE expected = BASE_DATA_STATE::MODIFIED;

if (BaseData::getState().compare_exchange_strong(expected, BASE_DATA_STATE::PREPARE)) {
pthread_rwlock_wrlock(&rwlock);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob Q, why we need to lock this. Is that because we only have only AMX instruction can run at a time?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lock is designed for multi-thread scenario, if two threads operate on the same base dataset with different query dataset, the lock prevent the base dataset from being modified by the other thread while working on it.

dnnl::reorder(f32_mem1, bf16_mem1).execute(engine_stream, f32_mem1, bf16_mem1);
BASE_DATA_STATE expected = BASE_DATA_STATE::MODIFIED;

if (BaseData::getState().compare_exchange_strong(expected, BASE_DATA_STATE::PREPARE)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plz CMIIW. In the first call, expected will be BASE_DATA_STATE::MODIFIED and changed into BASE_DATA_STATE::PREPARE in this line and return false. Then it will loop in line 196 forever

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The state is also designed for multi-thread scenario, the state change is INIT->MODIFIED -> PREPARE -> READY. When the first thread have finished the initialization, the other thread will get the state is READY, and then skip line 196.

if (is_dnnl_enabled()) {
float *res_arr = NULL;

comput_f32bf16f32_inner_product(nx, d, ny, d, const_cast<float*>(x), const_cast<float*>(y), &res_arr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we implement a dynamic hook like all other simd in Knowhere?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have also considered following the other simd interface, but due to the implementation of AMX, it may be a bit incompatible with the current interface:

  1. AMX prefers batch data calculation, and it's library will schedule multiple threads on its own.
  2. The return value is a array for batch data operation.
    So if we use dynamic hook, maybe need add new interface for batch data operation, and call the new interface when AMX is available.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@liliu-z We are planning to port code to adapt dynamic hook, do you have any other suggestions?

@@ -211,30 +214,59 @@ void exhaustive_inner_product_seq_impl(
using SingleResultHandler = typename BlockResultHandler::SingleResultHandler;
int nt = std::min(int(nx), omp_get_max_threads());

#ifdef FAISS_WITH_DNNL
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem here is that this code is inserted into the function that computes inner products according to a filter. So, if the filter filters out 90% of samples, then 9 out of 10 computed distances will not be used, costing quite an extra memory bandwidth.
Benchmarks are needed for this PR.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexanderguzhva The filter is inside Knowhere or in the Milvus?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xtangxtang an external filter (in the form of bitset), provided from Milvus

@mergify mergify bot removed the ci-passed label May 15, 2024
@mergify mergify bot removed the dco-passed label May 15, 2024
@mellonyou
Copy link
Author

  1. port the code to knowhere to follow dynamic hook interface.
  2. about filter, write a simple benchmark to compare no filter amx inner product with simd inner product with filer(0.1f, 0.5f, 0.9f), it can be seen that AMX still has a perf. advantage when the filter percentage reaches 0.9.
    percentage 0.1 0.5 0.9 amx
    result(s) 0.432 0.208 0.043 0.033
    dnnl perf. boost 13.1x 6.3x 1.3x

For the code, the amx inner product interface is more suitable for producing batch vectors, and it doesn't support a filter interface, I have two ideas:

  1. amx inner product just handle no filter scenario.
  2. add a percentage parameter to the interface, when it is less than 0.9, we choose amx inner product.

@alexanderguzhva Looking forward to your suggestions.

@alexanderguzhva
Copy link
Collaborator

alexanderguzhva commented May 15, 2024

@mellonyou Could you please include a benchmark or, at least, its details?
The numbers that you've provided cannot be interpreted properly without knowing

  • the exact number of samples
  • the dimensionality
  • whether it is a single query/batched query requests
  • is it a test for this particular function or for a whole index,
  • etc.

The results are potentially interesting and are definitely worth checking on my end.

@mellonyou
Copy link
Author

#include "simd/distances_onednn.h"

#define MAX_LOOP 20
TEST_CASE("Test Brute Force", "[float vector]") {
using Catch::Approx;

const int64_t nb = 2000000;
const int64_t nq = 10;
const int64_t dim = 512;
const int64_t k = 100;

auto metric = GENERATE(as<std::string>{}, knowhere::metric::IP );

const auto train_ds = GenDataSet(nb, dim);
const auto query_ds = CopyDataSet(train_ds, nq);

const knowhere::Json conf = {
    {knowhere::meta::DIM, dim},
    {knowhere::meta::METRIC_TYPE, metric},
    {knowhere::meta::TOPK, k},
    {knowhere::meta::RADIUS, knowhere::IsMetricType(metric, knowhere::metric::IP) ? 10.0 : 0.99},
};

SECTION("Test Search Batch") {
 faiss::BaseData::getState().store(faiss::BASE_DATA_STATE::MODIFIED);
 struct timeval t1,t2;
 double timeuse;
 gettimeofday(&t1,NULL);

     std::vector<std::function<std::vector<uint8_t>(size_t, size_t)>> gen_bitset_funcs = {
             GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet};
     const auto bitset_percentages = {0.1f, 0.5f, 0.9f};
     for (const float percentage : bitset_percentages) {
             for (const auto& gen_func : gen_bitset_funcs) {
                     auto bitset_data = gen_func(nb, percentage * nb);
                     knowhere::BitsetView bitset(bitset_data.data(), nb);

                     for (int i = 0; i < MAX_LOOP; i++)
                     {
                             gettimeofday(&t1,NULL);

                             //    threads.emplace_back(WrapSearch, queryvar1);
                             auto res = knowhere::BruteForce::Search<knowhere::fp32>(train_ds, query_ds, conf, bitset);
                             gettimeofday(&t2,NULL);
                             timeuse = (t2.tv_sec - t1.tv_sec) + (double)(t2.tv_usec - t1.tv_usec)/1000000.0;
                             std::cout << "elpased: " << timeuse << std::endl;
                     }

             }
     }

     gettimeofday(&t2,NULL);
     timeuse = (t2.tv_sec - t1.tv_sec) + (double)(t2.tv_usec - t1.tv_usec)/1000000.0;

     std::cout << "All thread finished." << std::endl;

    }

}

@mellonyou
Copy link
Author

@alexanderguzhva I just add this code to ut as a temporary benchmark, and build it with "-o with_dnnl=True", then run the test:
DNNL_ENABLE=0/1 ./Release/tests/ut/knowhere_tests
The test will run 20 rounds, and the results above are the average after discarding the best 20% and the worst 20%. And I ran the test on Intel SPR platform with Ubuntu 22.04 system.

@alexanderguzhva
Copy link
Collaborator

@mellonyou I'll take a look. Thanks!

BaseData::getState().store(BASE_DATA_STATE::MODIFIED);
}

void execut(float** out_f32) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a typo.

@mellonyou
Copy link
Author

Add searchwithbuf and rangesearch interface implementation with AMX onednn. And will submit the related build config into milvus later.

@mellonyou
Copy link
Author

I am trying to do a manual filter with multithread before AMX IP.
@liliu-z @alexanderguzhva @godchen0212 Do you have any other opinions on the current interface implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants