Skip to content

Commit a6ecb9f

Browse files
authored
[SPARKNLP-1315] changing input data type for CamemBertForTokenClassific… (#14701)
* SPARKNLP-1315 changing input data type for CamemBertForTokenClassification from int 64 to 32 * SPARKNLP-1315 adding test for tensorflow models
1 parent 5a43dfc commit a6ecb9f

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

src/main/scala/com/johnsnowlabs/ml/ai/CamemBertClassification.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,28 +164,27 @@ private[johnsnowlabs] class CamemBertClassification(
164164
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
165165
val batchLength = batch.length
166166

167-
val tokenBuffers: LongDataBuffer = tensors.createLongBuffer(batchLength * maxSentenceLength)
168-
val maskBuffers: LongDataBuffer = tensors.createLongBuffer(batchLength * maxSentenceLength)
167+
val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
168+
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
169169

170170
// [nb of encoded sentences , maxSentenceLength]
171171
val shape = Array(batch.length.toLong, maxSentenceLength)
172172

173173
batch.zipWithIndex
174174
.foreach { case (sentence, idx) =>
175-
val sentenceLong = sentence.map(x => x.toLong)
176175
val offset = idx * maxSentenceLength
177-
tokenBuffers.offset(offset).write(sentenceLong)
176+
tokenBuffers.offset(offset).write(sentence)
178177
maskBuffers
179178
.offset(offset)
180-
.write(sentence.map(x => if (x == sentencePadTokenId) 0L else 1L))
179+
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
181180
}
182181

183182
val runner = tensorflowWrapper.get
184183
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
185184
.runner
186185

187-
val tokenTensors = tensors.createLongBufferTensor(shape, tokenBuffers)
188-
val maskTensors = tensors.createLongBufferTensor(shape, maskBuffers)
186+
val tokenTensors = tensors.createIntBufferTensor(shape, tokenBuffers)
187+
val maskTensors = tensors.createIntBufferTensor(shape, maskBuffers)
189188

190189
runner
191190
.feed(

src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassificationTestSpec.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,37 @@ class CamemBertForTokenClassificationTestSpec extends AnyFlatSpec {
157157

158158
assert(totalTokens == totalTags)
159159
}
160+
161+
162+
"CamemBertForTokenClassification" should "work with tensorflow models" taggedAs SlowTest in {
163+
164+
val tokenClassifier: CamemBertForTokenClassification = CamemBertForTokenClassification
165+
.pretrained("camembert_classifier_base_wikipedia_4gb_finetuned_job_ner")
166+
.setInputCols(Array("token", "document"))
167+
.setOutputCol("ner")
168+
.setCaseSensitive(true)
169+
.setMaxSentenceLength(512)
170+
171+
val pipeline = new Pipeline().setStages(Array(document, tokenizer, tokenClassifier))
172+
173+
val pipelineModel = pipeline.fit(ddd)
174+
val pipelineDF = pipelineModel.transform(ddd)
175+
176+
pipelineDF.select("token.result").show(false)
177+
pipelineDF.select("ner.result").show(false)
178+
pipelineDF
179+
.withColumn("token_size", size(col("token")))
180+
.withColumn("ner_size", size(col("ner")))
181+
.where(col("token_size") =!= col("ner_size"))
182+
.select("token_size", "ner_size", "token.result", "ner.result")
183+
.show(false)
184+
185+
val totalTokens = pipelineDF.select(explode($"token.result")).count.toInt
186+
val totalEmbeddings = pipelineDF.select(explode($"ner.result")).count.toInt
187+
188+
println(s"total tokens: $totalTokens")
189+
println(s"total embeddings: $totalEmbeddings")
190+
191+
}
192+
160193
}

0 commit comments

Comments
 (0)