Skip to content
Open
47 changes: 41 additions & 6 deletions delphi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from simple_parsing import ArgumentParser
from torch import Tensor
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
PreTrainedModel,
Expand All @@ -27,7 +27,12 @@
from delphi.latents.neighbours import NeighbourCalculator
from delphi.log.result_analysis import log_results
from delphi.pipeline import Pipe, Pipeline, process_wrapper
from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator
from delphi.scorers import (
DetectionScorer,
FuzzingScorer,
OpenAISimulator,
SurprisalInterventionScorer,
)
from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders
from delphi.utils import assert_type, load_tokenized_data

Expand All @@ -40,7 +45,7 @@ def load_artifacts(run_cfg: RunConfig):
else:
dtype = "auto"

model = AutoModel.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
run_cfg.model,
device_map={"": "cuda"},
quantization_config=(
Expand Down Expand Up @@ -118,6 +123,8 @@ async def process_cache(
hookpoints: list[str],
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
latent_range: Tensor | None,
model,
hookpoint_to_sparse_encode,
):
"""
Converts SAE latent activations in on-disk cache in the `latents_path` directory
Expand Down Expand Up @@ -219,6 +226,12 @@ def none_postprocessor(result):
)
)

def custom_serializer(obj):
"""A custom serializer for orjson to handle specific types."""
if isinstance(obj, Tensor):
return obj.tolist()
raise TypeError

# Builds the record from result returned by the pipeline
def scorer_preprocess(result):
if isinstance(result, list):
Expand All @@ -230,11 +243,18 @@ def scorer_preprocess(result):
return record

# Saves the score to a file
def scorer_postprocess(result, score_dir):
# In your __main__.py file

def scorer_postprocess(result, score_dir, scorer_name=None):
if isinstance(result, list):
if not result:
return
result = result[0]

safe_latent_name = str(result.record.latent).replace("/", "--")

with open(score_dir / f"{safe_latent_name}.txt", "wb") as f:
f.write(orjson.dumps(result.score))
f.write(orjson.dumps(result.score, default=custom_serializer))

scorers = []
for scorer_name in run_cfg.scorers:
Expand All @@ -257,6 +277,16 @@ def scorer_postprocess(result, score_dir):
verbose=run_cfg.verbose,
log_prob=run_cfg.log_probs,
)

elif scorer_name == "surprisal_intervention":
scorer = SurprisalInterventionScorer(
model,
hookpoint_to_sparse_encode,
hookpoints=run_cfg.hookpoints,
n_examples_shown=run_cfg.num_examples_per_scorer_prompt,
verbose=run_cfg.verbose,
log_prob=run_cfg.log_probs,
)
else:
raise ValueError(f"Scorer {scorer_name} not supported")

Expand Down Expand Up @@ -396,6 +426,8 @@ async def run(
hookpoints, hookpoint_to_sparse_encode, model, transcode = load_artifacts(run_cfg)
tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token)

model.tokenizer = tokenizer

nrh = assert_type(
dict,
non_redundant_hookpoints(
Expand All @@ -412,7 +444,6 @@ async def run(
transcode,
)

del model, hookpoint_to_sparse_encode
if run_cfg.constructor_cfg.non_activating_source == "neighbours":
nrh = assert_type(
list,
Expand Down Expand Up @@ -445,8 +476,12 @@ async def run(
nrh,
tokenizer,
latent_range,
model,
hookpoint_to_sparse_encode,
)

del model, hookpoint_to_sparse_encode

if run_cfg.verbose:
log_results(scores_path, visualize_path, run_cfg.hookpoints, run_cfg.scorers)

Expand Down
10 changes: 3 additions & 7 deletions delphi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,14 @@ class RunConfig(Serializable):
the default single token explainer, and 'none' for no explanation generation."""

scorers: list[str] = list_field(
choices=[
"fuzz",
"detection",
"simulation",
],
choices=["fuzz", "detection", "simulation", "surprisal_intervention"],
default=[
"fuzz",
"detection",
],
)
"""Scorer methods to score latent explanations. Options are 'fuzz', 'detection', and
'simulation'."""
"""Scorer methods to score latent explanations. Options are 'fuzz', 'detection',
'simulation' and 'surprisal_intervention'."""

name: str = ""
"""The name of the run. Results are saved in a directory with this name."""
Expand Down
7 changes: 7 additions & 0 deletions delphi/latents/latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ class LatentRecord:
"""Frequency of the latent. Number of activations in a context per total
number of contexts."""

@property
def feature_id(self) -> int:
"""
Returns the unique feature index for this latent.
"""
return self.latent.latent_index

@property
def max_activation(self) -> float:
"""
Expand Down
Loading