Skip to content

Commit 331ab0e

Browse files
authored
Merge pull request #9 from zhao-lang/use-weak-arc
Use weak arc
2 parents 3db9092 + 1b571bb commit 331ab0e

File tree

5 files changed

+105
-54
lines changed

5 files changed

+105
-54
lines changed

cmd.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ data=$(printf "${i} %.0s" {1..128})
66
redis-cli hnsw.node.add test1 node${i-1} ${data}
77
done
88

9-
redis-cli bgsave
9+
# redis-cli bgsave
1010

1111
redis-cli hnsw.get test1
1212
redis-cli hnsw.node.get test1 node1
@@ -17,6 +17,7 @@ redis-cli hnsw.search test1 5 ${data}
1717
for i in {1..100}
1818
do
1919
redis-cli hnsw.node.del test1 node${i-1}
20+
sleep 0.1
2021
done
2122

2223
redis-cli hnsw.del test1

src/hnsw/hnsw.rs

Lines changed: 65 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use std::convert::From;
1212
use std::fmt;
1313
use std::hash::{Hash, Hasher};
1414
use std::rc::Rc;
15-
use std::sync::{Arc, RwLock};
15+
use std::sync::{Arc, RwLock, Weak};
1616
use std::thread;
1717

1818
#[derive(Debug)]
@@ -84,12 +84,13 @@ where
8484
}
8585

8686
type NodeRef<T> = Arc<RwLock<_Node<T>>>;
87+
type NodeRefWeak<T> = Weak<RwLock<_Node<T>>>;
8788

8889
#[derive(Clone)]
8990
pub struct _Node<T: Float> {
9091
pub name: String,
9192
pub data: Vec<T>,
92-
pub neighbors: Vec<Vec<Node<T>>>,
93+
pub neighbors: Vec<Vec<NodeWeak<T>>>,
9394
}
9495

9596
impl<T> fmt::Debug for _Node<T>
@@ -108,7 +109,7 @@ where
108109
.iter()
109110
.map(|l| {
110111
l.into_iter()
111-
.map(|n| n.read().name.to_owned())
112+
.map(|n| n.upgrade().read().name.to_owned())
112113
.collect::<Vec<String>>()
113114
})
114115
.collect::<Vec<Vec<String>>>(),
@@ -127,28 +128,45 @@ impl<T: Float> _Node<T> {
127128
}
128129
}
129130

130-
fn add_neighbor(&mut self, level: usize, neighbor: Node<T>, capacity: Option<usize>) {
131+
fn add_neighbor(&mut self, level: usize, neighbor: NodeWeak<T>, capacity: Option<usize>) {
131132
self.push_levels(level, capacity);
132133
let neighbors = &mut self.neighbors;
133134
if !neighbors[level].contains(&neighbor) {
134135
neighbors[level].push(neighbor);
135136
}
136137
}
137138

138-
fn rm_neighbor(&mut self, level: usize, neighbor: &Node<T>) {
139+
fn rm_neighbor(&mut self, level: usize, neighbor: &NodeWeak<T>) {
139140
let neighbors = &mut self.neighbors;
140141
let index = neighbors[level]
141142
.iter()
142143
.position(|n| *n == *neighbor)
143144
.unwrap();
144145
neighbors[level].remove(index);
145146
}
147+
}
148+
149+
#[derive(Debug, Clone)]
150+
pub struct NodeWeak<T: Float>(pub NodeRefWeak<T>);
151+
152+
impl<T: Float> PartialEq for NodeWeak<T> {
153+
fn eq(&self, other: &Self) -> bool {
154+
Weak::ptr_eq(&self.0, &other.0)
155+
}
156+
}
146157

147-
// fn clear_neighbors(&mut self, level: usize) {
148-
// let neighbors = &mut self.neighbors;
149-
// let cap = neighbors[level].capacity();
150-
// neighbors[level] = Vec::with_capacity(cap);
151-
// }
158+
impl<T: Float> Eq for NodeWeak<T> {}
159+
160+
impl<T: Float> Hash for NodeWeak<T> {
161+
fn hash<H: Hasher>(&self, state: &mut H) {
162+
self.upgrade().read().name.hash(state);
163+
}
164+
}
165+
166+
impl<T: Float> NodeWeak<T> {
167+
pub fn upgrade(&self) -> Node<T> {
168+
Node(self.0.upgrade().unwrap())
169+
}
152170
}
153171

154172
#[derive(Debug, Clone)]
@@ -191,20 +209,19 @@ impl<T: Float> Node<T> {
191209
node.push_levels(level, capacity);
192210
}
193211

194-
fn add_neighbor(&self, level: usize, neighbor: Node<T>, capacity: Option<usize>) {
212+
fn add_neighbor(&self, level: usize, neighbor: NodeWeak<T>, capacity: Option<usize>) {
195213
let node = &mut self.0.try_write().unwrap();
196214
node.add_neighbor(level, neighbor, capacity);
197215
}
198216

199-
fn rm_neighbor(&self, level: usize, neighbor: &Node<T>) {
217+
fn rm_neighbor(&self, level: usize, neighbor: &NodeWeak<T>) {
200218
let node = &mut self.0.try_write().unwrap();
201219
node.rm_neighbor(level, neighbor);
202220
}
203221

204-
// fn clear_neighbors(&self, level: usize) {
205-
// let node = &mut self.0.try_write().unwrap();
206-
// node.clear_neighbors(level);
207-
// }
222+
pub fn downgrade(&self) -> NodeWeak<T> {
223+
NodeWeak(Arc::downgrade(&self.0))
224+
}
208225
}
209226

210227
type SimPairRef<T, R> = Rc<RefCell<_SimPair<T, R>>>;
@@ -291,9 +308,9 @@ pub struct Index<T: Float, R: Float> {
291308
pub level_mult: f64, // level generation factor
292309
pub node_count: usize, // count of nodes
293310
pub max_layer: usize, // idx of top layer
294-
pub layers: Vec<HashSet<Node<T>>>, // distinct nodes in each layer
311+
pub layers: Vec<HashSet<NodeWeak<T>>>, // distinct nodes in each layer
295312
pub nodes: HashMap<String, Node<T>>, // hashmap of nodes
296-
pub enterpoint: Option<Node<T>>, // enterpoint node
313+
pub enterpoint: Option<NodeWeak<T>>, // enterpoint node
297314
rng_: StdRng, // rng for level generation
298315
}
299316

@@ -347,7 +364,7 @@ impl<T: Float, R: Float> fmt::Debug for Index<T, R> {
347364
self.node_count,
348365
self.max_layer,
349366
match &self.enterpoint {
350-
Some(node) => node.read().name.clone(),
367+
Some(node) => node.upgrade().read().name.clone(),
351368
None => "null".to_owned(),
352369
},
353370
)
@@ -400,10 +417,10 @@ where
400417

401418
if self.node_count == 0 {
402419
let node = Node::new(name, data, self.m_max_0);
403-
self.enterpoint = Some(node.clone());
420+
self.enterpoint = Some(node.downgrade());
404421

405422
let mut layer = HashSet::new();
406-
layer.insert(node.clone());
423+
layer.insert(node.downgrade());
407424
self.layers.push(layer);
408425

409426
self.nodes.insert(name.to_owned(), node);
@@ -432,7 +449,7 @@ where
432449
self.node_count -= 1;
433450

434451
for lc in (0..(self.max_layer + 1)).rev() {
435-
if self.layers[lc].remove(&node) {
452+
if self.layers[lc].remove(&node.downgrade()) {
436453
break;
437454
}
438455
}
@@ -455,7 +472,7 @@ where
455472

456473
// update enterpoint if necessary
457474
match &self.enterpoint {
458-
Some(ep) if node == *ep => {
475+
Some(ep) if node == ep.upgrade() => {
459476
let mut new_ep = None;
460477
for lc in (0..(self.max_layer + 1)).rev() {
461478
match self.layers[lc].iter().next() {
@@ -518,8 +535,8 @@ where
518535

519536
let mut lc = l_max;
520537
while lc > l {
521-
w = self.search_level(data, &ep, 1, lc);
522-
ep = w.pop().unwrap().read().node.clone();
538+
w = self.search_level(data, &ep.upgrade(), 1, lc);
539+
ep = w.pop().unwrap().read().node.downgrade();
523540

524541
if lc == 0 {
525542
break;
@@ -529,7 +546,7 @@ where
529546

530547
let mut updated = HashSet::new();
531548
for lc in (0..(min(l_max, l) + 1)).rev() {
532-
w = self.search_level(data, &ep, self.ef_construction, lc);
549+
w = self.search_level(data, &ep.upgrade(), self.ef_construction, lc);
533550
let mut neighbors = self.select_neighbors(query, &w, self.m, lc, true, true, None);
534551
self.connect_neighbors(query, &neighbors, lc);
535552

@@ -551,10 +568,10 @@ where
551568
for n in eneighbors {
552569
let ensim = OrderedFloat::from((self.mfunc)(
553570
&enr.data,
554-
&n.read().data,
571+
&n.upgrade().read().data,
555572
self.data_dim,
556573
));
557-
let enpair = SimPair::new(ensim, n.to_owned());
574+
let enpair = SimPair::new(ensim, n.upgrade());
558575
econn.push(enpair);
559576
}
560577
}
@@ -570,7 +587,7 @@ where
570587
}
571588
}
572589

573-
ep = w.peek().unwrap().read().node.clone();
590+
ep = w.peek().unwrap().read().node.downgrade();
574591
}
575592

576593
// update nodes in redis
@@ -583,14 +600,14 @@ where
583600
// new enterpoint if we're in a higher layer
584601
if l > l_max {
585602
self.max_layer = l;
586-
self.enterpoint = Some(query.to_owned());
603+
self.enterpoint = Some(query.downgrade());
587604
while self.layers.len() < l + 1 {
588605
self.layers.push(HashSet::new());
589606
}
590607
}
591608

592609
// add node to layer set
593-
self.layers[l].insert(query.to_owned());
610+
self.layers[l].insert(query.downgrade());
594611

595612
Ok(())
596613
}
@@ -641,6 +658,7 @@ where
641658
let cpr = cpair.read();
642659
let neighbors = &cpr.node.read().neighbors[level];
643660
for neighbor in neighbors {
661+
let neighbor = neighbor.upgrade();
644662
if !v.contains(&neighbor) {
645663
v.insert(neighbor.clone());
646664

@@ -699,8 +717,9 @@ where
699717
let epair = ccopy.pop().unwrap();
700718

701719
for eneighbor in &epair.read().node.read().neighbors[lc] {
702-
if *eneighbor == *query
703-
|| (ignored_node.is_some() && *eneighbor == *ignored_node.unwrap())
720+
let eneighbor = eneighbor.upgrade();
721+
if eneighbor == *query
722+
|| (ignored_node.is_some() && eneighbor == *ignored_node.unwrap())
704723
{
705724
continue;
706725
}
@@ -765,9 +784,9 @@ where
765784
let npair = neighbors.pop().unwrap();
766785
let npr = npair.read();
767786

768-
query.add_neighbor(level, npr.node.clone(), Some(self.m_max_0));
787+
query.add_neighbor(level, npr.node.downgrade(), Some(self.m_max_0));
769788
npr.node
770-
.add_neighbor(level, query.clone(), Some(self.m_max_0));
789+
.add_neighbor(level, query.downgrade(), Some(self.m_max_0));
771790
}
772791
}
773792

@@ -788,9 +807,9 @@ where
788807
while !newconn.is_empty() {
789808
let newpair = newconn.pop().unwrap();
790809
let npr = newpair.read();
791-
node.add_neighbor(level, npr.node.clone(), Some(self.m_max_0));
810+
node.add_neighbor(level, npr.node.downgrade(), Some(self.m_max_0));
792811
npr.node
793-
.add_neighbor(level, node.clone(), Some(self.m_max_0));
812+
.add_neighbor(level, node.downgrade(), Some(self.m_max_0));
794813
updated.insert(npr.node.clone());
795814
// if new neighbor exists in the old set then we remove it from
796815
// the set of neighbors to be removed
@@ -806,14 +825,14 @@ where
806825
while !rmconn.is_empty() {
807826
let rmpair = rmconn.pop().unwrap();
808827
let rmpr = rmpair.read();
809-
node.rm_neighbor(level, &rmpr.node);
828+
node.rm_neighbor(level, &rmpr.node.downgrade());
810829
// if node to be removed is the ignored node then pass
811830
match ignored_node {
812831
Some(n) if rmpr.node == *n => {
813832
continue;
814833
}
815834
_ => {
816-
rmpr.node.rm_neighbor(level, &node);
835+
rmpr.node.rm_neighbor(level, &node.downgrade());
817836
updated.insert(rmpr.node.clone());
818837
}
819838
}
@@ -828,6 +847,7 @@ where
828847
let mut updated = HashSet::new();
829848

830849
for n in neighbors {
850+
let n = n.upgrade();
831851
let nnewconn: BinaryHeap<SimPair<T, R>>;
832852
let mut nconn: BinaryHeap<SimPair<T, R>>;
833853
{
@@ -836,17 +856,18 @@ where
836856
nconn = BinaryHeap::with_capacity(nneighbors.len());
837857

838858
for nn in nneighbors {
859+
let nn = nn.upgrade();
839860
let nnsim =
840861
OrderedFloat::from((self.mfunc)(&nr.data, &nn.read().data, self.data_dim));
841862
let nnpair = SimPair::new(nnsim, nn.to_owned());
842863
nconn.push(nnpair);
843864
}
844865

845866
let m_max = if lc == 0 { self.m_max_0 } else { self.m_max };
846-
nnewconn = self.select_neighbors(n, &nconn, m_max, lc, true, true, Some(node));
867+
nnewconn = self.select_neighbors(&n, &nconn, m_max, lc, true, true, Some(node));
847868
}
848869
updated.insert(n.clone());
849-
let up = self.update_node_connections(n, &nnewconn, &nconn, lc, Some(node));
870+
let up = self.update_node_connections(&n, &nnewconn, &nconn, lc, Some(node));
850871
for u in up {
851872
updated.insert(u);
852873
}
@@ -861,12 +882,12 @@ where
861882

862883
let mut lc = l_max;
863884
while lc > 0 {
864-
let w = self.search_level(query, &ep, 1, lc);
865-
ep = w.peek().unwrap().read().node.clone();
885+
let w = self.search_level(query, &ep.upgrade(), 1, lc);
886+
ep = w.peek().unwrap().read().node.downgrade();
866887
lc -= 1;
867888
}
868889

869-
let mut w = self.search_level(query, &ep, ef, 0);
890+
let mut w = self.search_level(query, &ep.upgrade(), ef, 0);
870891

871892
let mut res = Vec::with_capacity(k);
872893
while res.len() < k && !w.is_empty() {

src/hnsw/hnsw_tests.rs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::hnsw::hnsw::*;
22
use crate::hnsw::metrics::euclidean;
33
use std::sync::Arc;
4+
use std::{thread, time};
45

56
#[test]
67
fn hnsw_test() {
@@ -22,6 +23,18 @@ fn hnsw_test() {
2223
let data = vec![i as f32; 4];
2324
index.add_node(&name, &data, mock_fn).unwrap();
2425
}
26+
// sleep for a brief period to make sure all threads are done
27+
let ten_millis = time::Duration::from_millis(10);
28+
thread::sleep(ten_millis);
29+
for i in 0..100 {
30+
let node_name = format!("node{}", i);
31+
let node = index.nodes.get(&node_name).unwrap();
32+
let sc = Arc::strong_count(&node.0);
33+
if sc > 1 {
34+
println!("{:?}", node);
35+
}
36+
assert_eq!(sc, 1);
37+
}
2538
assert_eq!(index.node_count, 100);
2639
assert_ne!(index.enterpoint, None);
2740

@@ -44,15 +57,22 @@ fn hnsw_test() {
4457
assert_eq!(index.node_count, 100 - i - 1);
4558
assert_eq!(index.nodes.get(&node_name).is_none(), true);
4659
for l in &index.layers {
47-
assert_eq!(l.contains(&node), false);
60+
assert_eq!(l.contains(&node.downgrade()), false);
4861
}
4962
for n in index.nodes.values() {
5063
for l in &n.read().neighbors {
5164
for nn in l {
52-
assert_ne!(*nn, node);
65+
assert_ne!(nn.upgrade(), node);
5366
}
5467
}
5568
}
56-
assert_eq!(Arc::strong_count(&node.0), 1);
69+
// sleep for a brief period to make sure all threads are done
70+
let ten_millis = time::Duration::from_millis(10);
71+
thread::sleep(ten_millis);
72+
let sc = Arc::strong_count(&node.0);
73+
if sc > 1 {
74+
println!("Delete {:?}", node);
75+
}
76+
assert_eq!(sc, 1);
5777
}
5878
}

0 commit comments

Comments
 (0)