1919from typing import Callable , Dict , Generator , List , Optional , Union
2020
2121from torchx .specs import AppDef
22- from torchx .specs .file_linter import get_fn_docstring , validate
22+ from torchx .specs .file_linter import get_fn_docstring , TorchxFunctionValidator , validate
2323from torchx .util import entrypoints
2424from torchx .util .io import read_conf_file
2525from torchx .util .types import none_throws
@@ -59,7 +59,9 @@ class _Component:
5959
6060class ComponentsFinder (abc .ABC ):
6161 @abc .abstractmethod
62- def find (self ) -> List [_Component ]:
62+ def find (
63+ self , validators : Optional [List [TorchxFunctionValidator ]]
64+ ) -> List [_Component ]:
6365 """
6466 Retrieves a set of components. A component is defined as a python
6567 function that conforms to ``torchx.specs.file_linter`` linter.
@@ -203,10 +205,12 @@ def _iter_modules_recursive(
203205 else :
204206 yield self ._try_import (module_info .name )
205207
206- def find (self ) -> List [_Component ]:
208+ def find (
209+ self , validators : Optional [List [TorchxFunctionValidator ]]
210+ ) -> List [_Component ]:
207211 components = []
208212 for m in self ._iter_modules_recursive (self .base_module ):
209- components += self ._get_components_from_module (m )
213+ components += self ._get_components_from_module (m , validators )
210214 return components
211215
212216 def _try_import (self , module : Union [str , ModuleType ]) -> ModuleType :
@@ -221,7 +225,9 @@ def _try_import(self, module: Union[str, ModuleType]) -> ModuleType:
221225 else :
222226 return module
223227
224- def _get_components_from_module (self , module : ModuleType ) -> List [_Component ]:
228+ def _get_components_from_module (
229+ self , module : ModuleType , validators : Optional [List [TorchxFunctionValidator ]]
230+ ) -> List [_Component ]:
225231 functions = getmembers (module , isfunction )
226232 component_defs = []
227233
@@ -230,7 +236,7 @@ def _get_components_from_module(self, module: ModuleType) -> List[_Component]:
230236 module_path = os .path .abspath (module_path )
231237 rel_module_name = module_relname (module , relative_to = self .base_module )
232238 for function_name , function in functions :
233- linter_errors = validate (module_path , function_name )
239+ linter_errors = validate (module_path , function_name , validators )
234240 component_desc , _ = get_fn_docstring (function )
235241
236242 # remove empty string to deal with group=""
@@ -255,13 +261,20 @@ def __init__(self, filepath: str, function_name: str) -> None:
255261 self ._filepath = filepath
256262 self ._function_name = function_name
257263
258- def _get_validation_errors (self , path : str , function_name : str ) -> List [str ]:
259- linter_errors = validate (path , function_name )
264+ def _get_validation_errors (
265+ self ,
266+ path : str ,
267+ function_name : str ,
268+ validators : Optional [List [TorchxFunctionValidator ]],
269+ ) -> List [str ]:
270+ linter_errors = validate (path , function_name , validators )
260271 return [linter_error .description for linter_error in linter_errors ]
261272
262- def find (self ) -> List [_Component ]:
273+ def find (
274+ self , validators : Optional [List [TorchxFunctionValidator ]]
275+ ) -> List [_Component ]:
263276 validation_errors = self ._get_validation_errors (
264- self ._filepath , self ._function_name
277+ self ._filepath , self ._function_name , validators
265278 )
266279
267280 file_source = read_conf_file (self ._filepath )
@@ -284,7 +297,9 @@ def find(self) -> List[_Component]:
284297 ]
285298
286299
287- def _load_custom_components () -> List [_Component ]:
300+ def _load_custom_components (
301+ validators : Optional [List [TorchxFunctionValidator ]],
302+ ) -> List [_Component ]:
288303 component_modules = {
289304 name : load_fn ()
290305 for name , load_fn in
@@ -303,11 +318,13 @@ def _load_custom_components() -> List[_Component]:
303318 # _0 = torchx.components.dist
304319 # _1 = torchx.components.utils
305320 group = "" if group .startswith ("_" ) else group
306- components += ModuleComponentsFinder (module , group ).find ()
321+ components += ModuleComponentsFinder (module , group ).find (validators )
307322 return components
308323
309324
310- def _load_components () -> Dict [str , _Component ]:
325+ def _load_components (
326+ validators : Optional [List [TorchxFunctionValidator ]],
327+ ) -> Dict [str , _Component ]:
311328 """
312329 Loads either the custom component defs from the entrypoint ``[torchx.components]``
313330 or the default builtins from ``torchx.components`` module.
@@ -318,37 +335,43 @@ def _load_components() -> Dict[str, _Component]:
318335
319336 """
320337
321- components = _load_custom_components ()
338+ components = _load_custom_components (validators )
322339 if not components :
323- components = ModuleComponentsFinder ("torchx.components" , "" ).find ()
340+ components = ModuleComponentsFinder ("torchx.components" , "" ).find (validators )
324341 return {c .name : c for c in components }
325342
326343
327344_components : Optional [Dict [str , _Component ]] = None
328345
329346
330- def _find_components () -> Dict [str , _Component ]:
347+ def _find_components (
348+ validators : Optional [List [TorchxFunctionValidator ]],
349+ ) -> Dict [str , _Component ]:
331350 global _components
332351 if not _components :
333- _components = _load_components ()
352+ _components = _load_components (validators )
334353 return none_throws (_components )
335354
336355
337356def _is_custom_component (component_name : str ) -> bool :
338357 return ":" in component_name
339358
340359
341- def _find_custom_components (name : str ) -> Dict [str , _Component ]:
360+ def _find_custom_components (
361+ name : str , validators : Optional [List [TorchxFunctionValidator ]]
362+ ) -> Dict [str , _Component ]:
342363 if ":" not in name :
343364 raise ValueError (
344365 f"Invalid custom component: { name } , valid template : `FILEPATH`:`FUNCTION_NAME`"
345366 )
346367 filepath , component_name = name .split (":" )
347- components = CustomComponentsFinder (filepath , component_name ).find ()
368+ components = CustomComponentsFinder (filepath , component_name ).find (validators )
348369 return {component .name : component for component in components }
349370
350371
351- def get_components () -> Dict [str , _Component ]:
372+ def get_components (
373+ validators : Optional [List [TorchxFunctionValidator ]] = None ,
374+ ) -> Dict [str , _Component ]:
352375 """
353376 Returns all custom components registered via ``[torchx.components]`` entrypoints
354377 OR builtin components that ship with TorchX (but not both).
@@ -395,23 +418,25 @@ def get_components() -> Dict[str, _Component]:
395418 """
396419
397420 valid_components : Dict [str , _Component ] = {}
398- for component_name , component in _find_components ().items ():
421+ for component_name , component in _find_components (validators ).items ():
399422 if len (component .validation_errors ) == 0 :
400423 valid_components [component_name ] = component
401424 return valid_components
402425
403426
404- def get_component (name : str ) -> _Component :
427+ def get_component (
428+ name : str , validators : Optional [List [TorchxFunctionValidator ]] = None
429+ ) -> _Component :
405430 """
406431 Retrieves components by the provided name.
407432
408433 Returns:
409434 Component or None if no component with ``name`` exists
410435 """
411436 if _is_custom_component (name ):
412- components = _find_custom_components (name )
437+ components = _find_custom_components (name , validators )
413438 else :
414- components = _find_components ()
439+ components = _find_components (validators )
415440 if name not in components :
416441 raise ComponentNotFoundException (
417442 f"Component `{ name } ` not found. Please make sure it is one of the "
@@ -428,7 +453,9 @@ def get_component(name: str) -> _Component:
428453 return component
429454
430455
431- def get_builtin_source (name : str ) -> str :
456+ def get_builtin_source (
457+ name : str , validators : Optional [List [TorchxFunctionValidator ]] = None
458+ ) -> str :
432459 """
433460 Returns a string of the the builtin component's function source code
434461 with all the import statements. Intended to be used to make a copy
@@ -446,7 +473,7 @@ def get_builtin_source(name: str) -> str:
446473 are optimized and formatting adheres to your organization's standards.
447474 """
448475
449- component = get_component (name )
476+ component = get_component (name , validators )
450477 fn = component .fn
451478 fn_name = component .name .split ("." )[- 1 ]
452479
0 commit comments