Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions models/retriever/enhanced_kt_retriever.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import pickle
import threading
Expand Down Expand Up @@ -397,7 +398,16 @@ def _save_node_embedding_cache(self):
return False

def _load_node_embedding_cache(self):
"""Check if node embedding cache exists and matches the current model"""
"""Load node embedding cache from disk"""
embedding_model_info = f"{self.cache_dir}/{self.dataset}/embedding_model_info.json"
if os.path.exists(embedding_model_info):
with open(embedding_model_info, 'r') as f:
embedding_model_info = json.load(f)
if embedding_model_info['model_name'] != self.config.embeddings.model_name:
return False
else:
return False
cache_path = f"{self.cache_dir}/{self.dataset}/node_embedding_cache.pt"
cache_path_npz = cache_path.replace('.pt', '.npz')

Expand Down
164 changes: 34 additions & 130 deletions models/retriever/faiss_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, dataset, graph: nx.MultiDiGraph, model_name: str = "all-MiniL
:param cache_dir: cache directory for FAISS indices
"""
self.graph = graph
self.model_name = model_name
self.model = SentenceTransformer(model_name)
self.cache_dir = cache_dir
os.makedirs(cache_dir, exist_ok=True)
Expand Down Expand Up @@ -66,13 +67,6 @@ def __init__(self, dataset, graph: nx.MultiDiGraph, model_name: str = "all-MiniL

# Get model output dimension
self.model_dim = self.model.get_sentence_embedding_dimension()
self.dim_transform = None
if self.model_dim != 384: # If model output dimension is not 384
self.dim_transform = torch.nn.Linear(self.model_dim, 384)
if self.device.type == "cuda" and torch.cuda.is_available():
self.dim_transform = self.dim_transform.to(self.device)
else:
self.dim_transform = self.dim_transform.to("cpu")

self.name_to_id = {}
for node_id in self.graph.nodes():
Expand Down Expand Up @@ -269,9 +263,7 @@ def retrieve_via_triples(self, query_embed, top_k: int = 5) -> List[Tuple[str, s
query_embed = query_embed.to(self.device)
else:
query_embed = torch.FloatTensor(query_embed).to(self.device)

query_embed = self.transform_vector(query_embed)


# Create cache key and perform search
cache_key = f"triple_search_{hash(query_embed.cpu().numpy().tobytes())}_{top_k}"
D, I = self._cached_faiss_search(self.triple_index, query_embed, top_k, cache_key)
Expand Down Expand Up @@ -304,8 +296,6 @@ def retrieve_via_communities(self, query_embed, top_k: int = 3) -> List[str]:
else:
query_embed = torch.FloatTensor(query_embed).to(self.device)

# Apply dimension transformation
query_embed = self.transform_vector(query_embed)

# Create cache key for this search
cache_key = f"comm_search_{hash(query_embed.cpu().numpy().tobytes())}_{top_k}"
Expand Down Expand Up @@ -437,7 +427,6 @@ def _calculate_node_scores(self, query_embed, nodes: List[str]) -> Dict[str, flo

query_embed = query_embed.cpu().detach().numpy()
query_tensor = torch.FloatTensor(query_embed).to(self.device)
query_tensor = self.transform_vector(query_tensor)

nodes_with_embedding = []
nodes_without_embedding = []
Expand Down Expand Up @@ -474,9 +463,6 @@ def _calculate_node_scores(self, query_embed, nodes: List[str]) -> Dict[str, flo
if texts:
node_embeddings = self.model.encode(texts, convert_to_tensor=True, device=self.device)

if self.dim_transform is not None:
node_embeddings = self.dim_transform(node_embeddings)

similarities = F.cosine_similarity(query_tensor.unsqueeze(0), node_embeddings, dim=1)

for i, node in enumerate(nodes_to_encode):
Expand All @@ -491,9 +477,7 @@ def _calculate_node_scores_optimized(self, query_embed, nodes: List[str]) -> Dic
return {}

query_embed = query_embed.cpu().detach().numpy()
query_tensor = torch.FloatTensor(query_embed).to(self.device)
query_tensor = self.transform_vector(query_tensor)

query_tensor = torch.FloatTensor(query_embed).to(self.device)

node_embeddings = []
node_names = []
Expand Down Expand Up @@ -524,8 +508,6 @@ def _calculate_node_scores_optimized(self, query_embed, nodes: List[str]) -> Dic
try:
embeddings = self.model.encode(texts, convert_to_tensor=True, device=self.device)

if self.dim_transform is not None:
embeddings = self.dim_transform(embeddings)

similarities = F.cosine_similarity(query_tensor.unsqueeze(0), embeddings, dim=1)

Expand Down Expand Up @@ -694,8 +676,6 @@ def _compute_and_transform_embeddings(self, texts: list) -> torch.Tensor:
"""Compute embeddings and apply dimension transformation if needed."""
embeddings = self.model.encode(texts, convert_to_tensor=True, device=self.device)

if hasattr(self, 'dim_transform') and self.dim_transform is not None:
embeddings = self.dim_transform(embeddings)

return embeddings

Expand All @@ -707,10 +687,7 @@ def _process_single_node_fallback(self, node: str) -> bool:
return False

embedding = self.model.encode([text], convert_to_tensor=True, device=self.device)[0]

if hasattr(self, 'dim_transform') and self.dim_transform is not None:
embedding = self.dim_transform(embedding.unsqueeze(0)).squeeze(0)


self.node_embedding_cache[node] = embedding.detach()
return True

Expand Down Expand Up @@ -787,16 +764,18 @@ def _precompute_node_embeddings(self, batch_size: int = 100, force_recompute: bo
def build_indices(self):
"""Build FAISS Index only if they don't already exist and are consistent with current graph"""
# Check if all indices and embedding files already exist
embedding_model_info = f"{self.cache_dir}/{self.dataset}/embedding_model_info.json"
node_path = f"{self.cache_dir}/{self.dataset}/node.index"
relation_path = f"{self.cache_dir}/{self.dataset}/relation.index"
triple_path = f"{self.cache_dir}/{self.dataset}/triple.index"
comm_path = f"{self.cache_dir}/{self.dataset}/comm.index"
node_embed_path = f"{self.cache_dir}/{self.dataset}/node_embeddings.pt"
relation_embed_path = f"{self.cache_dir}/{self.dataset}/relation_embeddings.pt"
node_map_path = f"{self.cache_dir}/{self.dataset}/node_map.json"
dim_transform_path = f"{self.cache_dir}/{self.dataset}/dim_transform.pt"

all_exist = (os.path.exists(node_path) and
all_exist = (
os.path.exists(embedding_model_info) and
os.path.exists(node_path) and
os.path.exists(relation_path) and
os.path.exists(triple_path) and
os.path.exists(comm_path) and
Expand All @@ -805,42 +784,33 @@ def build_indices(self):
os.path.exists(node_map_path))

indices_consistent = False
embedding_model_info_consistent = False
if all_exist:
try:
with open(embedding_model_info, 'r') as f:
embedding_model_info = json.load(f)
model_name = embedding_model_info.get('model_name',"UNKNOWN")
embedding_model_info_consistent = (model_name == self.model_name)
if not embedding_model_info_consistent:
logger.warning(f"The embedding model {model_name} used when creating the index is inconsistent with the model {self.model_name} used during inference. The index needs to be rebuilt~")
except Exception as e:
logger.error(f"Error loading embedding model info {e}")
if all_exist and embedding_model_info_consistent:
try:
with open(node_map_path, 'r') as f:
cached_node_map = json.load(f)
current_nodes = set(self.graph.nodes())
cached_nodes = set(cached_node_map.values())

# Check graph consistency
graph_consistent = current_nodes == cached_nodes

# Check model dimension consistency
dim_consistent = True
if os.path.exists(dim_transform_path):
try:
cached_dim_info = torch.load(dim_transform_path, map_location='cpu', weights_only=False)
cached_model_dim = cached_dim_info.get('model_dim')
if cached_model_dim != self.model_dim:
logger.info(f"Model dimension changed: cached {cached_model_dim}, current {self.model_dim}")
dim_consistent = False
except Exception as e:
logger.warning(f"Error checking dimension transform consistency: {e}")
dim_consistent = False

if graph_consistent and dim_consistent:
if current_nodes == cached_nodes:
indices_consistent = True
logger.info("Cached FAISS indices are consistent with current graph and model")
logger.info("Cached FAISS indices are consistent with current graph")
else:
if not graph_consistent:
logger.info(f"Graph inconsistency detected: current nodes {len(current_nodes)}, cached nodes {len(cached_nodes)}")
logger.info(f"Missing in cache: {current_nodes - cached_nodes}")
logger.info(f"Extra in cache: {cached_nodes - current_nodes}")
if not dim_consistent:
logger.info("Model dimension inconsistency detected")
logger.info(f"Graph inconsistency detected: current nodes {len(current_nodes)}, cached nodes {len(cached_nodes)}")
logger.info(f"Missing in cache: {current_nodes - cached_nodes}")
logger.info(f"Extra in cache: {cached_nodes - current_nodes}")
except Exception as e:
logger.error(f"Error checking index consistency: {e}")

if all_exist and indices_consistent:
logger.info("All FAISS indices and embeddings already exist, loading from cache...")
if not hasattr(self, 'node_index') or self.node_index is None:
Expand All @@ -856,15 +826,15 @@ def build_indices(self):
logger.info("Building FAISS indices and embeddings...")
if all_exist and not indices_consistent:
logger.info("Clearing inconsistent cache files...")
for path in [node_path, relation_path, triple_path, comm_path, node_embed_path, relation_embed_path, node_map_path, dim_transform_path]:
for path in [node_path, relation_path, triple_path, comm_path, node_embed_path, relation_embed_path, node_map_path, embedding_model_info]:
if os.path.exists(path):
os.remove(path)

self._build_node_index()
self._build_relation_index()
self._build_triple_index()
self._build_community_index()
self._save_dim_transform()
self._save_embedding_model_info()
logger.info("FAISS indices and embeddings built successfully!")
self._populate_embedding_maps()
try:
Expand All @@ -881,72 +851,16 @@ def build_indices(self):

self._preload_faiss_indices()

def _save_dim_transform(self):
"""Save dimension transform state to disk"""
dim_transform_path = f"{self.cache_dir}/{self.dataset}/dim_transform.pt"
try:
save_data = {
'model_dim': self.model_dim,
'target_dim': 384,
'has_transform': self.dim_transform is not None
}

if self.dim_transform is not None:
save_data['state_dict'] = self.dim_transform.cpu().state_dict()

torch.save(save_data, dim_transform_path)
logger.info(f"Saved dimension transform to {dim_transform_path}")
except Exception as e:
logger.error(f"Error saving dimension transform: {e}")

def _load_dim_transform(self):
"""Load dimension transform state from disk"""
dim_transform_path = f"{self.cache_dir}/{self.dataset}/dim_transform.pt"
if not os.path.exists(dim_transform_path):
return False

def _save_embedding_model_info(self):
embedding_model_info_path = f"{self.cache_dir}/{self.dataset}/embedding_model_info.json"
embedding_model_info = {
"model_name":self.model_name
}
try:
try:
save_data = torch.load(dim_transform_path, map_location='cpu', weights_only=False)
except TypeError:
save_data = torch.load(dim_transform_path, map_location='cpu')

cached_model_dim = save_data.get('model_dim')
has_transform = save_data.get('has_transform', False)

# Verify dimension consistency
if cached_model_dim != self.model_dim:
logger.warning(f"Model dimension mismatch: cached {cached_model_dim}, current {self.model_dim}")
return False

if has_transform and 'state_dict' in save_data:
if self.dim_transform is None:
self.dim_transform = torch.nn.Linear(self.model_dim, 384)

# Load the saved weights
self.dim_transform.load_state_dict(save_data['state_dict'])

if self.device.type == "cuda" and torch.cuda.is_available():
self.dim_transform = self.dim_transform.to(self.device)
else:
self.dim_transform = self.dim_transform.to("cpu")

logger.info(f"Loaded dimension transform from {dim_transform_path}")
return True
elif not has_transform:
if self.dim_transform is None:
logger.info("No dimension transform needed (cached and current both 384-dim)")
return True
else:
logger.warning("Dimension transform state mismatch")
return False

with open(embedding_model_info_path,'w') as f:
json.dump(embedding_model_info, f)
except Exception as e:
logger.error(f"Error loading dimension transform: {e}")
return False

return False

logger.error(f"Error saving embedding model info: {e}")
def _build_node_index(self):
"""Build FAISS index for all nodes and cache embeddings"""
nodes = list(self.graph.nodes())
Expand Down Expand Up @@ -1115,8 +1029,6 @@ def _load_indices(self):
except Exception as e:
logger.warning(f"Warning: Failed to load relation embeddings: {e}")

# Load dimension transform if available
self._load_dim_transform()

# Populate maps if all necessary data is loaded
if self.node_map and self.node_embeddings is not None:
Expand Down Expand Up @@ -1329,12 +1241,6 @@ def _nodes_to_text(self, nodes: List[str]) -> str:

return "\n".join(text_parts)

def transform_vector(self, vector: torch.Tensor) -> torch.Tensor:
"""Transform vector dimensions if needed"""
if self.dim_transform is not None:
return self.dim_transform(vector)
return vector

def _calculate_triple_relevance_scores(self, query_embed: torch.Tensor, triples: List[Tuple[str, str, str]], threshold: float = 0.3, top_k: int = 10) -> List[Tuple[str, str, str, float]]:
"""
Calculate relevance scores for triples and filter out low-relevance ones using FAISS.
Expand All @@ -1356,7 +1262,6 @@ def _calculate_triple_relevance_scores(self, query_embed: torch.Tensor, triples:
return []

# Transform query embedding for FAISS search
query_embed = self.transform_vector(query_embed)
query_embed_np = query_embed.cpu().detach().numpy().reshape(1, -1)

# Normalize query embedding for FAISS search
Expand Down Expand Up @@ -1424,4 +1329,3 @@ def __del__(self):
self.save_embedding_cache()
except Exception as e:
logger.warning(f"Error during __del__ saving embedding cache: {type(e).__name__}: {e}")

3 changes: 3 additions & 0 deletions retriever/faiss_cache_new/demo/embedding_model_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"model_name": "all-MiniLM-L6-v2"
}