Skip to content

Commit 5cbc3ee

Browse files
committed
2 parents cb3d345 + 28a8736 commit 5cbc3ee

File tree

3 files changed

+4
-121
lines changed

3 files changed

+4
-121
lines changed

language_modeling_via_stochastic_processes/transformers/examples/pytorch/text-generation/generation_metrics.py

Lines changed: 0 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -20,40 +20,6 @@
2020
sys.path.append("../language-modeling/")
2121
from run_time_clm import get_special_tokens
2222

23-
def get_classification_model(model_args):
24-
model_path = model_args.classification_model
25-
cache_dir = "/nlp/scr/rewang/huggingface/"
26-
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path)
27-
model = BertForSequenceClassification.from_pretrained(model_path)
28-
model.to(model_args.device)
29-
# tokenizer = AutoTokenizer.from_pretrained(
30-
# model_path,
31-
# cache_dir=cache_dir,
32-
# use_fast=True, # model_args.use_fast_tokenizer,
33-
# revision="main", # model_args.model_revision,
34-
# # use_auth_token=True if model_args.use_auth_token else None,
35-
# use_auth_token=None,
36-
# )
37-
# config = AutoConfig.from_pretrained(
38-
# # model_args.config_name if model_args.config_name else model_args.model_name_or_path,
39-
# model_path,
40-
# num_labels=4, # num_labels,
41-
# finetuning_task="wikisection", # data_args.task_name,
42-
# cache_dir=cache_dir,
43-
# revision="main",
44-
# use_auth_token=None, # True if model_args.use_auth_token else None,
45-
# )
46-
# model = AutoModelForSequenceClassification.from_pretrained(
47-
# model_path,
48-
# # model_args.model_name_or_path,
49-
# from_tf=bool(".ckpt" in model_path),
50-
# config=config,
51-
# cache_dir=cache_dir,
52-
# revision="main", # model_args.model_revision,
53-
# use_auth_token=None,
54-
# )
55-
return tokenizer, model
56-
5723
class GenerationMetrics:
5824

5925
def __init__(self, model, device, tokenizer, dataset_name, fname,
@@ -75,12 +41,10 @@ def __init__(self, model, device, tokenizer, dataset_name, fname,
7541
self.section_ids = self.section_ids[:-1]
7642
self._info = []
7743
self._examples = []
78-
self._classification_examples = dict()
7944
self.metrics = defaultdict(lambda: [])
8045
self.examples = {}
8146
self.fname = fname
8247

83-
self.classification_tokenizer, self.classification_model = get_classification_model(model_args)
8448
self.mode = "section" if ("splitsection" in dataset_name) else "doc"
8549

8650
def calculate(self, input_ids, raw_seq, section_name=None,
@@ -97,7 +61,6 @@ def calculate(self, input_ids, raw_seq, section_name=None,
9761
self._examples.append({'text': raw_seq})
9862

9963
def _stories(self, input_ids, raw_seq):
100-
# TODO get story classification
10164
# Check for redundancy in WP and Prompt
10265
info = {}
10366
for special_tok, name in zip([50257, 50258], ['[ WP ]', '[ RESPONSE ]']):
@@ -132,9 +95,6 @@ def _stories(self, input_ids, raw_seq):
13295
def _track_doc_examples(self, raw_seq):
13396
self.examples['ordering = {}'.format(self.metrics['ordering'][-1])] = raw_seq
13497

135-
for k, v in self._classification_examples.items():
136-
self.examples[k] = v
137-
13898
for section_i, section_name in enumerate(self.section_names):
13999
is_present = self.metrics['{} present'.format(section_name)]
140100
is_redundant = self.metrics['{} redundant'.format(section_name)]
@@ -186,19 +146,6 @@ def _check_total_length(self, input_ids, info):
186146
info['total length'] = input_ids.shape[-1]
187147
return info
188148

189-
def _check_classification(self, raw_seq, info):
190-
classification_results = self._get_classification(raw_seq)
191-
histograms = dict()
192-
for k, v in classification_results.items():
193-
# if list, create a histogram and include mean
194-
if isinstance(v, list):
195-
histograms[self.prepend_ + k + " hist"] = wandb.Histogram(v)
196-
v = np.mean(v)
197-
info[k] = v
198-
wandb.log(histograms)
199-
return info
200-
201-
202149
def _taskmaster_section_length(self, input_ids, idxs, section_name, info):
203150
lengths = []
204151
other_id = 50258 if 'USER' in section_name else 50257
@@ -300,8 +247,6 @@ def _document(self, input_ids, raw_seq, gt_raw_seq):
300247
info = {}
301248

302249
info = self._check_total_length(input_ids=input_ids, info=info)
303-
if 'taskmaster' not in self.dataset_name:
304-
info = self._check_classification(raw_seq=raw_seq, info=info)
305250
info = self._check_ordering(input_ids=input_ids, raw_seq=raw_seq, info=info)
306251
for section_id, section_name in zip(self.section_ids, self.section_names):
307252
idxs = (input_ids == section_id).nonzero(as_tuple=True)
@@ -385,68 +330,6 @@ def _document(self, input_ids, raw_seq, gt_raw_seq):
385330

386331
wandb.log(most_recent)
387332

388-
def _get_classification(self, raw_seq):
389-
results = defaultdict(lambda: [])
390-
self._classification_examples = dict()
391-
raw_seq = raw_seq.replace("<|endoftext|> ", "")
392-
split_seq = raw_seq.split(". ")
393-
sec_id = 0
394-
seq_idxs = []
395-
for seq_idx, seq in enumerate(split_seq):
396-
if not seq:
397-
continue
398-
seq_idxs.append(seq_idx)
399-
seq += "."
400-
for tok in self.section_names:
401-
if tok in seq:
402-
sec_id = self.section_names.index(tok)
403-
seq = seq.replace(tok+" ", "")
404-
try:
405-
assert tok not in seq
406-
except:
407-
seq = seq.replace(tok, "")
408-
409-
tokenized_seq = self.classification_tokenizer(seq, return_tensors='pt').to(
410-
self.classification_model.device
411-
)
412-
result = self.classification_model(input_ids=tokenized_seq['input_ids'][:, :512])
413-
probs = torch.nn.functional.softmax(result.logits, dim=1)
414-
415-
acc = int(torch.argmax(probs) == sec_id)
416-
entropy = -torch.sum(probs * torch.log(probs)).detach().cpu().numpy()
417-
prob_sec_id = probs[0, sec_id].detach().cpu().numpy()
418-
419-
# uniform_p = torch.tensor([0.25]*4)
420-
# y_entropy = -torch.sum(uniform_p * torch.log(uniform_p))
421-
# mi = float(y_entropy - entropy)
422-
423-
self.metrics["{} class acc".format(self.section_names[sec_id])].append(acc)
424-
self.metrics["{} class entropy".format(self.section_names[sec_id])].append(entropy)
425-
# self.metrics["{} MI".format(self.section_names[sec_id])].append(mi)
426-
self.metrics["{} p(section_id*|x)".format(self.section_names[sec_id])].append(prob_sec_id)
427-
results["{} class acc".format(self.section_names[sec_id])].append(acc)
428-
results["{} class entropy".format(self.section_names[sec_id])].append(entropy)
429-
# results["{} MI".format(self.section_names[sec_id])].append(mi)
430-
results["{} p(section_id*|x)".format(self.section_names[sec_id])].append(prob_sec_id)
431-
432-
# sentences that are induce high/low acc/entropy/mi
433-
for key, metric in zip(["{} class acc", "{} class entropy"], [acc, entropy,]):
434-
key = key.format(self.section_names[sec_id])
435-
if results[key] and max(results[key]) == metric:
436-
self._classification_examples[key + " MAX"] = seq
437-
self.metrics[key + " MAX IDX"].append(
438-
len(results["{} class acc".format(self.section_names[sec_id])]))
439-
results[key + " MAX IDX"].append(
440-
len(results["{} class acc".format(self.section_names[sec_id])]))
441-
if results[key] and min(results[key]) == metric:
442-
self._classification_examples[key + " MIN"] = seq
443-
self.metrics[key + " MIN IDX"].append(
444-
len(results["{} class acc".format(self.section_names[sec_id])]))
445-
results[key + " MIN IDX"].append(
446-
len(results["{} class acc".format(self.section_names[sec_id])]))
447-
448-
return results
449-
450333
def print_results(self):
451334
print("Examples")
452335
extreme_ex = []

language_modeling_via_stochastic_processes/transformers/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999
"flake8>=3.8.3",
100100
"flax>=0.3.4",
101101
"fugashi>=1.0",
102-
"huggingface-hub==0.0.8",
102+
"huggingface-hub==0.1.0",
103103
"importlib_metadata",
104104
"ipadic>=1.0.0,<2.0",
105105
"isort>=5.5.4",

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
packages=['language_modeling_via_stochastic_processes',],
88
install_requires=[
99
"dotmap==1.3.23",
10-
"datasets=2.0.0",
10+
"datasets==2.0.0",
1111
"hydra-core==1.1.1",
1212
"matplotlib==3.3.4",
1313
"numpy==1.19.2",
@@ -21,7 +21,7 @@
2121
"tensorflow==2.4.1",
2222
"torch",
2323
"torchvision",
24-
"tqdm==4.49.0",
24+
"tqdm==4.62.1",
2525
"wandb==0.10.23",
2626
"numpy",
2727
"Pillow",
@@ -32,4 +32,4 @@
3232
"tqdm",
3333
"packaging",
3434
]
35-
)
35+
)

0 commit comments

Comments
 (0)