Skip to content

Commit 271a2b3

Browse files
authored
added hugging face example (#81)
Co-authored-by: John Calderon <[email protected]>
1 parent f371351 commit 271a2b3

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

examples/huggingface/entry_point.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from transformers import (
2+
get_linear_schedule_with_warmup,
3+
AutoModelForCausalLM,
4+
Trainer,
5+
)
6+
import torch
7+
import torch.optim as optim
8+
9+
model_id = "roberta-base"
10+
11+
12+
def deepview_model_provider():
13+
return AutoModelForCausalLM.from_pretrained(model_id, is_decoder=True).cuda()
14+
15+
16+
def deepview_input_provider(batch_size=2):
17+
vocab_size = 30522
18+
src_seq_len = 512
19+
tgt_seq_len = 512
20+
21+
device = torch.device("cuda")
22+
23+
source = torch.randint(
24+
low=0,
25+
high=vocab_size,
26+
size=(batch_size, src_seq_len),
27+
dtype=torch.int64,
28+
device=device,
29+
)
30+
target = torch.randint(
31+
low=0,
32+
high=vocab_size,
33+
size=(batch_size, tgt_seq_len),
34+
dtype=torch.int64,
35+
device=device,
36+
)
37+
return (source, target)
38+
39+
40+
def deepview_iteration_provider(model):
41+
model.parameters()
42+
optimizer = optim.AdamW(
43+
params=model.parameters(),
44+
betas=(0.9, 0.999),
45+
eps=1e-6,
46+
weight_decay=0.01,
47+
lr=1e-4,
48+
)
49+
scheduler = get_linear_schedule_with_warmup(optimizer, 10000, 500000)
50+
trainer = Trainer(model=model, optimizers=(optimizer, scheduler))
51+
52+
def iteration(source, label):
53+
trainer.training_step(model, {"input_ids": source, "labels": label})
54+
55+
return iteration

0 commit comments

Comments
 (0)