@@ -54,6 +54,7 @@ def on_epoch_end(self):
54
54
"GradientClipCallback" ,
55
55
"EarlyStopCallback" ,
56
56
"TensorboardCallback" ,
57
+ "FitlogCallback" ,
57
58
"LRScheduler" ,
58
59
"ControlC" ,
59
60
@@ -65,6 +66,7 @@ def on_epoch_end(self):
65
66
66
67
import torch
67
68
from copy import deepcopy
69
+
68
70
try :
69
71
from tensorboardX import SummaryWriter
70
72
@@ -81,6 +83,7 @@ def on_epoch_end(self):
81
83
except :
82
84
pass
83
85
86
+
84
87
class Callback (object ):
85
88
"""
86
89
别名::class:`fastNLP.Callback` :class:`fastNLP.core.callback.Callback`
@@ -367,16 +370,17 @@ class GradientClipCallback(Callback):
367
370
368
371
每次backward前,将parameter的gradient clip到某个范围。
369
372
370
- :param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。如果为None则默认对Trainer
371
- 的model中所有参数进行clip
373
+ :param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。
374
+ 如果为None则默认对Trainer的model中所有参数进行clip
372
375
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数
373
376
:param str clip_type: 支持'norm', 'value'
374
377
两种::
375
378
376
379
1 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
377
380
378
- 2 'value', 将gradient限制在[-clip_value, clip_value], 小于-clip_value的gradient被赋值为-clip_value;
379
- 大于clip_value的gradient被赋值为clip_value.
381
+ 2 'value', 将gradient限制在[-clip_value, clip_value],
382
+ 小于-clip_value的gradient被赋值为-clip_value;
383
+ 大于clip_value的gradient被赋值为clip_value.
380
384
381
385
"""
382
386
@@ -431,14 +435,13 @@ def on_exception(self, exception):
431
435
else :
432
436
raise exception # 抛出陌生Error
433
437
438
+
434
439
class FitlogCallback (Callback ):
435
440
"""
436
- 别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback`
437
-
438
441
该callback将loss和progress自动写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入
439
- 一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。
440
- 并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则
441
- fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。
442
+ 一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。
443
+ 并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则
444
+ fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。
442
445
443
446
:param DataSet,dict(DataSet) data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个
444
447
DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过
@@ -447,7 +450,9 @@ class FitlogCallback(Callback):
447
450
:param int verbose: 是否在终端打印内容,0不打印
448
451
:param bool log_exception: fitlog是否记录发生的exception信息
449
452
"""
450
-
453
+ # 还没有被导出到 fastNLP 层
454
+ # 别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback`
455
+
451
456
def __init__ (self , data = None , tester = None , verbose = 0 , log_exception = False ):
452
457
super ().__init__ ()
453
458
self .datasets = {}
@@ -460,7 +465,7 @@ def __init__(self, data=None, tester=None, verbose=0, log_exception=False):
460
465
assert 'test' not in data , "Cannot use `test` as DataSet key, when tester is passed."
461
466
setattr (tester , 'verbose' , 0 )
462
467
self .testers ['test' ] = tester
463
-
468
+
464
469
if isinstance (data , dict ):
465
470
for key , value in data .items ():
466
471
assert isinstance (value , DataSet ), f"Only DataSet object is allowed, not { type (value )} ."
@@ -470,46 +475,46 @@ def __init__(self, data=None, tester=None, verbose=0, log_exception=False):
470
475
self .datasets ['test' ] = data
471
476
else :
472
477
raise TypeError ("data receives dict[DataSet] or DataSet object." )
473
-
478
+
474
479
self .verbose = verbose
475
-
480
+
476
481
def on_train_begin (self ):
477
- if (len (self .datasets )> 0 or len (self .testers )> 0 ) and self .trainer .dev_data is None :
482
+ if (len (self .datasets ) > 0 or len (self .testers ) > 0 ) and self .trainer .dev_data is None :
478
483
raise RuntimeError ("Trainer has no dev data, you cannot pass extra data to do evaluation." )
479
-
480
- if len (self .datasets )> 0 :
484
+
485
+ if len (self .datasets ) > 0 :
481
486
for key , data in self .datasets .items ():
482
487
tester = Tester (data = data , model = self .model , batch_size = self .batch_size , metrics = self .trainer .metrics ,
483
488
verbose = 0 )
484
489
self .testers [key ] = tester
485
490
fitlog .add_progress (total_steps = self .n_steps )
486
-
491
+
487
492
def on_backward_begin (self , loss ):
488
493
fitlog .add_loss (loss .item (), name = 'loss' , step = self .step , epoch = self .epoch )
489
-
494
+
490
495
def on_valid_end (self , eval_result , metric_key , optimizer , better_result ):
491
496
if better_result :
492
497
eval_result = deepcopy (eval_result )
493
498
eval_result ['step' ] = self .step
494
499
eval_result ['epoch' ] = self .epoch
495
500
fitlog .add_best_metric (eval_result )
496
501
fitlog .add_metric (eval_result , step = self .step , epoch = self .epoch )
497
- if len (self .testers )> 0 :
502
+ if len (self .testers ) > 0 :
498
503
for key , tester in self .testers .items ():
499
504
try :
500
505
eval_result = tester .test ()
501
- if self .verbose != 0 :
506
+ if self .verbose != 0 :
502
507
self .pbar .write ("Evaluation on DataSet {}:" .format (key ))
503
508
self .pbar .write (tester ._format_eval_results (eval_result ))
504
509
fitlog .add_metric (eval_result , name = key , step = self .step , epoch = self .epoch )
505
510
if better_result :
506
511
fitlog .add_best_metric (eval_result , name = key )
507
512
except Exception :
508
513
self .pbar .write ("Exception happens when evaluate on DataSet named `{}`." .format (key ))
509
-
514
+
510
515
def on_train_end (self ):
511
516
fitlog .finish ()
512
-
517
+
513
518
def on_exception (self , exception ):
514
519
fitlog .finish (status = 1 )
515
520
if self ._log_exception :
0 commit comments