20
20
sys .path .append ("../language-modeling/" )
21
21
from run_time_clm import get_special_tokens
22
22
23
- def get_classification_model (model_args ):
24
- model_path = model_args .classification_model
25
- cache_dir = "/nlp/scr/rewang/huggingface/"
26
- tokenizer = PreTrainedTokenizerFast .from_pretrained (model_path )
27
- model = BertForSequenceClassification .from_pretrained (model_path )
28
- model .to (model_args .device )
29
- # tokenizer = AutoTokenizer.from_pretrained(
30
- # model_path,
31
- # cache_dir=cache_dir,
32
- # use_fast=True, # model_args.use_fast_tokenizer,
33
- # revision="main", # model_args.model_revision,
34
- # # use_auth_token=True if model_args.use_auth_token else None,
35
- # use_auth_token=None,
36
- # )
37
- # config = AutoConfig.from_pretrained(
38
- # # model_args.config_name if model_args.config_name else model_args.model_name_or_path,
39
- # model_path,
40
- # num_labels=4, # num_labels,
41
- # finetuning_task="wikisection", # data_args.task_name,
42
- # cache_dir=cache_dir,
43
- # revision="main",
44
- # use_auth_token=None, # True if model_args.use_auth_token else None,
45
- # )
46
- # model = AutoModelForSequenceClassification.from_pretrained(
47
- # model_path,
48
- # # model_args.model_name_or_path,
49
- # from_tf=bool(".ckpt" in model_path),
50
- # config=config,
51
- # cache_dir=cache_dir,
52
- # revision="main", # model_args.model_revision,
53
- # use_auth_token=None,
54
- # )
55
- return tokenizer , model
56
-
57
23
class GenerationMetrics :
58
24
59
25
def __init__ (self , model , device , tokenizer , dataset_name , fname ,
@@ -75,12 +41,10 @@ def __init__(self, model, device, tokenizer, dataset_name, fname,
75
41
self .section_ids = self .section_ids [:- 1 ]
76
42
self ._info = []
77
43
self ._examples = []
78
- self ._classification_examples = dict ()
79
44
self .metrics = defaultdict (lambda : [])
80
45
self .examples = {}
81
46
self .fname = fname
82
47
83
- self .classification_tokenizer , self .classification_model = get_classification_model (model_args )
84
48
self .mode = "section" if ("splitsection" in dataset_name ) else "doc"
85
49
86
50
def calculate (self , input_ids , raw_seq , section_name = None ,
@@ -97,7 +61,6 @@ def calculate(self, input_ids, raw_seq, section_name=None,
97
61
self ._examples .append ({'text' : raw_seq })
98
62
99
63
def _stories (self , input_ids , raw_seq ):
100
- # TODO get story classification
101
64
# Check for redundancy in WP and Prompt
102
65
info = {}
103
66
for special_tok , name in zip ([50257 , 50258 ], ['[ WP ]' , '[ RESPONSE ]' ]):
@@ -132,9 +95,6 @@ def _stories(self, input_ids, raw_seq):
132
95
def _track_doc_examples (self , raw_seq ):
133
96
self .examples ['ordering = {}' .format (self .metrics ['ordering' ][- 1 ])] = raw_seq
134
97
135
- for k , v in self ._classification_examples .items ():
136
- self .examples [k ] = v
137
-
138
98
for section_i , section_name in enumerate (self .section_names ):
139
99
is_present = self .metrics ['{} present' .format (section_name )]
140
100
is_redundant = self .metrics ['{} redundant' .format (section_name )]
@@ -186,19 +146,6 @@ def _check_total_length(self, input_ids, info):
186
146
info ['total length' ] = input_ids .shape [- 1 ]
187
147
return info
188
148
189
- def _check_classification (self , raw_seq , info ):
190
- classification_results = self ._get_classification (raw_seq )
191
- histograms = dict ()
192
- for k , v in classification_results .items ():
193
- # if list, create a histogram and include mean
194
- if isinstance (v , list ):
195
- histograms [self .prepend_ + k + " hist" ] = wandb .Histogram (v )
196
- v = np .mean (v )
197
- info [k ] = v
198
- wandb .log (histograms )
199
- return info
200
-
201
-
202
149
def _taskmaster_section_length (self , input_ids , idxs , section_name , info ):
203
150
lengths = []
204
151
other_id = 50258 if 'USER' in section_name else 50257
@@ -300,8 +247,6 @@ def _document(self, input_ids, raw_seq, gt_raw_seq):
300
247
info = {}
301
248
302
249
info = self ._check_total_length (input_ids = input_ids , info = info )
303
- if 'taskmaster' not in self .dataset_name :
304
- info = self ._check_classification (raw_seq = raw_seq , info = info )
305
250
info = self ._check_ordering (input_ids = input_ids , raw_seq = raw_seq , info = info )
306
251
for section_id , section_name in zip (self .section_ids , self .section_names ):
307
252
idxs = (input_ids == section_id ).nonzero (as_tuple = True )
@@ -385,68 +330,6 @@ def _document(self, input_ids, raw_seq, gt_raw_seq):
385
330
386
331
wandb .log (most_recent )
387
332
388
- def _get_classification (self , raw_seq ):
389
- results = defaultdict (lambda : [])
390
- self ._classification_examples = dict ()
391
- raw_seq = raw_seq .replace ("<|endoftext|> " , "" )
392
- split_seq = raw_seq .split (". " )
393
- sec_id = 0
394
- seq_idxs = []
395
- for seq_idx , seq in enumerate (split_seq ):
396
- if not seq :
397
- continue
398
- seq_idxs .append (seq_idx )
399
- seq += "."
400
- for tok in self .section_names :
401
- if tok in seq :
402
- sec_id = self .section_names .index (tok )
403
- seq = seq .replace (tok + " " , "" )
404
- try :
405
- assert tok not in seq
406
- except :
407
- seq = seq .replace (tok , "" )
408
-
409
- tokenized_seq = self .classification_tokenizer (seq , return_tensors = 'pt' ).to (
410
- self .classification_model .device
411
- )
412
- result = self .classification_model (input_ids = tokenized_seq ['input_ids' ][:, :512 ])
413
- probs = torch .nn .functional .softmax (result .logits , dim = 1 )
414
-
415
- acc = int (torch .argmax (probs ) == sec_id )
416
- entropy = - torch .sum (probs * torch .log (probs )).detach ().cpu ().numpy ()
417
- prob_sec_id = probs [0 , sec_id ].detach ().cpu ().numpy ()
418
-
419
- # uniform_p = torch.tensor([0.25]*4)
420
- # y_entropy = -torch.sum(uniform_p * torch.log(uniform_p))
421
- # mi = float(y_entropy - entropy)
422
-
423
- self .metrics ["{} class acc" .format (self .section_names [sec_id ])].append (acc )
424
- self .metrics ["{} class entropy" .format (self .section_names [sec_id ])].append (entropy )
425
- # self.metrics["{} MI".format(self.section_names[sec_id])].append(mi)
426
- self .metrics ["{} p(section_id*|x)" .format (self .section_names [sec_id ])].append (prob_sec_id )
427
- results ["{} class acc" .format (self .section_names [sec_id ])].append (acc )
428
- results ["{} class entropy" .format (self .section_names [sec_id ])].append (entropy )
429
- # results["{} MI".format(self.section_names[sec_id])].append(mi)
430
- results ["{} p(section_id*|x)" .format (self .section_names [sec_id ])].append (prob_sec_id )
431
-
432
- # sentences that are induce high/low acc/entropy/mi
433
- for key , metric in zip (["{} class acc" , "{} class entropy" ], [acc , entropy ,]):
434
- key = key .format (self .section_names [sec_id ])
435
- if results [key ] and max (results [key ]) == metric :
436
- self ._classification_examples [key + " MAX" ] = seq
437
- self .metrics [key + " MAX IDX" ].append (
438
- len (results ["{} class acc" .format (self .section_names [sec_id ])]))
439
- results [key + " MAX IDX" ].append (
440
- len (results ["{} class acc" .format (self .section_names [sec_id ])]))
441
- if results [key ] and min (results [key ]) == metric :
442
- self ._classification_examples [key + " MIN" ] = seq
443
- self .metrics [key + " MIN IDX" ].append (
444
- len (results ["{} class acc" .format (self .section_names [sec_id ])]))
445
- results [key + " MIN IDX" ].append (
446
- len (results ["{} class acc" .format (self .section_names [sec_id ])]))
447
-
448
- return results
449
-
450
333
def print_results (self ):
451
334
print ("Examples" )
452
335
extreme_ex = []
0 commit comments