@@ -12,7 +12,7 @@ use std::convert::From;
1212use std:: fmt;
1313use std:: hash:: { Hash , Hasher } ;
1414use std:: rc:: Rc ;
15- use std:: sync:: { Arc , RwLock } ;
15+ use std:: sync:: { Arc , RwLock , Weak } ;
1616use std:: thread;
1717
1818#[ derive( Debug ) ]
@@ -84,12 +84,13 @@ where
8484}
8585
8686type NodeRef < T > = Arc < RwLock < _Node < T > > > ;
87+ type NodeRefWeak < T > = Weak < RwLock < _Node < T > > > ;
8788
8889#[ derive( Clone ) ]
8990pub struct _Node < T : Float > {
9091 pub name : String ,
9192 pub data : Vec < T > ,
92- pub neighbors : Vec < Vec < Node < T > > > ,
93+ pub neighbors : Vec < Vec < NodeWeak < T > > > ,
9394}
9495
9596impl < T > fmt:: Debug for _Node < T >
@@ -108,7 +109,7 @@ where
108109 . iter( )
109110 . map( |l| {
110111 l. into_iter( )
111- . map( |n| n. read( ) . name. to_owned( ) )
112+ . map( |n| n. upgrade ( ) . read( ) . name. to_owned( ) )
112113 . collect:: <Vec <String >>( )
113114 } )
114115 . collect:: <Vec <Vec <String >>>( ) ,
@@ -127,28 +128,45 @@ impl<T: Float> _Node<T> {
127128 }
128129 }
129130
130- fn add_neighbor ( & mut self , level : usize , neighbor : Node < T > , capacity : Option < usize > ) {
131+ fn add_neighbor ( & mut self , level : usize , neighbor : NodeWeak < T > , capacity : Option < usize > ) {
131132 self . push_levels ( level, capacity) ;
132133 let neighbors = & mut self . neighbors ;
133134 if !neighbors[ level] . contains ( & neighbor) {
134135 neighbors[ level] . push ( neighbor) ;
135136 }
136137 }
137138
138- fn rm_neighbor ( & mut self , level : usize , neighbor : & Node < T > ) {
139+ fn rm_neighbor ( & mut self , level : usize , neighbor : & NodeWeak < T > ) {
139140 let neighbors = & mut self . neighbors ;
140141 let index = neighbors[ level]
141142 . iter ( )
142143 . position ( |n| * n == * neighbor)
143144 . unwrap ( ) ;
144145 neighbors[ level] . remove ( index) ;
145146 }
147+ }
148+
149+ #[ derive( Debug , Clone ) ]
150+ pub struct NodeWeak < T : Float > ( pub NodeRefWeak < T > ) ;
151+
152+ impl < T : Float > PartialEq for NodeWeak < T > {
153+ fn eq ( & self , other : & Self ) -> bool {
154+ Weak :: ptr_eq ( & self . 0 , & other. 0 )
155+ }
156+ }
146157
147- // fn clear_neighbors(&mut self, level: usize) {
148- // let neighbors = &mut self.neighbors;
149- // let cap = neighbors[level].capacity();
150- // neighbors[level] = Vec::with_capacity(cap);
151- // }
158+ impl < T : Float > Eq for NodeWeak < T > { }
159+
160+ impl < T : Float > Hash for NodeWeak < T > {
161+ fn hash < H : Hasher > ( & self , state : & mut H ) {
162+ self . upgrade ( ) . read ( ) . name . hash ( state) ;
163+ }
164+ }
165+
166+ impl < T : Float > NodeWeak < T > {
167+ pub fn upgrade ( & self ) -> Node < T > {
168+ Node ( self . 0 . upgrade ( ) . unwrap ( ) )
169+ }
152170}
153171
154172#[ derive( Debug , Clone ) ]
@@ -191,20 +209,19 @@ impl<T: Float> Node<T> {
191209 node. push_levels ( level, capacity) ;
192210 }
193211
194- fn add_neighbor ( & self , level : usize , neighbor : Node < T > , capacity : Option < usize > ) {
212+ fn add_neighbor ( & self , level : usize , neighbor : NodeWeak < T > , capacity : Option < usize > ) {
195213 let node = & mut self . 0 . try_write ( ) . unwrap ( ) ;
196214 node. add_neighbor ( level, neighbor, capacity) ;
197215 }
198216
199- fn rm_neighbor ( & self , level : usize , neighbor : & Node < T > ) {
217+ fn rm_neighbor ( & self , level : usize , neighbor : & NodeWeak < T > ) {
200218 let node = & mut self . 0 . try_write ( ) . unwrap ( ) ;
201219 node. rm_neighbor ( level, neighbor) ;
202220 }
203221
204- // fn clear_neighbors(&self, level: usize) {
205- // let node = &mut self.0.try_write().unwrap();
206- // node.clear_neighbors(level);
207- // }
222+ pub fn downgrade ( & self ) -> NodeWeak < T > {
223+ NodeWeak ( Arc :: downgrade ( & self . 0 ) )
224+ }
208225}
209226
210227type SimPairRef < T , R > = Rc < RefCell < _SimPair < T , R > > > ;
@@ -291,9 +308,9 @@ pub struct Index<T: Float, R: Float> {
291308 pub level_mult : f64 , // level generation factor
292309 pub node_count : usize , // count of nodes
293310 pub max_layer : usize , // idx of top layer
294- pub layers : Vec < HashSet < Node < T > > > , // distinct nodes in each layer
311+ pub layers : Vec < HashSet < NodeWeak < T > > > , // distinct nodes in each layer
295312 pub nodes : HashMap < String , Node < T > > , // hashmap of nodes
296- pub enterpoint : Option < Node < T > > , // enterpoint node
313+ pub enterpoint : Option < NodeWeak < T > > , // enterpoint node
297314 rng_ : StdRng , // rng for level generation
298315}
299316
@@ -347,7 +364,7 @@ impl<T: Float, R: Float> fmt::Debug for Index<T, R> {
347364 self . node_count,
348365 self . max_layer,
349366 match & self . enterpoint {
350- Some ( node) => node. read( ) . name. clone( ) ,
367+ Some ( node) => node. upgrade ( ) . read( ) . name. clone( ) ,
351368 None => "null" . to_owned( ) ,
352369 } ,
353370 )
@@ -400,10 +417,10 @@ where
400417
401418 if self . node_count == 0 {
402419 let node = Node :: new ( name, data, self . m_max_0 ) ;
403- self . enterpoint = Some ( node. clone ( ) ) ;
420+ self . enterpoint = Some ( node. downgrade ( ) ) ;
404421
405422 let mut layer = HashSet :: new ( ) ;
406- layer. insert ( node. clone ( ) ) ;
423+ layer. insert ( node. downgrade ( ) ) ;
407424 self . layers . push ( layer) ;
408425
409426 self . nodes . insert ( name. to_owned ( ) , node) ;
@@ -432,7 +449,7 @@ where
432449 self . node_count -= 1 ;
433450
434451 for lc in ( 0 ..( self . max_layer + 1 ) ) . rev ( ) {
435- if self . layers [ lc] . remove ( & node) {
452+ if self . layers [ lc] . remove ( & node. downgrade ( ) ) {
436453 break ;
437454 }
438455 }
@@ -455,7 +472,7 @@ where
455472
456473 // update enterpoint if necessary
457474 match & self . enterpoint {
458- Some ( ep) if node == * ep => {
475+ Some ( ep) if node == ep . upgrade ( ) => {
459476 let mut new_ep = None ;
460477 for lc in ( 0 ..( self . max_layer + 1 ) ) . rev ( ) {
461478 match self . layers [ lc] . iter ( ) . next ( ) {
@@ -518,8 +535,8 @@ where
518535
519536 let mut lc = l_max;
520537 while lc > l {
521- w = self . search_level ( data, & ep, 1 , lc) ;
522- ep = w. pop ( ) . unwrap ( ) . read ( ) . node . clone ( ) ;
538+ w = self . search_level ( data, & ep. upgrade ( ) , 1 , lc) ;
539+ ep = w. pop ( ) . unwrap ( ) . read ( ) . node . downgrade ( ) ;
523540
524541 if lc == 0 {
525542 break ;
@@ -529,7 +546,7 @@ where
529546
530547 let mut updated = HashSet :: new ( ) ;
531548 for lc in ( 0 ..( min ( l_max, l) + 1 ) ) . rev ( ) {
532- w = self . search_level ( data, & ep, self . ef_construction , lc) ;
549+ w = self . search_level ( data, & ep. upgrade ( ) , self . ef_construction , lc) ;
533550 let mut neighbors = self . select_neighbors ( query, & w, self . m , lc, true , true , None ) ;
534551 self . connect_neighbors ( query, & neighbors, lc) ;
535552
@@ -551,10 +568,10 @@ where
551568 for n in eneighbors {
552569 let ensim = OrderedFloat :: from ( ( self . mfunc ) (
553570 & enr. data ,
554- & n. read ( ) . data ,
571+ & n. upgrade ( ) . read ( ) . data ,
555572 self . data_dim ,
556573 ) ) ;
557- let enpair = SimPair :: new ( ensim, n. to_owned ( ) ) ;
574+ let enpair = SimPair :: new ( ensim, n. upgrade ( ) ) ;
558575 econn. push ( enpair) ;
559576 }
560577 }
@@ -570,7 +587,7 @@ where
570587 }
571588 }
572589
573- ep = w. peek ( ) . unwrap ( ) . read ( ) . node . clone ( ) ;
590+ ep = w. peek ( ) . unwrap ( ) . read ( ) . node . downgrade ( ) ;
574591 }
575592
576593 // update nodes in redis
@@ -583,14 +600,14 @@ where
583600 // new enterpoint if we're in a higher layer
584601 if l > l_max {
585602 self . max_layer = l;
586- self . enterpoint = Some ( query. to_owned ( ) ) ;
603+ self . enterpoint = Some ( query. downgrade ( ) ) ;
587604 while self . layers . len ( ) < l + 1 {
588605 self . layers . push ( HashSet :: new ( ) ) ;
589606 }
590607 }
591608
592609 // add node to layer set
593- self . layers [ l] . insert ( query. to_owned ( ) ) ;
610+ self . layers [ l] . insert ( query. downgrade ( ) ) ;
594611
595612 Ok ( ( ) )
596613 }
@@ -641,6 +658,7 @@ where
641658 let cpr = cpair. read ( ) ;
642659 let neighbors = & cpr. node . read ( ) . neighbors [ level] ;
643660 for neighbor in neighbors {
661+ let neighbor = neighbor. upgrade ( ) ;
644662 if !v. contains ( & neighbor) {
645663 v. insert ( neighbor. clone ( ) ) ;
646664
@@ -699,8 +717,9 @@ where
699717 let epair = ccopy. pop ( ) . unwrap ( ) ;
700718
701719 for eneighbor in & epair. read ( ) . node . read ( ) . neighbors [ lc] {
702- if * eneighbor == * query
703- || ( ignored_node. is_some ( ) && * eneighbor == * ignored_node. unwrap ( ) )
720+ let eneighbor = eneighbor. upgrade ( ) ;
721+ if eneighbor == * query
722+ || ( ignored_node. is_some ( ) && eneighbor == * ignored_node. unwrap ( ) )
704723 {
705724 continue ;
706725 }
@@ -765,9 +784,9 @@ where
765784 let npair = neighbors. pop ( ) . unwrap ( ) ;
766785 let npr = npair. read ( ) ;
767786
768- query. add_neighbor ( level, npr. node . clone ( ) , Some ( self . m_max_0 ) ) ;
787+ query. add_neighbor ( level, npr. node . downgrade ( ) , Some ( self . m_max_0 ) ) ;
769788 npr. node
770- . add_neighbor ( level, query. clone ( ) , Some ( self . m_max_0 ) ) ;
789+ . add_neighbor ( level, query. downgrade ( ) , Some ( self . m_max_0 ) ) ;
771790 }
772791 }
773792
@@ -788,9 +807,9 @@ where
788807 while !newconn. is_empty ( ) {
789808 let newpair = newconn. pop ( ) . unwrap ( ) ;
790809 let npr = newpair. read ( ) ;
791- node. add_neighbor ( level, npr. node . clone ( ) , Some ( self . m_max_0 ) ) ;
810+ node. add_neighbor ( level, npr. node . downgrade ( ) , Some ( self . m_max_0 ) ) ;
792811 npr. node
793- . add_neighbor ( level, node. clone ( ) , Some ( self . m_max_0 ) ) ;
812+ . add_neighbor ( level, node. downgrade ( ) , Some ( self . m_max_0 ) ) ;
794813 updated. insert ( npr. node . clone ( ) ) ;
795814 // if new neighbor exists in the old set then we remove it from
796815 // the set of neighbors to be removed
@@ -806,14 +825,14 @@ where
806825 while !rmconn. is_empty ( ) {
807826 let rmpair = rmconn. pop ( ) . unwrap ( ) ;
808827 let rmpr = rmpair. read ( ) ;
809- node. rm_neighbor ( level, & rmpr. node ) ;
828+ node. rm_neighbor ( level, & rmpr. node . downgrade ( ) ) ;
810829 // if node to be removed is the ignored node then pass
811830 match ignored_node {
812831 Some ( n) if rmpr. node == * n => {
813832 continue ;
814833 }
815834 _ => {
816- rmpr. node . rm_neighbor ( level, & node) ;
835+ rmpr. node . rm_neighbor ( level, & node. downgrade ( ) ) ;
817836 updated. insert ( rmpr. node . clone ( ) ) ;
818837 }
819838 }
@@ -828,6 +847,7 @@ where
828847 let mut updated = HashSet :: new ( ) ;
829848
830849 for n in neighbors {
850+ let n = n. upgrade ( ) ;
831851 let nnewconn: BinaryHeap < SimPair < T , R > > ;
832852 let mut nconn: BinaryHeap < SimPair < T , R > > ;
833853 {
@@ -836,17 +856,18 @@ where
836856 nconn = BinaryHeap :: with_capacity ( nneighbors. len ( ) ) ;
837857
838858 for nn in nneighbors {
859+ let nn = nn. upgrade ( ) ;
839860 let nnsim =
840861 OrderedFloat :: from ( ( self . mfunc ) ( & nr. data , & nn. read ( ) . data , self . data_dim ) ) ;
841862 let nnpair = SimPair :: new ( nnsim, nn. to_owned ( ) ) ;
842863 nconn. push ( nnpair) ;
843864 }
844865
845866 let m_max = if lc == 0 { self . m_max_0 } else { self . m_max } ;
846- nnewconn = self . select_neighbors ( n, & nconn, m_max, lc, true , true , Some ( node) ) ;
867+ nnewconn = self . select_neighbors ( & n, & nconn, m_max, lc, true , true , Some ( node) ) ;
847868 }
848869 updated. insert ( n. clone ( ) ) ;
849- let up = self . update_node_connections ( n, & nnewconn, & nconn, lc, Some ( node) ) ;
870+ let up = self . update_node_connections ( & n, & nnewconn, & nconn, lc, Some ( node) ) ;
850871 for u in up {
851872 updated. insert ( u) ;
852873 }
@@ -861,12 +882,12 @@ where
861882
862883 let mut lc = l_max;
863884 while lc > 0 {
864- let w = self . search_level ( query, & ep, 1 , lc) ;
865- ep = w. peek ( ) . unwrap ( ) . read ( ) . node . clone ( ) ;
885+ let w = self . search_level ( query, & ep. upgrade ( ) , 1 , lc) ;
886+ ep = w. peek ( ) . unwrap ( ) . read ( ) . node . downgrade ( ) ;
866887 lc -= 1 ;
867888 }
868889
869- let mut w = self . search_level ( query, & ep, ef, 0 ) ;
890+ let mut w = self . search_level ( query, & ep. upgrade ( ) , ef, 0 ) ;
870891
871892 let mut res = Vec :: with_capacity ( k) ;
872893 while res. len ( ) < k && !w. is_empty ( ) {
0 commit comments