Skip to content

Commit f50eba3

Browse files
danilojslDevinTDHa
authored andcommitted
Adding changes to load models with complex objects
1 parent 5d0f4d0 commit f50eba3

File tree

2 files changed

+168
-3
lines changed

2 files changed

+168
-3
lines changed

src/main/scala/com/johnsnowlabs/nlp/ParamsAndFeaturesReadable.scala

Lines changed: 146 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,35 @@
1616

1717
package com.johnsnowlabs.nlp
1818

19+
import com.johnsnowlabs.nlp.LegacyMetadataSupport.ParamsReflection
20+
import org.apache.hadoop.fs.Path
21+
import org.apache.spark.internal.Logging
22+
import org.apache.spark.ml.param.Params
1923
import org.apache.spark.ml.util.{DefaultParamsReadable, MLReader}
2024
import org.apache.spark.sql.SparkSession
25+
import org.json4s.jackson.JsonMethods.{compact, parse, render}
26+
import org.json4s.{DefaultFormats, JNothing, JNull, JObject, JValue}
2127

2228
import scala.collection.mutable.ArrayBuffer
2329
import scala.util.{Failure, Success, Try}
2430

2531
class FeaturesReader[T <: HasFeatures](
2632
baseReader: MLReader[T],
2733
onRead: (T, String, SparkSession) => Unit)
28-
extends MLReader[T] {
34+
extends MLReader[T]
35+
with Logging {
2936

3037
override def load(path: String): T = {
3138

32-
val instance = baseReader.load(path)
39+
val instance =
40+
try {
41+
// Let Spark's own loader handle modern bundles.
42+
baseReader.load(path)
43+
} catch {
44+
case e: NoSuchElementException if isMissingParamError(e) =>
45+
// Reconstruct legacy models that referenced params removed in newer releases.
46+
loadWithLegacyParams(path)
47+
}
3348

3449
for (feature <- instance.features) {
3550
val value = feature.deserialize(sparkSession, path, feature.name)
@@ -40,6 +55,59 @@ class FeaturesReader[T <: HasFeatures](
4055

4156
instance
4257
}
58+
59+
private def isMissingParamError(e: NoSuchElementException): Boolean = {
60+
val msg = Option(e.getMessage).getOrElse("")
61+
msg.contains("Param")
62+
}
63+
64+
private def loadWithLegacyParams(path: String): T = {
65+
val metadata = LegacyMetadataSupport.load(path, sparkSession)
66+
val cls = Class.forName(metadata.className)
67+
val ctor = cls.getConstructor(classOf[String])
68+
val instance = ctor.newInstance(metadata.uid).asInstanceOf[Params]
69+
setParamsIgnoringUnknown(instance, metadata)
70+
instance.asInstanceOf[T]
71+
}
72+
73+
private def setParamsIgnoringUnknown(
74+
instance: Params,
75+
metadata: LegacyMetadataSupport.Metadata): Unit = {
76+
// Replay active params; skip mismatches so legacy bundles still come back.
77+
assignParams(instance, metadata.params, isDefault = false, metadata)
78+
79+
val hasDefaultSection = metadata.defaultParams != JNothing && metadata.defaultParams != JNull
80+
if (hasDefaultSection) {
81+
// If the metadata carried defaults, restore only those that still exists.
82+
assignParams(instance, metadata.defaultParams, isDefault = true, metadata)
83+
}
84+
}
85+
86+
private def assignParams(
87+
instance: Params,
88+
jsonParams: JValue,
89+
isDefault: Boolean,
90+
metadata: LegacyMetadataSupport.Metadata): Unit = {
91+
jsonParams match {
92+
case JObject(pairs) =>
93+
pairs.foreach { case (paramName, jsonValue) =>
94+
if (instance.hasParam(paramName)) {
95+
val param = instance.getParam(paramName)
96+
val value = param.jsonDecode(compact(render(jsonValue)))
97+
if (isDefault) {
98+
// Spark keeps setDefault protected; call it via reflection to restore legacy defaults.
99+
ParamsReflection.setDefault(instance, param, value)
100+
} else {
101+
instance.set(param, value)
102+
}
103+
}
104+
}
105+
case JNothing | JNull =>
106+
case other =>
107+
throw new IllegalArgumentException(
108+
s"Cannot recognize JSON metadata when loading legacy params for ${metadata.className}: $other")
109+
}
110+
}
43111
}
44112

45113
trait ParamsAndFeaturesReadable[T <: HasFeatures] extends DefaultParamsReadable[T] {
@@ -137,3 +205,79 @@ trait ParamsAndFeaturesFallbackReadable[T <: HasFeatures] extends ParamsAndFeatu
137205

138206
override def read: MLReader[T] = new FeaturesFallbackReader(super.read, onRead, fallbackLoad)
139207
}
208+
209+
// Minimal metadata parser + helper utilities for replaying legacy params.
210+
protected object LegacyMetadataSupport {
211+
212+
object ParamsReflection {
213+
private val setDefaultMethod = {
214+
val maybeMethod = classOf[Params].getDeclaredMethods.find { method =>
215+
method.getName == "setDefault" && method.getParameterCount == 2
216+
}
217+
218+
maybeMethod match {
219+
case Some(method) =>
220+
method.setAccessible(true)
221+
method
222+
case None =>
223+
throw new NoSuchMethodException("Params.setDefault(Param, value) not found via reflection")
224+
}
225+
}
226+
227+
def setDefault[T](
228+
params: Params,
229+
param: org.apache.spark.ml.param.Param[T],
230+
value: T): Unit = {
231+
setDefaultMethod.invoke(params, param, toAnyRef(value))
232+
}
233+
234+
// Mirror JVM boxing rules so reflection can call the protected method safely.
235+
private def toAnyRef(value: Any): AnyRef = {
236+
if (value == null) {
237+
null
238+
} else {
239+
value match {
240+
case v: AnyRef => v
241+
case v: Boolean => java.lang.Boolean.valueOf(v)
242+
case v: Byte => java.lang.Byte.valueOf(v)
243+
case v: Short => java.lang.Short.valueOf(v)
244+
case v: Int => java.lang.Integer.valueOf(v)
245+
case v: Long => java.lang.Long.valueOf(v)
246+
case v: Float => java.lang.Float.valueOf(v)
247+
case v: Double => java.lang.Double.valueOf(v)
248+
case v: Char => java.lang.Character.valueOf(v)
249+
case other =>
250+
throw new IllegalArgumentException(
251+
s"Unsupported default value type ${other.getClass}")
252+
}
253+
}
254+
}
255+
}
256+
257+
case class Metadata(
258+
className: String,
259+
uid: String,
260+
sparkVersion: String,
261+
params: JValue,
262+
defaultParams: JValue,
263+
metadataJson: String)
264+
265+
def load(path: String, spark: SparkSession): Metadata = {
266+
val metadataPath = new Path(path, "metadata").toString
267+
val metadataStr = spark.sparkContext.textFile(metadataPath, 1).first()
268+
parseMetadata(metadataStr)
269+
}
270+
271+
private def parseMetadata(metadataStr: String): Metadata = {
272+
val metadata = parse(metadataStr)
273+
implicit val format: DefaultFormats.type = DefaultFormats
274+
275+
val className = (metadata \ "class").extract[String]
276+
val uid = (metadata \ "uid").extract[String]
277+
val sparkVersion = (metadata \ "sparkVersion").extractOpt[String].getOrElse("0.0")
278+
val params = metadata \ "paramMap"
279+
val defaultParams = metadata \ "defaultParamMap"
280+
281+
Metadata(className, uid, sparkVersion, params, defaultParams, metadataStr)
282+
}
283+
}

src/test/scala/com/johnsnowlabs/nlp/annotators/er/EntityRulerTest.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import com.johnsnowlabs.nlp.annotators.SparkSessionTest
2121
import com.johnsnowlabs.nlp.annotators.er.EntityRulerFixture._
2222
import com.johnsnowlabs.nlp.base.LightPipeline
2323
import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs}
24-
import com.johnsnowlabs.tags.FastTest
24+
import com.johnsnowlabs.tags.{FastTest, SlowTest}
2525
import org.apache.spark.ml.{Pipeline, PipelineModel}
2626
import org.scalatest.flatspec.AnyFlatSpec
2727

@@ -850,4 +850,25 @@ class EntityRulerTest extends AnyFlatSpec with SparkSessionTest {
850850
entityRulerPipeline
851851
}
852852

853+
it should "serialize EntityRulerModel" taggedAs SlowTest in {
854+
//Should br run with Java 8 and Scala 2.12
855+
val entityRuler = new EntityRulerApproach()
856+
.setInputCols("document", "token")
857+
.setOutputCol("entities")
858+
.setPatternsResource("src/test/resources/entity-ruler/keywords_only.json", ReadAs.TEXT)
859+
val entityRulerModel = entityRuler.fit(emptyDataSet)
860+
861+
entityRulerModel.write.overwrite().save("./tmp_entity_ruler_model_java8_scala2_12")
862+
}
863+
864+
it should "deserialize EntityRulerModel" in {
865+
val textDataSet = Seq(text1).toDS.toDF("text")
866+
val loadedEntityRulerModel = EntityRulerModel.load("./tmp_entity_ruler_model_java8_scala2_12")
867+
868+
val pipeline =
869+
new Pipeline().setStages(Array(documentAssembler, tokenizer, loadedEntityRulerModel))
870+
val resultDf = pipeline.fit(emptyDataSet).transform(textDataSet)
871+
resultDf.select("entities").show(truncate = false)
872+
}
873+
853874
}

0 commit comments

Comments
 (0)