13
13
import numpy as np
14
14
import pandas as pd
15
15
from attr import dataclass
16
+ from browsergym .experiments .loop import StepInfo as BGymStepInfo
16
17
from langchain .schema import BaseMessage , HumanMessage
17
18
from openai import OpenAI
18
19
from openai .types .responses import ResponseFunctionToolCall
@@ -74,6 +75,7 @@ class EpisodeId:
74
75
agent_id : str = None
75
76
task_name : str = None
76
77
seed : int = None
78
+ row_index : int = None # unique row index to disambiguate selections
77
79
78
80
79
81
@dataclass
@@ -99,24 +101,9 @@ def update_exp_result(self, episode_id: EpisodeId):
99
101
if self .result_df is None or episode_id .task_name is None or episode_id .seed is None :
100
102
self .exp_result = None
101
103
102
- # find unique row for task_name and seed
104
+ # find unique row using idx
103
105
result_df = self .agent_df .reset_index (inplace = False )
104
- sub_df = result_df [
105
- (result_df [TASK_NAME_KEY ] == episode_id .task_name )
106
- & (result_df [TASK_SEED_KEY ] == episode_id .seed )
107
- ]
108
- if len (sub_df ) == 0 :
109
- self .exp_result = None
110
- raise ValueError (
111
- f"Could not find task_name: { episode_id .task_name } and seed: { episode_id .seed } "
112
- )
113
-
114
- if len (sub_df ) > 1 :
115
- warning (
116
- f"Found multiple rows for task_name: { episode_id .task_name } and seed: { episode_id .seed } . Using the first one."
117
- )
118
-
119
- exp_dir = sub_df .iloc [0 ]["exp_dir" ]
106
+ exp_dir = result_df .iloc [episode_id .row_index ]["exp_dir" ]
120
107
print (exp_dir )
121
108
self .exp_result = ExpResult (exp_dir )
122
109
self .step = 0
@@ -128,16 +115,15 @@ def get_agent_id(self, row: pd.Series):
128
115
return agent_id
129
116
130
117
def filter_agent_id (self , agent_id : list [tuple ]):
131
- # query_str = " & ".join([f"`{col}` == {repr(val)}" for col, val in agent_id])
132
- # agent_df = info.result_df.query(query_str)
133
-
134
- agent_df = self .result_df .reset_index (inplace = False )
135
- agent_df .set_index (TASK_NAME_KEY , inplace = True )
118
+ # Preserve a stable row index to disambiguate selections later
119
+ tmp_df = self .result_df .reset_index (inplace = False )
120
+ tmp_df ["_row_index" ] = tmp_df .index
121
+ tmp_df .set_index (TASK_NAME_KEY , inplace = True )
136
122
137
123
for col , val in agent_id :
138
124
col = col .replace (".\n " , "." )
139
- agent_df = agent_df [ agent_df [col ] == val ]
140
- self .agent_df = agent_df
125
+ tmp_df = tmp_df [ tmp_df [col ] == val ]
126
+ self .agent_df = tmp_df
141
127
142
128
143
129
info = Info ()
@@ -735,7 +721,7 @@ def dict_msg_to_markdown(d: dict):
735
721
case _:
736
722
parts .append (f"\n ```\n { str (item )} \n ```\n " )
737
723
738
- markdown = f"### { d [" role" ].capitalize ()} \n "
724
+ markdown = f"### { d [' role' ].capitalize ()} \n "
739
725
markdown += "\n " .join (parts )
740
726
return markdown
741
727
@@ -1003,14 +989,17 @@ def get_seeds_df(result_df: pd.DataFrame, task_name: str):
1003
989
def extract_columns (row : pd .Series ):
1004
990
return pd .Series (
1005
991
{
1006
- "seed" : row [TASK_SEED_KEY ],
992
+ "idx" : row .get ("_row_index" , None ),
993
+ "seed" : row .get (TASK_SEED_KEY , None ),
1007
994
"reward" : row .get ("cum_reward" , None ),
1008
995
"err" : bool (row .get ("err_msg" , None )),
1009
996
"n_steps" : row .get ("n_steps" , None ),
1010
997
}
1011
998
)
1012
999
1013
1000
seed_df = result_df .apply (extract_columns , axis = 1 )
1001
+ # Ensure column order and readability
1002
+ seed_df = seed_df [["seed" , "reward" , "err" , "n_steps" , "idx" ]]
1014
1003
return seed_df
1015
1004
1016
1005
@@ -1028,15 +1017,20 @@ def on_select_task(evt: gr.SelectData, df: pd.DataFrame, agent_id: list[tuple]):
1028
1017
def update_seeds (agent_task_id : tuple ):
1029
1018
agent_id , task_name = agent_task_id
1030
1019
seed_df = get_seeds_df (info .agent_df , task_name )
1031
- first_seed = seed_df .iloc [0 ]["seed" ]
1032
- return seed_df , EpisodeId (agent_id = agent_id , task_name = task_name , seed = first_seed )
1020
+ first_seed = int (seed_df .iloc [0 ]["seed" ])
1021
+ first_index = int (seed_df .iloc [0 ]["idx" ])
1022
+ return seed_df , EpisodeId (
1023
+ agent_id = agent_id , task_name = task_name , seed = first_seed , row_index = first_index
1024
+ )
1033
1025
1034
1026
1035
1027
def on_select_seed (evt : gr .SelectData , df : pd .DataFrame , agent_task_id : tuple ):
1036
1028
agent_id , task_name = agent_task_id
1037
1029
col_idx = df .columns .get_loc ("seed" )
1038
- seed = evt .row_value [col_idx ] # seed should be the first column
1039
- return EpisodeId (agent_id = agent_id , task_name = task_name , seed = seed )
1030
+ idx_col = df .columns .get_loc ("idx" )
1031
+ seed = evt .row_value [col_idx ]
1032
+ row_index = evt .row_value [idx_col ]
1033
+ return EpisodeId (agent_id = agent_id , task_name = task_name , seed = seed , row_index = row_index )
1040
1034
1041
1035
1042
1036
def new_episode (episode_id : EpisodeId , progress = gr .Progress ()):
@@ -1134,7 +1128,7 @@ def new_exp_dir(study_names: list, progress=gr.Progress(), just_refresh=False):
1134
1128
study_names .remove (select_dir_instructions )
1135
1129
1136
1130
if len (study_names ) == 0 :
1137
- return None , None
1131
+ return None , None , None , None , None , None
1138
1132
1139
1133
info .study_dirs = [info .results_dir / study_name .split (" - " )[0 ] for study_name in study_names ]
1140
1134
info .result_df = inspect_results .load_result_df (info .study_dirs , progress_fn = progress .tqdm )
@@ -1287,7 +1281,9 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr
1287
1281
all_times = []
1288
1282
step_times = []
1289
1283
for i , step_info in progress_fn (list (enumerate (step_info_list )), desc = "Building plot." ):
1290
- assert isinstance (step_info , StepInfo ), f"Expected StepInfo, got { type (step_info )} "
1284
+ assert isinstance (
1285
+ step_info , (StepInfo , BGymStepInfo )
1286
+ ), f"Expected StepInfo or BGymStepInfo, got { type (step_info )} "
1291
1287
step = step_info .step
1292
1288
1293
1289
prof = deepcopy (step_info .profiling )
0 commit comments