|
22 | 22 | HuggingFace, |
23 | 23 | Ollama, |
24 | 24 | OpenAI, |
| 25 | + OneApi |
25 | 26 | ) |
26 | 27 | from ..utils.logging import set_verbosity_debug, set_verbosity_warning |
27 | 28 |
|
@@ -55,19 +56,20 @@ class AbstractGraph(ABC): |
55 | 56 | ... # Implementation of graph creation here |
56 | 57 | ... return graph |
57 | 58 | ... |
58 | | - >>> my_graph = MyGraph("Example Graph", {"llm": {"model": "gpt-3.5-turbo"}}, "example_source") |
| 59 | + >>> my_graph = MyGraph("Example Graph", |
| 60 | + {"llm": {"model": "gpt-3.5-turbo"}}, "example_source") |
59 | 61 | >>> result = my_graph.run() |
60 | 62 | """ |
61 | 63 |
|
62 | | - def __init__(self, prompt: str, config: dict, source: Optional[str] = None, schema: Optional[str] = None): |
| 64 | + def __init__(self, prompt: str, config: dict, |
| 65 | + source: Optional[str] = None, schema: Optional[str] = None): |
63 | 66 |
|
64 | 67 | self.prompt = prompt |
65 | 68 | self.source = source |
66 | 69 | self.config = config |
67 | 70 | self.schema = schema |
68 | 71 | self.llm_model = self._create_llm(config["llm"], chat=True) |
69 | | - self.embedder_model = self._create_default_embedder(llm_config=config["llm"] |
70 | | - ) if "embeddings" not in config else self._create_embedder( |
| 72 | + self.embedder_model = self._create_default_embedder(llm_config=config["llm"] ) if "embeddings" not in config else self._create_embedder( |
71 | 73 | config["embeddings"]) |
72 | 74 | self.verbose = False if config is None else config.get( |
73 | 75 | "verbose", False) |
@@ -99,7 +101,7 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None, sche |
99 | 101 | "llm_model": self.llm_model, |
100 | 102 | "embedder_model": self.embedder_model |
101 | 103 | } |
102 | | - |
| 104 | + |
103 | 105 | self.set_common_params(common_params, overwrite=False) |
104 | 106 |
|
105 | 107 | # set burr config |
@@ -174,7 +176,14 @@ def _create_llm(self, llm_config: dict, chat=False) -> object: |
174 | 176 | except KeyError as exc: |
175 | 177 | raise KeyError("Model not supported") from exc |
176 | 178 | return OpenAI(llm_params) |
177 | | - |
| 179 | + elif "oneapi" in llm_params["model"]: |
| 180 | + # take the model after the last dash |
| 181 | + llm_params["model"] = llm_params["model"].split("/")[-1] |
| 182 | + try: |
| 183 | + self.model_token = models_tokens["oneapi"][llm_params["model"]] |
| 184 | + except KeyError as exc: |
| 185 | + raise KeyError("Model Model not supported") from exc |
| 186 | + return OneApi(llm_params) |
178 | 187 | elif "azure" in llm_params["model"]: |
179 | 188 | # take the model after the last dash |
180 | 189 | llm_params["model"] = llm_params["model"].split("/")[-1] |
|
0 commit comments