1616
1717package com .johnsnowlabs .nlp
1818
19+ import com .johnsnowlabs .nlp .LegacyMetadataSupport .ParamsReflection
20+ import org .apache .hadoop .fs .Path
21+ import org .apache .spark .internal .Logging
22+ import org .apache .spark .ml .param .Params
1923import org .apache .spark .ml .util .{DefaultParamsReadable , MLReader }
2024import org .apache .spark .sql .SparkSession
25+ import org .json4s .jackson .JsonMethods .{compact , parse , render }
26+ import org .json4s .{DefaultFormats , JNothing , JNull , JObject , JValue }
2127
2228import scala .collection .mutable .ArrayBuffer
2329import scala .util .{Failure , Success , Try }
2430
2531class FeaturesReader [T <: HasFeatures ](
2632 baseReader : MLReader [T ],
2733 onRead : (T , String , SparkSession ) => Unit )
28- extends MLReader [T ] {
34+ extends MLReader [T ]
35+ with Logging {
2936
3037 override def load (path : String ): T = {
3138
32- val instance = baseReader.load(path)
39+ val instance =
40+ try {
41+ // Let Spark's own loader handle modern bundles.
42+ baseReader.load(path)
43+ } catch {
44+ case e : NoSuchElementException if isMissingParamError(e) =>
45+ // Reconstruct legacy models that referenced params removed in newer releases.
46+ loadWithLegacyParams(path)
47+ }
3348
3449 for (feature <- instance.features) {
3550 val value = feature.deserialize(sparkSession, path, feature.name)
@@ -40,6 +55,59 @@ class FeaturesReader[T <: HasFeatures](
4055
4156 instance
4257 }
58+
59+ private def isMissingParamError (e : NoSuchElementException ): Boolean = {
60+ val msg = Option (e.getMessage).getOrElse(" " )
61+ msg.contains(" Param" )
62+ }
63+
64+ private def loadWithLegacyParams (path : String ): T = {
65+ val metadata = LegacyMetadataSupport .load(path, sparkSession)
66+ val cls = Class .forName(metadata.className)
67+ val ctor = cls.getConstructor(classOf [String ])
68+ val instance = ctor.newInstance(metadata.uid).asInstanceOf [Params ]
69+ setParamsIgnoringUnknown(instance, metadata)
70+ instance.asInstanceOf [T ]
71+ }
72+
73+ private def setParamsIgnoringUnknown (
74+ instance : Params ,
75+ metadata : LegacyMetadataSupport .Metadata ): Unit = {
76+ // Replay active params; skip mismatches so legacy bundles still come back.
77+ assignParams(instance, metadata.params, isDefault = false , metadata)
78+
79+ val hasDefaultSection = metadata.defaultParams != JNothing && metadata.defaultParams != JNull
80+ if (hasDefaultSection) {
81+ // If the metadata carried defaults, restore only those that still exists.
82+ assignParams(instance, metadata.defaultParams, isDefault = true , metadata)
83+ }
84+ }
85+
86+ private def assignParams (
87+ instance : Params ,
88+ jsonParams : JValue ,
89+ isDefault : Boolean ,
90+ metadata : LegacyMetadataSupport .Metadata ): Unit = {
91+ jsonParams match {
92+ case JObject (pairs) =>
93+ pairs.foreach { case (paramName, jsonValue) =>
94+ if (instance.hasParam(paramName)) {
95+ val param = instance.getParam(paramName)
96+ val value = param.jsonDecode(compact(render(jsonValue)))
97+ if (isDefault) {
98+ // Spark keeps setDefault protected; call it via reflection to restore legacy defaults.
99+ ParamsReflection .setDefault(instance, param, value)
100+ } else {
101+ instance.set(param, value)
102+ }
103+ }
104+ }
105+ case JNothing | JNull =>
106+ case other =>
107+ throw new IllegalArgumentException (
108+ s " Cannot recognize JSON metadata when loading legacy params for ${metadata.className}: $other" )
109+ }
110+ }
43111}
44112
45113trait ParamsAndFeaturesReadable [T <: HasFeatures ] extends DefaultParamsReadable [T ] {
@@ -137,3 +205,79 @@ trait ParamsAndFeaturesFallbackReadable[T <: HasFeatures] extends ParamsAndFeatu
137205
138206 override def read : MLReader [T ] = new FeaturesFallbackReader (super .read, onRead, fallbackLoad)
139207}
208+
209+ // Minimal metadata parser + helper utilities for replaying legacy params.
210+ protected object LegacyMetadataSupport {
211+
212+ object ParamsReflection {
213+ private val setDefaultMethod = {
214+ val maybeMethod = classOf [Params ].getDeclaredMethods.find { method =>
215+ method.getName == " setDefault" && method.getParameterCount == 2
216+ }
217+
218+ maybeMethod match {
219+ case Some (method) =>
220+ method.setAccessible(true )
221+ method
222+ case None =>
223+ throw new NoSuchMethodException (" Params.setDefault(Param, value) not found via reflection" )
224+ }
225+ }
226+
227+ def setDefault [T ](
228+ params : Params ,
229+ param : org.apache.spark.ml.param.Param [T ],
230+ value : T ): Unit = {
231+ setDefaultMethod.invoke(params, param, toAnyRef(value))
232+ }
233+
234+ // Mirror JVM boxing rules so reflection can call the protected method safely.
235+ private def toAnyRef (value : Any ): AnyRef = {
236+ if (value == null ) {
237+ null
238+ } else {
239+ value match {
240+ case v : AnyRef => v
241+ case v : Boolean => java.lang.Boolean .valueOf(v)
242+ case v : Byte => java.lang.Byte .valueOf(v)
243+ case v : Short => java.lang.Short .valueOf(v)
244+ case v : Int => java.lang.Integer .valueOf(v)
245+ case v : Long => java.lang.Long .valueOf(v)
246+ case v : Float => java.lang.Float .valueOf(v)
247+ case v : Double => java.lang.Double .valueOf(v)
248+ case v : Char => java.lang.Character .valueOf(v)
249+ case other =>
250+ throw new IllegalArgumentException (
251+ s " Unsupported default value type ${other.getClass}" )
252+ }
253+ }
254+ }
255+ }
256+
257+ case class Metadata (
258+ className : String ,
259+ uid : String ,
260+ sparkVersion : String ,
261+ params : JValue ,
262+ defaultParams : JValue ,
263+ metadataJson : String )
264+
265+ def load (path : String , spark : SparkSession ): Metadata = {
266+ val metadataPath = new Path (path, " metadata" ).toString
267+ val metadataStr = spark.sparkContext.textFile(metadataPath, 1 ).first()
268+ parseMetadata(metadataStr)
269+ }
270+
271+ private def parseMetadata (metadataStr : String ): Metadata = {
272+ val metadata = parse(metadataStr)
273+ implicit val format : DefaultFormats .type = DefaultFormats
274+
275+ val className = (metadata \ " class" ).extract[String ]
276+ val uid = (metadata \ " uid" ).extract[String ]
277+ val sparkVersion = (metadata \ " sparkVersion" ).extractOpt[String ].getOrElse(" 0.0" )
278+ val params = metadata \ " paramMap"
279+ val defaultParams = metadata \ " defaultParamMap"
280+
281+ Metadata (className, uid, sparkVersion, params, defaultParams, metadataStr)
282+ }
283+ }
0 commit comments