|
14 | 14 | # limitations under the License. |
15 | 15 | from __future__ import annotations |
16 | 16 |
|
| 17 | +from collections import defaultdict |
17 | 18 | from typing import ( |
18 | 19 | Any, |
19 | 20 | ClassVar, |
@@ -336,37 +337,40 @@ def _get_connections(self) -> list[ConnectionDefinition]: |
336 | 337 | return connections |
337 | 338 |
|
338 | 339 | def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: |
339 | | - run_params = {} |
340 | | - if self.lexical_graph_config: |
341 | | - run_params["extractor"] = { |
342 | | - "lexical_graph_config": self.lexical_graph_config, |
343 | | - } |
344 | | - run_params["writer"] = { |
345 | | - "lexical_graph_config": self.lexical_graph_config, |
346 | | - } |
347 | | - run_params["pruner"] = { |
348 | | - "lexical_graph_config": self.lexical_graph_config, |
349 | | - } |
350 | 340 | text = user_input.get("text") |
351 | 341 | file_path = user_input.get("file_path") |
352 | | - if not ((text is None) ^ (file_path is None)): |
353 | | - # exactly one of text or user_input must be set |
| 342 | + if text is None and file_path is None: |
| 343 | + # user must provide either text or file_path or both |
354 | 344 | raise PipelineDefinitionError( |
355 | | - "Use either 'text' (when from_pdf=False) or 'file_path' (when from_pdf=True) argument." |
| 345 | + "At least one of `text` (when from_pdf=False) or `file_path` (when from_pdf=True) argument must be provided." |
356 | 346 | ) |
| 347 | + run_params: dict[str, dict[str, Any]] = defaultdict(dict) |
| 348 | + if self.lexical_graph_config: |
| 349 | + run_params["extractor"]["lexical_graph_config"] = self.lexical_graph_config |
| 350 | + run_params["writer"]["lexical_graph_config"] = self.lexical_graph_config |
| 351 | + run_params["pruner"]["lexical_graph_config"] = self.lexical_graph_config |
357 | 352 | if self.from_pdf: |
358 | 353 | if not file_path: |
359 | 354 | raise PipelineDefinitionError( |
360 | 355 | "Expected 'file_path' argument when 'from_pdf' is True." |
361 | 356 | ) |
362 | | - run_params["pdf_loader"] = {"filepath": file_path} |
| 357 | + run_params["pdf_loader"]["filepath"] = file_path |
| 358 | + run_params["pdf_loader"]["metadata"] = user_input.get("document_metadata") |
363 | 359 | else: |
364 | 360 | if not text: |
365 | 361 | raise PipelineDefinitionError( |
366 | 362 | "Expected 'text' argument when 'from_pdf' is False." |
367 | 363 | ) |
368 | | - run_params["splitter"] = {"text": text} |
| 364 | + run_params["splitter"]["text"] = text |
369 | 365 | # Add full text to schema component for automatic schema extraction |
370 | 366 | if not self.has_user_provided_schema(): |
371 | | - run_params["schema"] = {"text": text} |
| 367 | + run_params["schema"]["text"] = text |
| 368 | + run_params["extractor"]["document_info"] = dict( |
| 369 | + path=user_input.get( |
| 370 | + "file_path", |
| 371 | + ) |
| 372 | + or "document.txt", |
| 373 | + metadata=user_input.get("document_metadata"), |
| 374 | + document_type="inline_text", |
| 375 | + ) |
372 | 376 | return run_params |
0 commit comments