diff --git a/business_objects/general.py b/business_objects/general.py index 0475f39b..6a5978ca 100644 --- a/business_objects/general.py +++ b/business_objects/general.py @@ -77,7 +77,7 @@ def force_remove_and_refresh_session_by_id(session_id: str) -> bool: if session_id not in session_lookup: return False # context vars cant be closed from a different context but we can work around it by using a thread (which creates a new context) with the same id - daemon.run_without_db_token(__close_in_context(session_id)) + daemon.run_without_db_token(__close_in_context, session_id) return True diff --git a/business_objects/monitor.py b/business_objects/monitor.py index 65429c55..c56aa426 100644 --- a/business_objects/monitor.py +++ b/business_objects/monitor.py @@ -5,6 +5,7 @@ from submodules.model.models import TaskQueue, Organization from submodules.model.util import prevent_sql_injection from submodules.model.session import session +from submodules.model.global_objects import etl_task as etl_task_db_bo from submodules.model.cognition_objects import ( macro as macro_db_bo, markdown_file as markdown_file_db_bo, @@ -220,6 +221,26 @@ def set_integration_task_to_failed( ) +def set_etl_task_to_failed( + id: str, + is_active: bool = False, + error_message: Optional[str] = None, + state: Optional[ + enums.CognitionMarkdownFileState + ] = enums.CognitionMarkdownFileState.FAILED, + with_commit: bool = True, +) -> None: + # argument `state` is a workaround for cognition-gateway/api/routes/integrations.delete_many + etl_task_db_bo.update( + id=id, + state=state, + finished_at=datetime.datetime.now(datetime.timezone.utc), + is_active=is_active, + error_message=error_message, + with_commit=with_commit, + ) + + def __select_running_information_source_payloads( project_id: Optional[str] = None, only_running: bool = False, diff --git a/cognition_objects/environment_variable.py b/cognition_objects/environment_variable.py index 68e86301..cb6787ca 100644 --- a/cognition_objects/environment_variable.py +++ b/cognition_objects/environment_variable.py @@ -64,6 +64,22 @@ def get_by_name_and_org_id( ) +def get_by_id_and_org_id( + org_id: str, + id: str, +) -> CognitionEnvironmentVariable: + + return ( + session.query(CognitionEnvironmentVariable) + .filter( + CognitionEnvironmentVariable.organization_id == org_id, + CognitionEnvironmentVariable.project_id == None, + CognitionEnvironmentVariable.id == id, + ) + .first() + ) + + def get_dataset_env_var_value( dataset_id: str, org_id: str, scope: Literal["extraction", "transformation"] ) -> Union[str, None]: diff --git a/cognition_objects/integration.py b/cognition_objects/integration.py index 7b6fce78..c9f0ca94 100644 --- a/cognition_objects/integration.py +++ b/cognition_objects/integration.py @@ -3,7 +3,9 @@ from sqlalchemy import func from sqlalchemy.orm.attributes import flag_modified + from ..business_objects import general +from ..integration_objects import manager as integration_manager_db_bo from ..session import session from ..models import CognitionIntegration, CognitionGroup from ..enums import ( @@ -200,6 +202,7 @@ def create( def update( id: str, + project_id: Optional[str] = None, updated_by: Optional[str] = None, name: Optional[str] = None, description: Optional[str] = None, @@ -219,6 +222,8 @@ def update( if not integration: return None + if project_id is not None and integration.project_id is None: + integration.project_id = project_id if updated_by is not None: integration.updated_by = updated_by if name is not None: @@ -278,6 +283,16 @@ def execution_finished(id: str) -> bool: def delete_many( ids: List[str], delete_cognition_groups: bool = True, with_commit: bool = True ) -> None: + for id in ids: + integration_records, IntegrationModel = ( + integration_manager_db_bo.get_all_by_integration_id(id) + ) + integration_manager_db_bo.delete_many( + IntegrationModel, + ids=[rec.id for rec in integration_records], + with_commit=True, + ) + ( session.query(CognitionIntegration) .filter(CognitionIntegration.id.in_(ids)) @@ -289,6 +304,7 @@ def delete_many( .filter(CognitionGroup.meta_data.op("->>")("integration_id").in_(ids)) .delete(synchronize_session=False) ) + general.flush_or_commit(with_commit) diff --git a/cognition_objects/markdown_dataset.py b/cognition_objects/markdown_dataset.py index a7c2cbfe..1035ee0f 100644 --- a/cognition_objects/markdown_dataset.py +++ b/cognition_objects/markdown_dataset.py @@ -3,9 +3,10 @@ from ..business_objects import general from ..session import session -from ..models import CognitionMarkdownDataset, Project +from ..models import CognitionMarkdownDataset, CognitionMarkdownFile, Project from ..enums import Tablenames, MarkdownFileCategoryOrigin from ..util import prevent_sql_injection +from .markdown_file import delete_many as delete_many_md_files def get(org_id: str, id: str) -> CognitionMarkdownDataset: @@ -184,6 +185,21 @@ def delete_many(org_id: str, dataset_ids: List[str], with_commit: bool = True) - ), ).delete(synchronize_session=False) + md_file_ids = ( + session.query(CognitionMarkdownFile.id) + .filter( + CognitionMarkdownFile.organization_id == org_id, + CognitionMarkdownFile.dataset_id.in_(dataset_ids), + ) + .all() + ) + + delete_many_md_files( + org_id=org_id, + md_file_ids=[md_file_id for (md_file_id,) in md_file_ids], + with_commit=True, + ) + session.query(CognitionMarkdownDataset).filter( CognitionMarkdownDataset.organization_id == org_id, CognitionMarkdownDataset.id.in_(dataset_ids), diff --git a/cognition_objects/markdown_file.py b/cognition_objects/markdown_file.py index 8b1e2d26..91d8e814 100644 --- a/cognition_objects/markdown_file.py +++ b/cognition_objects/markdown_file.py @@ -4,7 +4,7 @@ from .. import enums from ..business_objects import general from ..session import session -from ..models import CognitionMarkdownFile +from ..models import CognitionMarkdownFile, EtlTask from ..util import prevent_sql_injection @@ -19,6 +19,17 @@ def get(org_id: str, md_file_id: str) -> CognitionMarkdownFile: ) +def get_by_etl_task_id(org_id: str, etl_task_id: str) -> CognitionMarkdownFile: + return ( + session.query(CognitionMarkdownFile) + .filter( + CognitionMarkdownFile.organization_id == org_id, + CognitionMarkdownFile.etl_task_id == etl_task_id, + ) + .first() + ) + + def get_enriched(org_id: str, md_file_id: str) -> Dict[str, Any]: org_id = prevent_sql_injection(org_id, isinstance(org_id, str)) md_file_id = prevent_sql_injection(md_file_id, isinstance(org_id, str)) @@ -71,8 +82,12 @@ def __get_enriched_query( ) else: mf_select = "mf.*" + et_state = "et.state" + mf_state = "mf.state" - query = f"""SELECT {mf_select} FROM cognition.markdown_file mf + query = f"""SELECT {mf_select}, COALESCE({et_state}, {mf_state}) AS etl_state + FROM cognition.markdown_file mf + LEFT JOIN global.etl_task et ON mf.etl_task_id = et.id """ query += f"WHERE mf.organization_id = '{org_id}' {where_add}" query += query_add @@ -175,6 +190,7 @@ def update( finished_at: Optional[datetime] = None, error: Optional[str] = None, meta_data: Optional[Dict[str, Any]] = None, + etl_task_id: Optional[Dict[str, Any]] = None, overwrite_meta_data: bool = True, with_commit: bool = True, ) -> CognitionMarkdownFile: @@ -199,22 +215,33 @@ def update( markdown_file.meta_data = meta_data else: markdown_file.meta_data = {**markdown_file.meta_data, **meta_data} + if etl_task_id is not None: + markdown_file.etl_task_id = etl_task_id general.flush_or_commit(with_commit) return markdown_file def delete(org_id: str, md_file_id: str, with_commit: bool = True) -> None: - session.query(CognitionMarkdownFile).filter( + md_file = session.query(CognitionMarkdownFile).filter( CognitionMarkdownFile.organization_id == org_id, CognitionMarkdownFile.id == md_file_id, + ) + session.query(EtlTask).filter( + EtlTask.organization_id == org_id, EtlTask.id == md_file.etl_task_id ).delete() + md_file.delete() general.flush_or_commit(with_commit) def delete_many(org_id: str, md_file_ids: List[str], with_commit: bool = True) -> None: - session.query(CognitionMarkdownFile).filter( + md_files = session.query(CognitionMarkdownFile).filter( CognitionMarkdownFile.organization_id == org_id, CognitionMarkdownFile.id.in_(md_file_ids), + ) + session.query(EtlTask).filter( + EtlTask.organization_id == org_id, + EtlTask.id.in_([mf.etl_task_id for mf in md_files]), ).delete(synchronize_session=False) + md_files.delete(synchronize_session=False) general.flush_or_commit(with_commit) diff --git a/enums.py b/enums.py index 128b8e99..f7196d62 100644 --- a/enums.py +++ b/enums.py @@ -2,6 +2,21 @@ from typing import Any +class EnumKern(Enum): + @classmethod + def all(cls): + return [e.value for e in cls] + + @classmethod + def from_string(cls, value: str): + changed_value = value.upper().replace(" ", "_").replace("-", "_") + for member in cls: + if member.value == changed_value: + return member + print(f"ERROR: unknown enum {cls.__name__}: {value}", flush=True) + raise ValueError(f"Unknown enum {cls.__name__}: {value}") + + class DataTypes(Enum): INTEGER = "INTEGER" FLOAT = "FLOAT" @@ -178,6 +193,8 @@ class Tablenames(Enum): ADMIN_QUERY_MESSAGE_SUMMARY = "admin_query_message_summary" RELEASE_NOTIFICATION = "release_notification" TIMED_EXECUTIONS = "timed_executions" + ETL_TASK = "etl_task" + ETL_CONFIG_PRESET = "etl_config_preset" def snake_case_to_pascal_case(self): # the type name (written in PascalCase) of a table is needed to create backrefs @@ -470,22 +487,18 @@ class TokenScope(Enum): READ = "READ" READ_WRITE = "READ_WRITE" - def all(): - return [ - TokenScope.READ.value, - TokenScope.READ_WRITE.value, - ] + @classmethod + def all(cls): + return [e.value for e in cls] class TokenSubject(Enum): PROJECT = Tablenames.PROJECT.value.upper() MARKDOWN_DATASET = Tablenames.MARKDOWN_DATASET.value.upper() - def all(): - return [ - TokenSubject.PROJECT.value, - TokenSubject.MARKDOWN_DATASET.value, - ] + @classmethod + def all(cls): + return [e.value for e in cls] class TokenizationTaskTypes(Enum): @@ -517,6 +530,7 @@ class TaskType(Enum): RUN_COGNITION_MACRO = "RUN_COGNITION_MACRO" PARSE_COGNITION_FILE = "PARSE_COGNITION_FILE" EXECUTE_INTEGRATION = "EXECUTE_INTEGRATION" + EXECUTE_ETL = "EXECUTE_ETL" class TaskQueueAction(Enum): @@ -807,11 +821,9 @@ class MacroType(Enum): DOCUMENT_MESSAGE_QUEUE = "DOCUMENT_MESSAGE_QUEUE" FOLDER_MESSAGE_QUEUE = "FOLDER_MESSAGE_QUEUE" - def all(): - return [ - MacroType.DOCUMENT_MESSAGE_QUEUE.value, - MacroType.FOLDER_MESSAGE_QUEUE.value, - ] + @classmethod + def all(cls): + return [e.value for e in cls] # currently only one option, but could be extended in the future @@ -1017,3 +1029,77 @@ class MessageInitiationType(Enum): class TimedExecutionKey(Enum): LAST_RESET_USER_MESSAGE_COUNT = "LAST_RESET_USER_MESSAGE_COUNT" + + +class ETLSplitStrategy(EnumKern): + CHUNK = "CHUNK" + SHRINK = "SHRINK" + + +class ETLFileType(Enum): + PDF = "PDF" + WORD = "WORD" + MD = "MD" + + @classmethod + def from_string(cls, value: str): + changed_value = value.upper().replace(" ", "_").replace("-", "_") + for member in cls: + if member.value == changed_value: + return member + print( + f"WARNING: unknown enum {cls.__name__}: {value}, defaulting to {cls.__name__}.MD", + flush=True, + ) + return cls.MD + + +class ETLExtractorMD(EnumKern): + FILESYSTEM = "FILESYSTEM" + + +class ETLExtractorPDF(Enum): + VISION = "VISION" + AZURE_DI = "AZURE_DI" + PDF2MD = "PDF2MD" + + @classmethod + def from_string(cls, value: str): + changed_value = value.upper().replace(" ", "_").replace("-", "_") + for member in cls: + if member.value == changed_value: + return member + if changed_value == "PDF2MARKDOWN": + return cls.PDF2MD + return cls.VISION + + +class ETLExtractorWord(EnumKern): + FILESYSTEM = "FILESYSTEM" + + +class ETLExtractor: + MD = ETLExtractorMD + PDF = ETLExtractorPDF + WORD = ETLExtractorWord + + @classmethod + def from_string(cls, value: str): + changed_value = value.upper().replace(" ", "_").replace("-", "_") + for member in cls: + if member.name == changed_value: + return member + raise ValueError(f"ERROR: Unknown enum {cls.__name__}: {value}") + + +class ETLTransformer(EnumKern): + SUMMARIZE = "SUMMARIZE" + CLEANSE = "CLEANSE" + TEXT_TO_TABLE = "TEXT_TO_TABLE" + + +class ETLCacheKeys(EnumKern): + FILE_CACHE = "use_file_cache" + EXTRACTION = "use_extraction_cache" + SPLITTING = "use_splitting_cache" + TRANSFORMATION = "use_transformation_cache" diff --git a/global_objects/etl_task.py b/global_objects/etl_task.py new file mode 100644 index 00000000..a5b6dd50 --- /dev/null +++ b/global_objects/etl_task.py @@ -0,0 +1,350 @@ +from typing import List, Optional, Dict, Union +from sqlalchemy.orm.attributes import flag_modified + +import datetime + +from submodules.model import enums +from submodules.model.session import session +from submodules.model.business_objects import general +from submodules.model.models import ( + EtlTask, + FileReference, + CognitionMarkdownFile, + CognitionMarkdownDataset, + CognitionIntegration, + IntegrationSharepoint, +) + +FINISHED_STATES = [ + enums.CognitionMarkdownFileState.FINISHED.value, + enums.CognitionMarkdownFileState.FAILED.value, +] +DEFAULT_FILE_TYPE = enums.ETLFileType.PDF +DEFAULT_EXTRACTORS = { + enums.ETLFileType.MD: enums.ETLExtractorMD.FILESYSTEM, + enums.ETLFileType.PDF: enums.ETLExtractorPDF.PDF2MD, +} + +DEFAULT_FALLBACK_EXTRACTORS = { + enums.ETLFileType.MD: [], + enums.ETLFileType.PDF: [ + enums.ETLExtractorPDF.PDF2MD, + enums.ETLExtractorPDF.VISION, + ], +} + + +def get_by_ids(ids: List[str]) -> List[EtlTask]: + return session.query(EtlTask).filter(EtlTask.id.in_(ids)).all() + + +def get_by_id(id: str) -> EtlTask: + return session.query(EtlTask).filter(EtlTask.id == id).first() + + +def get_all( + markdown_file_id: Optional[str] = None, + sharepoint_file_id: Optional[str] = None, + exclude_failed: Optional[bool] = False, + only_active: Optional[bool] = False, +) -> List[EtlTask]: + query = session.query(EtlTask) + if markdown_file_id is not None and sharepoint_file_id is not None: + raise ValueError( + "get_all: Only one of markdown_file_id or sharepoint_file_id should be provided." + ) + if markdown_file_id: + query = query.filter(EtlTask.markdown_file_id == markdown_file_id) + if sharepoint_file_id: + query = query.filter(EtlTask.sharepoint_file_id == sharepoint_file_id) + + if exclude_failed: + query = query.filter( + EtlTask.state != enums.CognitionMarkdownFileState.FAILED.value + ) + if only_active: + query = query.filter(EtlTask.is_active == True) + return query.order_by(EtlTask.created_at.desc()).all() + + +def get_all_in_org( + org_id: str, + exclude_failed: Optional[bool] = False, + only_active: Optional[bool] = False, +) -> List[EtlTask]: + query = session.query(EtlTask).filter(EtlTask.organization_id == org_id) + if only_active: + query = query.filter(EtlTask.is_active == True) + if exclude_failed: + query = query.filter( + EtlTask.state != enums.CognitionMarkdownFileState.FAILED.value + ) + return query.order_by(EtlTask.created_at.desc()).all() + + +def get_all_in_org_paginated( + org_id: str, + page: int = 1, + page_size: int = 10, +) -> List[EtlTask]: + query = session.query(EtlTask).filter( + EtlTask.organization_id == org_id, + ) + + return ( + query.order_by(EtlTask.created_at.desc()) + .limit(page_size) + .offset(max(0, (page - 1) * page_size)) + .all() + ) + + +def get_or_create_markdown_file_etl_task( + org_id: str, + file_reference: FileReference, + markdown_file: CognitionMarkdownFile, + markdown_dataset: CognitionMarkdownDataset, + extractor: str, + cache_config: Dict, + split_config: Dict, + transform_config: Dict, + load_config: Dict, + notify_config: Dict, + priority: Optional[int] = -1, + fallback_extractors: Optional[list[enums.ETLExtractorPDF]] = [], +) -> EtlTask: + if etl_task := ( + session.query(EtlTask).filter(EtlTask.id == markdown_file.etl_task_id).first() + ): + return etl_task + + file_type = enums.ETLFileType.from_string(markdown_file.category_origin) + extractor = enums.ETLExtractorPDF.from_string(extractor) + fallback_extractors = list( + filter( + lambda x: x != extractor, + (fallback_extractors or DEFAULT_FALLBACK_EXTRACTORS.get(file_type, [])), + ) + ) + + return create( + org_id=org_id, + user_id=markdown_file.created_by, + file_size_bytes=file_reference.file_size_bytes, + cache_config=cache_config, + extract_config={ + "file_type": file_type.value, + "extractor": extractor.value, + "fallback_extractors": [fe.value for fe in fallback_extractors], + "minio_path": file_reference.minio_path, + "original_file_name": file_reference.original_file_name, + }, + split_config=split_config, + transform_config=transform_config, + load_config=load_config, + notify_config=notify_config, + llm_config=markdown_dataset.llm_config, + tokenizer=markdown_dataset.tokenizer, + priority=priority, + ) + + +def get_or_create_integration_etl_task( + org_id: str, + integration: CognitionIntegration, + record: IntegrationSharepoint, + file_path: str, + extractor: Optional[str], + cache_config: Dict, + split_config: Dict, + transform_config: Dict, + load_config: Dict, + notify_config: Optional[Dict] = None, + priority: Optional[int] = -1, + fallback_extractors: Optional[list[enums.ETLExtractorPDF]] = [], +) -> EtlTask: + if etl_task := ( + session.query(EtlTask).filter(EtlTask.id == record.etl_task_id).first() + ): + return etl_task + + if record.extension.replace(".", "") == "FOLDER": + _file_type = "md" + file_size_bytes = 0 + else: + _file_type = record.extension.replace(".", "") + file_size_bytes = record.size + + file_type = enums.ETLFileType.from_string(_file_type) + extractor = extractor or DEFAULT_EXTRACTORS.get( + file_type, enums.ETLExtractorMD.FILESYSTEM + ) + + if fallback_extractors is None: + fallback_extractors = [] + else: + fallback_extractors = list( + filter( + lambda x: x != extractor, + (fallback_extractors or DEFAULT_FALLBACK_EXTRACTORS.get(file_type, [])), + ) + ) + + return create( + org_id=org_id, + user_id=integration.created_by, + file_path=file_path, + file_size_bytes=file_size_bytes, + cache_config=cache_config, + extract_config={ + "file_type": file_type.value, + "extractor": extractor.value, + "fallback_extractors": [fe.value for fe in fallback_extractors], + }, + split_config=split_config, + transform_config=transform_config, + load_config=load_config, + notify_config=notify_config, + llm_config=integration.llm_config, + tokenizer=integration.tokenizer, + priority=priority, + ) + + +def create( + org_id: str, + user_id: str, + file_size_bytes: int, + cache_config: Dict, + extract_config: Dict, + split_config: Dict, + transform_config: Dict, + load_config: Dict, + notify_config: Dict, + llm_config: Dict, + tokenizer: str, + priority: Optional[int] = -1, + file_path: Optional[str] = None, + id: Optional[str] = None, + with_commit: bool = True, +) -> EtlTask: + etl_task: EtlTask = EtlTask( + id=id, + organization_id=org_id, + created_by=user_id, + file_path=file_path, + file_size_bytes=file_size_bytes, + cache_config=cache_config, + extract_config=extract_config, + split_config=split_config, + transform_config=transform_config, + load_config=load_config, + notify_config=notify_config, + llm_config=llm_config, + tokenizer=tokenizer, + priority=priority, + ) + general.add(etl_task, with_commit) + + return etl_task + + +def update( + id: Optional[str] = None, + etl_task: Optional[EtlTask] = None, + updated_by: Optional[str] = None, + file_path: Optional[str] = None, + file_size_bytes: Optional[int] = None, + cache_config: Optional[Dict] = None, + extract_config: Optional[Dict] = None, + split_config: Optional[Dict] = None, + transform_config: Optional[Dict] = None, + load_config: Optional[Dict] = None, + notify_config: Optional[Dict] = None, + llm_config: Optional[Dict] = None, + started_at: Optional[datetime.datetime] = None, + finished_at: Optional[Union[str, datetime.datetime]] = None, + state: Optional[enums.CognitionMarkdownFileState] = None, + is_active: Optional[bool] = None, + priority: Optional[int] = None, + error_message: Optional[str] = None, + with_commit: bool = True, +) -> Optional[EtlTask]: + if not id and not etl_task: + return None + if id: + etl_task: EtlTask = get_by_id(id) + if not etl_task: + return None + + if updated_by is not None: + etl_task.updated_by = updated_by + if file_path is not None and etl_task.file_path is None: + etl_task.file_path = file_path + if file_size_bytes is not None and etl_task.file_size_bytes is None: + etl_task.file_size_bytes = file_size_bytes + if cache_config is not None: + etl_task.cache_config = cache_config + flag_modified(etl_task, "cache_config") + if extract_config is not None: + etl_task.extract_config = extract_config + flag_modified(etl_task, "extract_config") + if split_config is not None: + etl_task.split_config = split_config + flag_modified(etl_task, "split_config") + if transform_config is not None: + etl_task.transform_config = transform_config + flag_modified(etl_task, "transform_config") + if load_config is not None: + etl_task.load_config = load_config + flag_modified(etl_task, "load_config") + if notify_config is not None: + etl_task.notify_config = notify_config + flag_modified(etl_task, "notify_config") + if llm_config is not None: + etl_task.llm_config = llm_config + flag_modified(etl_task, "llm_config") + if started_at is not None: + etl_task.started_at = started_at + if finished_at is not None: + if finished_at == "NULL": + etl_task.finished_at = None + else: + etl_task.finished_at = finished_at + if state is not None: + etl_task.state = state.value + if is_active is not None: + etl_task.is_active = is_active + if priority is not None: + etl_task.priority = priority + if error_message is not None: + if error_message == "NULL": + etl_task.error_message = None + else: + etl_task.error_message = error_message + + general.add(etl_task, with_commit) + return etl_task + + +def execution_finished(id: str) -> bool: + if not get_by_id(id): + return True + return bool( + session.query(EtlTask) + .filter( + EtlTask.id == id, + EtlTask.state.in_(FINISHED_STATES), + ) + .first() + ) + + +def delete_many(ids: List[str], with_commit: bool = True) -> None: + # TODO: cascade delete cached files + ( + session.query(EtlTask) + .filter(EtlTask.id.in_(ids)) + .delete(synchronize_session=False) + ) + general.flush_or_commit(with_commit) diff --git a/integration_objects/manager.py b/integration_objects/manager.py index ea2a53fa..2bb302ff 100644 --- a/integration_objects/manager.py +++ b/integration_objects/manager.py @@ -1,12 +1,21 @@ -from typing import List, Optional, Dict, Union, Type, Any +from typing import List, Optional, Dict, Tuple, Union, Type, Any from datetime import datetime from sqlalchemy import func from sqlalchemy.orm.attributes import flag_modified +from ..enums import CognitionIntegrationType from ..business_objects import general from ..cognition_objects import integration as integration_db_bo +from ..global_objects import etl_task as etl_task_db_bo from ..session import session from .helper import get_supported_metadata_keys +from ..models import ( + IntegrationSharepoint, + IntegrationPdf, + IntegrationGithubIssue, + IntegrationGithubFile, + EtlTask, +) def get( @@ -30,6 +39,17 @@ def get_by_id( return session.query(IntegrationModel).filter(IntegrationModel.id == id).first() +def get_by_etl_task_id( + IntegrationModel: Type, + etl_task_id: str, +) -> object: + return ( + session.query(IntegrationModel) + .filter(IntegrationModel.etl_task_id == etl_task_id) + .first() + ) + + def get_by_running_id( IntegrationModel: Type, integration_id: str, @@ -61,17 +81,34 @@ def get_by_source( def get_all_by_integration_id( - IntegrationModel: Type, integration_id: str, -) -> List[object]: +) -> Tuple[List[object], Type]: + IntegrationModel = integration_model(integration_id) return ( - session.query(IntegrationModel) - .filter(IntegrationModel.integration_id == integration_id) - .order_by(IntegrationModel.created_at) - .all() + ( + session.query(IntegrationModel) + .filter(IntegrationModel.integration_id == integration_id) + .order_by(IntegrationModel.created_at) + .all() + ), + IntegrationModel, ) +def integration_model(integration_id: str) -> Type: + integration = integration_db_bo.get_by_id(integration_id) + if integration.type == CognitionIntegrationType.SHAREPOINT.value: + return IntegrationSharepoint + elif integration.type == CognitionIntegrationType.PDF.value: + return IntegrationPdf + elif integration.type == CognitionIntegrationType.GITHUB_FILE.value: + return IntegrationGithubFile + elif integration.type == CognitionIntegrationType.GITHUB_ISSUE.value: + return IntegrationGithubIssue + else: + raise ValueError(f"Unsupported integration type: {integration.type}") + + def get_all_by_project_id( IntegrationModel: Type, project_id: str, @@ -88,23 +125,38 @@ def get_all_by_project_id( def get_existing_integration_records( - IntegrationModel: Type, integration_id: str, by: str = "source", ) -> Dict[str, object]: # TODO(extension): make return type Dict[str, List[object]] # once an object_id can reference multiple different integration records - return { - getattr(record, by, record.source): record - for record in get_all_by_integration_id(IntegrationModel, integration_id) - } + records, _ = get_all_by_integration_id(integration_id) + return {getattr(record, by, record.source): record for record in records} + + +def get_active_integration_records( + integration_id: str, +) -> Dict[str, object]: + IntegrationModel = integration_model(integration_id) + return ( + session.query(IntegrationModel) + .join( + EtlTask, + IntegrationModel.etl_task_id == EtlTask.id, + ) + .filter( + IntegrationModel.integration_id == integration_id, + EtlTask.is_active == True, + ) + .all() + ) def get_running_ids( - IntegrationModel: Type, integration_id: str, by: str = "source", ) -> Dict[str, int]: + IntegrationModel = integration_model(integration_id) return dict( session.query( getattr(IntegrationModel, by, IntegrationModel.source), @@ -155,6 +207,7 @@ def update( running_id: Optional[int] = None, updated_at: Optional[datetime] = None, error_message: Optional[str] = None, + etl_task_id: Optional[str] = None, with_commit: bool = True, **metadata, ) -> Optional[object]: @@ -172,6 +225,8 @@ def update( integration_record.updated_at = updated_at if error_message is not None: integration_record.error_message = error_message + if etl_task_id is not None and integration_record.etl_task_id is None: + integration_record.etl_task_id = etl_task_id record_updated = False for key, value in metadata.items(): @@ -199,6 +254,9 @@ def delete_many( integration_records = session.query(IntegrationModel).filter( IntegrationModel.id.in_(ids) ) + etl_task_db_bo.delete_many( + ids=[record.etl_task_id for record in integration_records] + ) integration_records.delete(synchronize_session=False) general.flush_or_commit(with_commit) diff --git a/models.py b/models.py index 00e624ea..2553e72c 100644 --- a/models.py +++ b/models.py @@ -20,6 +20,7 @@ TokenSubject, UploadStates, UserRoles, + CognitionMarkdownFileState, ) from sqlalchemy import ( BigInteger, @@ -1562,7 +1563,14 @@ class CognitionMarkdownDataset(Base): class CognitionMarkdownFile(Base): __tablename__ = Tablenames.MARKDOWN_FILE.value - __table_args__ = {"schema": "cognition"} + __table_args__ = ( + UniqueConstraint( + "id", + "etl_task_id", + name=f"unique_{__tablename__}_etl_task_id", + ), + {"schema": "cognition"}, + ) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) organization_id = Column( UUID(as_uuid=True), @@ -1592,6 +1600,12 @@ class CognitionMarkdownFile(Base): is_reviewed = Column(Boolean, default=False) meta_data = Column(JSON) + etl_task_id = Column( + UUID(as_uuid=True), + ForeignKey(f"global.{Tablenames.ETL_TASK.value}.id", ondelete="CASCADE"), + index=True, + ) + class FileTransformationLLMLogs(Base): __tablename__ = Tablenames.FILE_TRANSFORMATION_LLM_LOGS.value @@ -2005,6 +2019,33 @@ class CognitionGroupMember(Base): created_at = Column(DateTime, default=sql.func.now()) +class ETLConfigPresets(Base): + __tablename__ = Tablenames.ETL_CONFIG_PRESET.value + __table_args__ = {"schema": "cognition"} + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.ORGANIZATION.value}.id", ondelete="CASCADE"), + index=True, + ) + project_id = Column( + UUID(as_uuid=True), + ForeignKey(f"cognition.{Tablenames.PROJECT.value}.id", ondelete="CASCADE"), + index=True, + nullable=True, # future proofing for organization-wide presets/etl page presets + ) + name = Column(String, unique=True) + description = Column(String) + created_at = Column(DateTime, default=sql.func.now()) + created_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + etl_config = Column(JSON) # full ETL config JSON schema for how to run the ETL + add_config = Column(JSON) # additional config for e.g. setting scope dict values + + # =========================== Global tables =========================== class GlobalWebsocketAccess(Base): # table to store prepared websocket configuration. @@ -2230,6 +2271,7 @@ class IntegrationGithubFile(Base): "integration_id", "running_id", "source", + "etl_task_id", name=f"unique_{__tablename__}_source", ), {"schema": "integration"}, @@ -2261,6 +2303,12 @@ class IntegrationGithubFile(Base): sha = Column(String) code_language = Column(String) + etl_task_id = Column( + UUID(as_uuid=True), + ForeignKey(f"global.{Tablenames.ETL_TASK.value}.id", ondelete="CASCADE"), + index=True, + ) + class IntegrationGithubIssue(Base): __tablename__ = Tablenames.INTEGRATION_GITHUB_ISSUE.value @@ -2269,6 +2317,7 @@ class IntegrationGithubIssue(Base): "integration_id", "running_id", "source", + "etl_task_id", name=f"unique_{__tablename__}_source", ), {"schema": "integration"}, @@ -2303,6 +2352,12 @@ class IntegrationGithubIssue(Base): milestone = Column(String) number = Column(Integer) + etl_task_id = Column( + UUID(as_uuid=True), + ForeignKey(f"global.{Tablenames.ETL_TASK.value}.id", ondelete="CASCADE"), + index=True, + ) + class IntegrationPdf(Base): __tablename__ = Tablenames.INTEGRATION_PDF.value @@ -2311,6 +2366,7 @@ class IntegrationPdf(Base): "integration_id", "running_id", "source", + "etl_task_id", name=f"unique_{__tablename__}_source", ), {"schema": "integration"}, @@ -2343,6 +2399,12 @@ class IntegrationPdf(Base): total_pages = Column(Integer) title = Column(String) + etl_task_id = Column( + UUID(as_uuid=True), + ForeignKey(f"global.{Tablenames.ETL_TASK.value}.id", ondelete="CASCADE"), + index=True, + ) + class IntegrationSharepoint(Base): __tablename__ = Tablenames.INTEGRATION_SHAREPOINT.value @@ -2351,6 +2413,7 @@ class IntegrationSharepoint(Base): "integration_id", "running_id", "source", + "etl_task_id", name=f"unique_{__tablename__}_source", ), {"schema": "integration"}, @@ -2394,6 +2457,12 @@ class IntegrationSharepoint(Base): permissions = Column(JSON) file_properties = Column(JSON) + etl_task_id = Column( + UUID(as_uuid=True), + ForeignKey(f"global.{Tablenames.ETL_TASK.value}.id", ondelete="CASCADE"), + index=True, + ) + class IntegrationSharepointPropertySync(Base): __tablename__ = Tablenames.INTEGRATION_SHAREPOINT_PROPERTY_SYNC.value @@ -2512,3 +2581,46 @@ class TimedExecutions(Base): __table_args__ = {"schema": "global"} time_key = Column(String, unique=True, primary_key=True) # enums.TimedExecutionKey last_executed_at = Column(DateTime) + + +class EtlTask(Base): + __tablename__ = Tablenames.ETL_TASK.value + __table_args__ = {"schema": "global"} + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + organization_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.ORGANIZATION.value}.id", ondelete="CASCADE"), + index=True, + ) + created_at = Column(DateTime, default=sql.func.now()) + created_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + file_path = Column(String) + file_size_bytes = Column(BigInteger) + tokenizer = Column(String) + + cache_config = Column( + JSON + ) # {"use_file_cache": true, "use_extraction_cache": false, "use_transformation_cache": true} + extract_config = Column(JSON) # schema depends on the file type + split_config = Column(JSON) # {"chunk": true, "shrink": false} + transform_config = Column( + JSON + ) # {"summarize": true, "cleanse": true, "text_to_table": true} + load_config = Column(JSON) # {"refinery_project": false, "markdown_file": true} + notify_config = Column( + JSON + ) # {"http": {"url": "http://cognition-gateway:80/etl/complete/{task_id}", "method": "POST"}} + llm_config = Column(JSON) + + started_at = Column(DateTime) + finished_at = Column(DateTime) + state = Column( + String, default=CognitionMarkdownFileState.QUEUE.value + ) # of type enums.CognitionMarkdownFileState + is_active = Column(Boolean, default=False) + priority = Column(Integer, default=0) + error_message = Column(String)