Skip to content

Commit 5cba216

Browse files
committed
Introduce NerDLDataLoader for NerDLApproach
Threaded NerDLDataLoader fetches batches in the background while training is happening in NerDLApproach, reducing idle time in the driver thread.
1 parent 5dcb0af commit 5cba216

File tree

4 files changed

+432
-28
lines changed

4 files changed

+432
-28
lines changed

src/main/scala/com/johnsnowlabs/nlp/annotators/common/Tagged.scala

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,11 @@ trait Tagged[T >: TaggedSentence <: TaggedSentence] extends Annotated[T] {
118118
row.getAs[Seq[Row]](colNum).map(obj => Annotation(obj))
119119
}
120120

121-
protected def getLabelsFromSentences(
121+
def getAnnotations(row: Row, col: String): Seq[Annotation] = {
122+
row.getAs[Seq[Row]](col).map(obj => Annotation(obj))
123+
}
124+
125+
def getLabelsFromSentences(
122126
sentences: Seq[WordpieceEmbeddingsSentence],
123127
labelAnnotations: Seq[Annotation]): Seq[TextSentenceLabels] = {
124128
val sortedLabels = labelAnnotations.sortBy(a => a.begin).toArray
@@ -203,16 +207,25 @@ object NerTagged extends Tagged[NerTaggedSentence] {
203207
dataset: Dataset[Row],
204208
sentenceCols: Seq[String],
205209
labelColumn: String,
206-
batchSize: Int): Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = {
210+
batchSize: Int,
211+
shuffleInPartition: Boolean = true)
212+
: Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = {
207213

208214
new Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] {
209215
import com.johnsnowlabs.nlp.annotators.common.DatasetHelpers._
210216

211217
// Send batches, don't collect(), only keeping a single batch in memory anytime
212-
val it: util.Iterator[Row] = dataset
213-
.select(labelColumn, sentenceCols: _*)
214-
.randomize // to improve training
215-
.toLocalIterator() // Uses as much memory as the largest partition, potentially all data if not careful
218+
val it: util.Iterator[Row] = {
219+
val selected = dataset
220+
.select(labelColumn, sentenceCols: _*)
221+
(
222+
// to improve training
223+
// NOTE: This might have implications on model performance, partitions are not shuffled
224+
if (shuffleInPartition) selected.randomize
225+
else
226+
selected
227+
).toLocalIterator() // Uses as much memory as the largest partition, potentially all data if not careful
228+
}
216229

217230
// create a batch
218231
override def next(): Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)] = {

src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLApproach.scala

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, NAMED_ENTITY, TOKEN, WORD_E
2424
import com.johnsnowlabs.nlp.annotators.common.{NerTagged, WordpieceEmbeddingsSentence}
2525
import com.johnsnowlabs.nlp.annotators.ner.{ModelMetrics, NerApproach, Verbose}
2626
import com.johnsnowlabs.nlp.annotators.param.EvaluationDLParams
27+
import com.johnsnowlabs.nlp.training.NerDLDataLoader
2728
import com.johnsnowlabs.nlp.util.io.{OutputHelper, ResourceHelper}
2829
import com.johnsnowlabs.nlp.{AnnotatorApproach, AnnotatorType, ParamsAndFeaturesWritable}
2930
import com.johnsnowlabs.storage.HasStorageRef
@@ -450,6 +451,14 @@ class NerDLApproach(override val uid: String)
450451

451452
}
452453

454+
val prefetchBatches = new IntParam(
455+
this,
456+
"prefetchBatches",
457+
"Number of batches to prefetch while training using memory optimizer. Has no effect if memory optimizer is disabled.")
458+
459+
def getPrefetchBatches: Int = $(this.prefetchBatches)
460+
def setPrefetchBatches(value: Int): this.type = set(this.prefetchBatches, value)
461+
453462
setDefault(
454463
minEpochs -> 0,
455464
maxEpochs -> 70,
@@ -462,7 +471,8 @@ class NerDLApproach(override val uid: String)
462471
includeAllConfidenceScores -> false,
463472
enableMemoryOptimizer -> false,
464473
useBestModel -> false,
465-
bestModelMetric -> ModelMetrics.loss)
474+
bestModelMetric -> ModelMetrics.loss,
475+
prefetchBatches -> 0)
466476

467477
override val verboseLevel: Verbose.Level = Verbose($(verbose))
468478

@@ -485,6 +495,24 @@ class NerDLApproach(override val uid: String)
485495
$(validationSplit) <= 1f | $(validationSplit) >= 0f,
486496
"The validationSplit must be between 0f and 1f")
487497

498+
def getIteratorFunc(split: Dataset[Row]) = if (!getEnableMemoryOptimizer) {
499+
// No memory optimizer
500+
NerDLApproach.getIteratorFunc(
501+
split,
502+
inputColumns = getInputCols,
503+
labelColumn = $(labelColumn),
504+
batchSize = $(batchSize),
505+
enableMemoryOptimizer = $(enableMemoryOptimizer))
506+
} else {
507+
logger.info(s"Using memory optimizer with $prefetchBatches prefetch batches.")
508+
NerDLApproach.getIteratorFunc(
509+
split,
510+
inputColumns = getInputCols,
511+
labelColumn = $(labelColumn),
512+
batchSize = $(batchSize),
513+
prefetchBatches = getPrefetchBatches)
514+
}
515+
488516
val embeddingsRef =
489517
HasStorageRef.getStorageRefFromInput(dataset, $(inputCols), AnnotatorType.WORD_EMBEDDINGS)
490518

@@ -506,26 +534,10 @@ class NerDLApproach(override val uid: String)
506534
(cacheIfNeeded(trainSplit), cacheIfNeeded(validSplit), cacheIfNeeded(test))
507535
}
508536

509-
val trainIteratorFunc = NerDLApproach.getIteratorFunc(
510-
trainSplit,
511-
inputColumns = getInputCols,
512-
labelColumn = $(labelColumn),
513-
batchSize = $(batchSize),
514-
enableMemoryOptimizer = $(enableMemoryOptimizer))
515-
516-
val validIteratorFunc = NerDLApproach.getIteratorFunc(
517-
validSplit,
518-
inputColumns = getInputCols,
519-
labelColumn = $(labelColumn),
520-
batchSize = $(batchSize),
521-
enableMemoryOptimizer = $(enableMemoryOptimizer))
522-
523-
val testIteratorFunc = NerDLApproach.getIteratorFunc(
524-
test,
525-
inputColumns = getInputCols,
526-
labelColumn = $(labelColumn),
527-
batchSize = $(batchSize),
528-
enableMemoryOptimizer = $(enableMemoryOptimizer))
537+
// TODO DHA: Better way to do this?
538+
val trainIteratorFunc = getIteratorFunc(trainSplit)
539+
val validIteratorFunc = getIteratorFunc(validSplit)
540+
val testIteratorFunc = getIteratorFunc(test)
529541

530542
val (
531543
labels: mutable.Set[AnnotatorType],
@@ -752,8 +764,9 @@ object NerDLApproach extends DefaultParamsReadable[NerDLApproach] with WithGraph
752764
: () => Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = {
753765

754766
if (enableMemoryOptimizer) { () =>
767+
// Old implementation, kept for backward compatibility but won't be called from NerDLApproach.train
768+
// NerDLDataLoader will be used with memory optimizer
755769
NerTagged.iterateOnDataframe(dataset, inputColumns, labelColumn, batchSize)
756-
757770
} else {
758771
val inMemory = dataset
759772
.select(labelColumn, inputColumns.toSeq: _*)
@@ -763,6 +776,21 @@ object NerDLApproach extends DefaultParamsReadable[NerDLApproach] with WithGraph
763776
}
764777
}
765778

779+
def getIteratorFunc(
780+
dataset: Dataset[Row],
781+
inputColumns: Array[String],
782+
labelColumn: String,
783+
batchSize: Int,
784+
prefetchBatches: Int)
785+
: () => Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = { () =>
786+
NerDLDataLoader.iterateOnDataframe(
787+
dataset = dataset,
788+
inputColumns = inputColumns,
789+
labelColumn = labelColumn,
790+
batchSize = batchSize,
791+
prefetchBatches = prefetchBatches)
792+
}
793+
766794
def getDataSetParams(dsIt: Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]])
767795
: (mutable.Set[String], mutable.Set[Char], Int, Long) = {
768796

0 commit comments

Comments
 (0)