Skip to content

Commit 573ab84

Browse files
committed
Use min to limit the index of next prefetch.
1 parent b5c2eba commit 573ab84

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

hnswlib/hnswalg.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
260260
tableint candidate_id = *(datal + j);
261261
// if (candidate_id == 0) continue;
262262
#ifdef USE_SSE
263-
_mm_prefetch((char *) (visited_array + *(datal + j)), _MM_HINT_T0);
264-
_mm_prefetch(getDataByInternalId(*(datal + j)), _MM_HINT_T0);
263+
size_t next_index = std::min(size - 1, j + 1);
264+
_mm_prefetch((char *) (visited_array + *(datal + next_index)), _MM_HINT_T0);
265+
_mm_prefetch(getDataByInternalId(*(datal + next_index)), _MM_HINT_T0);
265266
#endif
266267
if (visited_array[candidate_id] == visited_array_tag) continue;
267268
visited_array[candidate_id] = visited_array_tag;
@@ -343,8 +344,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
343344
int candidate_id = *(data + j);
344345
// if (candidate_id == 0) continue;
345346
#ifdef USE_SSE
346-
_mm_prefetch((char *) (visited_array + *(data + j)), _MM_HINT_T0);
347-
_mm_prefetch(data_level0_memory_ + (*(data + j)) * size_data_per_element_ + offsetData_,
347+
size_t next_index = std::min(size, j + 1);
348+
_mm_prefetch((char *) (visited_array + *(data + next_index)), _MM_HINT_T0);
349+
_mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
348350
_MM_HINT_T0); ////////////
349351
#endif
350352
if (!(visited_array[candidate_id] == visited_array_tag)) {
@@ -1007,7 +1009,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
10071009
#endif
10081010
for (int i = 0; i < size; i++) {
10091011
#ifdef USE_SSE
1010-
_mm_prefetch(getDataByInternalId(*(datal + i)), _MM_HINT_T0);
1012+
size_t next_index = std::min(size - 1, i + 1);
1013+
_mm_prefetch(getDataByInternalId(*(datal + next_index)), _MM_HINT_T0);
10111014
#endif
10121015
tableint cand = datal[i];
10131016
dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);

0 commit comments

Comments
 (0)