diff --git a/models/retriever/enhanced_kt_retriever.py b/models/retriever/enhanced_kt_retriever.py index 768f42de..b981a07c 100644 --- a/models/retriever/enhanced_kt_retriever.py +++ b/models/retriever/enhanced_kt_retriever.py @@ -1,3 +1,4 @@ +import json import os import pickle import threading @@ -397,7 +398,16 @@ def _save_node_embedding_cache(self): return False def _load_node_embedding_cache(self): + """Check if node embedding cache exists and matches the current model""" """Load node embedding cache from disk""" + embedding_model_info = f"{self.cache_dir}/{self.dataset}/embedding_model_info.json" + if os.path.exists(embedding_model_info): + with open(embedding_model_info, 'r') as f: + embedding_model_info = json.load(f) + if embedding_model_info['model_name'] != self.config.embeddings.model_name: + return False + else: + return False cache_path = f"{self.cache_dir}/{self.dataset}/node_embedding_cache.pt" cache_path_npz = cache_path.replace('.pt', '.npz') diff --git a/models/retriever/faiss_filter.py b/models/retriever/faiss_filter.py index b0497b67..70ffa4de 100644 --- a/models/retriever/faiss_filter.py +++ b/models/retriever/faiss_filter.py @@ -22,6 +22,7 @@ def __init__(self, dataset, graph: nx.MultiDiGraph, model_name: str = "all-MiniL :param cache_dir: cache directory for FAISS indices """ self.graph = graph + self.model_name = model_name self.model = SentenceTransformer(model_name) self.cache_dir = cache_dir os.makedirs(cache_dir, exist_ok=True) @@ -66,13 +67,6 @@ def __init__(self, dataset, graph: nx.MultiDiGraph, model_name: str = "all-MiniL # Get model output dimension self.model_dim = self.model.get_sentence_embedding_dimension() - self.dim_transform = None - if self.model_dim != 384: # If model output dimension is not 384 - self.dim_transform = torch.nn.Linear(self.model_dim, 384) - if self.device.type == "cuda" and torch.cuda.is_available(): - self.dim_transform = self.dim_transform.to(self.device) - else: - self.dim_transform = self.dim_transform.to("cpu") self.name_to_id = {} for node_id in self.graph.nodes(): @@ -269,9 +263,7 @@ def retrieve_via_triples(self, query_embed, top_k: int = 5) -> List[Tuple[str, s query_embed = query_embed.to(self.device) else: query_embed = torch.FloatTensor(query_embed).to(self.device) - - query_embed = self.transform_vector(query_embed) - + # Create cache key and perform search cache_key = f"triple_search_{hash(query_embed.cpu().numpy().tobytes())}_{top_k}" D, I = self._cached_faiss_search(self.triple_index, query_embed, top_k, cache_key) @@ -304,8 +296,6 @@ def retrieve_via_communities(self, query_embed, top_k: int = 3) -> List[str]: else: query_embed = torch.FloatTensor(query_embed).to(self.device) - # Apply dimension transformation - query_embed = self.transform_vector(query_embed) # Create cache key for this search cache_key = f"comm_search_{hash(query_embed.cpu().numpy().tobytes())}_{top_k}" @@ -437,7 +427,6 @@ def _calculate_node_scores(self, query_embed, nodes: List[str]) -> Dict[str, flo query_embed = query_embed.cpu().detach().numpy() query_tensor = torch.FloatTensor(query_embed).to(self.device) - query_tensor = self.transform_vector(query_tensor) nodes_with_embedding = [] nodes_without_embedding = [] @@ -474,9 +463,6 @@ def _calculate_node_scores(self, query_embed, nodes: List[str]) -> Dict[str, flo if texts: node_embeddings = self.model.encode(texts, convert_to_tensor=True, device=self.device) - if self.dim_transform is not None: - node_embeddings = self.dim_transform(node_embeddings) - similarities = F.cosine_similarity(query_tensor.unsqueeze(0), node_embeddings, dim=1) for i, node in enumerate(nodes_to_encode): @@ -491,9 +477,7 @@ def _calculate_node_scores_optimized(self, query_embed, nodes: List[str]) -> Dic return {} query_embed = query_embed.cpu().detach().numpy() - query_tensor = torch.FloatTensor(query_embed).to(self.device) - query_tensor = self.transform_vector(query_tensor) - + query_tensor = torch.FloatTensor(query_embed).to(self.device) node_embeddings = [] node_names = [] @@ -524,8 +508,6 @@ def _calculate_node_scores_optimized(self, query_embed, nodes: List[str]) -> Dic try: embeddings = self.model.encode(texts, convert_to_tensor=True, device=self.device) - if self.dim_transform is not None: - embeddings = self.dim_transform(embeddings) similarities = F.cosine_similarity(query_tensor.unsqueeze(0), embeddings, dim=1) @@ -694,8 +676,6 @@ def _compute_and_transform_embeddings(self, texts: list) -> torch.Tensor: """Compute embeddings and apply dimension transformation if needed.""" embeddings = self.model.encode(texts, convert_to_tensor=True, device=self.device) - if hasattr(self, 'dim_transform') and self.dim_transform is not None: - embeddings = self.dim_transform(embeddings) return embeddings @@ -707,10 +687,7 @@ def _process_single_node_fallback(self, node: str) -> bool: return False embedding = self.model.encode([text], convert_to_tensor=True, device=self.device)[0] - - if hasattr(self, 'dim_transform') and self.dim_transform is not None: - embedding = self.dim_transform(embedding.unsqueeze(0)).squeeze(0) - + self.node_embedding_cache[node] = embedding.detach() return True @@ -787,6 +764,7 @@ def _precompute_node_embeddings(self, batch_size: int = 100, force_recompute: bo def build_indices(self): """Build FAISS Index only if they don't already exist and are consistent with current graph""" # Check if all indices and embedding files already exist + embedding_model_info = f"{self.cache_dir}/{self.dataset}/embedding_model_info.json" node_path = f"{self.cache_dir}/{self.dataset}/node.index" relation_path = f"{self.cache_dir}/{self.dataset}/relation.index" triple_path = f"{self.cache_dir}/{self.dataset}/triple.index" @@ -794,9 +772,10 @@ def build_indices(self): node_embed_path = f"{self.cache_dir}/{self.dataset}/node_embeddings.pt" relation_embed_path = f"{self.cache_dir}/{self.dataset}/relation_embeddings.pt" node_map_path = f"{self.cache_dir}/{self.dataset}/node_map.json" - dim_transform_path = f"{self.cache_dir}/{self.dataset}/dim_transform.pt" - all_exist = (os.path.exists(node_path) and + all_exist = ( + os.path.exists(embedding_model_info) and + os.path.exists(node_path) and os.path.exists(relation_path) and os.path.exists(triple_path) and os.path.exists(comm_path) and @@ -805,42 +784,33 @@ def build_indices(self): os.path.exists(node_map_path)) indices_consistent = False + embedding_model_info_consistent = False if all_exist: + try: + with open(embedding_model_info, 'r') as f: + embedding_model_info = json.load(f) + model_name = embedding_model_info.get('model_name',"UNKNOWN") + embedding_model_info_consistent = (model_name == self.model_name) + if not embedding_model_info_consistent: + logger.warning(f"The embedding model {model_name} used when creating the index is inconsistent with the model {self.model_name} used during inference. The index needs to be rebuilt~") + except Exception as e: + logger.error(f"Error loading embedding model info {e}") + if all_exist and embedding_model_info_consistent: try: with open(node_map_path, 'r') as f: cached_node_map = json.load(f) current_nodes = set(self.graph.nodes()) cached_nodes = set(cached_node_map.values()) - # Check graph consistency - graph_consistent = current_nodes == cached_nodes - - # Check model dimension consistency - dim_consistent = True - if os.path.exists(dim_transform_path): - try: - cached_dim_info = torch.load(dim_transform_path, map_location='cpu', weights_only=False) - cached_model_dim = cached_dim_info.get('model_dim') - if cached_model_dim != self.model_dim: - logger.info(f"Model dimension changed: cached {cached_model_dim}, current {self.model_dim}") - dim_consistent = False - except Exception as e: - logger.warning(f"Error checking dimension transform consistency: {e}") - dim_consistent = False - - if graph_consistent and dim_consistent: + if current_nodes == cached_nodes: indices_consistent = True - logger.info("Cached FAISS indices are consistent with current graph and model") + logger.info("Cached FAISS indices are consistent with current graph") else: - if not graph_consistent: - logger.info(f"Graph inconsistency detected: current nodes {len(current_nodes)}, cached nodes {len(cached_nodes)}") - logger.info(f"Missing in cache: {current_nodes - cached_nodes}") - logger.info(f"Extra in cache: {cached_nodes - current_nodes}") - if not dim_consistent: - logger.info("Model dimension inconsistency detected") + logger.info(f"Graph inconsistency detected: current nodes {len(current_nodes)}, cached nodes {len(cached_nodes)}") + logger.info(f"Missing in cache: {current_nodes - cached_nodes}") + logger.info(f"Extra in cache: {cached_nodes - current_nodes}") except Exception as e: logger.error(f"Error checking index consistency: {e}") - if all_exist and indices_consistent: logger.info("All FAISS indices and embeddings already exist, loading from cache...") if not hasattr(self, 'node_index') or self.node_index is None: @@ -856,7 +826,7 @@ def build_indices(self): logger.info("Building FAISS indices and embeddings...") if all_exist and not indices_consistent: logger.info("Clearing inconsistent cache files...") - for path in [node_path, relation_path, triple_path, comm_path, node_embed_path, relation_embed_path, node_map_path, dim_transform_path]: + for path in [node_path, relation_path, triple_path, comm_path, node_embed_path, relation_embed_path, node_map_path, embedding_model_info]: if os.path.exists(path): os.remove(path) @@ -864,7 +834,7 @@ def build_indices(self): self._build_relation_index() self._build_triple_index() self._build_community_index() - self._save_dim_transform() + self._save_embedding_model_info() logger.info("FAISS indices and embeddings built successfully!") self._populate_embedding_maps() try: @@ -881,72 +851,16 @@ def build_indices(self): self._preload_faiss_indices() - def _save_dim_transform(self): - """Save dimension transform state to disk""" - dim_transform_path = f"{self.cache_dir}/{self.dataset}/dim_transform.pt" - try: - save_data = { - 'model_dim': self.model_dim, - 'target_dim': 384, - 'has_transform': self.dim_transform is not None - } - - if self.dim_transform is not None: - save_data['state_dict'] = self.dim_transform.cpu().state_dict() - - torch.save(save_data, dim_transform_path) - logger.info(f"Saved dimension transform to {dim_transform_path}") - except Exception as e: - logger.error(f"Error saving dimension transform: {e}") - - def _load_dim_transform(self): - """Load dimension transform state from disk""" - dim_transform_path = f"{self.cache_dir}/{self.dataset}/dim_transform.pt" - if not os.path.exists(dim_transform_path): - return False - + def _save_embedding_model_info(self): + embedding_model_info_path = f"{self.cache_dir}/{self.dataset}/embedding_model_info.json" + embedding_model_info = { + "model_name":self.model_name + } try: - try: - save_data = torch.load(dim_transform_path, map_location='cpu', weights_only=False) - except TypeError: - save_data = torch.load(dim_transform_path, map_location='cpu') - - cached_model_dim = save_data.get('model_dim') - has_transform = save_data.get('has_transform', False) - - # Verify dimension consistency - if cached_model_dim != self.model_dim: - logger.warning(f"Model dimension mismatch: cached {cached_model_dim}, current {self.model_dim}") - return False - - if has_transform and 'state_dict' in save_data: - if self.dim_transform is None: - self.dim_transform = torch.nn.Linear(self.model_dim, 384) - - # Load the saved weights - self.dim_transform.load_state_dict(save_data['state_dict']) - - if self.device.type == "cuda" and torch.cuda.is_available(): - self.dim_transform = self.dim_transform.to(self.device) - else: - self.dim_transform = self.dim_transform.to("cpu") - - logger.info(f"Loaded dimension transform from {dim_transform_path}") - return True - elif not has_transform: - if self.dim_transform is None: - logger.info("No dimension transform needed (cached and current both 384-dim)") - return True - else: - logger.warning("Dimension transform state mismatch") - return False - + with open(embedding_model_info_path,'w') as f: + json.dump(embedding_model_info, f) except Exception as e: - logger.error(f"Error loading dimension transform: {e}") - return False - - return False - + logger.error(f"Error saving embedding model info: {e}") def _build_node_index(self): """Build FAISS index for all nodes and cache embeddings""" nodes = list(self.graph.nodes()) @@ -1115,8 +1029,6 @@ def _load_indices(self): except Exception as e: logger.warning(f"Warning: Failed to load relation embeddings: {e}") - # Load dimension transform if available - self._load_dim_transform() # Populate maps if all necessary data is loaded if self.node_map and self.node_embeddings is not None: @@ -1329,12 +1241,6 @@ def _nodes_to_text(self, nodes: List[str]) -> str: return "\n".join(text_parts) - def transform_vector(self, vector: torch.Tensor) -> torch.Tensor: - """Transform vector dimensions if needed""" - if self.dim_transform is not None: - return self.dim_transform(vector) - return vector - def _calculate_triple_relevance_scores(self, query_embed: torch.Tensor, triples: List[Tuple[str, str, str]], threshold: float = 0.3, top_k: int = 10) -> List[Tuple[str, str, str, float]]: """ Calculate relevance scores for triples and filter out low-relevance ones using FAISS. @@ -1356,7 +1262,6 @@ def _calculate_triple_relevance_scores(self, query_embed: torch.Tensor, triples: return [] # Transform query embedding for FAISS search - query_embed = self.transform_vector(query_embed) query_embed_np = query_embed.cpu().detach().numpy().reshape(1, -1) # Normalize query embedding for FAISS search @@ -1424,4 +1329,3 @@ def __del__(self): self.save_embedding_cache() except Exception as e: logger.warning(f"Error during __del__ saving embedding cache: {type(e).__name__}: {e}") - diff --git a/retriever/faiss_cache_new/demo/embedding_model_info.json b/retriever/faiss_cache_new/demo/embedding_model_info.json new file mode 100644 index 00000000..d877c519 --- /dev/null +++ b/retriever/faiss_cache_new/demo/embedding_model_info.json @@ -0,0 +1,3 @@ +{ + "model_name": "all-MiniLM-L6-v2" +}