diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/common/TokenizedWithSentence.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/common/TokenizedWithSentence.scala index 163dd884a98642..06f66f763d8cd4 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/common/TokenizedWithSentence.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/common/TokenizedWithSentence.scala @@ -26,7 +26,6 @@ object TokenizedWithSentence extends Annotated[TokenizedSentence] { val tokens = annotations .filter(_.annotatorType == annotatorType) .toArray - val sentences = SentenceSplit.unpack(annotations) /** // Evaluate whether to enable this validation to check proper usage of DOCUMENT and @@ -37,7 +36,10 @@ object TokenizedWithSentence extends Annotated[TokenizedSentence] { sentences .map(sentence => { val sentenceTokens = tokens - .filter(token => token.begin >= sentence.start & token.end <= sentence.end) + .filter(token => + token.begin >= sentence.start && + token.end <= sentence.end && + token.metadata.getOrElse("sentence", "0").toInt == sentence.index) .map(token => IndexedToken(token.result, token.begin, token.end)) sentenceTokens }) diff --git a/src/test/resources/word-embedding/test-repeated-tokens/part-00000-13a8c543-e8bc-46c9-904f-81967baf0b76-c000.snappy.parquet b/src/test/resources/word-embedding/test-repeated-tokens/part-00000-13a8c543-e8bc-46c9-904f-81967baf0b76-c000.snappy.parquet new file mode 100644 index 00000000000000..8dec94cecca379 Binary files /dev/null and b/src/test/resources/word-embedding/test-repeated-tokens/part-00000-13a8c543-e8bc-46c9-904f-81967baf0b76-c000.snappy.parquet differ diff --git a/src/test/resources/word-embedding/test-repeated-tokens/part-00001-13a8c543-e8bc-46c9-904f-81967baf0b76-c000.snappy.parquet b/src/test/resources/word-embedding/test-repeated-tokens/part-00001-13a8c543-e8bc-46c9-904f-81967baf0b76-c000.snappy.parquet new file mode 100644 index 00000000000000..1fa1fa37faffb5 Binary files /dev/null and b/src/test/resources/word-embedding/test-repeated-tokens/part-00001-13a8c543-e8bc-46c9-904f-81967baf0b76-c000.snappy.parquet differ diff --git a/src/test/scala/com/johnsnowlabs/nlp/embeddings/WordEmbeddingsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/embeddings/WordEmbeddingsTestSpec.scala index 4b0c27a3751f39..fbe863f67307a1 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/embeddings/WordEmbeddingsTestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/embeddings/WordEmbeddingsTestSpec.scala @@ -23,6 +23,7 @@ import com.johnsnowlabs.nlp.{Annotation, AssertAnnotations} import com.johnsnowlabs.tags.{FastTest, SlowTest} import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.col import org.scalatest.flatspec.AnyFlatSpec class WordEmbeddingsTestSpec extends AnyFlatSpec with SparkSessionTest { @@ -31,6 +32,34 @@ class WordEmbeddingsTestSpec extends AnyFlatSpec with SparkSessionTest { .option("header", "true") .csv("src/test/resources/embeddings/clinical_words.txt") + "Word Embeddings" should "Should not repeat tokens" taggedAs FastTest in { + + val loaded = spark.read.parquet("src/test/resources/word-embedding/test-repeated-tokens") + + val embeddings = WordEmbeddingsModel + .pretrained("glove_100d", "en") + .setInputCols(Array("splitter", "token")) + .setOutputCol("embedding") + + val pipeline = new Pipeline() + .setStages(Array(embeddings)) + + val model = pipeline.fit(loaded) + + val result = model.transform(loaded) + val duplicateBegins = result + .selectExpr("explode(embedding) as e") + .select(col("e.begin").alias("begin")) + .groupBy("begin") + .count() + .filter(col("count") > 2) + .count() + + assert( + duplicateBegins == 0, + s"Found $duplicateBegins repeated tokens (duplicate begin positions)") + } + "Word Embeddings" should "correctly embed clinical words not embed non-existent words" taggedAs SlowTest in { val notWords = spark.read