Skip to content

Commit c48e40b

Browse files
xray bugfix (#276)
* xray bugfix duplicate seeds * Backward compatibility with bgym stepInfo * clean up xray-fixes --------- Co-authored-by: Aman Jaiswal <[email protected]>
1 parent 6522057 commit c48e40b

File tree

1 file changed

+28
-32
lines changed

1 file changed

+28
-32
lines changed

src/agentlab/analyze/agent_xray.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414
import pandas as pd
1515
from attr import dataclass
16+
from browsergym.experiments.loop import StepInfo as BGymStepInfo
1617
from langchain.schema import BaseMessage, HumanMessage
1718
from openai import OpenAI
1819
from openai.types.responses import ResponseFunctionToolCall
@@ -74,6 +75,7 @@ class EpisodeId:
7475
agent_id: str = None
7576
task_name: str = None
7677
seed: int = None
78+
row_index: int = None # unique row index to disambiguate selections
7779

7880

7981
@dataclass
@@ -99,24 +101,9 @@ def update_exp_result(self, episode_id: EpisodeId):
99101
if self.result_df is None or episode_id.task_name is None or episode_id.seed is None:
100102
self.exp_result = None
101103

102-
# find unique row for task_name and seed
104+
# find unique row using idx
103105
result_df = self.agent_df.reset_index(inplace=False)
104-
sub_df = result_df[
105-
(result_df[TASK_NAME_KEY] == episode_id.task_name)
106-
& (result_df[TASK_SEED_KEY] == episode_id.seed)
107-
]
108-
if len(sub_df) == 0:
109-
self.exp_result = None
110-
raise ValueError(
111-
f"Could not find task_name: {episode_id.task_name} and seed: {episode_id.seed}"
112-
)
113-
114-
if len(sub_df) > 1:
115-
warning(
116-
f"Found multiple rows for task_name: {episode_id.task_name} and seed: {episode_id.seed}. Using the first one."
117-
)
118-
119-
exp_dir = sub_df.iloc[0]["exp_dir"]
106+
exp_dir = result_df.iloc[episode_id.row_index]["exp_dir"]
120107
print(exp_dir)
121108
self.exp_result = ExpResult(exp_dir)
122109
self.step = 0
@@ -128,16 +115,15 @@ def get_agent_id(self, row: pd.Series):
128115
return agent_id
129116

130117
def filter_agent_id(self, agent_id: list[tuple]):
131-
# query_str = " & ".join([f"`{col}` == {repr(val)}" for col, val in agent_id])
132-
# agent_df = info.result_df.query(query_str)
133-
134-
agent_df = self.result_df.reset_index(inplace=False)
135-
agent_df.set_index(TASK_NAME_KEY, inplace=True)
118+
# Preserve a stable row index to disambiguate selections later
119+
tmp_df = self.result_df.reset_index(inplace=False)
120+
tmp_df["_row_index"] = tmp_df.index
121+
tmp_df.set_index(TASK_NAME_KEY, inplace=True)
136122

137123
for col, val in agent_id:
138124
col = col.replace(".\n", ".")
139-
agent_df = agent_df[agent_df[col] == val]
140-
self.agent_df = agent_df
125+
tmp_df = tmp_df[tmp_df[col] == val]
126+
self.agent_df = tmp_df
141127

142128

143129
info = Info()
@@ -735,7 +721,7 @@ def dict_msg_to_markdown(d: dict):
735721
case _:
736722
parts.append(f"\n```\n{str(item)}\n```\n")
737723

738-
markdown = f"### {d["role"].capitalize()}\n"
724+
markdown = f"### {d['role'].capitalize()}\n"
739725
markdown += "\n".join(parts)
740726
return markdown
741727

@@ -1003,14 +989,17 @@ def get_seeds_df(result_df: pd.DataFrame, task_name: str):
1003989
def extract_columns(row: pd.Series):
1004990
return pd.Series(
1005991
{
1006-
"seed": row[TASK_SEED_KEY],
992+
"idx": row.get("_row_index", None),
993+
"seed": row.get(TASK_SEED_KEY, None),
1007994
"reward": row.get("cum_reward", None),
1008995
"err": bool(row.get("err_msg", None)),
1009996
"n_steps": row.get("n_steps", None),
1010997
}
1011998
)
1012999

10131000
seed_df = result_df.apply(extract_columns, axis=1)
1001+
# Ensure column order and readability
1002+
seed_df = seed_df[["seed", "reward", "err", "n_steps", "idx"]]
10141003
return seed_df
10151004

10161005

@@ -1028,15 +1017,20 @@ def on_select_task(evt: gr.SelectData, df: pd.DataFrame, agent_id: list[tuple]):
10281017
def update_seeds(agent_task_id: tuple):
10291018
agent_id, task_name = agent_task_id
10301019
seed_df = get_seeds_df(info.agent_df, task_name)
1031-
first_seed = seed_df.iloc[0]["seed"]
1032-
return seed_df, EpisodeId(agent_id=agent_id, task_name=task_name, seed=first_seed)
1020+
first_seed = int(seed_df.iloc[0]["seed"])
1021+
first_index = int(seed_df.iloc[0]["idx"])
1022+
return seed_df, EpisodeId(
1023+
agent_id=agent_id, task_name=task_name, seed=first_seed, row_index=first_index
1024+
)
10331025

10341026

10351027
def on_select_seed(evt: gr.SelectData, df: pd.DataFrame, agent_task_id: tuple):
10361028
agent_id, task_name = agent_task_id
10371029
col_idx = df.columns.get_loc("seed")
1038-
seed = evt.row_value[col_idx] # seed should be the first column
1039-
return EpisodeId(agent_id=agent_id, task_name=task_name, seed=seed)
1030+
idx_col = df.columns.get_loc("idx")
1031+
seed = evt.row_value[col_idx]
1032+
row_index = evt.row_value[idx_col]
1033+
return EpisodeId(agent_id=agent_id, task_name=task_name, seed=seed, row_index=row_index)
10401034

10411035

10421036
def new_episode(episode_id: EpisodeId, progress=gr.Progress()):
@@ -1134,7 +1128,7 @@ def new_exp_dir(study_names: list, progress=gr.Progress(), just_refresh=False):
11341128
study_names.remove(select_dir_instructions)
11351129

11361130
if len(study_names) == 0:
1137-
return None, None
1131+
return None, None, None, None, None, None
11381132

11391133
info.study_dirs = [info.results_dir / study_name.split(" - ")[0] for study_name in study_names]
11401134
info.result_df = inspect_results.load_result_df(info.study_dirs, progress_fn=progress.tqdm)
@@ -1287,7 +1281,9 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr
12871281
all_times = []
12881282
step_times = []
12891283
for i, step_info in progress_fn(list(enumerate(step_info_list)), desc="Building plot."):
1290-
assert isinstance(step_info, StepInfo), f"Expected StepInfo, got {type(step_info)}"
1284+
assert isinstance(
1285+
step_info, (StepInfo, BGymStepInfo)
1286+
), f"Expected StepInfo or BGymStepInfo, got {type(step_info)}"
12911287
step = step_info.step
12921288

12931289
prof = deepcopy(step_info.profiling)

0 commit comments

Comments
 (0)