From f50eba324b0606285a555c0777909c4fd3c539cb Mon Sep 17 00:00:00 2001 From: Danilo Burbano Date: Mon, 3 Nov 2025 09:55:51 -0500 Subject: [PATCH 1/2] Adding changes to load models with complex objects --- .../nlp/ParamsAndFeaturesReadable.scala | 148 +++++++++++++++++- .../nlp/annotators/er/EntityRulerTest.scala | 23 ++- 2 files changed, 168 insertions(+), 3 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/nlp/ParamsAndFeaturesReadable.scala b/src/main/scala/com/johnsnowlabs/nlp/ParamsAndFeaturesReadable.scala index 2cb3d505da3dd5..4f389f1f89c5c9 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/ParamsAndFeaturesReadable.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/ParamsAndFeaturesReadable.scala @@ -16,8 +16,14 @@ package com.johnsnowlabs.nlp +import com.johnsnowlabs.nlp.LegacyMetadataSupport.ParamsReflection +import org.apache.hadoop.fs.Path +import org.apache.spark.internal.Logging +import org.apache.spark.ml.param.Params import org.apache.spark.ml.util.{DefaultParamsReadable, MLReader} import org.apache.spark.sql.SparkSession +import org.json4s.jackson.JsonMethods.{compact, parse, render} +import org.json4s.{DefaultFormats, JNothing, JNull, JObject, JValue} import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Success, Try} @@ -25,11 +31,20 @@ import scala.util.{Failure, Success, Try} class FeaturesReader[T <: HasFeatures]( baseReader: MLReader[T], onRead: (T, String, SparkSession) => Unit) - extends MLReader[T] { + extends MLReader[T] + with Logging { override def load(path: String): T = { - val instance = baseReader.load(path) + val instance = + try { + // Let Spark's own loader handle modern bundles. + baseReader.load(path) + } catch { + case e: NoSuchElementException if isMissingParamError(e) => + // Reconstruct legacy models that referenced params removed in newer releases. + loadWithLegacyParams(path) + } for (feature <- instance.features) { val value = feature.deserialize(sparkSession, path, feature.name) @@ -40,6 +55,59 @@ class FeaturesReader[T <: HasFeatures]( instance } + + private def isMissingParamError(e: NoSuchElementException): Boolean = { + val msg = Option(e.getMessage).getOrElse("") + msg.contains("Param") + } + + private def loadWithLegacyParams(path: String): T = { + val metadata = LegacyMetadataSupport.load(path, sparkSession) + val cls = Class.forName(metadata.className) + val ctor = cls.getConstructor(classOf[String]) + val instance = ctor.newInstance(metadata.uid).asInstanceOf[Params] + setParamsIgnoringUnknown(instance, metadata) + instance.asInstanceOf[T] + } + + private def setParamsIgnoringUnknown( + instance: Params, + metadata: LegacyMetadataSupport.Metadata): Unit = { + // Replay active params; skip mismatches so legacy bundles still come back. + assignParams(instance, metadata.params, isDefault = false, metadata) + + val hasDefaultSection = metadata.defaultParams != JNothing && metadata.defaultParams != JNull + if (hasDefaultSection) { + // If the metadata carried defaults, restore only those that still exists. + assignParams(instance, metadata.defaultParams, isDefault = true, metadata) + } + } + + private def assignParams( + instance: Params, + jsonParams: JValue, + isDefault: Boolean, + metadata: LegacyMetadataSupport.Metadata): Unit = { + jsonParams match { + case JObject(pairs) => + pairs.foreach { case (paramName, jsonValue) => + if (instance.hasParam(paramName)) { + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + if (isDefault) { + // Spark keeps setDefault protected; call it via reflection to restore legacy defaults. + ParamsReflection.setDefault(instance, param, value) + } else { + instance.set(param, value) + } + } + } + case JNothing | JNull => + case other => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata when loading legacy params for ${metadata.className}: $other") + } + } } trait ParamsAndFeaturesReadable[T <: HasFeatures] extends DefaultParamsReadable[T] { @@ -137,3 +205,79 @@ trait ParamsAndFeaturesFallbackReadable[T <: HasFeatures] extends ParamsAndFeatu override def read: MLReader[T] = new FeaturesFallbackReader(super.read, onRead, fallbackLoad) } + +// Minimal metadata parser + helper utilities for replaying legacy params. +protected object LegacyMetadataSupport { + + object ParamsReflection { + private val setDefaultMethod = { + val maybeMethod = classOf[Params].getDeclaredMethods.find { method => + method.getName == "setDefault" && method.getParameterCount == 2 + } + + maybeMethod match { + case Some(method) => + method.setAccessible(true) + method + case None => + throw new NoSuchMethodException("Params.setDefault(Param, value) not found via reflection") + } + } + + def setDefault[T]( + params: Params, + param: org.apache.spark.ml.param.Param[T], + value: T): Unit = { + setDefaultMethod.invoke(params, param, toAnyRef(value)) + } + + // Mirror JVM boxing rules so reflection can call the protected method safely. + private def toAnyRef(value: Any): AnyRef = { + if (value == null) { + null + } else { + value match { + case v: AnyRef => v + case v: Boolean => java.lang.Boolean.valueOf(v) + case v: Byte => java.lang.Byte.valueOf(v) + case v: Short => java.lang.Short.valueOf(v) + case v: Int => java.lang.Integer.valueOf(v) + case v: Long => java.lang.Long.valueOf(v) + case v: Float => java.lang.Float.valueOf(v) + case v: Double => java.lang.Double.valueOf(v) + case v: Char => java.lang.Character.valueOf(v) + case other => + throw new IllegalArgumentException( + s"Unsupported default value type ${other.getClass}") + } + } + } + } + + case class Metadata( + className: String, + uid: String, + sparkVersion: String, + params: JValue, + defaultParams: JValue, + metadataJson: String) + + def load(path: String, spark: SparkSession): Metadata = { + val metadataPath = new Path(path, "metadata").toString + val metadataStr = spark.sparkContext.textFile(metadataPath, 1).first() + parseMetadata(metadataStr) + } + + private def parseMetadata(metadataStr: String): Metadata = { + val metadata = parse(metadataStr) + implicit val format: DefaultFormats.type = DefaultFormats + + val className = (metadata \ "class").extract[String] + val uid = (metadata \ "uid").extract[String] + val sparkVersion = (metadata \ "sparkVersion").extractOpt[String].getOrElse("0.0") + val params = metadata \ "paramMap" + val defaultParams = metadata \ "defaultParamMap" + + Metadata(className, uid, sparkVersion, params, defaultParams, metadataStr) + } +} diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/er/EntityRulerTest.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/er/EntityRulerTest.scala index 28819bac14c443..ea8465dc7c2e76 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/er/EntityRulerTest.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/er/EntityRulerTest.scala @@ -21,7 +21,7 @@ import com.johnsnowlabs.nlp.annotators.SparkSessionTest import com.johnsnowlabs.nlp.annotators.er.EntityRulerFixture._ import com.johnsnowlabs.nlp.base.LightPipeline import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs} -import com.johnsnowlabs.tags.FastTest +import com.johnsnowlabs.tags.{FastTest, SlowTest} import org.apache.spark.ml.{Pipeline, PipelineModel} import org.scalatest.flatspec.AnyFlatSpec @@ -850,4 +850,25 @@ class EntityRulerTest extends AnyFlatSpec with SparkSessionTest { entityRulerPipeline } + it should "serialize EntityRulerModel" taggedAs SlowTest in { + //Should br run with Java 8 and Scala 2.12 + val entityRuler = new EntityRulerApproach() + .setInputCols("document", "token") + .setOutputCol("entities") + .setPatternsResource("src/test/resources/entity-ruler/keywords_only.json", ReadAs.TEXT) + val entityRulerModel = entityRuler.fit(emptyDataSet) + + entityRulerModel.write.overwrite().save("./tmp_entity_ruler_model_java8_scala2_12") + } + + it should "deserialize EntityRulerModel" in { + val textDataSet = Seq(text1).toDS.toDF("text") + val loadedEntityRulerModel = EntityRulerModel.load("./tmp_entity_ruler_model_java8_scala2_12") + + val pipeline = + new Pipeline().setStages(Array(documentAssembler, tokenizer, loadedEntityRulerModel)) + val resultDf = pipeline.fit(emptyDataSet).transform(textDataSet) + resultDf.select("entities").show(truncate = false) + } + } From bcb02f708ba2ce8beef60617b9c5b85deb804b2b Mon Sep 17 00:00:00 2001 From: Danilo Burbano Date: Mon, 3 Nov 2025 09:59:27 -0500 Subject: [PATCH 2/2] Remobing logging trait --- .../com/johnsnowlabs/nlp/ParamsAndFeaturesReadable.scala | 3 +-- .../com/johnsnowlabs/nlp/annotators/er/EntityRulerTest.scala | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/nlp/ParamsAndFeaturesReadable.scala b/src/main/scala/com/johnsnowlabs/nlp/ParamsAndFeaturesReadable.scala index 4f389f1f89c5c9..549e0a2b617009 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/ParamsAndFeaturesReadable.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/ParamsAndFeaturesReadable.scala @@ -31,8 +31,7 @@ import scala.util.{Failure, Success, Try} class FeaturesReader[T <: HasFeatures]( baseReader: MLReader[T], onRead: (T, String, SparkSession) => Unit) - extends MLReader[T] - with Logging { + extends MLReader[T] { override def load(path: String): T = { diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/er/EntityRulerTest.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/er/EntityRulerTest.scala index ea8465dc7c2e76..97f10493c2661f 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/er/EntityRulerTest.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/er/EntityRulerTest.scala @@ -851,7 +851,7 @@ class EntityRulerTest extends AnyFlatSpec with SparkSessionTest { } it should "serialize EntityRulerModel" taggedAs SlowTest in { - //Should br run with Java 8 and Scala 2.12 + //Should be run with Java 8 and Scala 2.12 val entityRuler = new EntityRulerApproach() .setInputCols("document", "token") .setOutputCol("entities") @@ -861,7 +861,7 @@ class EntityRulerTest extends AnyFlatSpec with SparkSessionTest { entityRulerModel.write.overwrite().save("./tmp_entity_ruler_model_java8_scala2_12") } - it should "deserialize EntityRulerModel" in { + it should "deserialize EntityRulerModel" taggedAs SlowTest in { val textDataSet = Seq(text1).toDS.toDF("text") val loadedEntityRulerModel = EntityRulerModel.load("./tmp_entity_ruler_model_java8_scala2_12")