diff --git a/delphi/__main__.py b/delphi/__main__.py index 16f0c557..bf24a557 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -10,7 +10,7 @@ from simple_parsing import ArgumentParser from torch import Tensor from transformers import ( - AutoModel, + AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, @@ -27,7 +27,12 @@ from delphi.latents.neighbours import NeighbourCalculator from delphi.log.result_analysis import log_results from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator +from delphi.scorers import ( + DetectionScorer, + FuzzingScorer, + OpenAISimulator, + SurprisalInterventionScorer, +) from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders from delphi.utils import assert_type, load_tokenized_data @@ -40,7 +45,7 @@ def load_artifacts(run_cfg: RunConfig): else: dtype = "auto" - model = AutoModel.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( run_cfg.model, device_map={"": "cuda"}, quantization_config=( @@ -118,6 +123,8 @@ async def process_cache( hookpoints: list[str], tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, latent_range: Tensor | None, + model, + hookpoint_to_sparse_encode, ): """ Converts SAE latent activations in on-disk cache in the `latents_path` directory @@ -219,6 +226,12 @@ def none_postprocessor(result): ) ) + def custom_serializer(obj): + """A custom serializer for orjson to handle specific types.""" + if isinstance(obj, Tensor): + return obj.tolist() + raise TypeError + # Builds the record from result returned by the pipeline def scorer_preprocess(result): if isinstance(result, list): @@ -230,11 +243,18 @@ def scorer_preprocess(result): return record # Saves the score to a file - def scorer_postprocess(result, score_dir): + # In your __main__.py file + + def scorer_postprocess(result, score_dir, scorer_name=None): + if isinstance(result, list): + if not result: + return + result = result[0] + safe_latent_name = str(result.record.latent).replace("/", "--") with open(score_dir / f"{safe_latent_name}.txt", "wb") as f: - f.write(orjson.dumps(result.score)) + f.write(orjson.dumps(result.score, default=custom_serializer)) scorers = [] for scorer_name in run_cfg.scorers: @@ -257,6 +277,16 @@ def scorer_postprocess(result, score_dir): verbose=run_cfg.verbose, log_prob=run_cfg.log_probs, ) + + elif scorer_name == "surprisal_intervention": + scorer = SurprisalInterventionScorer( + model, + hookpoint_to_sparse_encode, + hookpoints=run_cfg.hookpoints, + n_examples_shown=run_cfg.num_examples_per_scorer_prompt, + verbose=run_cfg.verbose, + log_prob=run_cfg.log_probs, + ) else: raise ValueError(f"Scorer {scorer_name} not supported") @@ -396,6 +426,8 @@ async def run( hookpoints, hookpoint_to_sparse_encode, model, transcode = load_artifacts(run_cfg) tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token) + model.tokenizer = tokenizer + nrh = assert_type( dict, non_redundant_hookpoints( @@ -412,7 +444,6 @@ async def run( transcode, ) - del model, hookpoint_to_sparse_encode if run_cfg.constructor_cfg.non_activating_source == "neighbours": nrh = assert_type( list, @@ -445,8 +476,12 @@ async def run( nrh, tokenizer, latent_range, + model, + hookpoint_to_sparse_encode, ) + del model, hookpoint_to_sparse_encode + if run_cfg.verbose: log_results(scores_path, visualize_path, run_cfg.hookpoints, run_cfg.scorers) diff --git a/delphi/config.py b/delphi/config.py index 6e49b09d..b20c3324 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -148,18 +148,14 @@ class RunConfig(Serializable): the default single token explainer, and 'none' for no explanation generation.""" scorers: list[str] = list_field( - choices=[ - "fuzz", - "detection", - "simulation", - ], + choices=["fuzz", "detection", "simulation", "surprisal_intervention"], default=[ "fuzz", "detection", ], ) - """Scorer methods to score latent explanations. Options are 'fuzz', 'detection', and - 'simulation'.""" + """Scorer methods to score latent explanations. Options are 'fuzz', 'detection', + 'simulation' and 'surprisal_intervention'.""" name: str = "" """The name of the run. Results are saved in a directory with this name.""" diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index 0f4ff94d..ca08ffaa 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -157,6 +157,13 @@ class LatentRecord: """Frequency of the latent. Number of activations in a context per total number of contexts.""" + @property + def feature_id(self) -> int: + """ + Returns the unique feature index for this latent. + """ + return self.latent.latent_index + @property def max_activation(self) -> float: """ diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 9937bd96..4852e8d6 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -14,7 +14,20 @@ def plot_firing_vs_f1( ) -> None: out_dir.mkdir(parents=True, exist_ok=True) for module, module_df in latent_df.groupby("module"): + + if "firing_count" not in module_df.columns: + print( + f"""WARNING: 'firing_count' column not found for module {module}. + Skipping plot.""" + ) + continue + module_df = module_df.copy() + # Filter out rows where f1_score is NaN to avoid errors in plotting + module_df = module_df[module_df["f1_score"].notna()] + if module_df.empty: + continue + module_df["firing_rate"] = module_df["firing_count"] / num_tokens fig = px.scatter(module_df, x="firing_rate", y="f1_score", log_x=True) fig.update_layout( @@ -26,30 +39,32 @@ def plot_firing_vs_f1( def import_plotly(): """Import plotly with mitigiation for MathJax bug.""" try: - import plotly.express as px # type: ignore - import plotly.io as pio # type: ignore + import plotly.express as px + import plotly.io as pio except ImportError: raise ImportError( "Plotly is not installed.\n" "Please install it using `pip install plotly`, " "or install the `[visualize]` extra." ) - pio.kaleido.scope.mathjax = None # https://github.com/plotly/plotly.py/issues/3469 + pio.kaleido.scope.mathjax = None return px def compute_auc(df: pd.DataFrame) -> float | None: - if not df.probability.nunique(): - return None valid_df = df[df.probability.notna()] - - return roc_auc_score(valid_df.activating, valid_df.probability) # type: ignore + if valid_df.probability.nunique() <= 1: + return None + return roc_auc_score(valid_df.activating, valid_df.probability) def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): out_dir.mkdir(exist_ok=True, parents=True) for label in df["score_type"].unique(): + # Filter out surprisal_intervention as 'accuracy' is not relevant for it + if label == "surprisal_intervention": + continue fig = px.histogram( df[df["score_type"] == label], x="accuracy", @@ -60,11 +75,14 @@ def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): def plot_roc_curve(df: pd.DataFrame, out_dir: Path): - if not df.probability.nunique(): - return - # filter out NANs valid_df = df[df.probability.notna()] + if ( + valid_df.empty + or valid_df.activating.nunique() <= 1 + or valid_df.probability.nunique() <= 1 + ): + return fpr, tpr, _ = roc_curve(valid_df.activating, valid_df.probability) auc = roc_auc_score(valid_df.activating, valid_df.probability) @@ -85,19 +103,27 @@ def plot_roc_curve(df: pd.DataFrame, out_dir: Path): def compute_confusion(df: pd.DataFrame, threshold: float = 0.5) -> dict: df_valid = df[df["prediction"].notna()] - act = df_valid["activating"].astype(bool) + if df_valid.empty: + return dict( + true_positives=0, + true_negatives=0, + false_positives=0, + false_negatives=0, + total_examples=0, + total_positives=0, + total_negatives=0, + failed_count=len(df), + ) + act = df_valid["activating"].astype(bool) total = len(df_valid) pos = act.sum() neg = total - pos - tp = ((df_valid.prediction >= threshold) & act).sum() tn = ((df_valid.prediction < threshold) & ~act).sum() fp = ((df_valid.prediction >= threshold) & ~act).sum() fn = ((df_valid.prediction < threshold) & act).sum() - assert fp <= neg and tn <= neg and tp <= pos and fn <= pos - return dict( true_positives=tp, true_negatives=tn, @@ -106,26 +132,22 @@ def compute_confusion(df: pd.DataFrame, threshold: float = 0.5) -> dict: total_examples=total, total_positives=pos, total_negatives=neg, - failed_count=len(df_valid) - total, + failed_count=len(df) - len(df_valid), ) def compute_classification_metrics(conf: dict) -> dict: - tp = conf["true_positives"] - tn = conf["true_negatives"] - fp = conf["false_positives"] - fn = conf["false_negatives"] - total = conf["total_examples"] - pos = conf["total_positives"] - neg = conf["total_negatives"] - - assert pos + neg == total, "pos + neg must equal total" + tp, tn, fp, fn = ( + conf["true_positives"], + conf["true_negatives"], + conf["false_positives"], + conf["false_negatives"], + ) + pos, neg = conf["total_positives"], conf["total_negatives"] - # accuracy = (tp + tn) / total if total > 0 else 0 balanced_accuracy = ( (tp / pos if pos > 0 else 0) + (tn / neg if neg > 0 else 0) ) / 2 - precision = tp / (tp + fp) if tp + fp > 0 else 0 recall = tp / pos if pos > 0 else 0 f1 = ( @@ -141,11 +163,6 @@ def compute_classification_metrics(conf: dict) -> dict: true_negative_rate=tn / neg if neg > 0 else 0, false_positive_rate=fp / neg if neg > 0 else 0, false_negative_rate=fn / pos if pos > 0 else 0, - total_examples=total, - total_positives=pos, - total_negatives=neg, - positive_class_ratio=pos / total if total > 0 else 0, - negative_class_ratio=neg / total if total > 0 else 0, ) @@ -153,27 +170,35 @@ def load_data(scores_path: Path, modules: list[str]): """Load all on-disk data into a single DataFrame.""" def parse_score_file(path: Path) -> pd.DataFrame: - """ - Load a score file and return a raw DataFrame - """ try: data = orjson.loads(path.read_bytes()) except orjson.JSONDecodeError: print(f"Error decoding JSON from {path}. Skipping file.") return pd.DataFrame() + if not isinstance(data, list): + print( + f"""Warning: Expected a list of results in {path}, + but found {type(data)}. + Skipping file.""" + ) + return pd.DataFrame() + latent_idx = int(path.stem.split("latent")[-1]) + # Updated to extract all possible keys safely using .get() return pd.DataFrame( [ { - "text": "".join(ex["str_tokens"]), - "distance": ex["distance"], - "activating": ex["activating"], - "prediction": ex["prediction"], - "probability": ex["probability"], - "correct": ex["correct"], - "activations": ex["activations"], + "text": "".join(ex.get("str_tokens", [])), + "distance": ex.get("distance"), + "activating": ex.get("activating"), + "prediction": ex.get("prediction"), + "probability": ex.get("probability"), + "correct": ex.get("correct"), + "activations": ex.get("activations"), + "final_score": ex.get("final_score"), + "avg_kl_divergence": ex.get("avg_kl_divergence"), "latent_idx": latent_idx, } for ex in data @@ -187,68 +212,30 @@ def parse_score_file(path: Path) -> pd.DataFrame: print(f"Missing modules: {[m for m in modules if m not in counts]}") counts = None - # Collect per-latent data latent_dfs = [] for score_type_dir in scores_path.iterdir(): if not score_type_dir.is_dir(): continue for module in modules: for file in score_type_dir.glob(f"*{module}*"): - latent_idx = int(file.stem.split("latent")[-1]) - latent_df = parse_score_file(file) + if latent_df.empty: + continue latent_df["score_type"] = score_type_dir.name latent_df["module"] = module - latent_df["latent_idx"] = latent_idx if counts: + latent_idx = latent_df["latent_idx"].iloc[0] latent_df["firing_count"] = ( counts[module][latent_idx].item() - if latent_idx in counts[module] + if module in counts and latent_idx in counts[module] else None ) - latent_dfs.append(latent_df) - return pd.concat(latent_dfs, ignore_index=True), counts - + if not latent_dfs: + return pd.DataFrame(), counts -def frequency_weighted_f1( - df: pd.DataFrame, counts: dict[str, torch.Tensor] -) -> float | None: - rows = [] - for (module, latent_idx), grp in df.groupby(["module", "latent_idx"]): - f1 = compute_classification_metrics(compute_confusion(grp))["f1_score"] - fire = counts[module][latent_idx].item() - rows.append( - { - "module": module, - "latent_idx": latent_idx, - "f1_score": f1, - "firing_count": fire, - } - ) - - latent_df = pd.DataFrame(rows) - - per_module_f1 = [] - for module in latent_df["module"].unique(): - module_df = latent_df[latent_df["module"] == module] - - firing_weights = counts[module][module_df["latent_idx"]].float() - total_weight = firing_weights.sum() - if total_weight == 0: - continue - - f1_tensor = torch.as_tensor(module_df["f1_score"].values, dtype=torch.float32) - module_f1 = (f1_tensor * firing_weights).sum() / firing_weights.sum() - per_module_f1.append(module_f1) - - overall_frequency_weighted_f1 = torch.stack(per_module_f1).mean() - return ( - overall_frequency_weighted_f1.item() - if not overall_frequency_weighted_f1.isnan() - else None - ) + return pd.concat(latent_dfs, ignore_index=True), counts def get_agg_metrics( @@ -256,6 +243,10 @@ def get_agg_metrics( ) -> pd.DataFrame: processed_rows = [] for score_type, group_df in latent_df.groupby("score_type"): + # For surprisal_intervention, we don't compute classification metrics + if score_type == "surprisal_intervention": + continue + conf = compute_confusion(group_df) class_m = compute_classification_metrics(conf) auc = compute_auc(group_df) @@ -290,94 +281,66 @@ def log_results( import_plotly() latent_df, counts = load_data(scores_path, modules) - latent_df = latent_df[latent_df["score_type"].isin(scorer_names)] - latent_df = add_latent_f1(latent_df) - - plot_firing_vs_f1( - latent_df, num_tokens=10_000_000, out_dir=viz_path, run_label=scores_path.name - ) - if latent_df.empty: - print("No data found") + print("No data to analyze.") return - dead = sum((counts[m] == 0).sum().item() for m in modules) - print(f"Number of dead features: {dead}") - print(f"Number of interpreted live features: {len(latent_df)}") + latent_df = latent_df[latent_df["score_type"].isin(scorer_names)] - # Load constructor config for run - with open(scores_path.parent / "run_config.json", "r") as f: - run_cfg = orjson.loads(f.read()) - constructor_cfg = run_cfg.get("constructor_cfg", {}) - min_examples = constructor_cfg.get("min_examples", None) - print("min examples", min_examples) + # Separate the dataframes for different processing + classification_df = latent_df[latent_df["score_type"] != "surprisal_intervention"] + surprisal_df = latent_df[latent_df["score_type"] == "surprisal_intervention"] + + if not classification_df.empty: + classification_df = add_latent_f1(classification_df) + if counts: + plot_firing_vs_f1( + classification_df, + num_tokens=10_000_000, + out_dir=viz_path, + run_label=scores_path.name, + ) + plot_roc_curve(classification_df, viz_path) + processed_df = get_agg_metrics(classification_df, counts) + plot_accuracy_hist(processed_df, viz_path) - if min_examples is not None: - uninterpretable_features = sum( - [(counts[m] < min_examples).sum() for m in modules] - ) - print( - f"Number of features below the interpretation firing" - f" count threshold: {uninterpretable_features}" - ) + if counts: + dead = sum((counts[m] == 0).sum().item() for m in modules) + print(f"Number of dead features: {dead}") - plot_roc_curve(latent_df, viz_path) + for score_type in latent_df["score_type"].unique(): - processed_df = get_agg_metrics(latent_df, counts) + if score_type == "surprisal_intervention": + # Drop duplicates since score is per-latent, not per-example + unique_latents = surprisal_df.drop_duplicates( + subset=["module", "latent_idx"] + ) + avg_score = unique_latents["final_score"].mean() + avg_kl = unique_latents["avg_kl_divergence"].mean() - plot_accuracy_hist(processed_df, viz_path) + print(f"\n--- {score_type.title()} Metrics ---") + print(f"Average Normalized Score: {avg_score:.3f}") + print(f"Average KL Divergence: {avg_kl:.3f}") - for score_type in processed_df.score_type.unique(): - score_type_summary = processed_df[processed_df.score_type == score_type].iloc[0] - print(f"\n--- {score_type.title()} Metrics ---") - print(f"Class-Balanced Accuracy: {score_type_summary['accuracy']:.3f}") - print(f"F1 Score: {score_type_summary['f1_score']:.3f}") - print(f"Frequency-Weighted F1 Score: {score_type_summary['weighted_f1']:.3f}") - print( - "Note: the frequency-weighted F1 score is computed over each" - " hookpoint and averaged" - ) - print(f"Precision: {score_type_summary['precision']:.3f}") - print(f"Recall: {score_type_summary['recall']:.3f}") - # Only print AUC if unbalanced AUC is not -1. - if score_type_summary["auc"] is not None: - print(f"AUC: {score_type_summary['auc']:.3f}") else: - print("Logits not available.") - - fractions_failed = [ - score_type_summary["failed_count"] - / ( - ( - score_type_summary["total_examples"] - + score_type_summary["failed_count"] - ) - ) - ] - print( - f"""Average fraction of failed examples: \ -{sum(fractions_failed) / len(fractions_failed)}""" - ) + if not classification_df.empty: + score_type_summary = processed_df[ + processed_df.score_type == score_type + ].iloc[0] + print(f"\n--- {score_type.title()} Metrics ---") + print(f"Class-Balanced Accuracy: {score_type_summary['accuracy']:.3f}") + print(f"F1 Score: {score_type_summary['f1_score']:.3f}") + + if counts and score_type_summary["weighted_f1"] is not None: + print( + f"""Frequency-Weighted F1 Score: + {score_type_summary['weighted_f1']:.3f}""" + ) - print("\nConfusion Matrix:") - print( - f"True Positive Rate: {score_type_summary['true_positive_rate']:.3f} " - f"({score_type_summary['true_positives'].sum()})" - ) - print( - f"True Negative Rate: {score_type_summary['true_negative_rate']:.3f} " - f"({score_type_summary['true_negatives'].sum()})" - ) - print( - f"False Positive Rate: {score_type_summary['false_positive_rate']:.3f} " - f"({score_type_summary['false_positives'].sum()})" - ) - print( - f"False Negative Rate: {score_type_summary['false_negative_rate']:.3f} " - f"({score_type_summary['false_negatives'].sum()})" - ) + print(f"Precision: {score_type_summary['precision']:.3f}") + print(f"Recall: {score_type_summary['recall']:.3f}") - print("\nClass Distribution:") - print(f"""Positives: {score_type_summary['total_positives'].sum():.0f}""") - print(f"""Negatives: {score_type_summary['total_negatives'].sum():.0f}""") - print(f"Total: {score_type_summary['total_examples'].sum():.0f}") + if score_type_summary["auc"] is not None: + print(f"AUC: {score_type_summary['auc']:.3f}") + else: + print("AUC not available.") diff --git a/delphi/scorers/__init__.py b/delphi/scorers/__init__.py index 747db837..1191688c 100644 --- a/delphi/scorers/__init__.py +++ b/delphi/scorers/__init__.py @@ -3,6 +3,7 @@ from .classifier.intruder import IntruderScorer from .embedding.embedding import EmbeddingScorer from .embedding.example_embedding import ExampleEmbeddingScorer +from .intervention.surprisal_intervention_scorer import SurprisalInterventionScorer from .scorer import Scorer from .simulator.oai_simulator import OpenAISimulator from .surprisal.surprisal import SurprisalScorer @@ -16,4 +17,5 @@ "EmbeddingScorer", "IntruderScorer", "ExampleEmbeddingScorer", + "SurprisalInterventionScorer", ] diff --git a/delphi/scorers/intervention/__init__.py b/delphi/scorers/intervention/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py new file mode 100644 index 00000000..b47d61a7 --- /dev/null +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -0,0 +1,411 @@ +import copy +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer + +from ...latents import LatentRecord +from ..scorer import Scorer, ScorerResult + + +@dataclass +class SurprisalInterventionResult: + """ + Detailed results from the SurprisalInterventionScorer. + + Attributes: + score: The final computed score. + avg_kl: The average KL divergence between clean & intervened + next-token distributions. + explanation: The explanation string that was scored. + """ + + score: float + avg_kl: float + explanation: str + + +class SurprisalInterventionScorer(Scorer): + """ + Implements the Surprisal / Log-Probability Intervention Scorer. + + This scorer evaluates an explanation for a model's latent feature by measuring + how much an intervention in the feature's direction increases the model's belief + (log-probability) in the explanation. The change in log-probability is normalized + by the intervention's strength, measured by the KL divergence between the clean + and intervened next-token distributions. + + Reference: Paulo et al., "Automatically Interpreting Millions of Features in LLMs" + (https://arxiv.org/pdf/2410.13928), Section 3.3.5[cite: 206, 207]. + + Pipeline: + 1. For a small set of activating prompts: + a. Generate a continuation and get the next-token distribution ("clean"). + b. Add directional vector for the feature to the activations ("intervened"). + 2. Compute the log-probability of the explanation conditioned on both the clean + and intervened generated texts: log P(explanation | text)[cite: 209]. + 3. Compute KL divergence between the clean & intervened next-token distributions. + 4. The final score is the mean change in explanation log-prob, divided by the + mean KL divergence: + score = mean(log_prob_intervened - log_prob_clean) / (mean_KL + ε). + """ + + name = "surprisal_intervention" + + def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): + """ + Args: + subject_model: The language model to generate from and score with. + explainer_model: A model (e.g., an SAE) used to get feature directions. + **kwargs: Configuration options. + strength (float): The magnitude of the intervention. Default: 5.0. + num_prompts (int): Number of activating examples to test. Default: 3. + max_new_tokens (int): Max tokens to generate for continuations. + hookpoint (str): The module name (e.g., 'transformer.h.10.mlp') + for the intervention. + """ + self.subject_model = subject_model + self.explainer_model = explainer_model + self.strength = float(kwargs.get("strength", 5.0)) + self.num_prompts = int(kwargs.get("num_prompts", 3)) + self.max_new_tokens = int(kwargs.get("max_new_tokens", 20)) + self.hookpoints = kwargs.get("hookpoints") + + if len(self.hookpoints): + self.hookpoint_str = self.hookpoints[0] + + if hasattr(subject_model, "tokenizer"): + self.tokenizer = subject_model.tokenizer + else: + self.tokenizer = AutoTokenizer.from_pretrained("gpt2") + + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.subject_model.config.pad_token_id = self.tokenizer.eos_token_id + + def _get_device(self) -> torch.device: + """Safely gets the device of the subject model.""" + try: + return next(self.subject_model.parameters()).device + except StopIteration: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def _find_layer(self, model: Any, name: str) -> torch.nn.Module: + """Resolves a module by its dotted path name.""" + if name is None: + raise ValueError("Hookpoint name is not configured.") + current = model + for part in name.split("."): + if part.isdigit(): + current = current[int(part)] + else: + current = getattr(current, part) + return current + + def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: + """ + Dynamically finds the correct model prefix and resolves the full hookpoint path. + + This makes the scorer agnostic to different transformer architectures. + """ + parts = hookpoint_str.split(".") + + is_valid_format = ( + len(parts) == 3 + and parts[0] in ["layers", "h"] + and parts[1].isdigit() + and parts[2] in ["mlp", "attention", "attn"] + ) + + if not is_valid_format: + if len(parts) == 1 and hasattr(model, hookpoint_str): + return getattr(model, hookpoint_str) + raise ValueError( + f"""Hookpoint string '{hookpoint_str}' is not in a recognized format + like 'layers.6.mlp'.""" + ) + + # Heuristically find the model prefix. + prefix = None + for p in ["gpt_neox", "transformer", "model"]: + if hasattr(model, p): + candidate_body = getattr(model, p) + if hasattr(candidate_body, parts[0]): + prefix = p + break + + full_path = f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str + + try: + return self._find_layer(model, full_path) + except AttributeError as e: + raise AttributeError( + f"""Could not resolve path '{full_path}'. + Model structure might be unexpected. + Original error: {e}""" + ) + + def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: + """ + Function used for formatting results to run smoothly in the delphi pipeline + """ + sanitized = [] + for ex in examples: + if hasattr(ex, "str_tokens") and ex.str_tokens is not None: + sanitized.append({"str_tokens": ex.str_tokens}) + + elif isinstance(ex, dict) and "str_tokens" in ex: + sanitized.append(ex) + + elif isinstance(ex, str): + sanitized.append({"str_tokens": [ex]}) + + elif isinstance(ex, (list, tuple)): + sanitized.append({"str_tokens": [str(t) for t in ex]}) + + else: + sanitized.append({"str_tokens": [str(ex)]}) + + return sanitized + + async def __call__(self, record: LatentRecord) -> ScorerResult: + + record_copy = copy.deepcopy(record) + + raw_examples = getattr(record_copy, "test", []) or [] + + if not raw_examples: + result = SurprisalInterventionResult( + score=0.0, avg_kl=0.0, explanation=record_copy.explanation + ) + return ScorerResult(record=record, score=result) + + examples = self._sanitize_examples(raw_examples) + + record_copy.test = examples + record_copy.examples = examples + record_copy.train = examples + + prompts = ["".join(ex["str_tokens"]) for ex in examples[: self.num_prompts]] + + total_diff = 0.0 + total_kl = 0.0 + n = 0 + + for prompt in prompts: + clean_text, clean_logp_dist = ( + await self._generate_with_and_without_intervention( + prompt, record_copy, intervene=False + ) + ) + int_text, int_logp_dist = ( + await self._generate_with_and_without_intervention( + prompt, record_copy, intervene=True + ) + ) + + logp_clean = await self._score_explanation( + clean_text, record_copy.explanation + ) + logp_int = await self._score_explanation(int_text, record_copy.explanation) + + p_clean = torch.exp(clean_logp_dist) + kl_div = F.kl_div( + int_logp_dist, p_clean, reduction="sum", log_target=False + ).item() + + total_diff += logp_int - logp_clean + total_kl += kl_div + n += 1 + + avg_diff = total_diff / n if n > 0 else 0.0 + avg_kl = total_kl / n if n > 0 else 0.0 + final_score = avg_diff / (avg_kl + 1e-9) if n > 0 else 0.0 + + final_output_list = [] + for ex in examples[: self.num_prompts]: + final_output_list.append( + { + "str_tokens": ex["str_tokens"], + "final_score": final_score, + "avg_kl_divergence": avg_kl, + # Add placeholder keys that the parser expects, with default values. + "distance": None, + "activating": None, + "prediction": None, + "correct": None, + "probability": None, + "activations": None, + } + ) + return ScorerResult(record=record_copy, score=final_output_list) + + async def _generate_with_and_without_intervention( + self, prompt: str, record: LatentRecord, intervene: bool + ) -> Tuple[str, torch.Tensor]: + """ + Generates a text continuation and returns the next-token log-probabilities. + + If `intervene` is True, it adds a feature direction to the activations at the + specified hookpoint before generation. + + Returns: + A tuple containing: + - The generated text (string). + - The log-probability distribution for the token immediately following + the prompt (Tensor). + """ + device = self._get_device() + enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True) + input_ids = enc["input_ids"].to(device) + + hooks = [] + if intervene: + + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + if hookpoint_str is None: + raise ValueError("No hookpoint string specified for intervention.") + + layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) + + direction = self._get_intervention_direction(record).to(device) + direction = direction.unsqueeze(0).unsqueeze( + 0 + ) # Shape for broadcasting: [1, 1, D] + + def hook_fn(module, inp, out): + hidden_states = out[0] if isinstance(out, tuple) else out + + # Apply intervention to the last token's hidden state + hidden_states[:, -1:, :] += self.strength * direction + + # Return the modified activations in their original format + if isinstance(out, tuple): + return (hidden_states,) + out[1:] + return hidden_states + + hooks.append(layer_to_hook.register_forward_hook(hook_fn)) + + try: + with torch.no_grad(): + # 1. Get next-token logits for KL divergence calculation + outputs = self.subject_model(input_ids) + next_token_logits = outputs.logits[0, -1, :] + log_probs_next_token = F.log_softmax(next_token_logits, dim=-1) + + # 2. Generate the full text continuation + gen_ids = self.subject_model.generate( + input_ids, + max_new_tokens=self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.pad_token_id, + ) + generated_text = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True) + finally: + for h in hooks: + h.remove() + + return generated_text, log_probs_next_token.cpu() + + async def _score_explanation(self, generated_text: str, explanation: str) -> float: + """Computes log P(explanation | generated_text) under the subject model.""" + device = self._get_device() + + # Create the full input sequence: context + explanation + context_enc = self.tokenizer(generated_text, return_tensors="pt") + explanation_enc = self.tokenizer(explanation, return_tensors="pt") + + full_input_ids = torch.cat( + [context_enc.input_ids, explanation_enc.input_ids], dim=1 + ).to(device) + + with torch.no_grad(): + outputs = self.subject_model(full_input_ids) + logits = outputs.logits + + # We only need to score the explanation part + context_len = context_enc.input_ids.shape[1] + # Get logits for positions that predict the explanation tokens + explanation_logits = logits[:, context_len - 1 : -1, :] + + # Get the target token IDs for the explanation + target_ids = explanation_enc.input_ids.to(device) + + log_probs = F.log_softmax(explanation_logits, dim=-1) + + # Gather the log-probabilities of the actual explanation tokens + token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1) + + return token_log_probs.sum().item() + + def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: + """ + Gets the feature direction vector, preferring an SAE if available, + otherwise falling back to estimating it from activations. + """ + # --- Fast Path: Try to get vector from an SAE-like explainer model --- + if self.explainer_model: + sae = None + candidate = self.explainer_model + if isinstance(self.explainer_model, dict): + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + candidate = self.explainer_model.get(hookpoint_str) + + if hasattr(candidate, "get_feature_vector"): + sae = candidate + elif hasattr(candidate, "sae") and hasattr( + candidate.sae, "get_feature_vector" + ): + sae = candidate.sae + + if sae: + direction = sae.get_feature_vector(record.feature_id) + if not isinstance(direction, torch.Tensor): + direction = torch.tensor(direction, dtype=torch.float32) + direction = direction.squeeze() + return F.normalize(direction, p=2, dim=0) + + # --- Fallback: Estimate direction from activating examples --- + return self._estimate_direction_from_examples(record) + + def _estimate_direction_from_examples(self, record: LatentRecord) -> torch.Tensor: + """Estimates an intervention direction by averaging activations.""" + device = self._get_device() + + examples = self._sanitize_examples(getattr(record, "test", []) or []) + if not examples: + hidden_dim = self.subject_model.config.hidden_size + return torch.zeros(hidden_dim, device=device) + + captured_activations = [] + + def capture_hook(module, inp, out): + hidden_states = out[0] if isinstance(out, tuple) else out + + captured_activations.append(hidden_states[:, -1, :].detach().cpu()) + + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) + handle = layer_to_hook.register_forward_hook(capture_hook) + + try: + for ex in examples[: min(8, self.num_prompts)]: + prompt = "".join(ex["str_tokens"]) + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( + device + ) + with torch.no_grad(): + self.subject_model(input_ids) + finally: + handle.remove() + + if not captured_activations: + hidden_dim = self.subject_model.config.hidden_size + return torch.zeros(hidden_dim, device=device) + + activations = torch.cat(captured_activations, dim=0).to(device) + direction = activations.mean(dim=0) + + return F.normalize(direction, p=2, dim=0)