Skip to content

Commit eaa6b62

Browse files
authored
Refactor blocks, processors and datasources. Add tests (#8)
* Update documentation * Add test fixtures * Reafactor processor interface and add tests
1 parent 38dd35f commit eaa6b62

File tree

128 files changed

+2154
-3094
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

128 files changed

+2154
-3094
lines changed

apps/tasks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import List
55

66
from datasources.handlers.datasource_type_interface import DataSourceEntryItem
7-
from datasources.handlers.datasource_type_interface import DataSourceTypeInterface
7+
from datasources.handlers.datasource_type_interface import DataSourceProcessor
88
from datasources.models import DataSource
99
from datasources.models import DataSourceEntry
1010
from datasources.models import DataSourceEntryStatus
@@ -19,7 +19,7 @@ def add_data_entry_task(datasource: DataSource, datasource_entry_items: List[Dat
1919
datasource_entry_handler_cls = DataSourceTypeFactory.get_datasource_type_handler(
2020
datasource.type,
2121
)
22-
datasource_entry_handler: DataSourceTypeInterface = datasource_entry_handler_cls(
22+
datasource_entry_handler: DataSourceProcessor = datasource_entry_handler_cls(
2323
datasource,
2424
)
2525

File renamed without changes.
File renamed without changes.

common/promptly/core/base.py renamed to common/blocks/base/processor.py

Lines changed: 73 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,19 @@
11
import logging
2-
import pickle
32
from abc import ABC
43
from abc import abstractmethod
54
from typing import Generator
65
from typing import Generic
76
from typing import Optional
87
from typing import TypeVar
98

10-
from pydantic import BaseModel
119
from pydantic import Field
1210

13-
LOGGER = logging.getLogger(__name__)
14-
15-
16-
class Schema(BaseModel):
17-
pass
11+
from common.blocks.base.schema import BaseSchema as Schema
1812

13+
LOGGER = logging.getLogger(__name__)
1914

2015
class BaseInputEnvironment(Schema):
21-
bypass_cache: bool = Field(True, description='Bypass cache')
22-
23-
24-
class BaseError(Schema):
25-
code: int
26-
message: str
27-
16+
pass
2817

2918
class BaseInput(Schema):
3019
env: Optional[BaseInputEnvironment] = Field(
@@ -36,10 +25,6 @@ class BaseConfiguration(Schema):
3625
pass
3726

3827

39-
class BaseErrorOutput(Schema):
40-
error: Optional[BaseError] = Field(None, description='Error Object')
41-
42-
4328
class BaseOutput(Schema):
4429
metadata: Optional[Schema] = Field(
4530
default={}, description='Metadata',
@@ -61,27 +46,17 @@ def delete(self, key):
6146

6247

6348
BaseInputType = TypeVar('BaseInputType', BaseInput, dict)
64-
BaseOutputType = TypeVar(
65-
'BaseOutputType', BaseOutput,
66-
BaseErrorOutput, Generator[BaseOutput, None, None],
67-
)
68-
BaseConfigurationType = TypeVar(
69-
'BaseConfigurationType', BaseConfiguration, dict,
70-
)
71-
72-
73-
class BaseProcessor(Generic[BaseInputType, BaseOutputType, BaseConfigurationType], ABC):
74-
"""
75-
Base class for all processors
76-
"""
49+
BaseOutputType = TypeVar('BaseOutputType', BaseOutput, dict)
50+
BaseConfigurationType = TypeVar('BaseConfigurationType', BaseConfiguration, dict)
7751

52+
class ProcessorInterface(Generic[BaseInputType, BaseOutputType, BaseConfigurationType], ABC):
7853
@staticmethod
7954
def name() -> str:
8055
raise NotImplementedError
8156

8257
@staticmethod
8358
def slug() -> str:
84-
return BaseProcessor.name().lower().replace(' ', '_')
59+
raise NotImplementedError
8560

8661
@classmethod
8762
def _get_input_class(cls) -> BaseInputType:
@@ -90,26 +65,88 @@ def _get_input_class(cls) -> BaseInputType:
9065
@classmethod
9166
def _get_output_class(cls) -> BaseOutputType:
9267
return cls.__orig_bases__[0].__args__[1]
93-
68+
9469
@classmethod
9570
def _get_configuration_class(cls) -> BaseConfigurationType:
9671
return cls.__orig_bases__[0].__args__[2]
9772

9873
@classmethod
99-
def get_input_schema(cls) -> dict:
74+
def _get_input_schema(cls) -> dict:
10075
api_processor_interface_class = cls.__orig_bases__[0]
10176
return api_processor_interface_class.__args__[0].schema_json()
10277

10378
@classmethod
104-
def get_output_schema(cls) -> dict:
79+
def _get_output_schema(cls) -> dict:
10580
api_processor_interface_class = cls.__orig_bases__[0]
10681
return api_processor_interface_class.__args__[1].schema_json()
10782

10883
@classmethod
109-
def get_configuration_schema(cls) -> dict:
84+
def _get_configuration_schema(cls) -> dict:
11085
api_processor_interface_class = cls.__orig_bases__[0]
11186
return api_processor_interface_class.__args__[2].schema_json()
87+
88+
@classmethod
89+
def _get_input_ui_schema(cls) -> dict:
90+
api_processor_interface_class = cls.__orig_bases__[0]
91+
return api_processor_interface_class.__args__[0].get_ui_schema()
92+
93+
@classmethod
94+
def _get_output_ui_schema(cls) -> dict:
95+
api_processor_interface_class = cls.__orig_bases__[0]
96+
return api_processor_interface_class.__args__[1].get_ui_schema()
97+
98+
@classmethod
99+
def _get_configuration_ui_schema(cls) -> dict:
100+
api_processor_interface_class = cls.__orig_bases__[0]
101+
return api_processor_interface_class.__args__[2].get_ui_schema()
102+
103+
@classmethod
104+
def get_input_cls(cls) -> BaseInputType:
105+
return cls._get_input_class()
106+
107+
@classmethod
108+
def get_output_cls(cls) -> BaseOutputType:
109+
return cls._get_output_class()
110+
111+
@classmethod
112+
def get_configuration_cls(cls) -> BaseConfigurationType:
113+
return cls._get_configuration_class()
114+
115+
@classmethod
116+
def get_input_schema(cls) -> dict:
117+
return cls._get_input_schema()
118+
119+
@classmethod
120+
def get_output_schema(cls) -> dict:
121+
return cls._get_output_schema()
122+
123+
@classmethod
124+
def get_configuration_schema(cls) -> dict:
125+
return cls._get_configuration_schema()
126+
127+
@classmethod
128+
def get_input_ui_schema(cls) -> dict:
129+
return cls._get_input_ui_schema()
130+
131+
@classmethod
132+
def get_output_ui_schema(cls) -> dict:
133+
return cls._get_output_ui_schema()
134+
135+
@classmethod
136+
def get_configuration_ui_schema(cls) -> dict:
137+
return cls._get_configuration_ui_schema()
112138

139+
def process(self, input: BaseInputType, configuration: BaseConfigurationType) -> BaseOutputType:
140+
raise NotImplementedError()
141+
142+
def process_iter(self, input: BaseInputType, configuration: BaseConfigurationType) -> Generator[BaseOutputType, None, None]:
143+
raise NotImplementedError()
144+
145+
class BaseProcessor(ProcessorInterface[BaseInputType, BaseOutputType, BaseConfigurationType]):
146+
"""
147+
Base class for all processors
148+
"""
149+
113150
def __init__(self, configuration: dict, cache_manager: CacheManager = None, input_tx_cb: callable = None, output_tx_cb: callable = None):
114151
self.configuration = self.parse_validate_configuration(configuration)
115152
self.cache_manager = cache_manager
@@ -154,9 +191,6 @@ def process_iter(self, input: dict) -> BaseOutputType:
154191
LOGGER.exception('Exception occurred while processing')
155192
raise ex
156193

157-
def serialize(self):
158-
return pickle.dumps(self)
159-
160194
@property
161195
def id(self):
162196
return id(self)

common/blocks/base/schema.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from pydantic import BaseModel
2+
3+
def custom_json_dumps(v, **kwargs):
4+
import ujson as json
5+
6+
default_arg = kwargs.get('default', None)
7+
return json.dumps(v, default=default_arg)
8+
9+
10+
def custom_json_loads(v, **kwargs):
11+
import ujson as json
12+
13+
return json.loads(v, **kwargs)
14+
15+
16+
class BaseSchema(BaseModel):
17+
class Config:
18+
json_dumps = custom_json_dumps
19+
json_loads = custom_json_loads
20+
21+
@ classmethod
22+
def get_json_schema(cls):
23+
return super().schema_json(indent=2)
24+
25+
@ classmethod
26+
def get_schema(cls):
27+
return super().schema()
28+
29+
@ classmethod
30+
def get_ui_schema(cls):
31+
"""
32+
This function receives a class method; gets the schema of the class
33+
Calls the function form_ui_per_prop to form the dictionary of UI schema values
34+
The resultant UI Schema only contains the properties
35+
"""
36+
schema = cls.get_schema()
37+
ui_schema = {
38+
'ui:order': list(schema['properties'].keys()),
39+
'properties': {}
40+
}
41+
42+
for p_key, p_val in schema['properties'].items():
43+
ui_schema['properties'][p_key] = cls.form_ui_per_prop(p_key, p_val)
44+
45+
return ui_schema['properties']
46+
47+
@ classmethod
48+
def form_ui_per_prop(cls, p_key, prop_schema_dict):
49+
"""
50+
This functions receives the property key and its schema dictionary
51+
It checks the type of the property value and based on its type,
52+
assigns the correct UI widget and other UI properties to it
53+
"""
54+
ui_prop = {}
55+
56+
if 'title' in prop_schema_dict:
57+
ui_prop['ui:label'] = prop_schema_dict['title']
58+
59+
if 'description' in prop_schema_dict:
60+
ui_prop['ui:description'] = prop_schema_dict['description']
61+
62+
type_val = prop_schema_dict.get('type')
63+
64+
if type_val == 'string':
65+
ui_prop['ui:widget'] = 'text'
66+
elif type_val in ('integer', 'number'):
67+
if 'minimum' in prop_schema_dict and 'maximum' in prop_schema_dict:
68+
ui_prop['ui:widget'] = 'range'
69+
ui_prop['ui:options'] = {
70+
'min': prop_schema_dict['minimum'],
71+
'max': prop_schema_dict['maximum'],
72+
}
73+
else:
74+
ui_prop['ui:widget'] = 'updown'
75+
elif type_val == 'boolean':
76+
ui_prop['ui:widget'] = 'checkbox'
77+
elif 'enum' in prop_schema_dict:
78+
ui_prop['ui:widget'] = 'select'
79+
ui_prop['ui:options'] = {
80+
'enumOptions': [{'value': val, 'label': val} for val in prop_schema_dict['enum']]
81+
}
82+
83+
if 'widget' in prop_schema_dict:
84+
ui_prop['ui:widget'] = prop_schema_dict['widget']
85+
86+
if prop_schema_dict.get('format') == 'date-time':
87+
ui_prop['ui:widget'] = 'datetime'
88+
89+
# Unless explicitly mentioned, all properties are advanced parameters
90+
ui_prop['ui:advanced'] = prop_schema_dict.get('advanced_parameter', True)
91+
92+
return ui_prop
File renamed without changes.

common/promptly/blocks/utils.py renamed to common/blocks/cache/in_memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import hashlib
22

3-
from common.promptly.core.base import CacheManager
3+
from common.blocks.base.processor import CacheManager
44

55

66
def get_hash(input: dict):

common/blocks/data/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from typing import Any, Dict, Optional
2+
from common.blocks.base.schema import BaseSchema
3+
4+
class DataDocument(BaseSchema):
5+
content:Optional[bytes]
6+
content_text: Optional[str]
7+
metadata: Dict[str, Any] = {}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from pydantic import AnyUrl
2+
from typing import List, Optional
3+
from common.blocks.base.schema import BaseSchema
4+
from common.blocks.data import DataDocument
5+
6+
class DataUrl(AnyUrl):
7+
@classmethod
8+
def __modify_schema__(cls, field_schema):
9+
field_schema.update(
10+
{
11+
'format': 'data-url',
12+
'pattern': r'data:(.*);name=(.*);base64,(.*)',
13+
},
14+
)
15+
16+
class DataSourceEnvironmentSchema(BaseSchema):
17+
openai_key: str
18+
class DataSourceInputSchema(BaseSchema):
19+
env: Optional[DataSourceEnvironmentSchema]
20+
21+
class DataSourceConfigurationSchema(BaseSchema):
22+
pass
23+
24+
class DataSourceOutputSchema(BaseSchema):
25+
documents: List[DataDocument]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
import re
3+
from pydantic import root_validator
4+
5+
from common.blocks.base.processor import ProcessorInterface
6+
from common.blocks.data import DataDocument
7+
from common.blocks.data.source import DataSourceInputSchema, DataSourceConfigurationSchema, DataSourceOutputSchema
8+
9+
10+
class DirectoryTextLoaderInputSchema(DataSourceInputSchema):
11+
directory: str
12+
recursive: bool = False
13+
14+
@root_validator()
15+
@classmethod
16+
def validate_directory(cls, field_values) -> str:
17+
value = field_values.get("directory")
18+
recursive = field_values.get("recursive")
19+
20+
# TODO: Validate that directory is a valid directory path and the directory exists
21+
if not re.match(r"^[a-zA-Z0-9_\-\.\/]+$", value):
22+
raise ValueError("Directory must be a valid string")
23+
24+
return value
25+
26+
class DirectoryTextLoader(ProcessorInterface[DirectoryTextLoaderInputSchema, DataSourceOutputSchema, DataSourceConfigurationSchema]):
27+
28+
def process(self, input: DirectoryTextLoaderInputSchema, configuration: DataSourceConfigurationSchema) -> DataSourceOutputSchema:
29+
result = []
30+
files = []
31+
# If recursive is true, then we need to recursively walk the directory
32+
if input.recursive:
33+
for dir, dirname, filename in os.walk(input.directory):
34+
files.extend(filename)
35+
else:
36+
files = os.listdir(input.directory)
37+
38+
for file in files:
39+
with open(os.path.join(input.directory, file), "r") as f:
40+
result.append(DataDocument(name=file, content=f.read()))
41+
42+
return DataSourceOutputSchema(documents=result)
43+
44+

0 commit comments

Comments
 (0)