From cf7ddca35122833ad862ffe9ee5da377b9c438d2 Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Thu, 28 Aug 2025 19:23:43 +0000 Subject: [PATCH 01/11] Surprisal intervention and config --- delphi/__main__.py | 59 ++- delphi/config.py | 7 +- delphi/latents/latents.py | 7 + delphi/scorers/__init__.py | 7 + delphi/scorers/intervention/__init__.py | 0 .../output_based_intervention_scorer.py | 141 ++++++ .../surprisal_intervention_scorer.py | 449 ++++++++++++++++++ 7 files changed, 663 insertions(+), 7 deletions(-) create mode 100644 delphi/scorers/intervention/__init__.py create mode 100644 delphi/scorers/intervention/output_based_intervention_scorer.py create mode 100644 delphi/scorers/intervention/surprisal_intervention_scorer.py diff --git a/delphi/__main__.py b/delphi/__main__.py index 16f0c557..46111bc1 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -5,12 +5,15 @@ from pathlib import Path from typing import Callable +from dataclasses import asdict + import orjson import torch from simple_parsing import ArgumentParser from torch import Tensor from transformers import ( AutoModel, + AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, @@ -27,7 +30,7 @@ from delphi.latents.neighbours import NeighbourCalculator from delphi.log.result_analysis import log_results from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator +from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator, InterventionScorer, LogProbInterventionScorer, SurprisalInterventionScorer from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders from delphi.utils import assert_type, load_tokenized_data @@ -40,7 +43,7 @@ def load_artifacts(run_cfg: RunConfig): else: dtype = "auto" - model = AutoModel.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( run_cfg.model, device_map={"": "cuda"}, quantization_config=( @@ -118,6 +121,8 @@ async def process_cache( hookpoints: list[str], tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, latent_range: Tensor | None, + model, + hookpoint_to_sparse_encode ): """ Converts SAE latent activations in on-disk cache in the `latents_path` directory @@ -218,6 +223,12 @@ def none_postprocessor(result): postprocess=none_postprocessor, ) ) + + def custom_serializer(obj): + """A custom serializer for orjson to handle specific types.""" + if isinstance(obj, Tensor): + return obj.tolist() + raise TypeError # Builds the record from result returned by the pipeline def scorer_preprocess(result): @@ -230,12 +241,22 @@ def scorer_preprocess(result): return record # Saves the score to a file - def scorer_postprocess(result, score_dir): + # In your __main__.py file + + def scorer_postprocess(result, score_dir, scorer_name=None): + if isinstance(result, list): + if not result: + return + result = result[0] + safe_latent_name = str(result.record.latent).replace("/", "--") with open(score_dir / f"{safe_latent_name}.txt", "wb") as f: - f.write(orjson.dumps(result.score)) + # This line now works universally. For other scorers, it saves their simple + # score. For surprisal_intervention, it saves the rich 'final_payload'. + f.write(orjson.dumps(result.score, default=custom_serializer)) + scorers = [] for scorer_name in run_cfg.scorers: scorer_path = scores_path / scorer_name @@ -257,6 +278,29 @@ def scorer_postprocess(result, score_dir): verbose=run_cfg.verbose, log_prob=run_cfg.log_probs, ) + elif scorer_name == "intervention": + scorer = InterventionScorer( + llm_client, + n_examples_shown=run_cfg.num_examples_per_scorer_prompt, + verbose=run_cfg.verbose, + log_prob=run_cfg.log_probs, + ) + elif scorer_name == "logprob_intervention": + scorer = LogProbInterventionScorer( + llm_client, + n_examples_shown=run_cfg.num_examples_per_scorer_prompt, + verbose=run_cfg.verbose, + log_prob=run_cfg.log_probs, + ) + elif scorer_name == "surprisal_intervention": + scorer = SurprisalInterventionScorer( + model, + hookpoint_to_sparse_encode, + hookpoints = run_cfg.hookpoints, + n_examples_shown=run_cfg.num_examples_per_scorer_prompt, + verbose=run_cfg.verbose, + log_prob=run_cfg.log_probs, + ) else: raise ValueError(f"Scorer {scorer_name} not supported") @@ -396,6 +440,8 @@ async def run( hookpoints, hookpoint_to_sparse_encode, model, transcode = load_artifacts(run_cfg) tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token) + model.tokenizer = tokenizer + nrh = assert_type( dict, non_redundant_hookpoints( @@ -412,7 +458,6 @@ async def run( transcode, ) - del model, hookpoint_to_sparse_encode if run_cfg.constructor_cfg.non_activating_source == "neighbours": nrh = assert_type( list, @@ -445,8 +490,12 @@ async def run( nrh, tokenizer, latent_range, + model, + hookpoint_to_sparse_encode ) + del model, hookpoint_to_sparse_encode + if run_cfg.verbose: log_results(scores_path, visualize_path, run_cfg.hookpoints, run_cfg.scorers) diff --git a/delphi/config.py b/delphi/config.py index 6e49b09d..0d2193c5 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -152,14 +152,17 @@ class RunConfig(Serializable): "fuzz", "detection", "simulation", + "intervention", + "logprob_intervention", + "surprisal_intervention" ], default=[ "fuzz", "detection", ], ) - """Scorer methods to score latent explanations. Options are 'fuzz', 'detection', and - 'simulation'.""" + """Scorer methods to score latent explanations. Options are 'fuzz', 'detection', + 'simulation' and 'intervention'.""" name: str = "" """The name of the run. Results are saved in a directory with this name.""" diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index 0f4ff94d..ca08ffaa 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -157,6 +157,13 @@ class LatentRecord: """Frequency of the latent. Number of activations in a context per total number of contexts.""" + @property + def feature_id(self) -> int: + """ + Returns the unique feature index for this latent. + """ + return self.latent.latent_index + @property def max_activation(self) -> float: """ diff --git a/delphi/scorers/__init__.py b/delphi/scorers/__init__.py index 747db837..ad84c15f 100644 --- a/delphi/scorers/__init__.py +++ b/delphi/scorers/__init__.py @@ -6,6 +6,9 @@ from .scorer import Scorer from .simulator.oai_simulator import OpenAISimulator from .surprisal.surprisal import SurprisalScorer +from .intervention.intervention_scorer import InterventionScorer +from .intervention.logprob_intervention_scorer import LogProbInterventionScorer +from .intervention.surprisal_intervention_scorer import SurprisalInterventionScorer __all__ = [ "FuzzingScorer", @@ -16,4 +19,8 @@ "EmbeddingScorer", "IntruderScorer", "ExampleEmbeddingScorer", + "SurprisalInterventionScorer", + "InterventionScorer", + "LogProbInterventionScorer", + ] diff --git a/delphi/scorers/intervention/__init__.py b/delphi/scorers/intervention/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/delphi/scorers/intervention/output_based_intervention_scorer.py b/delphi/scorers/intervention/output_based_intervention_scorer.py new file mode 100644 index 00000000..9c706962 --- /dev/null +++ b/delphi/scorers/intervention/output_based_intervention_scorer.py @@ -0,0 +1,141 @@ +# Output-based intervention scorer (Gur-Arieh et al. 2025) +from dataclasses import dataclass +import torch +import torch.nn.functional as F +import random +from ...scorer import Scorer, ScorerResult +from ...latents import LatentRecord, ActivatingExample +from transformers import PreTrainedModel + +@dataclass +class OutputInterventionResult: + """Result of output-based intervention evaluation.""" + score: int # +1 if target set chosen, -1 otherwise + explanation: str + example_text: str + +class OutputInterventionScorer(Scorer): + """ + Output-based evaluation by steering (clamping) the feature and using a judge LLM + to pick which outputs best match the description:contentReference[oaicite:5]{index=5}. + We generate texts for the target feature and for a few random features, + then ask the judge to choose the matching set. + """ + name = "output_intervention" + + def __init__(self, subject_model: PreTrainedModel, explainer_model, **kwargs): + self.subject_model = subject_model + self.explainer_model = explainer_model + self.steering_strength = kwargs.get("strength", 5.0) + self.num_prompts = kwargs.get("num_prompts", 3) + self.num_random = kwargs.get("num_random_features", 2) + self.hookpoint = kwargs.get("hookpoint", "transformer.h.6.mlp") + self.tokenizer = getattr(subject_model, "tokenizer", None) + + async def __call__(self, record: LatentRecord) -> ScorerResult: + # Prepare activating prompts + examples = [ex for ex in record.test if isinstance(ex, ActivatingExample)] + random.shuffle(examples) + prompts = ["".join(str(t) for t in ex.str_tokens) for ex in examples[:self.num_prompts]] + + # Generate text for the target feature + target_texts = [] + for p in prompts: + text, _ = await self._generate(p, record.feature_id, self.steering_strength) + target_texts.append(text) + + # Pick a few random feature IDs (avoid the target) + random_ids = [] + while len(random_ids) < self.num_random: + rid = random.randint(0, 999) + if rid != record.feature_id: + random_ids.append(rid) + + # Generate texts for random features + random_sets = [] + for fid in random_ids: + rand_texts = [] + for p in prompts: + text, _ = await self._generate(p, fid, self.steering_strength) + rand_texts.append(text) + random_sets.append(rand_texts) + + # Create prompt for judge LLM + judge_prompt = self._format_judge_prompt(record.explanation, target_texts, random_sets) + judge_response = await self._ask_judge(judge_prompt) + + # Parse judge response: check if target set was chosen + resp_lower = judge_response.lower() + if "target" in resp_lower or "set 1" in resp_lower: + score = 1 + elif "set 2" in resp_lower or "set 3" in resp_lower or "random" in resp_lower: + score = -1 + else: + score = 0 + + example_text = prompts[0] if prompts else "" + detailed = OutputInterventionResult( + score=score, + explanation=record.explanation, + example_text=example_text + ) + return ScorerResult(record=record, score=detailed) + + async def _generate(self, prompt: str, feature_id: int, strength: float): + """ + Generates text with the feature clamped (added to hidden state). + Returns the (partial) generated text and logits. + """ + tokenizer = self.tokenizer or __import__("transformers").AutoTokenizer.from_pretrained("gpt2") + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + # Forward hook to clamp feature activation + direction = self.explainer_model.get_feature_vector(feature_id) + def hook_fn(module, inp, out): + out[:, -1, :] = out[:, -1, :] + strength * direction.to(out.device) + return out + layer = self._find_layer(self.subject_model, self.hookpoint) + handle = layer.register_forward_hook(hook_fn) + + with torch.no_grad(): + outputs = self.subject_model(input_ids) + logits = outputs.logits[0, -1, :] + log_probs = F.log_softmax(logits, dim=-1) + handle.remove() + + text = tokenizer.decode(input_ids[0]) + return text, log_probs + + def _format_judge_prompt(self, explanation: str, target_texts: list, other_sets: list): + """ + Constructs a prompt for the judge LLM listing each set of texts + under the target feature and random features. + """ + prompt = f"Feature description: \"{explanation}\"\n" + prompt += "Which of the following sets of generated texts best matches this description?\n\n" + prompt += "Set 1 (target feature):\n" + for txt in target_texts: + prompt += f"- {txt}\n" + for i, rand_set in enumerate(other_sets, start=2): + prompt += f"\nSet {i} (random feature):\n" + for txt in rand_set: + prompt += f"- {txt}\n" + prompt += "\nAnswer (mention the set number or 'target'/'random'): " + return prompt + + async def _ask_judge(self, prompt: str) -> str: + """ + Queries a judge LLM (e.g., GPT-4) with the prompt. Stubbed here. + """ + # TODO: Implement actual LLM call to get response + return "" + + def _find_layer(self, model, name: str): + """Locate a module by its dotted name.""" + current = model + for attr in name.split('.'): + if attr.isdigit(): + current = current[int(attr)] + else: + current = getattr(current, attr) + return current diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py new file mode 100644 index 00000000..f3678c9d --- /dev/null +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -0,0 +1,449 @@ +# surprisal_intervention_scorer.py +import functools +import random +import copy +from dataclasses import dataclass +from typing import Any, List, Dict, Tuple + +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer + +# Assuming 'delphi' is your project structure. +# If not, you may need to adjust these relative imports. +from ..scorer import Scorer, ScorerResult +from ...latents import LatentRecord, ActivatingExample + +@dataclass +class SurprisalInterventionResult: + """ + Detailed results from the SurprisalInterventionScorer. + + Attributes: + score: The final computed score. + avg_kl: The average KL divergence between the clean and intervened next-token distributions. + explanation: The explanation string that was scored. + """ + score: float + avg_kl: float + explanation: str + + +class SurprisalInterventionScorer(Scorer): + """ + Implements the Surprisal / Log-Probability Intervention Scorer. + + This scorer evaluates an explanation for a model's latent feature by measuring + how much an intervention in the feature's direction increases the model's belief + (log-probability) in the explanation. The change in log-probability is normalized + by the intervention's strength, measured by the KL divergence between the clean + and intervened next-token distributions. + + Reference: Paulo et al., "Automatically Interpreting Millions of Features in Large Language Models" + (https://arxiv.org/pdf/2410.13928), Section 3.3.5[cite: 206, 207]. + + Pipeline: + 1. For a small set of activating prompts: + a. Generate a continuation and get the next-token distribution ("clean"). + b. Add a directional vector for the feature to the activations and repeat ("intervened"). + 2. Compute the log-probability of the explanation conditioned on both the clean + and intervened generated texts: log P(explanation | text)[cite: 209]. + 3. Compute the KL divergence between the clean and intervened next-token distributions[cite: 216]. + 4. The final score is the mean change in explanation log-prob, divided by the mean KL divergence: + score = mean(log_prob_intervened - log_prob_clean) / (mean_KL + ε)[cite: 209]. + """ + name = "surprisal_intervention" + + def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): + """ + Args: + subject_model: The language model to generate from and score with. + explainer_model: An optional model (e.g., an SAE) used to get feature directions. + **kwargs: Configuration options. + strength (float): The magnitude of the intervention. Default: 5.0. + num_prompts (int): Number of activating examples to test. Default: 3. + max_new_tokens (int): Max tokens to generate for continuations. Default: 20. + hookpoint (str): The module name (e.g., 'transformer.h.10.mlp') for the intervention. + """ + self.subject_model = subject_model + self.explainer_model = explainer_model + self.strength = float(kwargs.get("strength", 5.0)) + self.num_prompts = int(kwargs.get("num_prompts", 3)) + self.max_new_tokens = int(kwargs.get("max_new_tokens", 20)) + self.hookpoints = kwargs.get("hookpoints") + + if len(self.hookpoints): + self.hookpoint_str = self.hookpoints[0] + + # Ensure tokenizer is available + if hasattr(subject_model, "tokenizer"): + self.tokenizer = subject_model.tokenizer + else: + # Fallback to a standard tokenizer if not attached to the model + self.tokenizer = AutoTokenizer.from_pretrained("gpt2") + + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.subject_model.config.pad_token_id = self.tokenizer.eos_token_id + + def _get_device(self) -> torch.device: + """Safely gets the device of the subject model.""" + try: + return next(self.subject_model.parameters()).device + except StopIteration: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def _find_layer(self, model: Any, name: str) -> torch.nn.Module: + """Resolves a module by its dotted path name.""" + if name is None: + raise ValueError("Hookpoint name is not configured.") + current = model + for part in name.split("."): + if part.isdigit(): + current = current[int(part)] + else: + current = getattr(current, part) + return current + + def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: + """ + Dynamically finds the correct model prefix and resolves the full hookpoint path. + + This makes the scorer agnostic to different transformer architectures. + """ + parts = hookpoint_str.split('.') + + # 1. Validate the string format. + is_valid_format = ( + len(parts) == 3 and + parts[0] in ['layers', 'h'] and + parts[1].isdigit() and + parts[2] in ['mlp', 'attention', 'attn'] + ) + + if not is_valid_format: + # Fallback for simple block types at the top level, e.g. 'embed_in' + if len(parts) == 1 and hasattr(model, hookpoint_str): + return getattr(model, hookpoint_str) + raise ValueError(f"Hookpoint string '{hookpoint_str}' is not in a recognized format like 'layers.6.mlp'.") + # --- End of changes --- + + # 2. Heuristically find the model prefix. + prefix = None + for p in ["gpt_neox", "transformer", "model"]: + if hasattr(model, p): + candidate_body = getattr(model, p) + # Use parts[0] to get the layer block name ('layers' or 'h') + if hasattr(candidate_body, parts[0]): + prefix = p + break + + full_path = f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str + + # 3. Use the simple path finder to get the module. + try: + return self._find_layer(model, full_path) + except AttributeError as e: + raise AttributeError(f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}") + + + + + # def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: + # """Ensures examples are in a consistent format: a list of dictionaries with 'str_tokens'.""" + # sanitized = [] + # for ex in examples: + # if isinstance(ex, dict) and "str_tokens" in ex: + # sanitized.append(ex) + # elif hasattr(ex, "str_tokens"): + # sanitized.append({"str_tokens": [str(t) for t in ex.str_tokens]}) + # elif isinstance(ex, str): + # sanitized.append({"str_tokens": [ex]}) + # elif isinstance(ex, (list, tuple)): + # sanitized.append({"str_tokens": [str(t) for t in ex]}) + # else: + # sanitized.append({"str_tokens": [str(ex)]}) + # return sanitized + + + def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: + sanitized = [] + for ex in examples: + # --- NEW, MORE ROBUST LOGIC --- + # 1. Prioritize handling objects that have the data we need (like ActivatingExample) + if hasattr(ex, 'str_tokens') and ex.str_tokens is not None: + # This correctly handles ActivatingExample objects and similar structures. + # It extracts the string tokens instead of converting the whole object to a string. + sanitized.append({'str_tokens': ex.str_tokens}) + + # 2. Handle cases where the item is already a correct dictionary + elif isinstance(ex, dict) and "str_tokens" in ex: + sanitized.append(ex) + + # 3. Handle plain strings + elif isinstance(ex, str): + sanitized.append({"str_tokens": [ex]}) + + # 4. Handle lists/tuples of strings as a fallback + elif isinstance(ex, (list, tuple)): + sanitized.append({"str_tokens": [str(t) for t in ex]}) + + # 5. Handle any other unexpected type as a last resort + else: + sanitized.append({"str_tokens": [str(ex)]}) + + return sanitized + + + # def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: + + # sanitized = [] + # for i, ex in enumerate(examples): + + + # if isinstance(ex, dict) and "str_tokens" in ex: + # sanitized.append(ex) + + + # elif isinstance(ex, str): + # # This is the key conversion + # converted_ex = {"str_tokens": [ex]} + # sanitized.append(converted_ex) + + + # elif isinstance(ex, (list, tuple)): + # converted_ex = {"str_tokens": [str(t) for t in ex]} + # sanitized.append(converted_ex) + + # else: + # converted_ex = {"str_tokens": [str(ex)]} + # sanitized.append(converted_ex) + + # print("fin this") + # return sanitized + + async def __call__(self, record: LatentRecord) -> ScorerResult: + # --- MODIFICATION START --- + # 1. Create a deep copy to work on, ensuring we don't interfere + # with other parts of the pipeline that might use the original record. + record_copy = copy.deepcopy(record) + + # 2. Read the raw examples from our copy. + raw_examples = getattr(record_copy, "test", []) or [] + + if not raw_examples: + result = SurprisalInterventionResult(score=0.0, avg_kl=0.0, explanation=record_copy.explanation) + # Return the result with the original record since no changes were made. + return ScorerResult(record=record, score=result) + + # 3. Sanitize the examples. + examples = self._sanitize_examples(raw_examples) + + # 4. Overwrite the attributes on the copy with the clean data. + record_copy.test = examples + record_copy.examples = examples + record_copy.train = examples + + # Now, use the sanitized 'examples' and the 'record_copy' for all subsequent operations. + prompts = ["".join(ex["str_tokens"]) for ex in examples[:self.num_prompts]] + + total_diff = 0.0 + total_kl = 0.0 + n = 0 + + for prompt in prompts: + # Pass the clean record_copy to the generation methods. + clean_text, clean_logp_dist = await self._generate_with_and_without_intervention(prompt, record_copy, intervene=False) + int_text, int_logp_dist = await self._generate_with_and_without_intervention(prompt, record_copy, intervene=True) + + logp_clean = await self._score_explanation(clean_text, record_copy.explanation) + logp_int = await self._score_explanation(int_text, record_copy.explanation) + + p_clean = torch.exp(clean_logp_dist) + kl_div = F.kl_div(int_logp_dist, p_clean, reduction='sum', log_target=False).item() + + total_diff += logp_int - logp_clean + total_kl += kl_div + n += 1 + + avg_diff = total_diff / n if n > 0 else 0.0 + avg_kl = total_kl / n if n > 0 else 0.0 + final_score = avg_diff / (avg_kl + 1e-9) if n > 0 else 0.0 + + final_output_list = [] + for ex in examples[:self.num_prompts]: + final_output_list.append({ + "str_tokens": ex["str_tokens"], + # Add the final scores. These will be duplicated for each example. + "final_score": final_score, + "avg_kl_divergence": avg_kl, + # Add placeholder keys that the parser expects, with default values. + "distance": None, + "activating": None, + "prediction": None, + "correct": None, + "probability": None, + "activations": None, + }) + return ScorerResult(record=record_copy, score=final_output_list) + + async def _generate_with_and_without_intervention( + self, prompt: str, record: LatentRecord, intervene: bool + ) -> Tuple[str, torch.Tensor]: + """ + Generates a text continuation and returns the next-token log-probabilities. + + If `intervene` is True, it adds a feature direction to the activations at the + specified hookpoint before generation. + + Returns: + A tuple containing: + - The generated text (string). + - The log-probability distribution for the token immediately following the prompt (Tensor). + """ + device = self._get_device() + enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True) + input_ids = enc["input_ids"].to(device) + + hooks = [] + if intervene: + + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + if hookpoint_str is None: + raise ValueError("No hookpoint string specified for intervention.") + + # Resolve the string into the actual layer module. + layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) + + direction = self._get_intervention_direction(record).to(device) + direction = direction.unsqueeze(0).unsqueeze(0) # Shape for broadcasting: [1, 1, D] + + def hook_fn(module, inp, out): + # Gracefully handle both tuple and tensor outputs + hidden_states = out[0] if isinstance(out, tuple) else out + + # Apply intervention to the last token's hidden state + hidden_states[:, -1:, :] += self.strength * direction + + # Return the modified activations in their original format + if isinstance(out, tuple): + return (hidden_states,) + out[1:] + return hidden_states + + hooks.append(layer_to_hook.register_forward_hook(hook_fn)) + + try: + with torch.no_grad(): + # 1. Get next-token logits for KL divergence calculation + outputs = self.subject_model(input_ids) + next_token_logits = outputs.logits[0, -1, :] + log_probs_next_token = F.log_softmax(next_token_logits, dim=-1) + + # 2. Generate the full text continuation + gen_ids = self.subject_model.generate( + input_ids, + max_new_tokens=self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.pad_token_id, + ) + generated_text = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True) + finally: + for h in hooks: + h.remove() + + return generated_text, log_probs_next_token.cpu() + + async def _score_explanation(self, generated_text: str, explanation: str) -> float: + """Computes log P(explanation | generated_text) under the subject model.""" + device = self._get_device() + + # Create the full input sequence: context + explanation + context_enc = self.tokenizer(generated_text, return_tensors="pt") + explanation_enc = self.tokenizer(explanation, return_tensors="pt") + + full_input_ids = torch.cat([context_enc.input_ids, explanation_enc.input_ids], dim=1).to(device) + + with torch.no_grad(): + outputs = self.subject_model(full_input_ids) + logits = outputs.logits + + # We only need to score the explanation part + context_len = context_enc.input_ids.shape[1] + # Get logits for positions that predict the explanation tokens + explanation_logits = logits[:, context_len - 1:-1, :] + + # Get the target token IDs for the explanation + target_ids = explanation_enc.input_ids.to(device) + + log_probs = F.log_softmax(explanation_logits, dim=-1) + + # Gather the log-probabilities of the actual explanation tokens + token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1) + + return token_log_probs.sum().item() + + def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: + """ + Gets the feature direction vector, preferring an SAE if available, + otherwise falling back to estimating it from activations. + """ + # --- Fast Path: Try to get vector from an SAE-like explainer model --- + if self.explainer_model: + sae = None + candidate = self.explainer_model + if isinstance(self.explainer_model, dict): + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + candidate = self.explainer_model.get(hookpoint_str) + + if hasattr(candidate, 'get_feature_vector'): + sae = candidate + elif hasattr(candidate, 'sae') and hasattr(candidate.sae, 'get_feature_vector'): + sae = candidate.sae + + if sae: + direction = sae.get_feature_vector(record.feature_id) + if not isinstance(direction, torch.Tensor): + direction = torch.tensor(direction, dtype=torch.float32) + direction = direction.squeeze() + return F.normalize(direction, p=2, dim=0) + + # --- Fallback: Estimate direction from activating examples --- + return self._estimate_direction_from_examples(record) + + def _estimate_direction_from_examples(self, record: LatentRecord) -> torch.Tensor: + """Estimates an intervention direction by averaging activations.""" + device = self._get_device() + + examples = self._sanitize_examples(getattr(record, "test", []) or []) + if not examples: + hidden_dim = self.subject_model.config.hidden_size + return torch.zeros(hidden_dim, device=device) + + captured_activations = [] + def capture_hook(module, inp, out): + hidden_states = out[0] if isinstance(out, tuple) else out + + # Now, hidden_states is guaranteed to be the 3D activation tensor + captured_activations.append(hidden_states[:, -1, :].detach().cpu()) + + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) + handle = layer_to_hook.register_forward_hook(capture_hook) + + try: + for ex in examples[:min(8, self.num_prompts)]: + prompt = "".join(ex["str_tokens"]) + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device) + with torch.no_grad(): + self.subject_model(input_ids) + finally: + handle.remove() + + if not captured_activations: + hidden_dim = self.subject_model.config.hidden_size + return torch.zeros(hidden_dim, device=device) + + activations = torch.cat(captured_activations, dim=0).to(device) + direction = activations.mean(dim=0) + + return F.normalize(direction, p=2, dim=0) \ No newline at end of file From 0ad4424ef88ec5b17c21bb424f0c6937c02d3556 Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Thu, 28 Aug 2025 20:44:05 +0000 Subject: [PATCH 02/11] Add metrics for surprisal_intervention --- delphi/log/result_analysis.py | 326 ++++++++++++---------------------- 1 file changed, 116 insertions(+), 210 deletions(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 9937bd96..4af7030a 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -14,7 +14,17 @@ def plot_firing_vs_f1( ) -> None: out_dir.mkdir(parents=True, exist_ok=True) for module, module_df in latent_df.groupby("module"): + + if 'firing_count' not in module_df.columns: + print(f"WARNING: 'firing_count' column not found for module {module}. Skipping plot.") + continue + module_df = module_df.copy() + # Filter out rows where f1_score is NaN to avoid errors in plotting + module_df = module_df[module_df['f1_score'].notna()] + if module_df.empty: + continue + module_df["firing_rate"] = module_df["firing_count"] / num_tokens fig = px.scatter(module_df, x="firing_rate", y="f1_score", log_x=True) fig.update_layout( @@ -26,30 +36,32 @@ def plot_firing_vs_f1( def import_plotly(): """Import plotly with mitigiation for MathJax bug.""" try: - import plotly.express as px # type: ignore - import plotly.io as pio # type: ignore + import plotly.express as px + import plotly.io as pio except ImportError: raise ImportError( "Plotly is not installed.\n" "Please install it using `pip install plotly`, " "or install the `[visualize]` extra." ) - pio.kaleido.scope.mathjax = None # https://github.com/plotly/plotly.py/issues/3469 + pio.kaleido.scope.mathjax = None return px def compute_auc(df: pd.DataFrame) -> float | None: - if not df.probability.nunique(): - return None - + # Filter for rows where probability is not None and there's more than one unique value valid_df = df[df.probability.notna()] - - return roc_auc_score(valid_df.activating, valid_df.probability) # type: ignore + if valid_df.probability.nunique() <= 1: + return None + return roc_auc_score(valid_df.activating, valid_df.probability) def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): out_dir.mkdir(exist_ok=True, parents=True) for label in df["score_type"].unique(): + # Filter out surprisal_intervention as 'accuracy' is not relevant for it + if label == 'surprisal_intervention': + continue fig = px.histogram( df[df["score_type"] == label], x="accuracy", @@ -60,11 +72,10 @@ def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): def plot_roc_curve(df: pd.DataFrame, out_dir: Path): - if not df.probability.nunique(): - return - - # filter out NANs + # Filter for rows where probability is not None and there's more than one unique value valid_df = df[df.probability.notna()] + if valid_df.empty or valid_df.activating.nunique() <= 1 or valid_df.probability.nunique() <= 1: + return fpr, tpr, _ = roc_curve(valid_df.activating, valid_df.probability) auc = roc_auc_score(valid_df.activating, valid_df.probability) @@ -85,67 +96,41 @@ def plot_roc_curve(df: pd.DataFrame, out_dir: Path): def compute_confusion(df: pd.DataFrame, threshold: float = 0.5) -> dict: df_valid = df[df["prediction"].notna()] - act = df_valid["activating"].astype(bool) + if df_valid.empty: + return dict(true_positives=0, true_negatives=0, false_positives=0, false_negatives=0, + total_examples=0, total_positives=0, total_negatives=0, failed_count=len(df)) + act = df_valid["activating"].astype(bool) total = len(df_valid) pos = act.sum() neg = total - pos - tp = ((df_valid.prediction >= threshold) & act).sum() tn = ((df_valid.prediction < threshold) & ~act).sum() fp = ((df_valid.prediction >= threshold) & ~act).sum() fn = ((df_valid.prediction < threshold) & act).sum() - assert fp <= neg and tn <= neg and tp <= pos and fn <= pos - return dict( - true_positives=tp, - true_negatives=tn, - false_positives=fp, - false_negatives=fn, - total_examples=total, - total_positives=pos, - total_negatives=neg, - failed_count=len(df_valid) - total, + true_positives=tp, true_negatives=tn, false_positives=fp, false_negatives=fn, + total_examples=total, total_positives=pos, total_negatives=neg, + failed_count=len(df) - len(df_valid), ) def compute_classification_metrics(conf: dict) -> dict: - tp = conf["true_positives"] - tn = conf["true_negatives"] - fp = conf["false_positives"] - fn = conf["false_negatives"] - total = conf["total_examples"] - pos = conf["total_positives"] - neg = conf["total_negatives"] - - assert pos + neg == total, "pos + neg must equal total" - - # accuracy = (tp + tn) / total if total > 0 else 0 - balanced_accuracy = ( - (tp / pos if pos > 0 else 0) + (tn / neg if neg > 0 else 0) - ) / 2 - + tp, tn, fp, fn = conf["true_positives"], conf["true_negatives"], conf["false_positives"], conf["false_negatives"] + pos, neg = conf["total_positives"], conf["total_negatives"] + + balanced_accuracy = ((tp / pos if pos > 0 else 0) + (tn / neg if neg > 0 else 0)) / 2 precision = tp / (tp + fp) if tp + fp > 0 else 0 recall = tp / pos if pos > 0 else 0 - f1 = ( - 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 - ) + f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 return dict( - precision=precision, - recall=recall, - f1_score=f1, - accuracy=balanced_accuracy, + precision=precision, recall=recall, f1_score=f1, accuracy=balanced_accuracy, true_positive_rate=tp / pos if pos > 0 else 0, true_negative_rate=tn / neg if neg > 0 else 0, false_positive_rate=fp / neg if neg > 0 else 0, false_negative_rate=fn / pos if pos > 0 else 0, - total_examples=total, - total_positives=pos, - total_negatives=neg, - positive_class_ratio=pos / total if total > 0 else 0, - negative_class_ratio=neg / total if total > 0 else 0, ) @@ -153,27 +138,32 @@ def load_data(scores_path: Path, modules: list[str]): """Load all on-disk data into a single DataFrame.""" def parse_score_file(path: Path) -> pd.DataFrame: - """ - Load a score file and return a raw DataFrame - """ try: data = orjson.loads(path.read_bytes()) except orjson.JSONDecodeError: print(f"Error decoding JSON from {path}. Skipping file.") return pd.DataFrame() + + if not isinstance(data, list): + print(f"Warning: Expected a list of results in {path}, but found {type(data)}. Skipping file.") + return pd.DataFrame() latent_idx = int(path.stem.split("latent")[-1]) + # --- MODIFICATION 1: PARSE THE NEW METRICS --- + # Updated to extract all possible keys safely using .get() return pd.DataFrame( [ { - "text": "".join(ex["str_tokens"]), - "distance": ex["distance"], - "activating": ex["activating"], - "prediction": ex["prediction"], - "probability": ex["probability"], - "correct": ex["correct"], - "activations": ex["activations"], + "text": "".join(ex.get("str_tokens", [])), + "distance": ex.get("distance"), + "activating": ex.get("activating"), + "prediction": ex.get("prediction"), + "probability": ex.get("probability"), + "correct": ex.get("correct"), + "activations": ex.get("activations"), + "final_score": ex.get("final_score"), + "avg_kl_divergence": ex.get("avg_kl_divergence"), "latent_idx": latent_idx, } for ex in data @@ -187,197 +177,113 @@ def parse_score_file(path: Path) -> pd.DataFrame: print(f"Missing modules: {[m for m in modules if m not in counts]}") counts = None - # Collect per-latent data latent_dfs = [] for score_type_dir in scores_path.iterdir(): if not score_type_dir.is_dir(): continue for module in modules: for file in score_type_dir.glob(f"*{module}*"): - latent_idx = int(file.stem.split("latent")[-1]) - latent_df = parse_score_file(file) + if latent_df.empty: + continue latent_df["score_type"] = score_type_dir.name latent_df["module"] = module - latent_df["latent_idx"] = latent_idx if counts: + latent_idx = latent_df["latent_idx"].iloc[0] latent_df["firing_count"] = ( counts[module][latent_idx].item() - if latent_idx in counts[module] + if module in counts and latent_idx in counts[module] else None ) - latent_dfs.append(latent_df) + if not latent_dfs: + return pd.DataFrame(), counts + return pd.concat(latent_dfs, ignore_index=True), counts -def frequency_weighted_f1( - df: pd.DataFrame, counts: dict[str, torch.Tensor] -) -> float | None: - rows = [] - for (module, latent_idx), grp in df.groupby(["module", "latent_idx"]): - f1 = compute_classification_metrics(compute_confusion(grp))["f1_score"] - fire = counts[module][latent_idx].item() - rows.append( - { - "module": module, - "latent_idx": latent_idx, - "f1_score": f1, - "firing_count": fire, - } - ) - - latent_df = pd.DataFrame(rows) - - per_module_f1 = [] - for module in latent_df["module"].unique(): - module_df = latent_df[latent_df["module"] == module] - - firing_weights = counts[module][module_df["latent_idx"]].float() - total_weight = firing_weights.sum() - if total_weight == 0: - continue - - f1_tensor = torch.as_tensor(module_df["f1_score"].values, dtype=torch.float32) - module_f1 = (f1_tensor * firing_weights).sum() / firing_weights.sum() - per_module_f1.append(module_f1) - - overall_frequency_weighted_f1 = torch.stack(per_module_f1).mean() - return ( - overall_frequency_weighted_f1.item() - if not overall_frequency_weighted_f1.isnan() - else None - ) - - def get_agg_metrics( latent_df: pd.DataFrame, counts: Optional[dict[str, torch.Tensor]] ) -> pd.DataFrame: processed_rows = [] for score_type, group_df in latent_df.groupby("score_type"): + # For surprisal_intervention, we don't compute classification metrics + if score_type == 'surprisal_intervention': + continue + conf = compute_confusion(group_df) class_m = compute_classification_metrics(conf) auc = compute_auc(group_df) f1_w = frequency_weighted_f1(group_df, counts) if counts else None - + row = { "score_type": score_type, - **conf, - **class_m, - "auc": auc, - "weighted_f1": f1_w, + **conf, **class_m, "auc": auc, "weighted_f1": f1_w } processed_rows.append(row) return pd.DataFrame(processed_rows) -def add_latent_f1(latent_df: pd.DataFrame) -> pd.DataFrame: - f1s = ( - latent_df.groupby(["module", "latent_idx"]) - .apply( - lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] - ) - .reset_index(name="f1_score") # <- naive (un-weighted) F1 - ) - return latent_df.merge(f1s, on=["module", "latent_idx"]) - - def log_results( scores_path: Path, viz_path: Path, modules: list[str], scorer_names: list[str] ): import_plotly() latent_df, counts = load_data(scores_path, modules) - latent_df = latent_df[latent_df["score_type"].isin(scorer_names)] - latent_df = add_latent_f1(latent_df) - - plot_firing_vs_f1( - latent_df, num_tokens=10_000_000, out_dir=viz_path, run_label=scores_path.name - ) - if latent_df.empty: - print("No data found") + print("No data to analyze.") return - - dead = sum((counts[m] == 0).sum().item() for m in modules) - print(f"Number of dead features: {dead}") - print(f"Number of interpreted live features: {len(latent_df)}") - - # Load constructor config for run - with open(scores_path.parent / "run_config.json", "r") as f: - run_cfg = orjson.loads(f.read()) - constructor_cfg = run_cfg.get("constructor_cfg", {}) - min_examples = constructor_cfg.get("min_examples", None) - print("min examples", min_examples) - - if min_examples is not None: - uninterpretable_features = sum( - [(counts[m] < min_examples).sum() for m in modules] - ) - print( - f"Number of features below the interpretation firing" - f" count threshold: {uninterpretable_features}" - ) - - plot_roc_curve(latent_df, viz_path) - - processed_df = get_agg_metrics(latent_df, counts) - - plot_accuracy_hist(processed_df, viz_path) - - for score_type in processed_df.score_type.unique(): - score_type_summary = processed_df[processed_df.score_type == score_type].iloc[0] - print(f"\n--- {score_type.title()} Metrics ---") - print(f"Class-Balanced Accuracy: {score_type_summary['accuracy']:.3f}") - print(f"F1 Score: {score_type_summary['f1_score']:.3f}") - print(f"Frequency-Weighted F1 Score: {score_type_summary['weighted_f1']:.3f}") - print( - "Note: the frequency-weighted F1 score is computed over each" - " hookpoint and averaged" - ) - print(f"Precision: {score_type_summary['precision']:.3f}") - print(f"Recall: {score_type_summary['recall']:.3f}") - # Only print AUC if unbalanced AUC is not -1. - if score_type_summary["auc"] is not None: - print(f"AUC: {score_type_summary['auc']:.3f}") + + latent_df = latent_df[latent_df["score_type"].isin(scorer_names)] + + # Separate the dataframes for different processing + classification_df = latent_df[latent_df['score_type'] != 'surprisal_intervention'] + surprisal_df = latent_df[latent_df['score_type'] == 'surprisal_intervention'] + + if not classification_df.empty: + classification_df = add_latent_f1(classification_df) + if counts: + plot_firing_vs_f1(classification_df, num_tokens=10_000_000, out_dir=viz_path, run_label=scores_path.name) + plot_roc_curve(classification_df, viz_path) + processed_df = get_agg_metrics(classification_df, counts) + plot_accuracy_hist(processed_df, viz_path) + + if counts: + dead = sum((counts[m] == 0).sum().item() for m in modules) + print(f"Number of dead features: {dead}") + + # --- MODIFICATION 2: ADD CONDITIONAL REPORTING --- + # Loop through all scorer types found in the data + for score_type in latent_df["score_type"].unique(): + + # Handle the new scorer with its specific metrics + if score_type == 'surprisal_intervention': + # Drop duplicates since score is per-latent, not per-example + unique_latents = surprisal_df.drop_duplicates(subset=['module', 'latent_idx']) + avg_score = unique_latents['final_score'].mean() + avg_kl = unique_latents['avg_kl_divergence'].mean() + + print(f"\n--- {score_type.title()} Metrics ---") + print(f"Average Normalized Score: {avg_score:.3f}") + print(f"Average KL Divergence: {avg_kl:.3f}") + + # Handle all other scorers with the original classification metrics else: - print("Logits not available.") - - fractions_failed = [ - score_type_summary["failed_count"] - / ( - ( - score_type_summary["total_examples"] - + score_type_summary["failed_count"] - ) - ) - ] - print( - f"""Average fraction of failed examples: \ -{sum(fractions_failed) / len(fractions_failed)}""" - ) - - print("\nConfusion Matrix:") - print( - f"True Positive Rate: {score_type_summary['true_positive_rate']:.3f} " - f"({score_type_summary['true_positives'].sum()})" - ) - print( - f"True Negative Rate: {score_type_summary['true_negative_rate']:.3f} " - f"({score_type_summary['true_negatives'].sum()})" - ) - print( - f"False Positive Rate: {score_type_summary['false_positive_rate']:.3f} " - f"({score_type_summary['false_positives'].sum()})" - ) - print( - f"False Negative Rate: {score_type_summary['false_negative_rate']:.3f} " - f"({score_type_summary['false_negatives'].sum()})" - ) - - print("\nClass Distribution:") - print(f"""Positives: {score_type_summary['total_positives'].sum():.0f}""") - print(f"""Negatives: {score_type_summary['total_negatives'].sum():.0f}""") - print(f"Total: {score_type_summary['total_examples'].sum():.0f}") + if not classification_df.empty: + score_type_summary = processed_df[processed_df.score_type == score_type].iloc[0] + print(f"\n--- {score_type.title()} Metrics ---") + print(f"Class-Balanced Accuracy: {score_type_summary['accuracy']:.3f}") + print(f"F1 Score: {score_type_summary['f1_score']:.3f}") + + if counts and score_type_summary['weighted_f1'] is not None: + print(f"Frequency-Weighted F1 Score: {score_type_summary['weighted_f1']:.3f}") + + print(f"Precision: {score_type_summary['precision']:.3f}") + print(f"Recall: {score_type_summary['recall']:.3f}") + + if score_type_summary["auc"] is not None: + print(f"AUC: {score_type_summary['auc']:.3f}") + else: + print("AUC not available.") \ No newline at end of file From aa12cf23c3e09e7c9534ee4475c75b008fb15a1d Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Fri, 29 Aug 2025 14:15:20 +0000 Subject: [PATCH 03/11] Code cleaning --- delphi/__main__.py | 19 +---- delphi/config.py | 4 +- delphi/log/result_analysis.py | 6 +- .../surprisal_intervention_scorer.py | 82 ++----------------- 4 files changed, 9 insertions(+), 102 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index 46111bc1..7a8fd399 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -30,7 +30,7 @@ from delphi.latents.neighbours import NeighbourCalculator from delphi.log.result_analysis import log_results from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator, InterventionScorer, LogProbInterventionScorer, SurprisalInterventionScorer +from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator, SurprisalInterventionScorer from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders from delphi.utils import assert_type, load_tokenized_data @@ -252,8 +252,6 @@ def scorer_postprocess(result, score_dir, scorer_name=None): safe_latent_name = str(result.record.latent).replace("/", "--") with open(score_dir / f"{safe_latent_name}.txt", "wb") as f: - # This line now works universally. For other scorers, it saves their simple - # score. For surprisal_intervention, it saves the rich 'final_payload'. f.write(orjson.dumps(result.score, default=custom_serializer)) @@ -278,20 +276,7 @@ def scorer_postprocess(result, score_dir, scorer_name=None): verbose=run_cfg.verbose, log_prob=run_cfg.log_probs, ) - elif scorer_name == "intervention": - scorer = InterventionScorer( - llm_client, - n_examples_shown=run_cfg.num_examples_per_scorer_prompt, - verbose=run_cfg.verbose, - log_prob=run_cfg.log_probs, - ) - elif scorer_name == "logprob_intervention": - scorer = LogProbInterventionScorer( - llm_client, - n_examples_shown=run_cfg.num_examples_per_scorer_prompt, - verbose=run_cfg.verbose, - log_prob=run_cfg.log_probs, - ) + elif scorer_name == "surprisal_intervention": scorer = SurprisalInterventionScorer( model, diff --git a/delphi/config.py b/delphi/config.py index 0d2193c5..9d54c26e 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -152,8 +152,6 @@ class RunConfig(Serializable): "fuzz", "detection", "simulation", - "intervention", - "logprob_intervention", "surprisal_intervention" ], default=[ @@ -162,7 +160,7 @@ class RunConfig(Serializable): ], ) """Scorer methods to score latent explanations. Options are 'fuzz', 'detection', - 'simulation' and 'intervention'.""" + 'simulation' and 'surprisal_intervention'.""" name: str = "" """The name of the run. Results are saved in a directory with this name.""" diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 4af7030a..99666acb 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -150,7 +150,6 @@ def parse_score_file(path: Path) -> pd.DataFrame: latent_idx = int(path.stem.split("latent")[-1]) - # --- MODIFICATION 1: PARSE THE NEW METRICS --- # Updated to extract all possible keys safely using .get() return pd.DataFrame( [ @@ -254,11 +253,9 @@ def log_results( dead = sum((counts[m] == 0).sum().item() for m in modules) print(f"Number of dead features: {dead}") - # --- MODIFICATION 2: ADD CONDITIONAL REPORTING --- - # Loop through all scorer types found in the data + for score_type in latent_df["score_type"].unique(): - # Handle the new scorer with its specific metrics if score_type == 'surprisal_intervention': # Drop duplicates since score is per-latent, not per-example unique_latents = surprisal_df.drop_duplicates(subset=['module', 'latent_idx']) @@ -269,7 +266,6 @@ def log_results( print(f"Average Normalized Score: {avg_score:.3f}") print(f"Average KL Divergence: {avg_kl:.3f}") - # Handle all other scorers with the original classification metrics else: if not classification_df.empty: score_type_summary = processed_df[processed_df.score_type == score_type].iloc[0] diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index f3678c9d..c683cca9 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -1,4 +1,3 @@ -# surprisal_intervention_scorer.py import functools import random import copy @@ -9,8 +8,6 @@ import torch.nn.functional as F from transformers import AutoTokenizer -# Assuming 'delphi' is your project structure. -# If not, you may need to adjust these relative imports. from ..scorer import Scorer, ScorerResult from ...latents import LatentRecord, ActivatingExample @@ -75,11 +72,9 @@ def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): if len(self.hookpoints): self.hookpoint_str = self.hookpoints[0] - # Ensure tokenizer is available if hasattr(subject_model, "tokenizer"): self.tokenizer = subject_model.tokenizer else: - # Fallback to a standard tokenizer if not attached to the model self.tokenizer = AutoTokenizer.from_pretrained("gpt2") if self.tokenizer.pad_token is None: @@ -113,7 +108,6 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: """ parts = hookpoint_str.split('.') - # 1. Validate the string format. is_valid_format = ( len(parts) == 3 and parts[0] in ['layers', 'h'] and @@ -122,129 +116,68 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: ) if not is_valid_format: - # Fallback for simple block types at the top level, e.g. 'embed_in' if len(parts) == 1 and hasattr(model, hookpoint_str): return getattr(model, hookpoint_str) raise ValueError(f"Hookpoint string '{hookpoint_str}' is not in a recognized format like 'layers.6.mlp'.") - # --- End of changes --- - # 2. Heuristically find the model prefix. + #Heuristically find the model prefix. prefix = None for p in ["gpt_neox", "transformer", "model"]: if hasattr(model, p): candidate_body = getattr(model, p) - # Use parts[0] to get the layer block name ('layers' or 'h') if hasattr(candidate_body, parts[0]): prefix = p break full_path = f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str - # 3. Use the simple path finder to get the module. try: return self._find_layer(model, full_path) except AttributeError as e: raise AttributeError(f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}") - - - - # def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: - # """Ensures examples are in a consistent format: a list of dictionaries with 'str_tokens'.""" - # sanitized = [] - # for ex in examples: - # if isinstance(ex, dict) and "str_tokens" in ex: - # sanitized.append(ex) - # elif hasattr(ex, "str_tokens"): - # sanitized.append({"str_tokens": [str(t) for t in ex.str_tokens]}) - # elif isinstance(ex, str): - # sanitized.append({"str_tokens": [ex]}) - # elif isinstance(ex, (list, tuple)): - # sanitized.append({"str_tokens": [str(t) for t in ex]}) - # else: - # sanitized.append({"str_tokens": [str(ex)]}) - # return sanitized - def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: + """ + Function used for formatting results to run smoothly in the delphi pipeline + """ sanitized = [] for ex in examples: - # --- NEW, MORE ROBUST LOGIC --- - # 1. Prioritize handling objects that have the data we need (like ActivatingExample) if hasattr(ex, 'str_tokens') and ex.str_tokens is not None: - # This correctly handles ActivatingExample objects and similar structures. - # It extracts the string tokens instead of converting the whole object to a string. sanitized.append({'str_tokens': ex.str_tokens}) - # 2. Handle cases where the item is already a correct dictionary elif isinstance(ex, dict) and "str_tokens" in ex: sanitized.append(ex) - # 3. Handle plain strings elif isinstance(ex, str): sanitized.append({"str_tokens": [ex]}) - # 4. Handle lists/tuples of strings as a fallback elif isinstance(ex, (list, tuple)): sanitized.append({"str_tokens": [str(t) for t in ex]}) - # 5. Handle any other unexpected type as a last resort else: sanitized.append({"str_tokens": [str(ex)]}) return sanitized - # def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: - - # sanitized = [] - # for i, ex in enumerate(examples): - - - # if isinstance(ex, dict) and "str_tokens" in ex: - # sanitized.append(ex) - - - # elif isinstance(ex, str): - # # This is the key conversion - # converted_ex = {"str_tokens": [ex]} - # sanitized.append(converted_ex) - - - # elif isinstance(ex, (list, tuple)): - # converted_ex = {"str_tokens": [str(t) for t in ex]} - # sanitized.append(converted_ex) - - # else: - # converted_ex = {"str_tokens": [str(ex)]} - # sanitized.append(converted_ex) - - # print("fin this") - # return sanitized async def __call__(self, record: LatentRecord) -> ScorerResult: - # --- MODIFICATION START --- - # 1. Create a deep copy to work on, ensuring we don't interfere - # with other parts of the pipeline that might use the original record. + record_copy = copy.deepcopy(record) - # 2. Read the raw examples from our copy. raw_examples = getattr(record_copy, "test", []) or [] if not raw_examples: result = SurprisalInterventionResult(score=0.0, avg_kl=0.0, explanation=record_copy.explanation) - # Return the result with the original record since no changes were made. return ScorerResult(record=record, score=result) - # 3. Sanitize the examples. examples = self._sanitize_examples(raw_examples) - # 4. Overwrite the attributes on the copy with the clean data. record_copy.test = examples record_copy.examples = examples record_copy.train = examples - # Now, use the sanitized 'examples' and the 'record_copy' for all subsequent operations. prompts = ["".join(ex["str_tokens"]) for ex in examples[:self.num_prompts]] total_diff = 0.0 @@ -252,7 +185,6 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: n = 0 for prompt in prompts: - # Pass the clean record_copy to the generation methods. clean_text, clean_logp_dist = await self._generate_with_and_without_intervention(prompt, record_copy, intervene=False) int_text, int_logp_dist = await self._generate_with_and_without_intervention(prompt, record_copy, intervene=True) @@ -274,7 +206,6 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: for ex in examples[:self.num_prompts]: final_output_list.append({ "str_tokens": ex["str_tokens"], - # Add the final scores. These will be duplicated for each example. "final_score": final_score, "avg_kl_divergence": avg_kl, # Add placeholder keys that the parser expects, with default values. @@ -312,14 +243,12 @@ async def _generate_with_and_without_intervention( if hookpoint_str is None: raise ValueError("No hookpoint string specified for intervention.") - # Resolve the string into the actual layer module. layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) direction = self._get_intervention_direction(record).to(device) direction = direction.unsqueeze(0).unsqueeze(0) # Shape for broadcasting: [1, 1, D] def hook_fn(module, inp, out): - # Gracefully handle both tuple and tensor outputs hidden_states = out[0] if isinstance(out, tuple) else out # Apply intervention to the last token's hidden state @@ -423,7 +352,6 @@ def _estimate_direction_from_examples(self, record: LatentRecord) -> torch.Tenso def capture_hook(module, inp, out): hidden_states = out[0] if isinstance(out, tuple) else out - # Now, hidden_states is guaranteed to be the 3D activation tensor captured_activations.append(hidden_states[:, -1, :].detach().cpu()) hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) From bbf915a005d5367a7063b96ff05b0e452e60d2d9 Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Fri, 29 Aug 2025 15:01:03 +0000 Subject: [PATCH 04/11] Fix results --- delphi/log/result_analysis.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 99666acb..bffa7f6b 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -225,6 +225,17 @@ def get_agg_metrics( return pd.DataFrame(processed_rows) +def add_latent_f1(latent_df: pd.DataFrame) -> pd.DataFrame: + f1s = ( + latent_df.groupby(["module", "latent_idx"]) + .apply( + lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] + ) + .reset_index(name="f1_score") # <- naive (un-weighted) F1 + ) + return latent_df.merge(f1s, on=["module", "latent_idx"]) + + def log_results( scores_path: Path, viz_path: Path, modules: list[str], scorer_names: list[str] ): From 90e8a34daf75acffb3f1d2c096b362e3087728d4 Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Fri, 29 Aug 2025 15:19:11 +0000 Subject: [PATCH 05/11] Remove output-based intervention --- .../output_based_intervention_scorer.py | 141 ------------------ 1 file changed, 141 deletions(-) delete mode 100644 delphi/scorers/intervention/output_based_intervention_scorer.py diff --git a/delphi/scorers/intervention/output_based_intervention_scorer.py b/delphi/scorers/intervention/output_based_intervention_scorer.py deleted file mode 100644 index 9c706962..00000000 --- a/delphi/scorers/intervention/output_based_intervention_scorer.py +++ /dev/null @@ -1,141 +0,0 @@ -# Output-based intervention scorer (Gur-Arieh et al. 2025) -from dataclasses import dataclass -import torch -import torch.nn.functional as F -import random -from ...scorer import Scorer, ScorerResult -from ...latents import LatentRecord, ActivatingExample -from transformers import PreTrainedModel - -@dataclass -class OutputInterventionResult: - """Result of output-based intervention evaluation.""" - score: int # +1 if target set chosen, -1 otherwise - explanation: str - example_text: str - -class OutputInterventionScorer(Scorer): - """ - Output-based evaluation by steering (clamping) the feature and using a judge LLM - to pick which outputs best match the description:contentReference[oaicite:5]{index=5}. - We generate texts for the target feature and for a few random features, - then ask the judge to choose the matching set. - """ - name = "output_intervention" - - def __init__(self, subject_model: PreTrainedModel, explainer_model, **kwargs): - self.subject_model = subject_model - self.explainer_model = explainer_model - self.steering_strength = kwargs.get("strength", 5.0) - self.num_prompts = kwargs.get("num_prompts", 3) - self.num_random = kwargs.get("num_random_features", 2) - self.hookpoint = kwargs.get("hookpoint", "transformer.h.6.mlp") - self.tokenizer = getattr(subject_model, "tokenizer", None) - - async def __call__(self, record: LatentRecord) -> ScorerResult: - # Prepare activating prompts - examples = [ex for ex in record.test if isinstance(ex, ActivatingExample)] - random.shuffle(examples) - prompts = ["".join(str(t) for t in ex.str_tokens) for ex in examples[:self.num_prompts]] - - # Generate text for the target feature - target_texts = [] - for p in prompts: - text, _ = await self._generate(p, record.feature_id, self.steering_strength) - target_texts.append(text) - - # Pick a few random feature IDs (avoid the target) - random_ids = [] - while len(random_ids) < self.num_random: - rid = random.randint(0, 999) - if rid != record.feature_id: - random_ids.append(rid) - - # Generate texts for random features - random_sets = [] - for fid in random_ids: - rand_texts = [] - for p in prompts: - text, _ = await self._generate(p, fid, self.steering_strength) - rand_texts.append(text) - random_sets.append(rand_texts) - - # Create prompt for judge LLM - judge_prompt = self._format_judge_prompt(record.explanation, target_texts, random_sets) - judge_response = await self._ask_judge(judge_prompt) - - # Parse judge response: check if target set was chosen - resp_lower = judge_response.lower() - if "target" in resp_lower or "set 1" in resp_lower: - score = 1 - elif "set 2" in resp_lower or "set 3" in resp_lower or "random" in resp_lower: - score = -1 - else: - score = 0 - - example_text = prompts[0] if prompts else "" - detailed = OutputInterventionResult( - score=score, - explanation=record.explanation, - example_text=example_text - ) - return ScorerResult(record=record, score=detailed) - - async def _generate(self, prompt: str, feature_id: int, strength: float): - """ - Generates text with the feature clamped (added to hidden state). - Returns the (partial) generated text and logits. - """ - tokenizer = self.tokenizer or __import__("transformers").AutoTokenizer.from_pretrained("gpt2") - input_ids = tokenizer(prompt, return_tensors="pt").input_ids - - # Forward hook to clamp feature activation - direction = self.explainer_model.get_feature_vector(feature_id) - def hook_fn(module, inp, out): - out[:, -1, :] = out[:, -1, :] + strength * direction.to(out.device) - return out - layer = self._find_layer(self.subject_model, self.hookpoint) - handle = layer.register_forward_hook(hook_fn) - - with torch.no_grad(): - outputs = self.subject_model(input_ids) - logits = outputs.logits[0, -1, :] - log_probs = F.log_softmax(logits, dim=-1) - handle.remove() - - text = tokenizer.decode(input_ids[0]) - return text, log_probs - - def _format_judge_prompt(self, explanation: str, target_texts: list, other_sets: list): - """ - Constructs a prompt for the judge LLM listing each set of texts - under the target feature and random features. - """ - prompt = f"Feature description: \"{explanation}\"\n" - prompt += "Which of the following sets of generated texts best matches this description?\n\n" - prompt += "Set 1 (target feature):\n" - for txt in target_texts: - prompt += f"- {txt}\n" - for i, rand_set in enumerate(other_sets, start=2): - prompt += f"\nSet {i} (random feature):\n" - for txt in rand_set: - prompt += f"- {txt}\n" - prompt += "\nAnswer (mention the set number or 'target'/'random'): " - return prompt - - async def _ask_judge(self, prompt: str) -> str: - """ - Queries a judge LLM (e.g., GPT-4) with the prompt. Stubbed here. - """ - # TODO: Implement actual LLM call to get response - return "" - - def _find_layer(self, model, name: str): - """Locate a module by its dotted name.""" - current = model - for attr in name.split('.'): - if attr.isdigit(): - current = current[int(attr)] - else: - current = getattr(current, attr) - return current From 9fca7db9f5d922d62af9869338158ded361756b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Aug 2025 15:20:56 +0000 Subject: [PATCH 06/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- delphi/__main__.py | 19 +- delphi/config.py | 7 +- delphi/log/result_analysis.py | 124 ++++++++---- delphi/scorers/__init__.py | 7 +- .../surprisal_intervention_scorer.py | 177 ++++++++++-------- 5 files changed, 201 insertions(+), 133 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index 7a8fd399..bf24a557 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -5,14 +5,11 @@ from pathlib import Path from typing import Callable -from dataclasses import asdict - import orjson import torch from simple_parsing import ArgumentParser from torch import Tensor from transformers import ( - AutoModel, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, @@ -30,7 +27,12 @@ from delphi.latents.neighbours import NeighbourCalculator from delphi.log.result_analysis import log_results from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator, SurprisalInterventionScorer +from delphi.scorers import ( + DetectionScorer, + FuzzingScorer, + OpenAISimulator, + SurprisalInterventionScorer, +) from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders from delphi.utils import assert_type, load_tokenized_data @@ -122,7 +124,7 @@ async def process_cache( tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, latent_range: Tensor | None, model, - hookpoint_to_sparse_encode + hookpoint_to_sparse_encode, ): """ Converts SAE latent activations in on-disk cache in the `latents_path` directory @@ -223,7 +225,7 @@ def none_postprocessor(result): postprocess=none_postprocessor, ) ) - + def custom_serializer(obj): """A custom serializer for orjson to handle specific types.""" if isinstance(obj, Tensor): @@ -254,7 +256,6 @@ def scorer_postprocess(result, score_dir, scorer_name=None): with open(score_dir / f"{safe_latent_name}.txt", "wb") as f: f.write(orjson.dumps(result.score, default=custom_serializer)) - scorers = [] for scorer_name in run_cfg.scorers: scorer_path = scores_path / scorer_name @@ -281,7 +282,7 @@ def scorer_postprocess(result, score_dir, scorer_name=None): scorer = SurprisalInterventionScorer( model, hookpoint_to_sparse_encode, - hookpoints = run_cfg.hookpoints, + hookpoints=run_cfg.hookpoints, n_examples_shown=run_cfg.num_examples_per_scorer_prompt, verbose=run_cfg.verbose, log_prob=run_cfg.log_probs, @@ -476,7 +477,7 @@ async def run( tokenizer, latent_range, model, - hookpoint_to_sparse_encode + hookpoint_to_sparse_encode, ) del model, hookpoint_to_sparse_encode diff --git a/delphi/config.py b/delphi/config.py index 9d54c26e..b20c3324 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -148,12 +148,7 @@ class RunConfig(Serializable): the default single token explainer, and 'none' for no explanation generation.""" scorers: list[str] = list_field( - choices=[ - "fuzz", - "detection", - "simulation", - "surprisal_intervention" - ], + choices=["fuzz", "detection", "simulation", "surprisal_intervention"], default=[ "fuzz", "detection", diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index bffa7f6b..bf53da22 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -15,13 +15,15 @@ def plot_firing_vs_f1( out_dir.mkdir(parents=True, exist_ok=True) for module, module_df in latent_df.groupby("module"): - if 'firing_count' not in module_df.columns: - print(f"WARNING: 'firing_count' column not found for module {module}. Skipping plot.") + if "firing_count" not in module_df.columns: + print( + f"WARNING: 'firing_count' column not found for module {module}. Skipping plot." + ) continue module_df = module_df.copy() # Filter out rows where f1_score is NaN to avoid errors in plotting - module_df = module_df[module_df['f1_score'].notna()] + module_df = module_df[module_df["f1_score"].notna()] if module_df.empty: continue @@ -60,7 +62,7 @@ def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): out_dir.mkdir(exist_ok=True, parents=True) for label in df["score_type"].unique(): # Filter out surprisal_intervention as 'accuracy' is not relevant for it - if label == 'surprisal_intervention': + if label == "surprisal_intervention": continue fig = px.histogram( df[df["score_type"] == label], @@ -74,7 +76,11 @@ def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): def plot_roc_curve(df: pd.DataFrame, out_dir: Path): # Filter for rows where probability is not None and there's more than one unique value valid_df = df[df.probability.notna()] - if valid_df.empty or valid_df.activating.nunique() <= 1 or valid_df.probability.nunique() <= 1: + if ( + valid_df.empty + or valid_df.activating.nunique() <= 1 + or valid_df.probability.nunique() <= 1 + ): return fpr, tpr, _ = roc_curve(valid_df.activating, valid_df.probability) @@ -97,8 +103,16 @@ def plot_roc_curve(df: pd.DataFrame, out_dir: Path): def compute_confusion(df: pd.DataFrame, threshold: float = 0.5) -> dict: df_valid = df[df["prediction"].notna()] if df_valid.empty: - return dict(true_positives=0, true_negatives=0, false_positives=0, false_negatives=0, - total_examples=0, total_positives=0, total_negatives=0, failed_count=len(df)) + return dict( + true_positives=0, + true_negatives=0, + false_positives=0, + false_negatives=0, + total_examples=0, + total_positives=0, + total_negatives=0, + failed_count=len(df), + ) act = df_valid["activating"].astype(bool) total = len(df_valid) @@ -110,23 +124,40 @@ def compute_confusion(df: pd.DataFrame, threshold: float = 0.5) -> dict: fn = ((df_valid.prediction < threshold) & act).sum() return dict( - true_positives=tp, true_negatives=tn, false_positives=fp, false_negatives=fn, - total_examples=total, total_positives=pos, total_negatives=neg, + true_positives=tp, + true_negatives=tn, + false_positives=fp, + false_negatives=fn, + total_examples=total, + total_positives=pos, + total_negatives=neg, failed_count=len(df) - len(df_valid), ) def compute_classification_metrics(conf: dict) -> dict: - tp, tn, fp, fn = conf["true_positives"], conf["true_negatives"], conf["false_positives"], conf["false_negatives"] + tp, tn, fp, fn = ( + conf["true_positives"], + conf["true_negatives"], + conf["false_positives"], + conf["false_negatives"], + ) pos, neg = conf["total_positives"], conf["total_negatives"] - - balanced_accuracy = ((tp / pos if pos > 0 else 0) + (tn / neg if neg > 0 else 0)) / 2 + + balanced_accuracy = ( + (tp / pos if pos > 0 else 0) + (tn / neg if neg > 0 else 0) + ) / 2 precision = tp / (tp + fp) if tp + fp > 0 else 0 recall = tp / pos if pos > 0 else 0 - f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 + f1 = ( + 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 + ) return dict( - precision=precision, recall=recall, f1_score=f1, accuracy=balanced_accuracy, + precision=precision, + recall=recall, + f1_score=f1, + accuracy=balanced_accuracy, true_positive_rate=tp / pos if pos > 0 else 0, true_negative_rate=tn / neg if neg > 0 else 0, false_positive_rate=fp / neg if neg > 0 else 0, @@ -143,9 +174,11 @@ def parse_score_file(path: Path) -> pd.DataFrame: except orjson.JSONDecodeError: print(f"Error decoding JSON from {path}. Skipping file.") return pd.DataFrame() - + if not isinstance(data, list): - print(f"Warning: Expected a list of results in {path}, but found {type(data)}. Skipping file.") + print( + f"Warning: Expected a list of results in {path}, but found {type(data)}. Skipping file." + ) return pd.DataFrame() latent_idx = int(path.stem.split("latent")[-1]) @@ -198,7 +231,7 @@ def parse_score_file(path: Path) -> pd.DataFrame: if not latent_dfs: return pd.DataFrame(), counts - + return pd.concat(latent_dfs, ignore_index=True), counts @@ -208,17 +241,20 @@ def get_agg_metrics( processed_rows = [] for score_type, group_df in latent_df.groupby("score_type"): # For surprisal_intervention, we don't compute classification metrics - if score_type == 'surprisal_intervention': + if score_type == "surprisal_intervention": continue - + conf = compute_confusion(group_df) class_m = compute_classification_metrics(conf) auc = compute_auc(group_df) f1_w = frequency_weighted_f1(group_df, counts) if counts else None - + row = { "score_type": score_type, - **conf, **class_m, "auc": auc, "weighted_f1": f1_w + **conf, + **class_m, + "auc": auc, + "weighted_f1": f1_w, } processed_rows.append(row) @@ -245,17 +281,22 @@ def log_results( if latent_df.empty: print("No data to analyze.") return - + latent_df = latent_df[latent_df["score_type"].isin(scorer_names)] - + # Separate the dataframes for different processing - classification_df = latent_df[latent_df['score_type'] != 'surprisal_intervention'] - surprisal_df = latent_df[latent_df['score_type'] == 'surprisal_intervention'] + classification_df = latent_df[latent_df["score_type"] != "surprisal_intervention"] + surprisal_df = latent_df[latent_df["score_type"] == "surprisal_intervention"] if not classification_df.empty: classification_df = add_latent_f1(classification_df) if counts: - plot_firing_vs_f1(classification_df, num_tokens=10_000_000, out_dir=viz_path, run_label=scores_path.name) + plot_firing_vs_f1( + classification_df, + num_tokens=10_000_000, + out_dir=viz_path, + run_label=scores_path.name, + ) plot_roc_curve(classification_df, viz_path) processed_df = get_agg_metrics(classification_df, counts) plot_accuracy_hist(processed_df, viz_path) @@ -263,34 +304,39 @@ def log_results( if counts: dead = sum((counts[m] == 0).sum().item() for m in modules) print(f"Number of dead features: {dead}") - for score_type in latent_df["score_type"].unique(): - - if score_type == 'surprisal_intervention': + + if score_type == "surprisal_intervention": # Drop duplicates since score is per-latent, not per-example - unique_latents = surprisal_df.drop_duplicates(subset=['module', 'latent_idx']) - avg_score = unique_latents['final_score'].mean() - avg_kl = unique_latents['avg_kl_divergence'].mean() - + unique_latents = surprisal_df.drop_duplicates( + subset=["module", "latent_idx"] + ) + avg_score = unique_latents["final_score"].mean() + avg_kl = unique_latents["avg_kl_divergence"].mean() + print(f"\n--- {score_type.title()} Metrics ---") print(f"Average Normalized Score: {avg_score:.3f}") print(f"Average KL Divergence: {avg_kl:.3f}") else: if not classification_df.empty: - score_type_summary = processed_df[processed_df.score_type == score_type].iloc[0] + score_type_summary = processed_df[ + processed_df.score_type == score_type + ].iloc[0] print(f"\n--- {score_type.title()} Metrics ---") print(f"Class-Balanced Accuracy: {score_type_summary['accuracy']:.3f}") print(f"F1 Score: {score_type_summary['f1_score']:.3f}") - if counts and score_type_summary['weighted_f1'] is not None: - print(f"Frequency-Weighted F1 Score: {score_type_summary['weighted_f1']:.3f}") - + if counts and score_type_summary["weighted_f1"] is not None: + print( + f"Frequency-Weighted F1 Score: {score_type_summary['weighted_f1']:.3f}" + ) + print(f"Precision: {score_type_summary['precision']:.3f}") print(f"Recall: {score_type_summary['recall']:.3f}") - + if score_type_summary["auc"] is not None: print(f"AUC: {score_type_summary['auc']:.3f}") else: - print("AUC not available.") \ No newline at end of file + print("AUC not available.") diff --git a/delphi/scorers/__init__.py b/delphi/scorers/__init__.py index ad84c15f..6eeed35b 100644 --- a/delphi/scorers/__init__.py +++ b/delphi/scorers/__init__.py @@ -3,12 +3,12 @@ from .classifier.intruder import IntruderScorer from .embedding.embedding import EmbeddingScorer from .embedding.example_embedding import ExampleEmbeddingScorer -from .scorer import Scorer -from .simulator.oai_simulator import OpenAISimulator -from .surprisal.surprisal import SurprisalScorer from .intervention.intervention_scorer import InterventionScorer from .intervention.logprob_intervention_scorer import LogProbInterventionScorer from .intervention.surprisal_intervention_scorer import SurprisalInterventionScorer +from .scorer import Scorer +from .simulator.oai_simulator import OpenAISimulator +from .surprisal.surprisal import SurprisalScorer __all__ = [ "FuzzingScorer", @@ -22,5 +22,4 @@ "SurprisalInterventionScorer", "InterventionScorer", "LogProbInterventionScorer", - ] diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index c683cca9..88c8497c 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -1,15 +1,14 @@ -import functools -import random import copy from dataclasses import dataclass -from typing import Any, List, Dict, Tuple +from typing import Any, Dict, List, Tuple import torch import torch.nn.functional as F from transformers import AutoTokenizer +from ...latents import LatentRecord from ..scorer import Scorer, ScorerResult -from ...latents import LatentRecord, ActivatingExample + @dataclass class SurprisalInterventionResult: @@ -21,6 +20,7 @@ class SurprisalInterventionResult: avg_kl: The average KL divergence between the clean and intervened next-token distributions. explanation: The explanation string that was scored. """ + score: float avg_kl: float explanation: str @@ -49,6 +49,7 @@ class SurprisalInterventionScorer(Scorer): 4. The final score is the mean change in explanation log-prob, divided by the mean KL divergence: score = mean(log_prob_intervened - log_prob_clean) / (mean_KL + ε)[cite: 209]. """ + name = "surprisal_intervention" def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): @@ -76,7 +77,7 @@ def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): self.tokenizer = subject_model.tokenizer else: self.tokenizer = AutoTokenizer.from_pretrained("gpt2") - + if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.subject_model.config.pad_token_id = self.tokenizer.eos_token_id @@ -99,28 +100,30 @@ def _find_layer(self, model: Any, name: str) -> torch.nn.Module: else: current = getattr(current, part) return current - + def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: """ Dynamically finds the correct model prefix and resolves the full hookpoint path. - + This makes the scorer agnostic to different transformer architectures. """ - parts = hookpoint_str.split('.') - + parts = hookpoint_str.split(".") + is_valid_format = ( - len(parts) == 3 and - parts[0] in ['layers', 'h'] and - parts[1].isdigit() and - parts[2] in ['mlp', 'attention', 'attn'] + len(parts) == 3 + and parts[0] in ["layers", "h"] + and parts[1].isdigit() + and parts[2] in ["mlp", "attention", "attn"] ) if not is_valid_format: if len(parts) == 1 and hasattr(model, hookpoint_str): - return getattr(model, hookpoint_str) - raise ValueError(f"Hookpoint string '{hookpoint_str}' is not in a recognized format like 'layers.6.mlp'.") + return getattr(model, hookpoint_str) + raise ValueError( + f"Hookpoint string '{hookpoint_str}' is not in a recognized format like 'layers.6.mlp'." + ) - #Heuristically find the model prefix. + # Heuristically find the model prefix. prefix = None for p in ["gpt_neox", "transformer", "model"]: if hasattr(model, p): @@ -128,14 +131,15 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: if hasattr(candidate_body, parts[0]): prefix = p break - + full_path = f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str try: return self._find_layer(model, full_path) except AttributeError as e: - raise AttributeError(f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}") - + raise AttributeError( + f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}" + ) def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: """ @@ -143,57 +147,69 @@ def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: """ sanitized = [] for ex in examples: - if hasattr(ex, 'str_tokens') and ex.str_tokens is not None: - sanitized.append({'str_tokens': ex.str_tokens}) - + if hasattr(ex, "str_tokens") and ex.str_tokens is not None: + sanitized.append({"str_tokens": ex.str_tokens}) + elif isinstance(ex, dict) and "str_tokens" in ex: sanitized.append(ex) - + elif isinstance(ex, str): sanitized.append({"str_tokens": [ex]}) - + elif isinstance(ex, (list, tuple)): sanitized.append({"str_tokens": [str(t) for t in ex]}) - + else: sanitized.append({"str_tokens": [str(ex)]}) - - return sanitized - + return sanitized async def __call__(self, record: LatentRecord) -> ScorerResult: record_copy = copy.deepcopy(record) raw_examples = getattr(record_copy, "test", []) or [] - + if not raw_examples: - result = SurprisalInterventionResult(score=0.0, avg_kl=0.0, explanation=record_copy.explanation) + result = SurprisalInterventionResult( + score=0.0, avg_kl=0.0, explanation=record_copy.explanation + ) return ScorerResult(record=record, score=result) examples = self._sanitize_examples(raw_examples) - + record_copy.test = examples record_copy.examples = examples record_copy.train = examples - - prompts = ["".join(ex["str_tokens"]) for ex in examples[:self.num_prompts]] - + + prompts = ["".join(ex["str_tokens"]) for ex in examples[: self.num_prompts]] + total_diff = 0.0 total_kl = 0.0 n = 0 for prompt in prompts: - clean_text, clean_logp_dist = await self._generate_with_and_without_intervention(prompt, record_copy, intervene=False) - int_text, int_logp_dist = await self._generate_with_and_without_intervention(prompt, record_copy, intervene=True) - - logp_clean = await self._score_explanation(clean_text, record_copy.explanation) + clean_text, clean_logp_dist = ( + await self._generate_with_and_without_intervention( + prompt, record_copy, intervene=False + ) + ) + int_text, int_logp_dist = ( + await self._generate_with_and_without_intervention( + prompt, record_copy, intervene=True + ) + ) + + logp_clean = await self._score_explanation( + clean_text, record_copy.explanation + ) logp_int = await self._score_explanation(int_text, record_copy.explanation) - + p_clean = torch.exp(clean_logp_dist) - kl_div = F.kl_div(int_logp_dist, p_clean, reduction='sum', log_target=False).item() - + kl_div = F.kl_div( + int_logp_dist, p_clean, reduction="sum", log_target=False + ).item() + total_diff += logp_int - logp_clean total_kl += kl_div n += 1 @@ -203,19 +219,21 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: final_score = avg_diff / (avg_kl + 1e-9) if n > 0 else 0.0 final_output_list = [] - for ex in examples[:self.num_prompts]: - final_output_list.append({ - "str_tokens": ex["str_tokens"], - "final_score": final_score, - "avg_kl_divergence": avg_kl, - # Add placeholder keys that the parser expects, with default values. - "distance": None, - "activating": None, - "prediction": None, - "correct": None, - "probability": None, - "activations": None, - }) + for ex in examples[: self.num_prompts]: + final_output_list.append( + { + "str_tokens": ex["str_tokens"], + "final_score": final_score, + "avg_kl_divergence": avg_kl, + # Add placeholder keys that the parser expects, with default values. + "distance": None, + "activating": None, + "prediction": None, + "correct": None, + "probability": None, + "activations": None, + } + ) return ScorerResult(record=record_copy, score=final_output_list) async def _generate_with_and_without_intervention( @@ -235,7 +253,7 @@ async def _generate_with_and_without_intervention( device = self._get_device() enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True) input_ids = enc["input_ids"].to(device) - + hooks = [] if intervene: @@ -246,14 +264,16 @@ async def _generate_with_and_without_intervention( layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) direction = self._get_intervention_direction(record).to(device) - direction = direction.unsqueeze(0).unsqueeze(0) # Shape for broadcasting: [1, 1, D] + direction = direction.unsqueeze(0).unsqueeze( + 0 + ) # Shape for broadcasting: [1, 1, D] def hook_fn(module, inp, out): hidden_states = out[0] if isinstance(out, tuple) else out - + # Apply intervention to the last token's hidden state hidden_states[:, -1:, :] += self.strength * direction - + # Return the modified activations in their original format if isinstance(out, tuple): return (hidden_states,) + out[1:] @@ -285,13 +305,15 @@ def hook_fn(module, inp, out): async def _score_explanation(self, generated_text: str, explanation: str) -> float: """Computes log P(explanation | generated_text) under the subject model.""" device = self._get_device() - + # Create the full input sequence: context + explanation context_enc = self.tokenizer(generated_text, return_tensors="pt") explanation_enc = self.tokenizer(explanation, return_tensors="pt") - - full_input_ids = torch.cat([context_enc.input_ids, explanation_enc.input_ids], dim=1).to(device) - + + full_input_ids = torch.cat( + [context_enc.input_ids, explanation_enc.input_ids], dim=1 + ).to(device) + with torch.no_grad(): outputs = self.subject_model(full_input_ids) logits = outputs.logits @@ -299,16 +321,16 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo # We only need to score the explanation part context_len = context_enc.input_ids.shape[1] # Get logits for positions that predict the explanation tokens - explanation_logits = logits[:, context_len - 1:-1, :] - + explanation_logits = logits[:, context_len - 1 : -1, :] + # Get the target token IDs for the explanation target_ids = explanation_enc.input_ids.to(device) - + log_probs = F.log_softmax(explanation_logits, dim=-1) - + # Gather the log-probabilities of the actual explanation tokens token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1) - + return token_log_probs.sum().item() def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: @@ -324,9 +346,11 @@ def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) candidate = self.explainer_model.get(hookpoint_str) - if hasattr(candidate, 'get_feature_vector'): + if hasattr(candidate, "get_feature_vector"): sae = candidate - elif hasattr(candidate, 'sae') and hasattr(candidate.sae, 'get_feature_vector'): + elif hasattr(candidate, "sae") and hasattr( + candidate.sae, "get_feature_vector" + ): sae = candidate.sae if sae: @@ -342,16 +366,17 @@ def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: def _estimate_direction_from_examples(self, record: LatentRecord) -> torch.Tensor: """Estimates an intervention direction by averaging activations.""" device = self._get_device() - + examples = self._sanitize_examples(getattr(record, "test", []) or []) if not examples: hidden_dim = self.subject_model.config.hidden_size return torch.zeros(hidden_dim, device=device) captured_activations = [] + def capture_hook(module, inp, out): hidden_states = out[0] if isinstance(out, tuple) else out - + captured_activations.append(hidden_states[:, -1, :].detach().cpu()) hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) @@ -359,9 +384,11 @@ def capture_hook(module, inp, out): handle = layer_to_hook.register_forward_hook(capture_hook) try: - for ex in examples[:min(8, self.num_prompts)]: + for ex in examples[: min(8, self.num_prompts)]: prompt = "".join(ex["str_tokens"]) - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device) + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( + device + ) with torch.no_grad(): self.subject_model(input_ids) finally: @@ -373,5 +400,5 @@ def capture_hook(module, inp, out): activations = torch.cat(captured_activations, dim=0).to(device) direction = activations.mean(dim=0) - - return F.normalize(direction, p=2, dim=0) \ No newline at end of file + + return F.normalize(direction, p=2, dim=0) From d3d269e53d678643ea8bb61a75de51e7003f3cbe Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Fri, 29 Aug 2025 15:51:10 +0000 Subject: [PATCH 07/11] Fix pre-commit --- delphi/log/result_analysis.py | 13 +++++---- delphi/scorers/__init__.py | 4 --- .../surprisal_intervention_scorer.py | 29 +++++++++++-------- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index bffa7f6b..9f8c4e65 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -16,7 +16,8 @@ def plot_firing_vs_f1( for module, module_df in latent_df.groupby("module"): if 'firing_count' not in module_df.columns: - print(f"WARNING: 'firing_count' column not found for module {module}. Skipping plot.") + print(f"""WARNING: 'firing_count' column not found for module {module}. + Skipping plot.""") continue module_df = module_df.copy() @@ -49,7 +50,7 @@ def import_plotly(): def compute_auc(df: pd.DataFrame) -> float | None: - # Filter for rows where probability is not None and there's more than one unique value + valid_df = df[df.probability.notna()] if valid_df.probability.nunique() <= 1: return None @@ -72,7 +73,7 @@ def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): def plot_roc_curve(df: pd.DataFrame, out_dir: Path): - # Filter for rows where probability is not None and there's more than one unique value + valid_df = df[df.probability.notna()] if valid_df.empty or valid_df.activating.nunique() <= 1 or valid_df.probability.nunique() <= 1: return @@ -145,7 +146,8 @@ def parse_score_file(path: Path) -> pd.DataFrame: return pd.DataFrame() if not isinstance(data, list): - print(f"Warning: Expected a list of results in {path}, but found {type(data)}. Skipping file.") + print(f"""Warning: Expected a list of results in {path}, but found {type(data)}. + Skipping file.""") return pd.DataFrame() latent_idx = int(path.stem.split("latent")[-1]) @@ -285,7 +287,8 @@ def log_results( print(f"F1 Score: {score_type_summary['f1_score']:.3f}") if counts and score_type_summary['weighted_f1'] is not None: - print(f"Frequency-Weighted F1 Score: {score_type_summary['weighted_f1']:.3f}") + print(f"""Frequency-Weighted F1 Score: + {score_type_summary['weighted_f1']:.3f}""") print(f"Precision: {score_type_summary['precision']:.3f}") print(f"Recall: {score_type_summary['recall']:.3f}") diff --git a/delphi/scorers/__init__.py b/delphi/scorers/__init__.py index ad84c15f..3bdd9d4d 100644 --- a/delphi/scorers/__init__.py +++ b/delphi/scorers/__init__.py @@ -6,8 +6,6 @@ from .scorer import Scorer from .simulator.oai_simulator import OpenAISimulator from .surprisal.surprisal import SurprisalScorer -from .intervention.intervention_scorer import InterventionScorer -from .intervention.logprob_intervention_scorer import LogProbInterventionScorer from .intervention.surprisal_intervention_scorer import SurprisalInterventionScorer __all__ = [ @@ -20,7 +18,5 @@ "IntruderScorer", "ExampleEmbeddingScorer", "SurprisalInterventionScorer", - "InterventionScorer", - "LogProbInterventionScorer", ] diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index c683cca9..a1fe6618 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -18,7 +18,7 @@ class SurprisalInterventionResult: Attributes: score: The final computed score. - avg_kl: The average KL divergence between the clean and intervened next-token distributions. + avg_kl: The average KL-D between clean & intervened next-token distributions. explanation: The explanation string that was scored. """ score: float @@ -36,18 +36,19 @@ class SurprisalInterventionScorer(Scorer): by the intervention's strength, measured by the KL divergence between the clean and intervened next-token distributions. - Reference: Paulo et al., "Automatically Interpreting Millions of Features in Large Language Models" + Reference: Paulo et al., "Automatically Interpreting Millions of Features in LLMs" (https://arxiv.org/pdf/2410.13928), Section 3.3.5[cite: 206, 207]. Pipeline: 1. For a small set of activating prompts: a. Generate a continuation and get the next-token distribution ("clean"). - b. Add a directional vector for the feature to the activations and repeat ("intervened"). + b. Add directional vector for the feature to the activations ("intervened"). 2. Compute the log-probability of the explanation conditioned on both the clean and intervened generated texts: log P(explanation | text)[cite: 209]. - 3. Compute the KL divergence between the clean and intervened next-token distributions[cite: 216]. - 4. The final score is the mean change in explanation log-prob, divided by the mean KL divergence: - score = mean(log_prob_intervened - log_prob_clean) / (mean_KL + ε)[cite: 209]. + 3. Compute KL divergence between the clean & intervened next-token distributions. + 4. The final score is the mean change in explanation log-prob, divided by the + mean KL divergence: + score = mean(log_prob_intervened - log_prob_clean) / (mean_KL + ε). """ name = "surprisal_intervention" @@ -55,12 +56,13 @@ def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): """ Args: subject_model: The language model to generate from and score with. - explainer_model: An optional model (e.g., an SAE) used to get feature directions. + explainer_model: A model (e.g., an SAE) used to get feature directions. **kwargs: Configuration options. strength (float): The magnitude of the intervention. Default: 5.0. num_prompts (int): Number of activating examples to test. Default: 3. - max_new_tokens (int): Max tokens to generate for continuations. Default: 20. - hookpoint (str): The module name (e.g., 'transformer.h.10.mlp') for the intervention. + max_new_tokens (int): Max tokens to generate for continuations. + hookpoint (str): The module name (e.g., 'transformer.h.10.mlp') + for the intervention. """ self.subject_model = subject_model self.explainer_model = explainer_model @@ -118,7 +120,8 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: if not is_valid_format: if len(parts) == 1 and hasattr(model, hookpoint_str): return getattr(model, hookpoint_str) - raise ValueError(f"Hookpoint string '{hookpoint_str}' is not in a recognized format like 'layers.6.mlp'.") + raise ValueError(f"""Hookpoint string '{hookpoint_str}' is not in a recognized format + like 'layers.6.mlp'.""") #Heuristically find the model prefix. prefix = None @@ -134,7 +137,8 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: try: return self._find_layer(model, full_path) except AttributeError as e: - raise AttributeError(f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}") + raise AttributeError(f"""Could not resolve path '{full_path}'. + Model structure might be unexpected. Original error: {e}""") def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: @@ -230,7 +234,8 @@ async def _generate_with_and_without_intervention( Returns: A tuple containing: - The generated text (string). - - The log-probability distribution for the token immediately following the prompt (Tensor). + - The log-probability distribution for the token immediately following + the prompt (Tensor). """ device = self._get_device() enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True) From 126b97814f93f440a32ff11ef980110fda4a5dce Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Fri, 29 Aug 2025 15:55:24 +0000 Subject: [PATCH 08/11] Fix pre-commit --- delphi/log/result_analysis.py | 2 +- delphi/scorers/__init__.py | 1 + delphi/scorers/intervention/surprisal_intervention_scorer.py | 3 ++- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 9f8c4e65..52f4b70f 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -16,7 +16,7 @@ def plot_firing_vs_f1( for module, module_df in latent_df.groupby("module"): if 'firing_count' not in module_df.columns: - print(f"""WARNING: 'firing_count' column not found for module {module}. + print(f"""WARNING:'firing_count' column not found for module {module}. Skipping plot.""") continue diff --git a/delphi/scorers/__init__.py b/delphi/scorers/__init__.py index 3bdd9d4d..84a98012 100644 --- a/delphi/scorers/__init__.py +++ b/delphi/scorers/__init__.py @@ -8,6 +8,7 @@ from .surprisal.surprisal import SurprisalScorer from .intervention.surprisal_intervention_scorer import SurprisalInterventionScorer + __all__ = [ "FuzzingScorer", "OpenAISimulator", diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index a1fe6618..387ed0eb 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -18,7 +18,8 @@ class SurprisalInterventionResult: Attributes: score: The final computed score. - avg_kl: The average KL-D between clean & intervened next-token distributions. + avg_kl: The average KL divergence between clean & intervened + next-token distributions. explanation: The explanation string that was scored. """ score: float From 8e893e063899ac485444fc366849dec33b623839 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Aug 2025 16:01:34 +0000 Subject: [PATCH 09/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- delphi/log/result_analysis.py | 8 +++---- delphi/scorers/__init__.py | 3 +-- .../surprisal_intervention_scorer.py | 21 +++++++++++-------- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index ae1aca3b..65334e52 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -16,7 +16,7 @@ def plot_firing_vs_f1( for module, module_df in latent_df.groupby("module"): if 'firing_count' not in module_df.columns: - print(f"WARNING: 'firing_count' column not found for module {module}. + print(f"WARNING: 'firing_count' column not found for module {module}. Skipping plot.") continue @@ -175,7 +175,7 @@ def parse_score_file(path: Path) -> pd.DataFrame: return pd.DataFrame() if not isinstance(data, list): - print(f"""Warning: Expected a list of results in {path}, but found {type(data)}. + print(f"""Warning: Expected a list of results in {path}, but found {type(data)}. Skipping file.""") return pd.DataFrame() @@ -327,9 +327,9 @@ def log_results( print(f"F1 Score: {score_type_summary['f1_score']:.3f}") if counts and score_type_summary['weighted_f1'] is not None: - print(f"""Frequency-Weighted F1 Score: + print(f"""Frequency-Weighted F1 Score: {score_type_summary['weighted_f1']:.3f}""") - + print(f"Precision: {score_type_summary['precision']:.3f}") print(f"Recall: {score_type_summary['recall']:.3f}") diff --git a/delphi/scorers/__init__.py b/delphi/scorers/__init__.py index 3bdd9d4d..1191688c 100644 --- a/delphi/scorers/__init__.py +++ b/delphi/scorers/__init__.py @@ -3,10 +3,10 @@ from .classifier.intruder import IntruderScorer from .embedding.embedding import EmbeddingScorer from .embedding.example_embedding import ExampleEmbeddingScorer +from .intervention.surprisal_intervention_scorer import SurprisalInterventionScorer from .scorer import Scorer from .simulator.oai_simulator import OpenAISimulator from .surprisal.surprisal import SurprisalScorer -from .intervention.surprisal_intervention_scorer import SurprisalInterventionScorer __all__ = [ "FuzzingScorer", @@ -18,5 +18,4 @@ "IntruderScorer", "ExampleEmbeddingScorer", "SurprisalInterventionScorer", - ] diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index 06eb2300..191d4c6f 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -17,7 +17,7 @@ class SurprisalInterventionResult: Attributes: score: The final computed score. - avg_kl: The average KL divergence between clean & intervened + avg_kl: The average KL divergence between clean & intervened next-token distributions. explanation: The explanation string that was scored. """ @@ -47,7 +47,7 @@ class SurprisalInterventionScorer(Scorer): 2. Compute the log-probability of the explanation conditioned on both the clean and intervened generated texts: log P(explanation | text)[cite: 209]. 3. Compute KL divergence between the clean & intervened next-token distributions. - 4. The final score is the mean change in explanation log-prob, divided by the + 4. The final score is the mean change in explanation log-prob, divided by the mean KL divergence: score = mean(log_prob_intervened - log_prob_clean) / (mean_KL + ε). """ @@ -63,7 +63,7 @@ def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): strength (float): The magnitude of the intervention. Default: 5.0. num_prompts (int): Number of activating examples to test. Default: 3. max_new_tokens (int): Max tokens to generate for continuations. - hookpoint (str): The module name (e.g., 'transformer.h.10.mlp') + hookpoint (str): The module name (e.g., 'transformer.h.10.mlp') for the intervention. """ self.subject_model = subject_model @@ -121,9 +121,11 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: if not is_valid_format: if len(parts) == 1 and hasattr(model, hookpoint_str): - return getattr(model, hookpoint_str) - raise ValueError(f"""Hookpoint string '{hookpoint_str}' is not in a recognized format - like 'layers.6.mlp'.""") + return getattr(model, hookpoint_str) + raise ValueError( + f"""Hookpoint string '{hookpoint_str}' is not in a recognized format + like 'layers.6.mlp'.""" + ) # Heuristically find the model prefix. prefix = None @@ -139,8 +141,9 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: try: return self._find_layer(model, full_path) except AttributeError as e: - raise AttributeError(f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}") - + raise AttributeError( + f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}" + ) def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: """ @@ -249,7 +252,7 @@ async def _generate_with_and_without_intervention( Returns: A tuple containing: - The generated text (string). - - The log-probability distribution for the token immediately following + - The log-probability distribution for the token immediately following the prompt (Tensor). """ device = self._get_device() From 2a546e09ce3bc59a6ec0fffb11162972cdb34052 Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Fri, 29 Aug 2025 16:07:31 +0000 Subject: [PATCH 10/11] Fix EOFs --- delphi/log/result_analysis.py | 9 +++++---- .../intervention/surprisal_intervention_scorer.py | 4 +++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index ae1aca3b..c8c8a48e 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -16,8 +16,8 @@ def plot_firing_vs_f1( for module, module_df in latent_df.groupby("module"): if 'firing_count' not in module_df.columns: - print(f"WARNING: 'firing_count' column not found for module {module}. - Skipping plot.") + print(f"""WARNING: 'firing_count' column not found for module {module}. + Skipping plot.""") continue module_df = module_df.copy() @@ -175,8 +175,9 @@ def parse_score_file(path: Path) -> pd.DataFrame: return pd.DataFrame() if not isinstance(data, list): - print(f"""Warning: Expected a list of results in {path}, but found {type(data)}. - Skipping file.""") + print(f"""Warning: Expected a list of results in {path}, + but found {type(data)}. + Skipping file.""") return pd.DataFrame() latent_idx = int(path.stem.split("latent")[-1]) diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index 06eb2300..66421eef 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -139,7 +139,9 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: try: return self._find_layer(model, full_path) except AttributeError as e: - raise AttributeError(f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}") + raise AttributeError(f"""Could not resolve path '{full_path}'. + Model structure might be unexpected. + Original error: {e}""") def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: From 6a6368c559d26be30d3f1dbe416388d9aa605d40 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Aug 2025 16:09:54 +0000 Subject: [PATCH 11/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- delphi/log/result_analysis.py | 24 ++++++++++++------- .../surprisal_intervention_scorer.py | 9 +++---- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 07fec861..4852e8d6 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -15,9 +15,11 @@ def plot_firing_vs_f1( out_dir.mkdir(parents=True, exist_ok=True) for module, module_df in latent_df.groupby("module"): - if 'firing_count' not in module_df.columns: - print(f"""WARNING: 'firing_count' column not found for module {module}. - Skipping plot.""") + if "firing_count" not in module_df.columns: + print( + f"""WARNING: 'firing_count' column not found for module {module}. + Skipping plot.""" + ) continue module_df = module_df.copy() @@ -175,9 +177,11 @@ def parse_score_file(path: Path) -> pd.DataFrame: return pd.DataFrame() if not isinstance(data, list): - print(f"""Warning: Expected a list of results in {path}, - but found {type(data)}. - Skipping file.""") + print( + f"""Warning: Expected a list of results in {path}, + but found {type(data)}. + Skipping file.""" + ) return pd.DataFrame() latent_idx = int(path.stem.split("latent")[-1]) @@ -327,9 +331,11 @@ def log_results( print(f"Class-Balanced Accuracy: {score_type_summary['accuracy']:.3f}") print(f"F1 Score: {score_type_summary['f1_score']:.3f}") - if counts and score_type_summary['weighted_f1'] is not None: - print(f"""Frequency-Weighted F1 Score: - {score_type_summary['weighted_f1']:.3f}""") + if counts and score_type_summary["weighted_f1"] is not None: + print( + f"""Frequency-Weighted F1 Score: + {score_type_summary['weighted_f1']:.3f}""" + ) print(f"Precision: {score_type_summary['precision']:.3f}") print(f"Recall: {score_type_summary['recall']:.3f}") diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index 1fd5ef14..b47d61a7 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -141,10 +141,11 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: try: return self._find_layer(model, full_path) except AttributeError as e: - raise AttributeError(f"""Could not resolve path '{full_path}'. - Model structure might be unexpected. - Original error: {e}""") - + raise AttributeError( + f"""Could not resolve path '{full_path}'. + Model structure might be unexpected. + Original error: {e}""" + ) def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: """