@@ -21,15 +21,16 @@ use redis_module::{
2121use redismodule_cmd:: { Command , ArgType , Collection } ;
2222use std:: collections:: { HashMap , HashSet } ;
2323use std:: collections:: hash_map:: Entry ;
24- use std:: sync:: { Arc , RwLock , RwLockWriteGuard } ;
24+ use std:: sync:: { Arc , RwLock } ;
2525use types:: * ;
2626
2727static PREFIX : & str = "hnsw" ;
2828
29+ type IndexArc = Arc < RwLock < IndexT > > ;
2930type IndexT = Index < f32 , f32 > ;
3031
3132lazy_static ! {
32- static ref INDICES : Arc <RwLock <HashMap <String , IndexT >>> =
33+ static ref INDICES : Arc <RwLock <HashMap <String , IndexArc >>> =
3334 Arc :: new( RwLock :: new( HashMap :: new( ) ) ) ;
3435}
3536
@@ -118,20 +119,20 @@ fn new_index(ctx: &Context, args: Vec<String>) -> RedisResult {
118119 }
119120 None => {
120121 // create index
121- let mut index = Index :: new (
122+ let index = Index :: new (
122123 & index_name,
123124 Box :: new ( hnsw:: metrics:: euclidean) ,
124125 data_dim,
125126 m,
126127 ef_construction,
127128 ) ;
128129 ctx. log_debug ( format ! ( "{:?}" , index) . as_str ( ) ) ;
129- key. set_value :: < IndexRedis > ( & HNSW_INDEX_REDIS_TYPE , ( & mut index) . into ( ) ) ?;
130+ key. set_value :: < IndexRedis > ( & HNSW_INDEX_REDIS_TYPE , ( & index) . into ( ) ) ?;
130131 // Add index to global hashmap
131132 INDICES
132133 . write ( )
133134 . unwrap ( )
134- . insert ( index_name, index) ;
135+ . insert ( index_name, Arc :: new ( RwLock :: new ( index) ) ) ;
135136 }
136137 }
137138
@@ -148,14 +149,17 @@ fn get_index(ctx: &Context, args: Vec<String>) -> RedisResult {
148149 let name_suffix = parsed. remove ( "name" ) . unwrap ( ) . as_string ( ) ?;
149150 let index_name = format ! ( "{}.{}" , PREFIX , name_suffix) ;
150151
151- let mut indices = INDICES . write ( ) . unwrap ( ) ;
152+ let index = load_index ( ctx, & index_name) ?;
153+ let index = match index. try_read ( ) {
154+ Ok ( index) => index,
155+ Err ( e) => return Err ( e. to_string ( ) . into ( ) )
156+ } ;
152157
153- let index = load_index ( ctx, & mut indices, & index_name) ?;
154158 ctx. log_debug ( format ! ( "Index: {:?}" , index) . as_str ( ) ) ;
155159 ctx. log_debug ( format ! ( "Layers: {:?}" , index. layers. len( ) ) . as_str ( ) ) ;
156160 ctx. log_debug ( format ! ( "Nodes: {:?}" , index. nodes. len( ) ) . as_str ( ) ) ;
157161
158- let index_redis: IndexRedis = index. into ( ) ;
162+ let index_redis: IndexRedis = ( & * index) . into ( ) ;
159163
160164 Ok ( index_redis. as_redisvalue ( ) )
161165}
@@ -193,7 +197,8 @@ fn delete_index(ctx: &Context, args: Vec<String>) -> RedisResult {
193197 Ok ( 1_usize . into ( ) )
194198}
195199
196- fn load_index < ' a > ( ctx : & Context , indices : & ' a mut RwLockWriteGuard < HashMap < String , IndexT > > , index_name : & str ) -> Result < & ' a mut IndexT , RedisError > {
200+ fn load_index ( ctx : & Context , index_name : & str ) -> Result < IndexArc , RedisError > {
201+ let mut indices = INDICES . write ( ) . unwrap ( ) ;
197202 // check if index is in global hashmap
198203 let index = match indices. entry ( index_name. to_string ( ) ) {
199204 Entry :: Occupied ( o) => o. into_mut ( ) ,
@@ -208,11 +213,11 @@ fn load_index<'a>(ctx: &Context, indices: &'a mut RwLockWriteGuard<HashMap<Strin
208213 None => return Err ( format ! ( "Index: {} does not exist" , index_name) . into ( ) ) ,
209214 } ;
210215 let index = make_index ( ctx, index_redis) ?;
211- v. insert ( index)
216+ v. insert ( Arc :: new ( RwLock :: new ( index) ) )
212217 }
213218 } ;
214219
215- Ok ( index)
220+ Ok ( index. clone ( ) )
216221}
217222
218223fn make_index ( ctx : & Context , ir : & IndexRedis ) -> Result < IndexT , RedisError > {
@@ -284,13 +289,13 @@ fn make_index(ctx: &Context, ir: &IndexRedis) -> Result<IndexT, RedisError> {
284289fn update_index (
285290 ctx : & Context ,
286291 index_name : & str ,
287- index : & mut IndexT ,
292+ index : & IndexT ,
288293) -> Result < ( ) , RedisError > {
289294 let key = ctx. open_key_writable ( index_name) ;
290295 match key. get_value :: < IndexRedis > ( & HNSW_INDEX_REDIS_TYPE ) ? {
291296 Some ( _) => {
292297 ctx. log_debug ( format ! ( "update index: {}" , index_name) . as_str ( ) ) ;
293- key. set_value :: < IndexRedis > ( & HNSW_INDEX_REDIS_TYPE , index. into ( ) ) ?;
298+ key. set_value :: < IndexRedis > ( & HNSW_INDEX_REDIS_TYPE , ( & * index) . into ( ) ) ?;
294299 }
295300 None => {
296301 return Err ( RedisError :: String ( format ! (
@@ -318,8 +323,11 @@ fn add_node(ctx: &Context, args: Vec<String>) -> RedisResult {
318323 let dataf64 = parsed. remove ( "data" ) . unwrap ( ) . as_f64vec ( ) ?;
319324 let data = dataf64. iter ( ) . map ( |d| * d as f32 ) . collect :: < Vec < f32 > > ( ) ;
320325
321- let mut indices = INDICES . write ( ) . unwrap ( ) ;
322- let index = load_index ( ctx, & mut indices, & index_name) ?;
326+ let index = load_index ( ctx, & index_name) ?;
327+ let mut index = match index. try_write ( ) {
328+ Ok ( index) => index,
329+ Err ( e) => return Err ( e. to_string ( ) . into ( ) )
330+ } ;
323331
324332 let up = |name : String , node : Node < f32 > | {
325333 write_node ( ctx, & name, ( & node) . into ( ) ) . unwrap ( ) ;
@@ -335,7 +343,7 @@ fn add_node(ctx: &Context, args: Vec<String>) -> RedisResult {
335343 write_node ( ctx, & node_name, node. into ( ) ) ?;
336344
337345 // update index in redis
338- update_index ( ctx, & index_name, index) ?;
346+ update_index ( ctx, & index_name, & index) ?;
339347
340348 Ok ( "OK" . into ( ) )
341349}
@@ -350,12 +358,15 @@ fn delete_node(ctx: &Context, args: Vec<String>) -> RedisResult {
350358 let index_suffix = parsed. remove ( "index" ) . unwrap ( ) . as_string ( ) ?;
351359 let node_suffix = parsed. remove ( "node" ) . unwrap ( ) . as_string ( ) ?;
352360
353- let mut indices = INDICES . write ( ) . unwrap ( ) ;
354361 let index_name = format ! ( "{}.{}" , PREFIX , index_suffix) ;
355- let index = load_index ( ctx, & mut indices, & index_name) ?;
356-
357362 let node_name = format ! ( "{}.{}.{}" , PREFIX , index_suffix, node_suffix) ;
358363
364+ let index = load_index ( ctx, & index_name) ?;
365+ let mut index = match index. try_write ( ) {
366+ Ok ( index) => index,
367+ Err ( e) => return Err ( e. to_string ( ) . into ( ) )
368+ } ;
369+
359370 // TODO return error if node has more than 1 strong_count
360371 let node = index. nodes . get ( & node_name) . unwrap ( ) ;
361372 if Arc :: strong_count ( & node. 0 ) > 1 {
@@ -387,7 +398,7 @@ fn delete_node(ctx: &Context, args: Vec<String>) -> RedisResult {
387398 } ;
388399
389400 // update index in redis
390- update_index ( ctx, & index_name, index) ?;
401+ update_index ( ctx, & index_name, & index) ?;
391402
392403 Ok ( 1_usize . into ( ) )
393404}
@@ -449,9 +460,12 @@ fn search_knn(ctx: &Context, args: Vec<String>) -> RedisResult {
449460 let dataf64 = parsed. remove ( "query" ) . unwrap ( ) . as_f64vec ( ) ?;
450461 let data = dataf64. iter ( ) . map ( |d| * d as f32 ) . collect :: < Vec < f32 > > ( ) ;
451462
452- let mut indices = INDICES . write ( ) . unwrap ( ) ;
453463 let index_name = format ! ( "{}.{}" , PREFIX , index_suffix) ;
454- let index = load_index ( ctx, & mut indices, & index_name) ?;
464+ let index = load_index ( ctx, & index_name) ?;
465+ let index = match index. try_read ( ) {
466+ Ok ( index) => index,
467+ Err ( e) => return Err ( e. to_string ( ) . into ( ) )
468+ } ;
455469
456470 ctx. log_debug (
457471 format ! (
0 commit comments