|
16 | 16 | #include <thread>
|
17 | 17 | #include <VecSim/algorithms/hnsw/hnsw_single.h>
|
18 | 18 | #include <VecSim/algorithms/brute_force/brute_force_single.h>
|
19 |
| -#include "tiered_index_mock.h" |
| 19 | +#include "mock_thread_pool.h" |
20 | 20 |
|
21 | 21 | namespace py = pybind11;
|
22 |
| -using namespace tiered_index_mock; |
23 | 22 |
|
24 | 23 | // Helper function that iterates query results and wrap them in python numpy object -
|
25 | 24 | // a tuple of two 2D arrays: (labels, distances)
|
@@ -366,92 +365,64 @@ class PyHNSWLibIndex : public PyVecSimIndex {
|
366 | 365 | };
|
367 | 366 |
|
368 | 367 | class PyTieredIndex : public PyVecSimIndex {
|
369 |
| -private: |
| 368 | +protected: |
| 369 | + tieredIndexMock mock_thread_pool; |
| 370 | + |
370 | 371 | VecSimIndexAbstract<float> *getFlatBuffer() {
|
371 | 372 | return reinterpret_cast<VecSimTieredIndex<float, float> *>(this->index.get())
|
372 |
| - ->getFlatbufferIndex(); |
| 373 | + ->getFlatBufferIndex(); |
373 | 374 | }
|
374 | 375 |
|
375 |
| -protected: |
376 |
| - JobQueue jobQueue; // External queue that holds the jobs. |
377 |
| - IndexExtCtx jobQueueCtx; // External context to be sent to the submit callback. |
378 |
| - SubmitCB submitCb; // A callback that submits an array of jobs into a given jobQueue. |
379 |
| - size_t flatBufferLimit; // Maximum size allowed for the flat buffer. If flat buffer is full, use |
380 |
| - // in-place insertion. |
381 |
| - bool run_thread; |
382 |
| - std::bitset<MAX_POOL_SIZE> executions_status; |
383 |
| - |
384 |
| - TieredIndexParams TieredIndexParams_Init() { |
385 |
| - TieredIndexParams ret = { |
386 |
| - .jobQueue = &this->jobQueue, |
387 |
| - .jobQueueCtx = &this->jobQueueCtx, |
388 |
| - .submitCb = this->submitCb, |
389 |
| - .flatBufferLimit = this->flatBufferLimit, |
| 376 | + TieredIndexParams getTieredIndexParams(size_t buffer_limit) { |
| 377 | + // Create TieredIndexParams using the mock thread pool. |
| 378 | + return TieredIndexParams{ |
| 379 | + .jobQueue = &(this->mock_thread_pool.jobQ), |
| 380 | + .jobQueueCtx = this->mock_thread_pool.ctx, |
| 381 | + .submitCb = tieredIndexMock::submit_callback, |
| 382 | + .flatBufferLimit = buffer_limit, |
390 | 383 | };
|
391 |
| - |
392 |
| - return ret; |
393 | 384 | }
|
394 | 385 |
|
395 | 386 | public:
|
396 |
| - explicit PyTieredIndex(size_t BufferLimit = 3000000) |
397 |
| - : submitCb(submit_callback), flatBufferLimit(BufferLimit), run_thread(true) { |
398 |
| - |
399 |
| - for (size_t i = 0; i < THREAD_POOL_SIZE; i++) { |
400 |
| - ThreadParams params(run_thread, executions_status, i, jobQueue); |
401 |
| - thread_pool.emplace_back(thread_main_loop, params); |
402 |
| - } |
403 |
| - } |
404 |
| - |
405 |
| - virtual ~PyTieredIndex() = 0; |
| 387 | + explicit PyTieredIndex() { mock_thread_pool.init_threads(); } |
406 | 388 |
|
407 | 389 | void WaitForIndex(size_t waiting_duration = 10) {
|
408 |
| - bool keep_wating = true; |
409 |
| - while (keep_wating) { |
410 |
| - std::this_thread::sleep_for(std::chrono::milliseconds(waiting_duration)); |
411 |
| - std::unique_lock<std::mutex> lock(queue_guard); |
412 |
| - if (jobQueue.empty()) { |
413 |
| - while (true) { |
414 |
| - if (executions_status.count() == 0) { |
415 |
| - keep_wating = false; |
416 |
| - break; |
417 |
| - } |
418 |
| - std::this_thread::sleep_for(std::chrono::milliseconds(waiting_duration)); |
419 |
| - } |
420 |
| - } |
421 |
| - } |
| 390 | + mock_thread_pool.thread_pool_wait(waiting_duration); |
422 | 391 | }
|
423 | 392 |
|
424 | 393 | size_t getFlatIndexSize() { return getFlatBuffer()->indexLabelCount(); }
|
425 | 394 |
|
426 |
| - static size_t GetThreadsNum() { return THREAD_POOL_SIZE; } |
| 395 | + size_t getThreadsNum() { return mock_thread_pool.thread_pool_size; } |
427 | 396 |
|
428 |
| - size_t getBufferLimit() { return flatBufferLimit; } |
| 397 | + size_t getBufferLimit() { |
| 398 | + return reinterpret_cast<VecSimTieredIndex<float, float> *>(this->index.get()) |
| 399 | + ->getFlatBufferLimit(); |
| 400 | + } |
429 | 401 | };
|
430 | 402 |
|
431 |
| -PyTieredIndex::~PyTieredIndex() { thread_pool_terminate(jobQueue, run_thread); } |
432 | 403 | class PyTiered_HNSWIndex : public PyTieredIndex {
|
433 | 404 | public:
|
434 | 405 | explicit PyTiered_HNSWIndex(const HNSWParams &hnsw_params,
|
435 |
| - const TieredHNSWParams &tiered_hnsw_params) { |
| 406 | + const TieredHNSWParams &tiered_hnsw_params, size_t buffer_limit) { |
436 | 407 |
|
437 | 408 | // Create primaryIndexParams and specific params for hnsw tiered index.
|
438 | 409 | VecSimParams primary_index_params = {.algo = VecSimAlgo_HNSWLIB,
|
439 | 410 | .algoParams = {.hnswParams = HNSWParams{hnsw_params}}};
|
440 | 411 |
|
441 |
| - // create TieredIndexParams |
442 |
| - TieredIndexParams tiered_params = TieredIndexParams_Init(); |
443 |
| - |
| 412 | + auto tiered_params = this->getTieredIndexParams(buffer_limit); |
444 | 413 | tiered_params.primaryIndexParams = &primary_index_params;
|
445 | 414 | tiered_params.specificParams.tieredHnswParams = tiered_hnsw_params;
|
446 | 415 |
|
447 |
| - // create VecSimParams for TieredIndexParams |
| 416 | + // Create VecSimParams for TieredIndexParams |
448 | 417 | VecSimParams params = {.algo = VecSimAlgo_TIERED,
|
449 | 418 | .algoParams = {.tieredParams = TieredIndexParams{tiered_params}}};
|
450 | 419 |
|
451 | 420 | this->index = std::shared_ptr<VecSimIndex>(VecSimIndex_New(¶ms), VecSimIndex_Free);
|
| 421 | + |
452 | 422 | // Set the created tiered index in the index external context.
|
453 |
| - this->jobQueueCtx.index_strong_ref = this->index; |
| 423 | + this->mock_thread_pool.ctx->index_strong_ref = this->index; |
454 | 424 | }
|
| 425 | + |
455 | 426 | size_t HNSWLabelCount() {
|
456 | 427 | return this->index->info().tieredInfo.backendCommonInfo.indexLabelCount;
|
457 | 428 | }
|
@@ -568,17 +539,17 @@ PYBIND11_MODULE(VecSim, m) {
|
568 | 539 | py::arg("radius"), py::arg("query_param") = nullptr, py::arg("num_threads") = -1);
|
569 | 540 |
|
570 | 541 | py::class_<PyTieredIndex, PyVecSimIndex>(m, "TieredIndex")
|
571 |
| - .def("wait_for_index", &PyTiered_HNSWIndex::WaitForIndex, py::arg("waiting_duration") = 10) |
572 |
| - .def("get_curr_bf_size", &PyTiered_HNSWIndex::getFlatIndexSize) |
573 |
| - .def("get_buffer_limit", &PyTiered_HNSWIndex::getBufferLimit) |
574 |
| - .def_static("get_threads_num", &PyTieredIndex::GetThreadsNum); |
| 542 | + .def("wait_for_index", &PyTieredIndex::WaitForIndex, py::arg("waiting_duration") = 10) |
| 543 | + .def("get_curr_bf_size", &PyTieredIndex::getFlatIndexSize) |
| 544 | + .def("get_buffer_limit", &PyTieredIndex::getBufferLimit) |
| 545 | + .def("get_threads_num", &PyTieredIndex::getThreadsNum); |
575 | 546 |
|
576 | 547 | py::class_<PyTiered_HNSWIndex, PyTieredIndex>(m, "Tiered_HNSWIndex")
|
577 |
| - .def( |
578 |
| - py::init([](const HNSWParams &hnsw_params, const TieredHNSWParams &tiered_hnsw_params) { |
579 |
| - return new PyTiered_HNSWIndex(hnsw_params, tiered_hnsw_params); |
580 |
| - }), |
581 |
| - py::arg("hnsw_params"), py::arg("tiered_hnsw_params")) |
| 548 | + .def(py::init([](const HNSWParams &hnsw_params, const TieredHNSWParams &tiered_hnsw_params, |
| 549 | + size_t flat_buffer_size = DEFAULT_BLOCK_SIZE) { |
| 550 | + return new PyTiered_HNSWIndex(hnsw_params, tiered_hnsw_params, flat_buffer_size); |
| 551 | + }), |
| 552 | + py::arg("hnsw_params"), py::arg("tiered_hnsw_params"), py::arg("flat_buffer_size")) |
582 | 553 | .def("hnsw_label_count", &PyTiered_HNSWIndex::HNSWLabelCount);
|
583 | 554 |
|
584 | 555 | py::class_<PyBFIndex, PyVecSimIndex>(m, "BFIndex")
|
|
0 commit comments