@@ -187,7 +187,7 @@ class TieredHNSWIndex : public VecSimTieredIndex<DataType, DistType> {
187
187
size_t indexSize () const override ;
188
188
size_t indexLabelCount () const override ;
189
189
size_t indexCapacity () const override ;
190
- double getDistanceFrom (labelType label, const void *blob) const override ;
190
+ double getDistanceFrom_Unsafe (labelType label, const void *blob) const override ;
191
191
// Do nothing here, each tier (flat buffer and HNSW) should increase capacity for itself when
192
192
// needed.
193
193
VecSimIndexInfo info () const override ;
@@ -210,6 +210,17 @@ class TieredHNSWIndex : public VecSimTieredIndex<DataType, DistType> {
210
210
" running asynchronous GC for tiered HNSW index" );
211
211
this ->executeReadySwapJobs (this ->pendingSwapJobsThreshold );
212
212
}
213
+ void acquireSharedLocks () override {
214
+ this ->flatIndexGuard .lock_shared ();
215
+ this ->mainIndexGuard .lock_shared ();
216
+ this ->getHNSWIndex ()->lockSharedIndexDataGuard ();
217
+ }
218
+
219
+ void releaseSharedLocks () override {
220
+ this ->flatIndexGuard .unlock_shared ();
221
+ this ->mainIndexGuard .unlock_shared ();
222
+ this ->getHNSWIndex ()->unlockSharedIndexDataGuard ();
223
+ }
213
224
#ifdef BUILD_TESTS
214
225
void getDataByLabel (labelType label, std::vector<std::vector<DataType>> &vectors_output) const ;
215
226
#endif
@@ -621,9 +632,9 @@ TieredHNSWIndex<DataType, DistType>::~TieredHNSWIndex() {
621
632
template <typename DataType, typename DistType>
622
633
size_t TieredHNSWIndex<DataType, DistType>::indexSize() const {
623
634
this ->flatIndexGuard .lock_shared ();
624
- this ->getHNSWIndex ()->lockIndexDataGuard ();
635
+ this ->getHNSWIndex ()->lockSharedIndexDataGuard ();
625
636
size_t res = this ->backendIndex ->indexSize () + this ->frontendIndex ->indexSize ();
626
- this ->getHNSWIndex ()->unlockIndexDataGuard ();
637
+ this ->getHNSWIndex ()->unlockSharedIndexDataGuard ();
627
638
this ->flatIndexGuard .unlock_shared ();
628
639
return res;
629
640
}
@@ -803,14 +814,18 @@ int TieredHNSWIndex<DataType, DistType>::deleteVector(labelType label) {
803
814
// 3. label exists in both indexes - we may have some of the vectors with the same label in the flat
804
815
// buffer only and some in the Main index only (and maybe temporal duplications).
805
816
// So, we get the distance from both indexes and return the minimum.
817
+
818
+ // IMPORTANT: this should be called when the *tiered index locks are locked for shared ownership*,
819
+ // along with HNSW index data guard lock. That is since the internal getDistanceFrom calls access
820
+ // the indexes' data, and it is not safe to run insert/delete operation in parallel. Also, we avoid
821
+ // acquiring the locks internally, since this is usually called for every vector individually, and
822
+ // the overhead of acquiring and releasing the locks is significant in that case.
806
823
template <typename DataType, typename DistType>
807
- double TieredHNSWIndex<DataType, DistType>::getDistanceFrom (labelType label,
808
- const void *blob) const {
824
+ double TieredHNSWIndex<DataType, DistType>::getDistanceFrom_Unsafe (labelType label,
825
+ const void *blob) const {
809
826
// Try to get the distance from the flat buffer.
810
827
// If the label doesn't exist, the distance will be NaN.
811
- this ->flatIndexGuard .lock_shared ();
812
- auto flat_dist = this ->frontendIndex ->getDistanceFrom (label, blob);
813
- this ->flatIndexGuard .unlock_shared ();
828
+ auto flat_dist = this ->frontendIndex ->getDistanceFrom_Unsafe (label, blob);
814
829
815
830
// Optimization. TODO: consider having different implementations for single and multi indexes,
816
831
// to avoid checking the index type on every query.
@@ -821,9 +836,7 @@ double TieredHNSWIndex<DataType, DistType>::getDistanceFrom(labelType label,
821
836
}
822
837
823
838
// Try to get the distance from the Main index.
824
- this ->mainIndexGuard .lock_shared ();
825
- auto hnsw_dist = getHNSWIndex ()->safeGetDistanceFrom (label, blob);
826
- this ->mainIndexGuard .unlock_shared ();
839
+ auto hnsw_dist = getHNSWIndex ()->getDistanceFrom_Unsafe (label, blob);
827
840
828
841
// Return the minimum distance that is not NaN.
829
842
return std::fmin (flat_dist, hnsw_dist);
0 commit comments