1
1
import json
2
2
from pathlib import Path
3
- from typing import Iterable , List , Optional , Union
3
+ from typing import Iterable , List , Optional , Union , Type
4
4
5
5
import torch
6
6
from torch .utils .data import DataLoader
18
18
from xturing .trainers .base import BaseTrainer
19
19
from xturing .trainers .lightning_trainer import LightningTrainer
20
20
from xturing .utils .logging import configure_logger
21
+ from pytorch_lightning .loggers import Logger
21
22
22
23
logger = configure_logger (__name__ )
23
24
@@ -63,7 +64,8 @@ def _make_collate_fn(self, dataset: Union[TextDataset, InstructionDataset]):
63
64
dataset .meta ,
64
65
)
65
66
66
- def _make_trainer (self , dataset : Union [TextDataset , InstructionDataset ]):
67
+ def _make_trainer (self , dataset : Union [TextDataset , InstructionDataset ],
68
+ logger : Union [Logger , Iterable [Logger ], bool ] = True ):
67
69
return BaseTrainer .create (
68
70
LightningTrainer .config_name ,
69
71
self .engine ,
@@ -73,14 +75,16 @@ def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]):
73
75
int (self .finetuning_args .batch_size ),
74
76
float (self .finetuning_args .learning_rate ),
75
77
self .finetuning_args .optimizer_name ,
78
+ logger = logger ,
76
79
)
77
80
78
- def finetune (self , dataset : Union [TextDataset , InstructionDataset ]):
81
+ def finetune (self , dataset : Union [TextDataset , InstructionDataset ],
82
+ logger : Union [Logger , Iterable [Logger ], bool ] = True ):
79
83
assert dataset .config_name in [
80
84
"text_dataset" ,
81
85
"instruction_dataset" ,
82
86
], "Please make sure the dataset_type is text_dataset or instruction_dataset"
83
- trainer = self ._make_trainer (dataset )
87
+ trainer = self ._make_trainer (dataset , logger )
84
88
trainer .fit ()
85
89
86
90
def evaluate (self , dataset : Union [TextDataset , InstructionDataset ]):
@@ -188,7 +192,8 @@ class CausalLoraModel(CausalModel):
188
192
def __init__ (self , engine : str , weights_path : Optional [str ] = None ):
189
193
super ().__init__ (engine , weights_path )
190
194
191
- def _make_trainer (self , dataset : Union [TextDataset , InstructionDataset ]):
195
+ def _make_trainer (self , dataset : Union [TextDataset , InstructionDataset ],
196
+ logger : Union [Logger , Iterable [Logger ], bool ] = True ):
192
197
return BaseTrainer .create (
193
198
LightningTrainer .config_name ,
194
199
self .engine ,
@@ -200,6 +205,7 @@ def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]):
200
205
self .finetuning_args .optimizer_name ,
201
206
True ,
202
207
True ,
208
+ logger = logger ,
203
209
)
204
210
205
211
0 commit comments