|
13 | 13 | import logging |
14 | 14 | import importlib |
15 | 15 | import itertools |
| 16 | +from types import ModuleType |
16 | 17 | from typing import Any, Dict, List, Tuple, Union, Optional |
17 | 18 |
|
18 | 19 | from .utils import text_, bytes_ |
@@ -75,31 +76,54 @@ def load( |
75 | 76 | # this plugin_ is implementing |
76 | 77 | base_klass = None |
77 | 78 | for k in mro: |
78 | | - if bytes_(k.__name__) in p: |
| 79 | + if bytes_(k.__qualname__) in p: |
79 | 80 | base_klass = k |
80 | 81 | break |
81 | 82 | if base_klass is None: |
82 | 83 | raise ValueError('%s is NOT a valid plugin' % text_(plugin_)) |
83 | | - if klass not in p[bytes_(base_klass.__name__)]: |
84 | | - p[bytes_(base_klass.__name__)].append(klass) |
85 | | - logger.info('Loaded plugin %s.%s', module_name, klass.__name__) |
| 84 | + if klass not in p[bytes_(base_klass.__qualname__)]: |
| 85 | + p[bytes_(base_klass.__qualname__)].append(klass) |
| 86 | + logger.info('Loaded plugin %s.%s', module_name, klass.__qualname__) |
86 | 87 | # print(p) |
87 | 88 | return p |
88 | 89 |
|
89 | 90 | @staticmethod |
90 | 91 | def importer(plugin: Union[bytes, type]) -> Tuple[type, str]: |
91 | 92 | """Import and returns the plugin.""" |
92 | 93 | if isinstance(plugin, type): |
93 | | - return (plugin, '__main__') |
| 94 | + if inspect.isclass(plugin): |
| 95 | + return (plugin, plugin.__module__ or '__main__') |
| 96 | + raise ValueError('%s is not a valid reference to a plugin class' % text_(plugin)) |
94 | 97 | plugin_ = text_(plugin.strip()) |
95 | 98 | assert plugin_ != '' |
96 | | - module_name, klass_name = plugin_.rsplit(text_(DOT), 1) |
97 | | - klass = getattr( |
98 | | - importlib.import_module( |
99 | | - module_name.replace( |
100 | | - os.path.sep, text_(DOT), |
101 | | - ), |
102 | | - ), |
103 | | - klass_name, |
104 | | - ) |
| 99 | + path = plugin_.split(text_(DOT)) |
| 100 | + klass = None |
| 101 | + |
| 102 | + def locate_klass(klass_module_name: str, klass_path: List[str]) -> Union[type, None]: |
| 103 | + klass_module_name = klass_module_name.replace(os.path.sep, text_(DOT)) |
| 104 | + try: |
| 105 | + klass_module = importlib.import_module(klass_module_name) |
| 106 | + except ModuleNotFoundError: |
| 107 | + return None |
| 108 | + klass_container: Union[ModuleType, type] = klass_module |
| 109 | + for klass_path_part in klass_path: |
| 110 | + try: |
| 111 | + klass_container = getattr(klass_container, klass_path_part) |
| 112 | + except AttributeError: |
| 113 | + return None |
| 114 | + if not isinstance(klass_container, type) or not inspect.isclass(klass_container): |
| 115 | + return None |
| 116 | + return klass_container |
| 117 | + |
| 118 | + module_name = None |
| 119 | + for module_name_parts in range(len(path) - 1, 0, -1): |
| 120 | + module_name = '.'.join(path[0:module_name_parts]) |
| 121 | + klass = locate_klass(module_name, path[module_name_parts:]) |
| 122 | + if klass: |
| 123 | + break |
| 124 | + if klass is None: |
| 125 | + module_name = '__main__' |
| 126 | + klass = locate_klass(module_name, path) |
| 127 | + if klass is None or module_name is None: |
| 128 | + raise ValueError('%s is not resolvable as a plugin class' % text_(plugin)) |
105 | 129 | return (klass, module_name) |
0 commit comments