11#!/usr/bin/env python
22
3- # Copyright (c) 2024 Oracle and/or its affiliates.
3+ # Copyright (c) 2025 Oracle and/or its affiliates.
44# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55
6+ import logging
7+ import os
8+ from pathlib import Path
69from typing import Dict , Optional
710
811from ads .model .extractor .embedding_onnx_extractor import EmbeddingONNXExtractor
912from ads .model .generic_model import FrameworkSpecificModel
1013
14+ logger = logging .getLogger (__name__ )
15+
16+ CONFIG = "config.json"
17+ TOKENIZERS = [
18+ "tokenizer.json" ,
19+ "tokenizer_config.json" ,
20+ "spiece.model" ,
21+ "vocab.txt" ,
22+ "vocab.json" ,
23+ ]
24+
1125
1226class EmbeddingONNXModel (FrameworkSpecificModel ):
1327 """EmbeddingONNXModel class for embedding onnx model.
@@ -18,6 +32,12 @@ class EmbeddingONNXModel(FrameworkSpecificModel):
1832 The algorithm of the model.
1933 artifact_dir: str
2034 Artifact directory to store the files needed for deployment.
35+ model_file_name: str
36+ Path to the model artifact.
37+ config_json: str
38+ Path to the config.json file.
39+ tokenizer_dir: str
40+ Path to the tokenizer directory.
2141 auth: Dict
2242 Default authentication is set using the `ads.set_auth` API. To override the
2343 default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create
@@ -166,6 +186,9 @@ class EmbeddingONNXModel(FrameworkSpecificModel):
166186 def __init__ (
167187 self ,
168188 artifact_dir : Optional [str ] = None ,
189+ model_file_name : Optional [str ] = None ,
190+ config_json : Optional [str ] = None ,
191+ tokenizer_dir : Optional [str ] = None ,
169192 auth : Optional [Dict ] = None ,
170193 serialize : bool = False ,
171194 ** kwargs : dict ,
@@ -175,8 +198,14 @@ def __init__(
175198
176199 Parameters
177200 ----------
178- artifact_dir: str
201+ artifact_dir: ( str, optional). Defaults to None.
179202 Directory for generate artifact.
203+ model_file_name: (str, optional). Defaults to None.
204+ Path to the model artifact.
205+ config_json: (str, optional). Defaults to None.
206+ Path to the config.json file.
207+ tokenizer_dir: (str, optional). Defaults to None.
208+ Path to the tokenizer directory.
180209 auth: (Dict, optional). Defaults to None.
181210 The default authetication is set using `ads.set_auth` API. If you need to override the
182211 default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
@@ -260,12 +289,63 @@ def __init__(
260289 ** kwargs ,
261290 )
262291
292+ self ._validate_artifact_directory (
293+ model_file_name = model_file_name ,
294+ config_json = config_json ,
295+ tokenizer_dir = tokenizer_dir ,
296+ )
297+
263298 self ._extractor = EmbeddingONNXExtractor ()
264299 self .framework = self ._extractor .framework
265300 self .algorithm = self ._extractor .algorithm
266301 self .version = self ._extractor .version
267302 self .hyperparameter = self ._extractor .hyperparameter
268303
304+ def _validate_artifact_directory (
305+ self ,
306+ model_file_name : str = None ,
307+ config_json : str = None ,
308+ tokenizer_dir : str = None ,
309+ ):
310+ artifacts = []
311+ for _ , _ , files in os .walk (self .artifact_dir ):
312+ artifacts .extend (files )
313+
314+ if not artifacts :
315+ raise ValueError (
316+ f"No files found in { self .artifact_dir } . Specify a valid `artifact_dir`."
317+ )
318+
319+ if not model_file_name :
320+ has_model_file = False
321+ for artifact in artifacts :
322+ if Path (artifact ).suffix .lstrip ("." ).lower () == "onnx" :
323+ has_model_file = True
324+ break
325+
326+ if not has_model_file :
327+ raise ValueError (
328+ f"No onnx model found in { self .artifact_dir } . Specify a valid `artifact_dir` or `model_file_name`."
329+ )
330+
331+ if not config_json :
332+ if CONFIG not in artifacts :
333+ logger .warning (
334+ f"No { CONFIG } found in { self .artifact_dir } . Specify a valid `artifact_dir` or `config_json`."
335+ )
336+
337+ if not tokenizer_dir :
338+ has_tokenizer = False
339+ for artifact in artifacts :
340+ if artifact in TOKENIZERS :
341+ has_tokenizer = True
342+ break
343+
344+ if not has_tokenizer :
345+ logger .warning (
346+ f"No tokenizer found in { self .artifact_dir } . Specify a valid `artifact_dir` or `tokenizer_dir`."
347+ )
348+
269349 def verify (
270350 self , data = None , reload_artifacts = True , auto_serialize_data = False , ** kwargs
271351 ):
0 commit comments