@@ -24,6 +24,7 @@ import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, NAMED_ENTITY, TOKEN, WORD_E
2424import com .johnsnowlabs .nlp .annotators .common .{NerTagged , WordpieceEmbeddingsSentence }
2525import com .johnsnowlabs .nlp .annotators .ner .{ModelMetrics , NerApproach , Verbose }
2626import com .johnsnowlabs .nlp .annotators .param .EvaluationDLParams
27+ import com .johnsnowlabs .nlp .training .NerDLDataLoader
2728import com .johnsnowlabs .nlp .util .io .{OutputHelper , ResourceHelper }
2829import com .johnsnowlabs .nlp .{AnnotatorApproach , AnnotatorType , ParamsAndFeaturesWritable }
2930import com .johnsnowlabs .storage .HasStorageRef
@@ -450,6 +451,14 @@ class NerDLApproach(override val uid: String)
450451
451452 }
452453
454+ val prefetchBatches = new IntParam (
455+ this ,
456+ " prefetchBatches" ,
457+ " Number of batches to prefetch while training using memory optimizer. Has no effect if memory optimizer is disabled." )
458+
459+ def getPrefetchBatches : Int = $(this .prefetchBatches)
460+ def setPrefetchBatches (value : Int ): this .type = set(this .prefetchBatches, value)
461+
453462 setDefault(
454463 minEpochs -> 0 ,
455464 maxEpochs -> 70 ,
@@ -462,7 +471,8 @@ class NerDLApproach(override val uid: String)
462471 includeAllConfidenceScores -> false ,
463472 enableMemoryOptimizer -> false ,
464473 useBestModel -> false ,
465- bestModelMetric -> ModelMetrics .loss)
474+ bestModelMetric -> ModelMetrics .loss,
475+ prefetchBatches -> 0 )
466476
467477 override val verboseLevel : Verbose .Level = Verbose ($(verbose))
468478
@@ -485,6 +495,24 @@ class NerDLApproach(override val uid: String)
485495 $(validationSplit) <= 1f | $(validationSplit) >= 0f ,
486496 " The validationSplit must be between 0f and 1f" )
487497
498+ def getIteratorFunc (split : Dataset [Row ]) = if (! getEnableMemoryOptimizer) {
499+ // No memory optimizer
500+ NerDLApproach .getIteratorFunc(
501+ split,
502+ inputColumns = getInputCols,
503+ labelColumn = $(labelColumn),
504+ batchSize = $(batchSize),
505+ enableMemoryOptimizer = $(enableMemoryOptimizer))
506+ } else {
507+ logger.info(s " Using memory optimizer with $prefetchBatches prefetch batches. " )
508+ NerDLApproach .getIteratorFunc(
509+ split,
510+ inputColumns = getInputCols,
511+ labelColumn = $(labelColumn),
512+ batchSize = $(batchSize),
513+ prefetchBatches = getPrefetchBatches)
514+ }
515+
488516 val embeddingsRef =
489517 HasStorageRef .getStorageRefFromInput(dataset, $(inputCols), AnnotatorType .WORD_EMBEDDINGS )
490518
@@ -506,26 +534,10 @@ class NerDLApproach(override val uid: String)
506534 (cacheIfNeeded(trainSplit), cacheIfNeeded(validSplit), cacheIfNeeded(test))
507535 }
508536
509- val trainIteratorFunc = NerDLApproach .getIteratorFunc(
510- trainSplit,
511- inputColumns = getInputCols,
512- labelColumn = $(labelColumn),
513- batchSize = $(batchSize),
514- enableMemoryOptimizer = $(enableMemoryOptimizer))
515-
516- val validIteratorFunc = NerDLApproach .getIteratorFunc(
517- validSplit,
518- inputColumns = getInputCols,
519- labelColumn = $(labelColumn),
520- batchSize = $(batchSize),
521- enableMemoryOptimizer = $(enableMemoryOptimizer))
522-
523- val testIteratorFunc = NerDLApproach .getIteratorFunc(
524- test,
525- inputColumns = getInputCols,
526- labelColumn = $(labelColumn),
527- batchSize = $(batchSize),
528- enableMemoryOptimizer = $(enableMemoryOptimizer))
537+ // TODO DHA: Better way to do this?
538+ val trainIteratorFunc = getIteratorFunc(trainSplit)
539+ val validIteratorFunc = getIteratorFunc(validSplit)
540+ val testIteratorFunc = getIteratorFunc(test)
529541
530542 val (
531543 labels : mutable.Set [AnnotatorType ],
@@ -752,8 +764,9 @@ object NerDLApproach extends DefaultParamsReadable[NerDLApproach] with WithGraph
752764 : () => Iterator [Array [(TextSentenceLabels , WordpieceEmbeddingsSentence )]] = {
753765
754766 if (enableMemoryOptimizer) { () =>
767+ // Old implementation, kept for backward compatibility but won't be called from NerDLApproach.train
768+ // NerDLDataLoader will be used with memory optimizer
755769 NerTagged .iterateOnDataframe(dataset, inputColumns, labelColumn, batchSize)
756-
757770 } else {
758771 val inMemory = dataset
759772 .select(labelColumn, inputColumns.toSeq: _* )
@@ -763,6 +776,21 @@ object NerDLApproach extends DefaultParamsReadable[NerDLApproach] with WithGraph
763776 }
764777 }
765778
779+ def getIteratorFunc (
780+ dataset : Dataset [Row ],
781+ inputColumns : Array [String ],
782+ labelColumn : String ,
783+ batchSize : Int ,
784+ prefetchBatches : Int )
785+ : () => Iterator [Array [(TextSentenceLabels , WordpieceEmbeddingsSentence )]] = { () =>
786+ NerDLDataLoader .iterateOnDataframe(
787+ dataset = dataset,
788+ inputColumns = inputColumns,
789+ labelColumn = labelColumn,
790+ batchSize = batchSize,
791+ prefetchBatches = prefetchBatches)
792+ }
793+
766794 def getDataSetParams (dsIt : Iterator [Array [(TextSentenceLabels , WordpieceEmbeddingsSentence )]])
767795 : (mutable.Set [String ], mutable.Set [Char ], Int , Long ) = {
768796
0 commit comments