1
+ from transformers import (
2
+ get_linear_schedule_with_warmup ,
3
+ AutoModelForCausalLM ,
4
+ Trainer ,
5
+ )
6
+ import torch
7
+ import torch .optim as optim
8
+
9
+ model_id = "roberta-base"
10
+
11
+
12
+ def deepview_model_provider ():
13
+ return AutoModelForCausalLM .from_pretrained (model_id , is_decoder = True ).cuda ()
14
+
15
+
16
+ def deepview_input_provider (batch_size = 2 ):
17
+ vocab_size = 30522
18
+ src_seq_len = 512
19
+ tgt_seq_len = 512
20
+
21
+ device = torch .device ("cuda" )
22
+
23
+ source = torch .randint (
24
+ low = 0 ,
25
+ high = vocab_size ,
26
+ size = (batch_size , src_seq_len ),
27
+ dtype = torch .int64 ,
28
+ device = device ,
29
+ )
30
+ target = torch .randint (
31
+ low = 0 ,
32
+ high = vocab_size ,
33
+ size = (batch_size , tgt_seq_len ),
34
+ dtype = torch .int64 ,
35
+ device = device ,
36
+ )
37
+ return (source , target )
38
+
39
+
40
+ def deepview_iteration_provider (model ):
41
+ model .parameters ()
42
+ optimizer = optim .AdamW (
43
+ params = model .parameters (),
44
+ betas = (0.9 , 0.999 ),
45
+ eps = 1e-6 ,
46
+ weight_decay = 0.01 ,
47
+ lr = 1e-4 ,
48
+ )
49
+ scheduler = get_linear_schedule_with_warmup (optimizer , 10000 , 500000 )
50
+ trainer = Trainer (model = model , optimizers = (optimizer , scheduler ))
51
+
52
+ def iteration (source , label ):
53
+ trainer .training_step (model , {"input_ids" : source , "labels" : label })
54
+
55
+ return iteration
0 commit comments