diff --git a/.env.example b/.env.example index 54c752a..3432b0e 100644 --- a/.env.example +++ b/.env.example @@ -1,32 +1,65 @@ # External API keys -OPENAI_API_KEY= -ANTHROPIC_API_KEY= -COHERE_API_KEY= -TOGETHER_API_KEY= -HYPERBOLIC_API_KEY= -CEREBRAS_API_KEY= -SAMBANOVA_API_KEY= -DEEPSEEK_API_KEY= -CUSTOM_API_KEY= +OWL_ANTHROPIC_API_KEY= +OWL_AZURE_API_KEY= +OWL_AZURE_AI_API_KEY= +OWL_BEDROCK_API_KEY= +OWL_CEREBRAS_API_KEY= +OWL_COHERE_API_KEY= +OWL_DEEPSEEK_API_KEY= +OWL_ELLM_API_KEY= +OWL_GEMINI_API_KEY= +OWL_GROQ_API_KEY= +OWL_HYPERBOLIC_API_KEY= +OWL_JINA_AI_API_KEY= +OWL_OPENAI_API_KEY= +OWL_OPENROUTER_API_KEY= +OWL_SAGEMAKER_API_KEY= +OWL_SAMBANOVA_API_KEY= +OWL_TOGETHER_AI_API_KEY= +OWL_VERTEX_AI_API_KEY= +OWL_VOYAGE_API_KEY= +OWL_STRIPE_API_KEY= +OWL_STRIPE_PUBLISHABLE_KEY_LIVE= +OWL_STRIPE_PUBLISHABLE_KEY_TEST= +OWL_STRIPE_WEBHOOK_SECRET_LIVE= +OWL_STRIPE_WEBHOOK_SECRET_TEST= +OWL_AUTH0_API_KEY= -# Service URLs -DOCIO_URL=http://docio:6979/api/docio -UNSTRUCTUREDIO_URL=http://unstructuredio:6989 -JAMAI_API_BASE=http://owl:6969/api +# CI +JAMAI_TOKEN= +JAMAI_API_BASE=http://localhost:6969/api -# Frontend config -JAMAI_URL=http://owl:6969 -PUBLIC_JAMAI_URL= -PUBLIC_IS_SPA=false -CHECK_ORIGIN=false +# Service connection (dev) +# OWL_DB_PATH=postgresql+psycopg://owlpguser:owlpgpassword@localhost:5432/jamaibase_owl +# OWL_CLICKHOUSE_HOST=localhost +# OWL_CLICKHOUSE_PORT=8123 +# OWL_OPENTELEMETRY_HOST=localhost +# OWL_OPENTELEMETRY_PORT=4317 +# OWL_REDIS_HOST=localhost +# OWL_REDIS_PORT=6379 +# OWL_VICTORIA_METRICS_HOST=localhost +# OWL_VICTORIA_LOGS_HOST=localhost +# OWL_CODE_EXECUTOR_ENDPOINT=http://localhost:5569 +# OWL_DOCLING_URL=http://localhost:5001 +# OWL_TEST_LLM_API_BASE=http://localhost:6970/v1 +# OWL_S3_ENDPOINT=http://localhost:9000 +# OWL_FILE_PROXY_URL=website-url # Configuration OWL_PORT=6969 OWL_WORKERS=3 -DOCIO_WORKERS=1 -DOCIO_DEVICE=cpu -EMBEDDING_MODEL=BAAI/bge-small-en-v1.5 -RERANKER_MODEL=mixedbread-ai/mxbai-rerank-xsmall-v1 OWL_CONCURRENT_ROWS_BATCH_SIZE=5 OWL_CONCURRENT_COLS_BATCH_SIZE=5 OWL_MAX_WRITE_BATCH_SIZE=1000 +PB_MAX_CLIENT_CONN=500 +PB_MAX_CLIENT_CONN=80 +PG_MAX_CONNECTIONS=100 + +# Frontend config +HOST=localhost +ORIGIN=http://localhost:4000 +AUTH_SECRET="changeme" +OWL_URL=http://owl:6969 +PUBLIC_JAMAI_URL= +PUBLIC_IS_SPA=false +CHECK_ORIGIN=false diff --git a/.gitattributes b/.gitattributes index 31b0a04..9c014d8 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,12 +33,18 @@ *.db binary *.doc binary *.docx binary +*.gif binary *.gz binary +*.heic* binary +*.heif* binary *.jar binary *.jpeg binary *.jpg binary +*.mov binary +*.mp* binary *.npy binary *.npz binary +*.parquet binary *.pcd binary *.pdf binary *.pkl binary @@ -47,13 +53,16 @@ *.pptx binary *.pth binary *.so binary +*.ttf binary +*.webp binary *.xls binary *.xlsx binary *.zip binary +# Track with LFS +# *.pth filter=lfs diff=lfs merge=lfs -text +# *.parquet filter=lfs diff=lfs merge=lfs -text + # These files should not be processed by Linguist for language detection on GitHub.com *.p linguist-detectable=false *.gz linguist-detectable=false - -# Track with Git LFS -*.parquet filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b69078a..ddcb2b2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,13 +1,15 @@ -name: CI (OSS) +name: CI on: pull_request: branches: - main + - legacy-lancedb push: branches: - main + - legacy-lancedb tags: - "v*" @@ -39,7 +41,8 @@ jobs: if: ${{ !(needs.check_changes.outputs.has-changes == 'true' || github.event_name == 'push') }} strategy: matrix: - python-version: ["3.10"] + jamai-mode: ["oss"] + test-group: [group1, group2, group3, group4] timeout-minutes: 2 steps: - name: No-op @@ -52,23 +55,28 @@ jobs: if: needs.check_changes.outputs.has-changes == 'true' || github.event_name == 'push' strategy: matrix: - python-version: ["3.10"] - timeout-minutes: 60 + jamai-mode: ["oss"] + test-group: [group1, group2, group3, group4] + timeout-minutes: 90 steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: lfs: true + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + - name: Set up Python uses: actions/setup-python@v5 with: - python-version: ${{ matrix.python-version }} + python-version: "3.12" - name: Inspect git version - run: | - git --version + run: git --version - name: Check Docker Version run: docker version @@ -76,26 +84,34 @@ jobs: - name: Check Docker Compose Version run: docker compose version - - name: Remove cloud-only modules and install Python client + - name: Remove cloud-only modules + if: matrix.jamai-mode == 'oss' + run: bash scripts/remove_cloud_modules.sh + + - name: Inspect directory tree + run: tree + + - name: Install jamaibase & owl run: | - set -e - bash scripts/remove_cloud_modules.sh - cd clients/python - python -m pip install .[test] + pushd clients/python + uv pip install --system -e .[test] + popd + pushd services/api + uv pip install --system -e .[test] - - name: Install ffmpeg + - name: Inspect jamaibase environment run: | - set -e - sudo apt-get update -qq && sudo apt-get install ffmpeg libavcodec-extra -y + uv pip list - - name: Authenticating to the Container registry - run: echo $JH_PAT | docker login ghcr.io -u tanjiahuei@gmail.com --password-stdin - env: - JH_PAT: ${{ secrets.JH_PAT }} + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} - name: Edit env file run: | - set -e mv .env.example .env ORGS=$(printenv | grep API_KEY | xargs -I {} echo {} | cut -d '=' -f 1) @@ -111,125 +127,391 @@ jobs: # Replace the org with the key in the .env file sed -i "s/$org=.*/$org=$key/g" .env done - sed -i "s:EMBEDDING_MODEL=.*:EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2:g" .env - sed -i "s:RERANKER_MODEL=.*:RERANKER_MODEL=cross-encoder/ms-marco-TinyBERT-L-2:g" .env - echo 'OWL_MODELS_CONFIG=models_ci.json' >> .env + echo "OWL_DB_INIT=False" >> .env + echo "OWL_COMPUTE_STORAGE_PERIOD_SEC=15" >> .env + echo "OWL_STRIPE_WEBHOOK_SECRET_TEST=${OWL_STRIPE_WEBHOOK_SECRET_TEST}" >> .env + echo "OWL_STRIPE_PUBLISHABLE_KEY_TEST=${OWL_STRIPE_PUBLISHABLE_KEY_TEST}" >> .env + echo 'OWL_SERVICE_KEY=lalala' >> .env + echo 'JAMAI_TOKEN=lalala' >> .env + echo 'JAMAI_API_BASE=http://localhost:6969/api' >> .env + echo 'OWL_FLUSH_CLICKHOUSE_BUFFER_SEC=5' >> .env env: - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} - TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }} - COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} - HYPERBOLIC_API_KEY: ${{ secrets.HYPERBOLIC_API_KEY }} - CUSTOM_API_KEY: ${{ secrets.CUSTOM_API_KEY }} - - - name: Launch services (OSS) - id: launch_oss + OWL_ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + OWL_COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} + OWL_HYPERBOLIC_API_KEY: ${{ secrets.HYPERBOLIC_API_KEY }} + OWL_JINA_AI_API_KEY: ${{ secrets.JINA_AI_API_KEY }} + OWL_OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + OWL_TOGETHER_AI_API_KEY: ${{ secrets.TOGETHER_AI_API_KEY }} + OWL_ELLM_API_KEY: ${{ secrets.CUSTOM_API_KEY }} + OWL_STRIPE_WEBHOOK_SECRET_TEST: ${{ secrets.OWL_STRIPE_WEBHOOK_SECRET_TEST }} + OWL_STRIPE_PUBLISHABLE_KEY_TEST: ${{ secrets.OWL_STRIPE_PUBLISHABLE_KEY_TEST }} + + - name: Launch services + id: launch_services timeout-minutes: 20 - run: | - set -e - docker compose -p jamai -f docker/compose.cpu.yml --profile minio --profile kopi up --quiet-pull -d --wait + if: always() + run: docker compose -p jm -f docker/compose.ci.yml up --quiet-pull -d --wait env: COMPOSE_DOCKER_CLI_BUILD: 1 DOCKER_BUILDKIT: 1 - - name: Inspect owl Python version - run: docker exec jamai-owl-1 python -V + - name: Inspect owl logs if failed to launch + timeout-minutes: 1 + if: failure() && steps.launch_services.outcome == 'failure' + run: docker compose -p jm -f docker/compose.ci.yml logs owl + + - name: Inspect owl UV and Python version + run: | + docker exec jm-owl-1 uv -V + docker exec jm-owl-1 $(docker exec jm-owl-1 uv python find) -V - name: Inspect owl environment - run: docker exec jamai-owl-1 pip list - - - name: Python SDK tests (OSS) - id: python_sdk_test_oss - if: always() && steps.launch_oss.outcome == 'success' - run: | - set -e - export JAMAI_API_BASE=http://localhost:6969/api - python -m pytest -vv \ - --timeout 300 \ - --doctest-modules \ - --junitxml=junit/test-results-${{ matrix.python-version }}.xml \ - --cov-report=xml \ - --no-flaky-report \ - clients/python/tests/oss/ - - - name: Inspect owl logs if Python SDK tests failed - if: failure() && steps.python_sdk_test_oss.outcome == 'failure' - timeout-minutes: 1 - run: docker exec jamai-owl-1 cat /app/api/logs/owl.log + if: always() + run: docker exec jm-owl-1 uv pip list + + - name: Copy OpenAPI JSON + id: copy_openapi + if: always() && matrix.test-group == 'group1' + run: | + curl localhost:6969/api/public/openapi.json > openapi.json - - name: Upload Pytest Test Results + - name: Generate OpenAPI Redoc HTML page + id: generate_redoc_html + if: always() && matrix.test-group == 'group1' && steps.copy_openapi.outcome == 'success' + run: | + npx @redocly/cli@latest build-docs openapi.json + mkdir openapi + mv redoc-static.html openapi + mv openapi.json openapi + + - name: Upload Redoc HTML + id: upload_redoc_html uses: actions/upload-artifact@v4 + if: always() && matrix.test-group == 'group1' && steps.generate_redoc_html.outcome == 'success' with: - name: pytest-results-${{ matrix.python-version }} - path: junit/test-results-${{ matrix.python-version }}.xml - # Always run this step to publish test results even when there are test failures - if: always() + name: redoc-html-${{ matrix.jamai-mode }} + path: openapi + + - name: Publish Redoc HTML link as PR comment + uses: thollander/actions-comment-pull-request@v3 + if: always() && matrix.test-group == 'group1' && github.event_name == 'pull_request' && steps.upload_redoc_html.outcome == 'success' + with: + message: | + [Link to OpenAPI Redoc HTML (${{ matrix.jamai-mode }})](${{ steps.upload_redoc_html.outputs.artifact-url }}) + comment-tag: redoc_html_comment_${{ matrix.jamai-mode }} - - name: TS/JS SDK tests (OSS) - id: ts_sdk_test_oss - if: always() && steps.launch_oss.outcome == 'success' + - name: Python SDK tests + id: python_sdk_test + if: always() && steps.launch_services.outcome == 'success' run: | - cd clients/typescript - echo "BASEURL=http://localhost:6969" >> __tests__/.env - npm install - npm run test + cp .env services/api/.env + cd services/api + + if [ "${{ matrix.test-group }}" = "group1" ]; then + DIRS=(tests --ignore=tests/gen_table/test_row_ops.py --ignore=tests/gen_table/test_row_ops_v2.py --ignore=tests/routers --ignore=tests/utils) + elif [ "${{ matrix.test-group }}" = "group2" ]; then + DIRS=(tests/gen_table/test_row_ops.py) + elif [ "${{ matrix.test-group }}" = "group3" ]; then + DIRS=(tests/gen_table/test_row_ops_v2.py tests/utils) + else + DIRS=(tests/routers) + fi + + coverage run --data-file=coverage/.coverage.${{ matrix.test-group }} --rcfile=pyproject.toml -m \ + pytest \ + --timeout 300 \ + --no-flaky-report \ + --junitxml=pytest_regular.xml \ + -m "not (${{ matrix.jamai-mode == 'cloud' && 'oss' || 'cloud' }} or stripe)" \ + "${DIRS[@]}" + env: + OWL_DB_PATH: postgresql+psycopg://owlpguser:owlpgpassword@localhost:5432/jamaibase_owl + OWL_CLICKHOUSE_HOST: localhost + OWL_REDIS_HOST: localhost - - name: Inspect owl logs if TS/JS SDK tests failed - if: failure() && steps.ts_sdk_test_oss.outcome == 'failure' + - name: Inspect owl logs + if: always() && steps.launch_services.outcome == 'success' timeout-minutes: 1 - run: docker exec jamai-owl-1 cat /app/api/logs/owl.log - - - name: Update owl service for S3 test - run: | - # Update the .env file to include the new environment variable - echo 'OWL_FILE_DIR=s3://file' >> .env - echo 'S3_ENDPOINT=http://minio:9000' >> .env - echo 'S3_ACCESS_KEY_ID=minioadmin' >> .env - echo 'S3_SECRET_ACCESS_KEY=minioadmin' >> .env - - # Restart the owl service with the updated environment - docker compose -p jamai -f docker/compose.cpu.yml up --quiet-pull -d --wait --no-deps --build --force-recreate owl - - - name: Python SDK tests (File API, OSS) - id: python_sdk_test_oss_file - if: always() && steps.launch_oss.outcome == 'success' - run: | - set -e - export JAMAI_API_BASE=http://localhost:6969/api - python -m pytest -vv \ - --timeout 300 \ - --doctest-modules \ - --junitxml=junit/test-results-${{ matrix.python-version }}.xml \ - --cov-report=xml \ - --no-flaky-report \ - clients/python/tests/oss/test_file.py - - lance_tests: - name: Lance tests + run: mkdir -p logs && docker compose -p jm -f docker/compose.ci.yml logs owl > logs/owl.log + + - name: Inspect starling logs + if: always() && steps.launch_services.outcome == 'success' + timeout-minutes: 1 + run: mkdir -p logs && docker compose -p jm -f docker/compose.ci.yml logs starling > logs/starling.log + + - name: Test Stripe integration (Cloud only) + id: test_stripe + if: matrix.jamai-mode == 'cloud' && matrix.test-group == 'group1' && steps.launch_services.outcome == 'success' + run: | + # Shut down owl to allow coverage data to be flushed + docker compose -p jm -f docker/compose.ci.yml down + # Copy Pytest coverage data + sudo cp -r docker_data docker_data_tmp + sudo rm -rf docker_data + + # Relaunch + echo "OWL_STRIPE_API_KEY=${OWL_STRIPE_API_KEY}" >> .env + docker compose -p jm -f docker/compose.ci.yml up --quiet-pull -d --wait --force-recreate + + # Install Stripe CLI + curl -L https://github.com/stripe/stripe-cli/releases/download/v1.27.0/stripe_1.27.0_linux_x86_64.tar.gz --output stripe.tar.gz + tar -xvf stripe.tar.gz + + # Listen for Stripe events and forward them to local endpoint + nohup ./stripe listen \ + --forward-to http://localhost:6969/api/v2/organizations/webhooks/stripe & + # --events customer.created,invoice.paid + # Give stripe listen a moment to establish the tunnel + sleep 5 + + # Run tests + pushd services/api + coverage run --data-file=coverage/.coverage.stripe --rcfile=pyproject.toml -m \ + pytest \ + --timeout 300 \ + --no-flaky-report \ + --junitxml=pytest_stripe.xml \ + -m stripe \ + tests + + # Move existing coverage data + popd + mv docker_data_tmp/owl/db/* docker_data/owl/db/. 2>/dev/null || true + env: + STRIPE_API_KEY: ${{ secrets.OWL_STRIPE_API_KEY }} + OWL_STRIPE_API_KEY: ${{ secrets.OWL_STRIPE_API_KEY }} + OWL_CLICKHOUSE_HOST: localhost + OWL_REDIS_HOST: localhost + + - name: Inspect owl logs if Stripe integration failed + timeout-minutes: 1 + if: failure() && matrix.jamai-mode == 'cloud' && steps.test_stripe.outcome == 'failure' + run: mkdir -p logs && docker compose -p jm --env-file .env -f docker/compose.ci.yml logs owl > logs/owl_stripe.log + + # - name: TS/JS SDK tests + # id: ts_sdk_test + # if: always() && steps.launch_services.outcome == 'success' + # run: | + # cd clients/typescript + # echo "BASEURL=http://localhost:6969" >> __tests__/.env + # npm install + # npm run test + + - name: Upload logs + id: upload_logs + uses: actions/upload-artifact@v4 + if: always() && steps.launch_services.outcome == 'success' + with: + name: logs-${{ matrix.jamai-mode }}-${{ matrix.test-group }} + path: logs + + - name: Publish logs link as PR comment + uses: thollander/actions-comment-pull-request@v3 + if: always() && github.event_name == 'pull_request' && steps.upload_logs.outcome == 'success' + with: + message: | + [Link to logs (${{ matrix.jamai-mode }}, ${{ matrix.test-group }})](${{ steps.upload_logs.outputs.artifact-url }}) + comment-tag: logs_comment_${{ matrix.jamai-mode }}-${{ matrix.test-group }} + + - name: Upload pytest coverage file + uses: actions/upload-artifact@v4 + if: always() && steps.python_sdk_test.outcome == 'success' + with: + name: pytest-coverage-data-${{ matrix.jamai-mode }}-${{ matrix.test-group }} + path: services/api/coverage + include-hidden-files: true + if-no-files-found: error + + - name: Merge JUnit XML and Coverage data + id: merge_test_data + if: always() && steps.launch_services.outcome == 'success' + run: | + # Shut down owl to allow coverage data to be flushed + docker compose -p jm --env-file .env -f docker/compose.ci.yml down + + # Combine coverage data + coverage combine --keep --data-file=services/api/coverage/.coverage --rcfile=services/api/pyproject.toml \ + docker_data/owl/db services/api/coverage + + # Merge JUnit XML files + mkdir -p services/api/junit_xml + junitparser merge --glob "services/api/pytest_*.xml" services/api/junit_xml/pytest-${{ matrix.jamai-mode }}-${{ matrix.test-group }}.xml + + - name: Upload coverage data + uses: actions/upload-artifact@v4 + if: always() && steps.merge_test_data.outcome == 'success' + with: + name: docker_data-coverage-data-${{ matrix.jamai-mode }}-${{ matrix.test-group }} + path: docker_data/owl/db + include-hidden-files: true + if-no-files-found: error + + - name: Upload JUnit XML file + uses: actions/upload-artifact@v4 + if: always() && steps.merge_test_data.outcome == 'success' + with: + name: junit-xml-data-${{ matrix.jamai-mode }}-${{ matrix.test-group }} + path: services/api/junit_xml + + - name: Log coverage data files + run: | + find docker_data/owl/db -type f | head -50 + find services/api/coverage -type f | head -50 + find services/api/junit_xml -type f | head -50 + + - name: Generate coverage reports + id: generate_coverage_report + if: always() && steps.merge_test_data.outcome == 'success' + run: | + cd services/api + coverage html --data-file=coverage/.coverage -d coverage/html + coverage xml --data-file=coverage/.coverage -o coverage/coverage.xml + coverage report --data-file=coverage/.coverage + + - name: Upload coverage HTML report + id: upload_coverage_html + uses: actions/upload-artifact@v4 + if: always() && steps.merge_test_data.outcome == 'success' + with: + name: pytest-coverage-${{ matrix.jamai-mode }}-${{ matrix.test-group }} + path: services/api/coverage/html + + merge_coverage: + name: Merge Coverage Reports runs-on: ubuntu-latest + needs: sdk_tests + if: always() && (needs.sdk_tests.result == 'success' || needs.sdk_tests.result == 'skipped') strategy: matrix: - python-version: ["3.12"] - timeout-minutes: 60 - + jamai-mode: ["oss"] steps: + - name: Skip coverage merge (no tests ran) + if: needs.sdk_tests.result == 'skipped' + run: echo "Skipping coverage merge - no tests were executed (no changes detected)" + - name: Checkout code + if: needs.sdk_tests.result == 'success' uses: actions/checkout@v4 + with: + lfs: true + + - name: Install uv + if: needs.sdk_tests.result == 'success' + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true - name: Set up Python + if: needs.sdk_tests.result == 'success' uses: actions/setup-python@v5 with: - python-version: ${{ matrix.python-version }} + python-version: "3.12" - - name: Inspect git version + - name: Install jamaibase & owl + if: needs.sdk_tests.result == 'success' run: | - git --version + pushd services/api + uv pip install --system -e .[test] + + - name: Download pytest coverage data artifacts + uses: actions/download-artifact@v4 + if: needs.sdk_tests.result == 'success' + with: + pattern: pytest-coverage-data-* + path: ./ - - name: Install owl + - name: Download coverage data artifacts + uses: actions/download-artifact@v4 + if: needs.sdk_tests.result == 'success' + with: + pattern: docker_data-coverage-data-* + path: ./ + + - name: Download junit xml artifacts + uses: actions/download-artifact@v4 + if: needs.sdk_tests.result == 'success' + with: + pattern: junit-xml-data-* + path: junit-xml-data + + - name: Log coverage data files + if: needs.sdk_tests.result == 'success' + run: | + find docker_data-coverage-data-* -type f | head -50 + find pytest-coverage-data-* -type f | head -50 + find junit-xml-data -type f | head -50 + + - name: Merge JUnit XML and Coverage data (${{ matrix.jamai-mode }}) + id: merge_coverage + if: needs.sdk_tests.result == 'success' + run: | + coverage combine --keep --data-file=services/api/coverage/.coverage --rcfile=services/api/pyproject.toml \ + ./docker_data-coverage-data-${{ matrix.jamai-mode }}-group[1-4] \ + ./pytest-coverage-data-${{ matrix.jamai-mode }}-group[1-4] + + # Merge JUnit XML files + junitparser merge --glob "junit-xml-data/junit-xml-data-${{ matrix.jamai-mode }}-*/pytest_*.xml" \ + junit-xml-data/pytest-${{ matrix.jamai-mode }}.xml + + - name: Generate coverage report + if: always() && steps.merge_coverage.outcome == 'success' run: | - set -e cd services/api - python -m pip install .[test] + coverage xml --data-file=coverage/.coverage -o coverage/coverage-${{ matrix.jamai-mode }}.xml + coverage report --data-file=coverage/.coverage - - name: Run tests - run: pytest services/api/tests/test_lance.py + - name: Pytest coverage comment + uses: MishaKav/pytest-coverage-comment@main + if: always() && github.event_name == 'pull_request' && steps.merge_coverage.outcome == 'success' + with: + title: Coverage Report (${{ matrix.jamai-mode }}) + pytest-xml-coverage-path: services/api/coverage/coverage-${{ matrix.jamai-mode }}.xml + junitxml-path: junit-xml-data/pytest-${{ matrix.jamai-mode }}.xml + unique-id-for-comment: coverage_report_comment_${{ matrix.jamai-mode }} + report-only-changed-files: true + + - name: Merge All JUnit XML and Coverage data + id: merge_all_test_data + if: needs.sdk_tests.result == 'success' && matrix.jamai-mode == 'oss' + run: | + coverage combine --keep --data-file=services/api/coverage/.coverage --rcfile=services/api/pyproject.toml \ + ./docker_data-coverage-data-oss-group[1-4] \ + ./pytest-coverage-data-oss-group[1-4] + + # Merge JUnit XML files + junitparser merge --glob "junit-xml-data/junit-xml-data-*/pytest_*.xml" junit-xml-data/pytest.xml + + - name: Generate coverage reports + id: generate_coverage_report + if: always() && steps.merge_all_test_data.outcome == 'success' + run: | + cd services/api + coverage html --data-file=coverage/.coverage -d coverage/html + coverage xml --data-file=coverage/.coverage -o coverage/coverage.xml + coverage report --data-file=coverage/.coverage + + - name: Pytest coverage comment + uses: MishaKav/pytest-coverage-comment@main + if: always() && github.event_name == 'pull_request' && steps.merge_all_test_data.outcome == 'success' + with: + title: Coverage Report (all) + pytest-xml-coverage-path: services/api/coverage/coverage.xml + junitxml-path: junit-xml-data/pytest.xml + unique-id-for-comment: coverage_report_comment_all + report-only-changed-files: true + + - name: Upload all coverage HTML report + id: upload_all_coverage_html + uses: actions/upload-artifact@v4 + if: always() && steps.merge_all_test_data.outcome == 'success' + with: + name: pytest-coverage-all + path: services/api/coverage/html + + - name: Publish coverage HTML report link as PR comment + uses: thollander/actions-comment-pull-request@v3 + if: always() && github.event_name == 'pull_request' && steps.upload_all_coverage_html.outcome == 'success' + with: + message: | + [Link to coverage HTML report (All)](${{ steps.upload_all_coverage_html.outputs.artifact-url }}) + comment-tag: coverage_html_report_comment_all diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 85a0754..144cad8 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -24,13 +24,11 @@ jobs: - name: Install linting libraries run: | - set -e cd clients/python python3 -m pip install .[lint] - name: Check Python files using Ruff run: | - set -e ruff check --output-format github --config clients/python/pyproject.toml . ruff format --diff --config clients/python/pyproject.toml . @@ -43,11 +41,9 @@ jobs: uses: actions/checkout@v4 - name: Setup Node.js - uses: actions/setup-node@v4 + uses: actions/setup-node@v5 with: - node-version: "16" # Specify the Node.js version you want to use + node-version: 24 - name: Check files using Prettier - run: | - npm install -g prettier@3.3.2 - prettier --check . + run: npx prettier@3.3.2 --check . diff --git a/.gitignore b/.gitignore index bbb9d49..c598157 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,18 @@ +# OS +thumbs.db +.DS_Store + +# Internal references, dependencies, temporary folders & files +.env +**/__ref__/ +*.log +*.lock +*.db +*.parquet + +# Docker +/docker_data/ + # Python __pycache__/ *.py[cod] @@ -6,38 +21,19 @@ __pycache__/ .pytest_cache .ipynb_checkpoints venv/ -*.npy -*.geojson -*.laz -*.db -*.parquet -# Internal references, dependencies, temporary folders & files -/db*/ -file/ -/infinity_cache/ -**/__ref__/ -/dependencies/ -logs/ -*.log -/datasets/ -/milvus_data/ -/vespa*/ -*.swp -.env -*.lock +# pip +**/build/ # pytest-cov +**/coverage.xml **/.coverage* /junit /htmlcov /coverage.xml -# jest-cov -**/coverage/* - -# pip -**/build/ +# ruff +.ruff_cache/ # OS thumbs.db diff --git a/.prettierrc b/.prettierrc index 74240b5..95e419e 100644 --- a/.prettierrc +++ b/.prettierrc @@ -1,4 +1,5 @@ { "printWidth": 150, - "proseWrap": "never" + "proseWrap": "never", + "trailingComma": "all" } diff --git a/CHANGELOG.md b/CHANGELOG.md index a2cc62a..52ea78c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,90 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] +### ADDED + +API + +- Conversation API for JamAI Chat: Chat with agents pre-configured by your organisation admins. +- Table row update endpoint can now update multiple rows in a single call. +- Table row list endpoint `/v1/gen_tables/{table_type}/{table_id}/rows` now accepts: + - `order_by` string parameter that specifies the column to sort rows by. + - `where` string parameter that defines an SQL where clause. Defaults to "" (no filter). + - `search_columns` string parameter to restrict the columns that are searched by `search_query`. + +### CHANGED (BREAKING) + +Python Client + +- Some `JamAI` and `JamAIAsync` client methods are deprecated and/or removed. + +API + +- We unified our DB to migrate from LanceDB + SQLite to Postgres only. For OSS users, please use the provided migration script to migrate your data. +- Changed endpoint `/v1/model_names` -> `/v1/models/ids` +- Renamed external keys: + - `jina_api_key` -> `jina_ai_api_key` + - `together_api_key` -> `together_ai_api_key` +- Added `OWL_` prefix to API server environment variables. +- Major changes to internal APIs which affect OSS users + - Endpoints are now `/v2` + - Most of the path params are converted into query params + - Input and output schemas may have changed. + - List endpoints param changed from `order_descending` with a default of True to `order_ascending` with a default of True. + +### CHANGED + +Python Client + +- `JamAI` is now a wrapper around `JamAIAsync`. +- `JamaiException` classes are now subclasses of `Exception` rather than `RuntimeError`. +- Deprecated `jamaibase.protocol`, use `jamaibase.types` instead. +- Types / protocol: + - Deprecated `AdminOrderBy` enum; use strings instead. + - Deprecated `GenTableOrderBy` enum; use strings instead. + - Deprecated `ModelInfoResponse`; use `ModelInfoListResponse` instead. + - Deprecated `MessageToolCallFunction`; use `ToolCallFunction` instead. + - Deprecated `MessageToolCall`; use `ToolCall` instead. + - Deprecated `ChatCompletionChoiceDelta`; use `ChatCompletionChoice` instead. + - Deprecated `CompletionUsage`; use `ChatCompletionUsage` instead. + - Deprecated `ChatCompletionChunk`; use `ChatCompletionChunkResponse` instead. + - Deprecated `ChatCompletionChoiceOutput`; use `ChatCompletionMessage` instead. + - Deprecated `ChatThread`; use `ChatThreadResponse` instead. + - Deprecated `ChatRequestWithTools`; use `ChatRequest` instead. + - Deprecated `GenTableStreamReferences`; use `CellReferencesResponse` instead. + - Deprecated `GenTableStreamChatCompletionChunk`; use `CellCompletionResponse` instead. + - Deprecated `GenTableChatCompletionChunks`; use `RowCompletionResponse` instead. + - Deprecated `GenTableRowsChatCompletionChunks`; use `MultiRowCompletionResponse` instead. + - Deprecated `RowAddRequest`; use `MultiRowAddRequest` instead. + - Deprecated `RowAddRequestWithLimit`; use `MultiRowAddRequestWithLimit` instead. + - Deprecated `RowRegenRequest`; use `MultiRowRegenRequest` instead. + - Deprecated `RowDeleteRequest`; use `MultiRowDeleteRequest` instead. + - All `reindex` parameters are removed. Reindexing now happens immediately. + +API + +- Improvements: + - All extra Knowledge Table columns are now injected into prompt + - RAG references are now stored alongside model response +- Fixed: + - Streaming responses from table row add endpoint now returns a final chunk with usage data. + - You can now set system prompt and prompt of Generative Tables to empty strings `""` without them being replaced with default prompts. + +### REMOVED + +API + +- Hybrid search endpoint: + - Removed parameters: `where`, `nprobes`, `refine_factor` + +### DEPRECATED + +API + +- Generative table endpoints: + - `order_descending` in table and row lists endpoints is deprecated and replaced with `order_ascending` with a default of True. + - Single row delete and update endpoints are deprecated for their multi-row counterparts. + ## [v0.4.1] (2025-02-26) ### CHANGED / FIXED @@ -39,6 +123,9 @@ Python SDK - jamaibase TS SDK - jamaibase +- Add `CodeGenConfigSchema` for code execution #446 +- Support audio data type + UI - Support chat mode multiturn option in add column and column resize #451 @@ -73,6 +160,7 @@ Python SDK - jamaibase TS SDK - jamaibase - Update the `uploadFile` method in `index.ts` to remove the trailing slash from the API endpoint #462 +- Update client and node enviroment conflict in file upload UI diff --git a/MIGRATION_GUIDE.md b/MIGRATION_GUIDE.md index 046611a..1bccfd1 100644 --- a/MIGRATION_GUIDE.md +++ b/MIGRATION_GUIDE.md @@ -1,69 +1 @@ # Migration Guide - -## [v0.4.0] - -This guide provides instructions to perform a database migration that adds a `version` column and `object` attribute to all gen_config in all action tables. (Migration from owl version earlier than v0.4.0). - -### Prerequisites - -1. Ensure **owl/jamaibase** has been updated to at least **v0.4.0**. - -### Steps to Perform the Migration - -1. Navigate to the **JamAIBase** repository directory (with `./db` and `./scripts` in it). - - ```bash - cd - ``` - -2. Run the migration script (ensure the current Python environment is the one with **owl** installed): - ```bash - python scripts/migration_v040.py - ``` - -### Expected Output - -- The script will print messages indicating whether the `file` column was renamed to `image` column. -- The script will print messages indicating whether the `Page` column was added to knowledge table or if it already exist. -- If any errors occur, they will be printed to the console. - -### Troubleshooting - -- Ensure that the migration script is run in the **JamAIBase** repository directory (`./db` and `./scripts` directories should be in this working directory). -- Ensure the Python environment is the one with **owl** installed. -- Check the script's error messages for any issues encountered during the migration process. -- Contact us for further assistance. - -## [v0.3.0] - -This guide provides instructions to perform a database migration that adds a `version` column and `object` attribute to all gen_config in all action tables. (Migration from owl version earlier than v0.3.0). - -### Prerequisites - -1. Ensure **owl/jamaibase** has been updated to at least **v0.3.0**. - -### Steps to Perform the Migration - -1. Navigate to the **JamAIBase** repository directory (with `./db` and `./scripts` in it). - - ```bash - cd - ``` - -2. Run the migration script (ensure the current Python environment is the one with **owl** installed): - ```bash - python scripts/migration_v030.py - ``` - -### Expected Output - -- The script will print messages indicating whether the `version` column was added or if it already exists in each database. -- The script will print messages indicating whether the `object` attribute was added into each `gen_config`. -- If any errors occur, they will be printed to the console. - -### Troubleshooting - -- Ensure that the migration script is run in the **JamAIBase** repository directory (`./db` and `./scripts` directories should be in this working directory). -- Ensure the Python environment is the one with **owl** installed. -- Check the script's error messages for any issues encountered during the migration process. -- Contact us for further assistance. diff --git a/clients/python/README.md b/clients/python/README.md index 8e1d122..88f4e39 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -12,10 +12,10 @@ The recommended way of using JamAI Base is via Cloud 🚀. Did we mention that y 1. First, [sign up for a free account on JamAI Base Cloud!](https://cloud.jamaibase.com/) 2. Create a project and give it any name that you want. -3. Create a Python (>= 3.10) environment and install `jamaibase` (here we use [micromamba](https://mamba.readthedocs.io/en/latest/user_guide/micromamba.html) but you can use other tools such as [conda](https://conda.io/projects/conda/en/latest/user-guide/getting-started.html), virtualenv, etc): +3. Create a Python (>= 3.11) environment and install `jamaibase` (here we use [micromamba](https://mamba.readthedocs.io/en/latest/user_guide/micromamba.html) but you can use other tools such as [conda](https://conda.io/projects/conda/en/latest/user-guide/getting-started.html), virtualenv, etc): ```shell - $ micromamba create -n jam310 python=3.10 -y + $ micromamba create -n jam310 python=3.11 -y $ micromamba activate jam310 $ pip install jamaibase ``` @@ -26,7 +26,7 @@ The recommended way of using JamAI Base is via Cloud 🚀. Did we mention that y - Project ID can be obtained by browsing to any of your projects. ```python - from jamaibase import JamAI, protocol as p + from jamaibase import JamAI, types as t jamai = JamAI(token="your_pat", project_id="your_project_id") ``` @@ -34,7 +34,7 @@ The recommended way of using JamAI Base is via Cloud 🚀. Did we mention that y Async is supported too: ```python - from jamaibase import JamAIAsync, protocol as p + from jamaibase import JamAIAsync, types as t jamai = JamAIAsync(token="your_pat", project_id="your_project_id") ``` @@ -54,14 +54,7 @@ The recommended way of using JamAI Base is via Cloud 🚀. Did we mention that y - `.env` specifies which model to run on the `infinity` service for locally-hosted embedding and reranking models. - `.env` also specifies all the third party API keys to be used. - For OSS mode, in order for you to see and use the other third party models such as OpenAI, you need to provide your own OpenAI API key in `.env` file. You can add one or more providers: - - ``` - OPENAI_API_KEY=... - ANTHROPIC_API_KEY=... - COHERE_API_KEY=... - TOGETHER_API_KEY=... - ``` + For OSS mode, in order for you to see and use the other third party models such as OpenAI, you need to provide your own OpenAI API key in `.env` file (refer to `.env.example` file). 3. Launch the Docker containers by running one of these: @@ -74,7 +67,7 @@ The recommended way of using JamAI Base is via Cloud 🚀. Did we mention that y ``` - By default, frontend and backend are accessible at ports 4000 and 6969. - - You can change the ports exposed to host by setting env var in `.env` or shell like so `API_PORT=6970 FRONTEND_PORT=4001 docker compose -f docker/compose.cpu.yml up --quiet-pull -d` + - You can change the ports exposed to host by setting env var in `.env` or shell like so `API_PORT=6968 FRONTEND_PORT=4001 docker compose -f docker/compose.cpu.yml up --quiet-pull -d` 4. Try the command below in your terminal, or open your browser and go to `localhost:4000`. @@ -89,7 +82,7 @@ The recommended way of using JamAI Base is via Cloud 🚀. Did we mention that y - `api_base` should point to the exposed port of `owl` service. ```python - from jamaibase import JamAI, protocol as p + from jamaibase import JamAI, types as t jamai = JamAI(api_base="http://localhost:6969/api") ``` @@ -97,7 +90,7 @@ The recommended way of using JamAI Base is via Cloud 🚀. Did we mention that y Async is supported too: ```python - from jamaibase import JamAIAsync, protocol as p + from jamaibase import JamAIAsync, types as t jamai = JamAIAsync(api_base="http://localhost:6969/api") ``` @@ -157,16 +150,16 @@ Let's start with creating simple tables. Create a table by defining a schema. ```python # Create an Action Table table = jamai.table.create_action_table( - p.ActionTableSchemaCreate( + t.ActionTableSchemaCreate( id="action-simple", cols=[ - p.ColumnSchemaCreate(id="image", dtype="image"), # Image input - p.ColumnSchemaCreate(id="length", dtype="int"), # Integer input - p.ColumnSchemaCreate(id="question", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="image", dtype="image"), # Image input + t.ColumnSchemaCreate(id="length", dtype="int"), # Integer input + t.ColumnSchemaCreate(id="question", dtype="str"), + t.ColumnSchemaCreate( id="answer", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="openai/gpt-4o-mini", # Leave this out to use a default model system_prompt="You are a concise assistant.", prompt="Image: ${image}\n\nQuestion: ${question}\n\nAnswer the question in ${length} words.", @@ -182,7 +175,7 @@ print(table) # Create a Knowledge Table table = jamai.table.create_knowledge_table( - p.KnowledgeTableSchemaCreate( + t.KnowledgeTableSchemaCreate( id="knowledge-simple", cols=[], embedding_model="ellm/BAAI/bge-m3", @@ -192,14 +185,14 @@ print(table) # Create a Chat Table table = jamai.table.create_chat_table( - p.ChatTableSchemaCreate( + t.ChatTableSchemaCreate( id="chat-simple", cols=[ - p.ColumnSchemaCreate(id="User", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="User", dtype="str"), + t.ColumnSchemaCreate( id="AI", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="openai/gpt-4o-mini", # Leave this out to use a default model system_prompt="You are a pirate.", temperature=0.001, @@ -228,7 +221,7 @@ text_c = "Identify the subject of the image." # Streaming completion = jamai.table.add_table_rows( "action", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="action-simple", data=[dict(length=5, question=text_a)], stream=True, @@ -243,7 +236,7 @@ print("") # Non-streaming completion = jamai.table.add_table_rows( "action", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="action-simple", data=[dict(length=5, question=text_b)], stream=False, @@ -255,7 +248,7 @@ print(completion.rows[0].columns["answer"].text) upload_response = jamai.file.upload_file("clients/python/tests/files/jpeg/rabbit.jpeg") completion = jamai.table.add_table_rows( "action", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="action-simple", data=[dict(image=upload_response.uri, length=5, question=text_c)], stream=True, @@ -270,7 +263,7 @@ print("") # Non-streaming (with image input) completion = jamai.table.add_table_rows( "action", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="action-simple", data=[dict(image=upload_response.uri, length=5, question=text_c)], stream=False, @@ -286,7 +279,7 @@ Next let's try adding to Chat Table: # Streaming completion = jamai.table.add_table_rows( "chat", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="chat-simple", data=[dict(User="Who directed Arrival (2016)?")], stream=True, @@ -301,7 +294,7 @@ print("") # Non-streaming completion = jamai.table.add_table_rows( "chat", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="chat-simple", data=[dict(User="Who directed Dune (2024)?")], stream=False, @@ -324,7 +317,7 @@ Finally we can add rows to Knowledge Table too: # Streaming completion = jamai.table.add_table_rows( "knowledge", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="knowledge-simple", data=[dict(Title="Arrival (2016)", Text=text_a)], stream=True, @@ -335,7 +328,7 @@ assert len(list(completion)) == 0 # Non-streaming completion = jamai.table.add_table_rows( "knowledge", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="knowledge-simple", data=[dict(Title="Dune (2024)", Text=text_b)], stream=False, @@ -407,18 +400,18 @@ with TemporaryDirectory() as tmp_dir: # Create an Action Table with RAG table = jamai.table.create_action_table( - p.ActionTableSchemaCreate( + t.ActionTableSchemaCreate( id="action-rag", cols=[ - p.ColumnSchemaCreate(id="question", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="question", dtype="str"), + t.ColumnSchemaCreate( id="answer", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="openai/gpt-4o-mini", # Leave this out to use a default model system_prompt="You are a concise assistant.", prompt="${question}", - rag_params=p.RAGParams( + rag_params=t.RAGParams( table_id="knowledge-simple", k=2, ), @@ -435,7 +428,7 @@ print(table) # Ask a question with streaming completion = jamai.table.add_table_rows( "action", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="action-rag", data=[dict(question="Where did I go in 2018?")], stream=True, @@ -444,7 +437,7 @@ completion = jamai.table.add_table_rows( for chunk in completion: if chunk.output_column_name != "answer": continue - if isinstance(chunk, p.GenTableStreamReferences): + if isinstance(chunk, t.CellReferencesResponse): # References that are retrieved from KT assert len(chunk.chunks) == 2 # k = 2 print(chunk.chunks) @@ -493,7 +486,7 @@ Now that you know how to add rows into tables, let's see how to delete them inst rows = jamai.table.list_table_rows("action", "action-simple") response = jamai.table.delete_table_rows( "action", - p.RowDeleteRequest( + t.MultiRowDeleteRequest( table_id="action-simple", row_ids=[row["ID"] for row in rows.items], ), @@ -548,22 +541,22 @@ The full script is as follows: ```python from jamaibase import JamAI -from jamaibase import protocol as p +from jamaibase import types as t def create_tables(jamai: JamAI): # Create an Action Table table = jamai.table.create_action_table( - p.ActionTableSchemaCreate( + t.ActionTableSchemaCreate( id="action-simple", cols=[ - p.ColumnSchemaCreate(id="image", dtype="image"), # Image input - p.ColumnSchemaCreate(id="length", dtype="int"), # Integer input - p.ColumnSchemaCreate(id="question", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="image", dtype="image"), # Image input + t.ColumnSchemaCreate(id="length", dtype="int"), # Integer input + t.ColumnSchemaCreate(id="question", dtype="str"), + t.ColumnSchemaCreate( id="answer", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="openai/gpt-4o-mini", # Leave this out to use a default model system_prompt="You are a concise assistant.", prompt="Image: ${image}\n\nQuestion: ${question}\n\nAnswer the question in ${length} words.", @@ -579,7 +572,7 @@ def create_tables(jamai: JamAI): # Create a Knowledge Table table = jamai.table.create_knowledge_table( - p.KnowledgeTableSchemaCreate( + t.KnowledgeTableSchemaCreate( id="knowledge-simple", cols=[], embedding_model="ellm/BAAI/bge-m3", @@ -589,14 +582,14 @@ def create_tables(jamai: JamAI): # Create a Chat Table table = jamai.table.create_chat_table( - p.ChatTableSchemaCreate( + t.ChatTableSchemaCreate( id="chat-simple", cols=[ - p.ColumnSchemaCreate(id="User", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="User", dtype="str"), + t.ColumnSchemaCreate( id="AI", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="openai/gpt-4o-mini", # Leave this out to use a default model system_prompt="You are a pirate.", temperature=0.001, @@ -619,7 +612,7 @@ def add_rows(jamai: JamAI): # Streaming completion = jamai.table.add_table_rows( "action", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="action-simple", data=[dict(length=5, question=text_a)], stream=True, @@ -634,7 +627,7 @@ def add_rows(jamai: JamAI): # Non-streaming completion = jamai.table.add_table_rows( "action", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="action-simple", data=[dict(length=5, question=text_b)], stream=False, @@ -646,7 +639,7 @@ def add_rows(jamai: JamAI): upload_response = jamai.file.upload_file("clients/python/tests/files/jpeg/rabbit.jpeg") completion = jamai.table.add_table_rows( "action", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="action-simple", data=[dict(image=upload_response.uri, length=5, question=text_c)], stream=True, @@ -661,7 +654,7 @@ def add_rows(jamai: JamAI): # Non-streaming (with image input) completion = jamai.table.add_table_rows( "action", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="action-simple", data=[dict(image=upload_response.uri, length=5, question=text_c)], stream=False, @@ -673,7 +666,7 @@ def add_rows(jamai: JamAI): # Streaming completion = jamai.table.add_table_rows( "chat", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="chat-simple", data=[dict(User="Who directed Arrival (2016)?")], stream=True, @@ -688,7 +681,7 @@ def add_rows(jamai: JamAI): # Non-streaming completion = jamai.table.add_table_rows( "chat", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="chat-simple", data=[dict(User="Who directed Dune (2024)?")], stream=False, @@ -700,7 +693,7 @@ def add_rows(jamai: JamAI): # Streaming completion = jamai.table.add_table_rows( "knowledge", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="knowledge-simple", data=[dict(Title="Arrival (2016)", Text=text_a)], stream=True, @@ -711,7 +704,7 @@ def add_rows(jamai: JamAI): # Non-streaming completion = jamai.table.add_table_rows( "knowledge", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="knowledge-simple", data=[dict(Title="Dune (2024)", Text=text_b)], stream=False, @@ -776,18 +769,18 @@ def rag(jamai: JamAI): # Create an Action Table with RAG table = jamai.table.create_action_table( - p.ActionTableSchemaCreate( + t.ActionTableSchemaCreate( id="action-rag", cols=[ - p.ColumnSchemaCreate(id="question", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="question", dtype="str"), + t.ColumnSchemaCreate( id="answer", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="openai/gpt-4o-mini", # Leave this out to use a default model system_prompt="You are a concise assistant.", prompt="${question}", - rag_params=p.RAGParams( + rag_params=t.RAGParams( table_id="knowledge-simple", k=2, ), @@ -804,7 +797,7 @@ def rag(jamai: JamAI): # Ask a question with streaming completion = jamai.table.add_table_rows( "action", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="action-rag", data=[dict(question="Where did I went in 2018?")], stream=True, @@ -813,7 +806,7 @@ def rag(jamai: JamAI): for chunk in completion: if chunk.output_column_name != "answer": continue - if isinstance(chunk, p.GenTableStreamReferences): + if isinstance(chunk, t.CellReferencesResponse): # References that are retrieved from KT assert len(chunk.chunks) == 2 # k = 2 print(chunk.chunks) @@ -854,7 +847,7 @@ def delete_rows(jamai: JamAI): rows = jamai.table.list_table_rows("action", "action-simple") response = jamai.table.delete_table_rows( "action", - p.RowDeleteRequest( + t.MultiRowDeleteRequest( table_id="action-simple", row_ids=[row["ID"] for row in rows.items], ), @@ -971,11 +964,11 @@ Generate chat completions using various models. Supports streaming and non-strea ```python # Streaming -request = p.ChatRequest( +request = t.ChatRequest( model="openai/gpt-4o-mini", messages=[ - p.ChatEntry.system("You are a concise assistant."), - p.ChatEntry.user("What is a llama?"), + t.ChatEntry.system("You are a concise assistant."), + t.ChatEntry.user("What is a llama?"), ], temperature=0.001, top_p=0.001, @@ -988,11 +981,11 @@ for chunk in completion: print("") # Non-streaming -request = p.ChatRequest( +request = t.ChatRequest( model="openai/gpt-4o-mini", messages=[ - p.ChatEntry.system("You are a concise assistant."), - p.ChatEntry.user("What is a llama?"), + t.ChatEntry.system("You are a concise assistant."), + t.ChatEntry.user("What is a llama?"), ], temperature=0.001, top_p=0.001, @@ -1010,7 +1003,7 @@ Generate embeddings for given input text. ```python texts = ["What is love?", "What is a llama?"] embeddings = jamai.generate_embeddings( - p.EmbeddingRequest( + t.EmbeddingRequest( model="ellm/BAAI/bge-m3", input=texts, ) @@ -1038,7 +1031,7 @@ print(f"Model: {model.id} Context length: {model.context_length}") # Get specific model info models = jamai.model_info(name="openai/gpt-4o") print(models.data[0]) -# id='openai/gpt-4o' object='model' name='OpenAI GPT-4' context_length=128000 languages=['en', 'cn'] capabilities=['chat'] owned_by='openai' +# id='openai/gpt-4o' object='model' name='OpenAI GPT-4' context_length=128000 languages=['en', 'cn'] capabilities=['chat'] owned_by=None # Filter based on capability: "chat", "embed", "rerank" models = jamai.model_info(capabilities=["chat"]) @@ -1094,14 +1087,14 @@ st.title("Simple chat") try: # Create a Chat Table jamai.table.create_chat_table( - p.ChatTableSchemaCreate( + t.ChatTableSchemaCreate( id="chat-simple", cols=[ - p.ColumnSchemaCreate(id="User", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="User", dtype="str"), + t.ColumnSchemaCreate( id="AI", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="openai/gpt-4o-mini", # Leave this out to use a default model system_prompt="You are a pirate.", temperature=0.001, @@ -1128,7 +1121,7 @@ for message in st.session_state.messages: def response_generator(_prompt): completion = jamai.table.add_table_rows( "chat", - p.RowAddRequest( + t.MultiRowAddRequest( table_id="chat-simple", data=[dict(User=_prompt)], stream=True, diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 64b37f3..c902ed0 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -5,11 +5,11 @@ # https://docs.pytest.org/en/latest/customize.html?highlight=pyproject#pyproject-toml [tool.pytest.ini_options] -timeout = 90 +timeout = 120 log_cli = true asyncio_mode = "auto" # log_cli_level = "DEBUG" -addopts = "--cov=jamaibase --doctest-modules" +addopts = "--doctest-modules -vv -ra --strict-markers --no-flaky-report" testpaths = ["tests"] filterwarnings = [ "ignore::DeprecationWarning:tensorflow.*", @@ -17,6 +17,20 @@ filterwarnings = [ "ignore::DeprecationWarning:matplotlib.*", "ignore::DeprecationWarning:flatbuffers.*", ] +markers = ["oss: Cloud-only tests", "cloud: Cloud-only tests"] + +# ----------------------------------------------------------------------------- +# Coverage configuration +# https://coverage.readthedocs.io/en + +[tool.coverage.run] +source = ["owl"] +relative_files = true +concurrency = ["multiprocessing", "thread", "greenlet"] +parallel = true + +# [tool.coverage.paths] +# source = ["services/api/src", "src"] # ----------------------------------------------------------------------------- # Ruff configuration @@ -25,7 +39,7 @@ filterwarnings = [ [tool.ruff] line-length = 99 indent-width = 4 -target-version = "py310" +target-version = "py312" extend-include = [".pyi?$", ".ipynb"] extend-exclude = ["archive/*"] respect-gitignore = true @@ -57,7 +71,7 @@ unfixable = ["B"] "**/{tests,docs,tools}/*" = ["E402"] [tool.ruff.lint.isort] -known-first-party = ["jamaibase", "owl", "docio"] +known-first-party = ["jamaibase", "owl"] [tool.ruff.lint.flake8-bugbear] # Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. @@ -74,16 +88,16 @@ extend-immutable-calls = [ # https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html [build-system] -requires = ["setuptools>=61.0", "setuptools-scm"] +requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "jamaibase" description = "JamAI Base: Let Your Database Orchestrate LLMs and RAG" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" # keywords = ["one", "two"] -license = { text = "Apache 2.0" } +license = "Apache-2.0" classifiers = [ # https://pypi.org/classifiers/ "Development Status :: 3 - Alpha", "Programming Language :: Python :: 3 :: Only", @@ -91,40 +105,40 @@ classifiers = [ # https://pypi.org/classifiers/ "Operating System :: Unix", ] # Sort your dependencies https://sortmylist.com/ +# In general, for v1 and above, we pin to minor version using ~= dependencies = [ - "filetype~=1.2.0", + "filetype~=1.2", "httpx>=0.25.0", "loguru>=0.7.2", - "numpy>=1.26.0,<2.0.0", - "orjson>=3.9.7", + "natsort[fast]~=8.4", + "numpy>=1.26.0", + "orjson~=3.9", "pandas", - "Pillow>=10.0.1", - "pydantic-settings>=2.0.3", - "pydantic>=2.4.2", - "srsly>=2.4.8", - "toml>=0.10.2", - "typing_extensions>=4.10.0", + "Pillow>=10.0", + "pydantic-extra-types~=2.9", + "pydantic-settings~=2.4", + "pydantic[email,timezone]~=2.10", + "pyyaml~=6.0", + "toml~=0.10.2", + "typing_extensions~=4.10", + "uuid-utils~=0.9", + "uuid7~=0.1", ] dynamic = ["version"] [project.optional-dependencies] -lint = ["ruff~=0.5.7"] +lint = ["ruff~=0.12.9"] test = [ "flaky~=3.8.1", + "locust~=2.39.1", "mypy~=1.11.1", - "pydub~=0.25.1", "pytest-asyncio>=0.23.8", "pytest-cov~=5.0.0", "pytest-timeout>=2.3.1", "pytest~=8.2.2", -] -docs = [ - "furo~=2024.8.6", # Sphinx theme (nice looking, with dark mode) - "myst-parser~=4.0.0", - "sphinx-autobuild~=2024.4.16", "sphinx-copybutton~=0.5.2", "sphinx>=7.0.0", - "sphinx_rtd_theme~=2.0.0", # Sphinx theme + "sphinx_rtd_theme~=2.0.0", # Sphinx theme ] build = [ "build", @@ -137,6 +151,12 @@ all = [ # [project.scripts] # jamaibase = "jamaibase.scripts.example:main_cli" +[project.urls] +Homepage = "https://www.jamaibase.com/" +Documentation = "https://docs.jamaibase.com/" +Repository = "https://github.com/EmbeddedLLM/JamAIBase" +Changelog = "https://github.com/EmbeddedLLM/JamAIBase/blob/main/CHANGELOG.md" + [tool.setuptools.dynamic] version = { attr = "jamaibase.version.__version__" } diff --git a/clients/python/src/jamaibase/client.py b/clients/python/src/jamaibase/client.py index d5df736..c5edd35 100644 --- a/clients/python/src/jamaibase/client.py +++ b/clients/python/src/jamaibase/client.py @@ -1,79 +1,112 @@ import platform -from mimetypes import guess_type -from os.path import split +import warnings +from contextlib import contextmanager +from datetime import datetime +from os.path import basename, split +from time import perf_counter from typing import Any, AsyncGenerator, BinaryIO, Generator, Literal, Type from urllib.parse import quote from warnings import warn -import filetype import httpx +import orjson +from loguru import logger from pydantic import BaseModel, SecretStr from pydantic_settings import BaseSettings, SettingsConfigDict from typing_extensions import deprecated -from jamaibase.exceptions import ResourceNotFoundError -from jamaibase.protocol import ( +from jamaibase.types import ( ActionTableSchemaCreate, AddActionColumnSchema, AddChatColumnSchema, AddKnowledgeColumnSchema, - AdminOrderBy, - ApiKeyCreate, - ApiKeyRead, - ChatCompletionChunk, + AgentMetaResponse, + CellCompletionResponse, + CellReferencesResponse, + ChatCompletionChunkResponse, + ChatCompletionResponse, ChatRequest, ChatTableSchemaCreate, - ChatThread, + ChatThreadResponse, + ChatThreadsResponse, ColumnDropRequest, ColumnRenameRequest, ColumnReorderRequest, + ConversationCreateRequest, + ConversationMetaResponse, + ConversationThreadsResponse, + DeploymentCreate, + DeploymentRead, + DeploymentUpdate, EmbeddingRequest, EmbeddingResponse, - EventCreate, - EventRead, - FileUploadRequest, FileUploadResponse, GenConfigUpdateRequest, - GenTableOrderBy, - GenTableRowsChatCompletionChunks, - GenTableStreamChatCompletionChunk, - GenTableStreamReferences, GetURLRequest, GetURLResponse, KnowledgeTableSchemaCreate, - ModelInfoResponse, - ModelListConfig, + MessageAddRequest, + MessagesRegenRequest, + MessageUpdateRequest, + ModelConfigCreate, + ModelConfigRead, + ModelConfigUpdate, + ModelInfoListResponse, ModelPrice, + MultiRowAddRequest, + MultiRowCompletionResponse, + MultiRowDeleteRequest, + MultiRowRegenRequest, + MultiRowUpdateRequest, OkResponse, OrganizationCreate, OrganizationRead, OrganizationUpdate, - OrgMemberCreate, OrgMemberRead, Page, - PATCreate, - PATRead, - Price, + PasswordChangeRequest, + PasswordLoginRequest, + PricePlanCreate, + PricePlanRead, + PricePlanUpdate, + ProgressState, ProjectCreate, + ProjectKeyCreate, + ProjectKeyRead, + ProjectKeyUpdate, + ProjectMemberRead, ProjectRead, ProjectUpdate, References, - RowAddRequest, - RowDeleteRequest, - RowRegenRequest, + RerankingRequest, + RerankingResponse, + Role, RowUpdateRequest, SearchRequest, - StringResponse, + StripePaymentInfo, TableDataImportRequest, TableImportRequest, TableMetaResponse, - TableType, - Template, + UsageResponse, UserCreate, UserRead, UserUpdate, + VerificationCodeRead, ) -from jamaibase.utils.io import json_loads +from jamaibase.utils import uuid7_str +from jamaibase.utils.background_loop import LOOP +from jamaibase.utils.exceptions import ( + AuthorizationError, + BadInputError, + ForbiddenError, + JamaiException, + RateLimitExceedError, + ResourceExistsError, + ResourceNotFoundError, + ServerBusyError, + UnexpectedError, +) +from jamaibase.utils.io import guess_mime, json_loads from jamaibase.version import __version__ USER_AGENT = f"SDK/{__version__} (Python/{platform.python_version()}; {platform.system()} {platform.release()}; {platform.machine()})" @@ -82,54 +115,64 @@ class EnvConfig(BaseSettings): - model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore") - jamai_token: SecretStr = "" - jamai_api_key: SecretStr = "" - jamai_api_base: str = "https://api.jamaibase.com/api" - jamai_project_id: str = "default" - jamai_timeout_sec: float = 5 * 60.0 - jamai_file_upload_timeout_sec: float = 60 * 60.0 + model_config = SettingsConfigDict( + env_prefix="jamai_", + env_file=".env", + env_file_encoding="utf-8", + extra="ignore", + ) + token: SecretStr = "" + api_base: str = "https://api.jamaibase.com/api" + project_id: str = "default" + timeout_sec: float = 60.0 * 5 # Default to 5 minutes + file_upload_timeout_sec: float = 60.0 * 15 # Default to 15 minutes @property - def jamai_token_plain(self): - api_key = self.jamai_api_key.get_secret_value().strip() - return self.jamai_token.get_secret_value().strip() or api_key + def token_plain(self): + return self.token.get_secret_value().strip() ENV_CONFIG = EnvConfig() GenTableChatResponseType = ( - GenTableRowsChatCompletionChunks - | Generator[GenTableStreamReferences | GenTableStreamChatCompletionChunk, None, None] + MultiRowCompletionResponse + | Generator[CellReferencesResponse | CellCompletionResponse, None, None] ) -class _Client: +class _ClientAsync: def __init__( self, + user_id: str, project_id: str, token: str, api_base: str, headers: dict | None, http_client: httpx.Client | httpx.AsyncClient, - file_upload_timeout: float | None, + timeout: float | None, + file_upload_timeout: float | None = None, ) -> None: """ Base client. Args: - project_id (str): The project ID. + user_id (str): User ID. + project_id (str): Project ID. token (str): Personal Access Token or organization API key (deprecated) for authentication. api_base (str): The base URL for the API. headers (dict | None): Additional headers to include in requests. http_client (httpx.Client | httpx.AsyncClient): The HTTPX client. - file_upload_timeout (float | None, optional): The timeout to use when sending file upload requests. """ if api_base.endswith("/"): api_base = api_base[:-1] + self.user_id = user_id self.project_id = project_id self.token = token self.api_base = api_base - self.headers = {"X-PROJECT-ID": project_id, "User-Agent": USER_AGENT} + self.headers = { + "X-USER-ID": user_id, + "X-PROJECT-ID": project_id, + "User-Agent": USER_AGENT, + } if token != "": self.headers["Authorization"] = f"Bearer {token}" if headers is not None: @@ -137,20 +180,64 @@ def __init__( raise TypeError("`headers` must be None or a dict.") self.headers.update(headers) self.http_client = http_client + self.timeout = timeout self.file_upload_timeout = file_upload_timeout - @property - def api_key(self) -> str: - return self.token + async def close(self) -> None: + """ + Close the HTTP async client. + """ + await self.http_client.aclose() + + @staticmethod + def _filter_params(params: dict[str, Any] | BaseModel | None) -> dict[str, Any] | None: + """ + Filter out None values from query parameters dictionary or Pydantic model. + + Args: + params (dict[str, Any] | BaseModel | None): Query parameters dictionary or Pydantic model. - def close(self) -> None: + Returns: + params (dict[str, Any] | None): Filtered query parameters dictionary. + """ + if isinstance(params, BaseModel): + params = params.model_dump() + if params is not None: + params = {k: v for k, v in params.items() if v is not None} + return params + + @staticmethod + def _process_body( + body: dict[str, Any] | BaseModel | None, + **kwargs, + ) -> dict[str, Any] | None: """ - Close the HTTP client. + Create a dictionary from request body. + + Args: + body (dict[str, Any] | BaseModel | None): JSON body dictionary or Pydantic model. + **kwargs: Keyword arguments to be pass into `model_dump`. + + Returns: + params (dict[str, Any] | None): JSON body dictionary. """ - self.http_client.close() + if body is not None: + body = body if isinstance(body, dict) else body.model_dump(mode="json", **kwargs) + return body + + @contextmanager + def _log_call(self): + request_id = uuid7_str() + self.headers["X-REQUEST-ID"] = request_id + try: + yield + except JamaiException: + raise + except Exception as e: + raise JamaiException(f"Request {request_id} failed. {repr(e)}") from e @staticmethod - def raise_exception( + async def _raise_exception( response: httpx.Response, *, ignore_code: int | None = None, @@ -176,2232 +263,2202 @@ def raise_exception( try: error = response.text except httpx.ResponseNotRead: - error = response.read().decode() - error = json_loads(error) - err_mssg = error.get("message", error.get("detail", str(error))) - if code == 404: - exc = ResourceNotFoundError + error = (await response.aread()).decode() + try: + error = json_loads(error) + err_mssg = error.get("message", error.get("detail", str(error))) + except Exception: + err_mssg = error + request_id = response.headers.get("x-request-id", "") + err_mssg = f"Request {request_id} failed. {err_mssg}" + if code == 401: + exc_class = AuthorizationError + elif code == 403: + exc_class = ForbiddenError + elif code == 404: + exc_class = ResourceNotFoundError + elif code == 409: + exc_class = ResourceExistsError + elif code == 422: + exc_class = BadInputError + elif code == 429: + _headers = response.headers + used = _headers.get("x-ratelimit-used", None) + retry_after = _headers.get("retry-after", None) + meta = _headers.get("x-ratelimit-meta", None) + raise RateLimitExceedError( + err_mssg, + limit=int(_headers.get("x-ratelimit-limit", 0)), + remaining=int(_headers.get("x-ratelimit-remaining", 0)), + reset_at=int(_headers.get("x-ratelimit-reset", 0)), + used=None if used is None else int(used), + retry_after=None if retry_after is None else int(retry_after), + meta=None if meta is None else orjson.loads(meta), + ) + elif code == 500: + exc_class = UnexpectedError + elif code == 503: + exc_class = ServerBusyError else: - exc = RuntimeError - raise exc(err_mssg) + exc_class = JamaiException + raise exc_class(err_mssg) - @staticmethod - def _filter_params(params: dict[str, Any] | None): + async def _request( + self, + method: str, + address: str, + endpoint: str, + *, + headers: dict[str, Any] | None = None, + params: dict[str, Any] | BaseModel | None = None, + body: BaseModel | dict[str, Any] | None = None, + response_model: Type[BaseModel] | None = None, + timeout: float | None = None, + ignore_code: int | None = None, + process_body_kwargs: dict[str, Any] | None = None, + **kwargs, + ) -> httpx.Response | BaseModel: """ - Filter out None values from the parameters dictionary. + Make an asynchronous request to the specified endpoint. Args: - params (dict[str, Any] | None): The parameters dictionary. + method (str): The HTTP method to use (e.g., "GET", "POST"). + address (str): The base address of the API. + endpoint (str): The API endpoint. + headers (dict[str, Any] | None, optional): Headers to include in the request. Defaults to None. + params (dict[str, Any] | None, optional): Query parameters. Defaults to None. + body (BaseModel | dict[str, Any] | None, optional): The body to send in the request. Defaults to None. + response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. Defaults to None. + timeout (float | None, optional): Timeout for the request. Defaults to None. + ignore_code (int | None, optional): HTTP error code to ignore. + process_body_kwargs (dict[str, Any] | None, optional): Additional keyword arguments for processing the body. + **kwargs (Any): Keyword arguments for `httpx.request`. Returns: - params (dict[str, Any] | None): The filtered parameters dictionary. + response (httpx.Response | BaseModel): The response text or Pydantic response object. """ - if params is not None: - params = {k: v for k, v in params.items() if v is not None} - return params + with self._log_call(): + if process_body_kwargs is None: + process_body_kwargs = {} + response = await self.http_client.request( + method, + f"{address}{endpoint}", + headers=headers, + params=self._filter_params(params), + json=self._process_body(body, **process_body_kwargs), + timeout=timeout or self.timeout, + **kwargs, + ) + response = await self._raise_exception(response, ignore_code=ignore_code) + if response_model is None: + return response + try: + return response_model.model_validate_json(response.text) + except Exception as e: + raise JamaiException( + f"Failed to parse response (code={response.status_code}): {response.text}" + ) from e - def _get( + async def _get( self, - address: str, endpoint: str, *, - params: dict[str, Any] | None = None, + params: dict[str, Any] | BaseModel | None = None, response_model: Type[BaseModel] | None = None, + timeout: float | None = None, **kwargs, ) -> httpx.Response | BaseModel: """ - Make a GET request to the specified endpoint. + Make an asynchronous GET request to the specified endpoint. Args: - address (str): The base address of the API. endpoint (str): The API endpoint. params (dict[str, Any] | None, optional): Query parameters. Defaults to None. - response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. - **kwargs (Any): Keyword arguments for `httpx.get`. + response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. Defaults to None. + timeout (float | None, optional): Timeout for the request. Defaults to None. + **kwargs (Any): Keyword arguments for `httpx.request`. Returns: response (httpx.Response | BaseModel): The response text or Pydantic response object. """ - response = self.http_client.get( - f"{address}{endpoint}", - params=self._filter_params(params), + return await self._request( + "GET", + self.api_base, + endpoint, headers=self.headers, + params=params, + body=None, + response_model=response_model, + timeout=timeout, **kwargs, ) - response = self.raise_exception(response) - if response_model is None: - return response - else: - return response_model.model_validate_json(response.text) - def _post( + async def _post( self, - address: str, endpoint: str, *, - body: BaseModel | None, - params: dict[str, Any] | None = None, + params: dict[str, Any] | BaseModel | None = None, + body: BaseModel | None = None, response_model: Type[BaseModel] | None = None, + timeout: float | None = None, **kwargs, ) -> httpx.Response | BaseModel: """ - Make a POST request to the specified endpoint. + Make an asynchronous POST request to the specified endpoint. Args: - address (str): The base address of the API. endpoint (str): The API endpoint. - body (BaseModel | None): The request body. - response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. Defaults to None. params (dict[str, Any] | None, optional): Query parameters. Defaults to None. - **kwargs (Any): Keyword arguments for `httpx.post`. + body (BaseModel | None, optional): The body to send in the request. Defaults to None. + response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. Defaults to None. + timeout (float | None, optional): Timeout for the request. Defaults to None. + **kwargs (Any): Keyword arguments for `httpx.request`. Returns: response (httpx.Response | BaseModel): The response text or Pydantic response object. """ - if body is not None: - body = body.model_dump() - response = self.http_client.post( - f"{address}{endpoint}", - json=body, + return await self._request( + "POST", + self.api_base, + endpoint, headers=self.headers, - params=self._filter_params(params), + params=params, + body=body, + response_model=response_model, + timeout=timeout, **kwargs, ) - response = self.raise_exception(response) - if response_model is None: - return response - else: - return response_model.model_validate_json(response.text) - def _options( + async def _options( self, - address: str, endpoint: str, *, - params: dict[str, Any] | None = None, + params: dict[str, Any] | BaseModel | None = None, response_model: Type[BaseModel] | None = None, + timeout: float | None = None, **kwargs, ) -> httpx.Response | BaseModel: """ - Make an OPTIONS request to the specified endpoint. + Make an asynchronous OPTIONS request to the specified endpoint. Args: - address (str): The base address of the API. endpoint (str): The API endpoint. params (dict[str, Any] | None, optional): Query parameters. Defaults to None. - response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. - **kwargs (Any): Keyword arguments for `httpx.options`. + response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. Defaults to None. + timeout (float | None, optional): Timeout for the request. Defaults to None. + **kwargs (Any): Keyword arguments for `httpx.request`. Returns: response (httpx.Response | BaseModel): The response or Pydantic response object. """ - response = self.http_client.options( - f"{address}{endpoint}", - params=self._filter_params(params), + return await self._request( + "OPTIONS", + self.api_base, + endpoint, headers=self.headers, + params=params, + body=None, + response_model=response_model, + timeout=timeout, **kwargs, ) - response = self.raise_exception(response) - if response_model is None: - return response - else: - return response_model.model_validate_json(response.text) - def _patch( + async def _patch( self, - address: str, endpoint: str, *, - body: BaseModel | None, - params: dict[str, Any] | None = None, + params: dict[str, Any] | BaseModel | None = None, + body: BaseModel | None = None, response_model: Type[BaseModel] | None = None, + timeout: float | None = None, **kwargs, ) -> httpx.Response | BaseModel: """ - Make a PATCH request to the specified endpoint. + Make an asynchronous PATCH request to the specified endpoint. Args: - address (str): The base address of the API. endpoint (str): The API endpoint. - body (BaseModel | None): The request body. + params (dict[str, Any] | None, optional): Query parameters. Defaults to None. + body (BaseModel | None, optional): The body to send in the request. Defaults to None. response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. Defaults to None. + timeout (float | None, optional): Timeout for the request. Defaults to None. + **kwargs (Any): Keyword arguments for `httpx.request`. + + Returns: + response (httpx.Response | BaseModel): The response text or Pydantic response object. + """ + return await self._request( + "PATCH", + self.api_base, + endpoint, + headers=self.headers, + params=params, + body=body, + response_model=response_model, + timeout=timeout, + process_body_kwargs={"exclude_unset": True}, + **kwargs, + ) + + async def _put( + self, + endpoint: str, + *, + params: dict[str, Any] | BaseModel | None = None, + body: BaseModel | None = None, + response_model: Type[BaseModel] | None = None, + timeout: float | None = None, + **kwargs, + ) -> httpx.Response | BaseModel: + """ + Make an asynchronous PUT request to the specified endpoint. + + Args: + endpoint (str): The API endpoint. params (dict[str, Any] | None, optional): Query parameters. Defaults to None. - **kwargs (Any): Keyword arguments for `httpx.patch`. + body (BaseModel | None, optional): The body to send in the request. Defaults to None. + response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. Defaults to None. + timeout (float | None, optional): Timeout for the request. Defaults to None. + **kwargs (Any): Keyword arguments for `httpx.request`. Returns: response (httpx.Response | BaseModel): The response text or Pydantic response object. """ - if body is not None: - body = body.model_dump() - response = self.http_client.patch( - f"{address}{endpoint}", - json=body, + return await self._request( + "PUT", + self.api_base, + endpoint, headers=self.headers, - params=self._filter_params(params), + params=params, + body=body, + response_model=response_model, + timeout=timeout, + process_body_kwargs={"exclude_unset": True}, **kwargs, ) - response = self.raise_exception(response) - if response_model is None: - return response - else: - return response_model.model_validate_json(response.text) - def _stream( + async def _stream( self, - address: str, endpoint: str, *, body: BaseModel | None, - params: dict[str, Any] | None = None, + params: dict[str, Any] | BaseModel | None = None, + timeout: float | None = None, **kwargs, - ) -> Generator[str, None, None]: + ) -> AsyncGenerator[str, None]: """ - Make a streaming POST request to the specified endpoint. + Make an asynchronous streaming POST request to the specified endpoint. Args: - address (str): The base address of the API. endpoint (str): The API endpoint. - body (BaseModel | None): The request body. + body (BaseModel | None): The body body. params (dict[str, Any] | None, optional): Query parameters. Defaults to None. **kwargs (Any): Keyword arguments for `httpx.stream`. Yields: str: The response chunks. """ - if body is not None: - body = body.model_dump() - with self.http_client.stream( - "POST", - f"{address}{endpoint}", - json=body, - headers=self.headers, - params=self._filter_params(params), - **kwargs, - ) as response: - response = self.raise_exception(response) - for chunk in response.iter_lines(): - chunk = chunk.strip() - if chunk == "" or chunk == "data: [DONE]": - continue - yield chunk + with self._log_call(): + async with self.http_client.stream( + "POST", + f"{self.api_base}{endpoint}", + headers=self.headers, + params=self._filter_params(params), + json=self._process_body(body), + timeout=timeout or self.timeout, + **kwargs, + ) as response: + response = await self._raise_exception(response) + async for chunk in response.aiter_lines(): + chunk = chunk.strip() + if chunk == "" or chunk == "data: [DONE]": + continue + yield chunk - def _delete( + async def _delete( self, - address: str, endpoint: str, *, - params: dict[str, Any] | None = None, + params: dict[str, Any] | BaseModel | None = None, response_model: Type[BaseModel] | None = None, + timeout: float | None = None, ignore_code: int | None = None, **kwargs, ) -> httpx.Response | BaseModel: """ - Make a DELETE request to the specified endpoint. + Make an asynchronous DELETE request to the specified endpoint. Args: - address (str): The base address of the API. endpoint (str): The API endpoint. params (dict[str, Any] | None, optional): Query parameters. Defaults to None. - response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. - ignore_code (int | None, optional): HTTP code to ignore. - **kwargs (Any): Keyword arguments for `httpx.delete`. + response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. Defaults to None. + timeout (float | None, optional): Timeout for the request. Defaults to None. + ignore_code (int | None, optional): HTTP error code to ignore. + **kwargs (Any): Keyword arguments for `httpx.request`. Returns: response (httpx.Response | BaseModel): The response text or Pydantic response object. """ - response = self.http_client.delete( - f"{address}{endpoint}", - params=self._filter_params(params), + return await self._request( + "DELETE", + self.api_base, + endpoint, headers=self.headers, + params=params, + body=None, + response_model=response_model, + timeout=timeout, + ignore_code=ignore_code, **kwargs, ) - response = self.raise_exception(response, ignore_code=ignore_code) - if response_model is None: - return response + + @staticmethod + async def _empty_async_generator(): + """Returns an empty asynchronous generator.""" + return + # This line is never reached, but makes it an async generator + yield + + @staticmethod + def _empty_sync_generator(): + """Returns an empty synchronous generator.""" + return + # This line is never reached, but makes it a sync generator + yield + + async def _return_async_iterator( + self, + agen: AsyncGenerator[Any, None], + stream_models: list[Type[BaseModel]] | None = None, + ) -> AsyncGenerator[Any, None]: + # Get the first chunk outside of the loop so that errors can be raised immediately + try: + chunk = await anext(agen) + except StopAsyncIteration: + # Return empty async generator + return self._empty_async_generator() + + def _process(_chunk: str) -> BaseModel | str: + if stream_models is None: + return _chunk + for m in stream_models: + try: + return m.model_validate_json(_chunk[5:]) + except Exception: + pass + raise RuntimeError(f"Unexpected SSE chunk: {chunk}") + + # For streaming responses, return an asynchronous generator + async def gen(): + nonlocal chunk + yield _process(chunk) + async for chunk in agen: + yield _process(chunk) + + # Directly return the asynchronous generator + return gen() + + def _return_iterator( + self, + agen: AsyncGenerator[Any, None] | Any, + stream: bool, + ) -> Generator[Any, None, None] | Any: + if stream: + # Get the first chunk outside of the loop so that errors can be raised immediately + try: + chunk = LOOP.run(anext(agen)) + except StopAsyncIteration: + # Return empty sync generator + return self._empty_sync_generator() + + def gen(): + nonlocal chunk + yield chunk + while True: + try: + yield LOOP.run(anext(agen)) + except StopAsyncIteration: + break + + return gen() else: - return response_model.model_validate_json(response.text) + return agen -class _BackendAdminClient(_Client): - """Backend administration methods.""" +class _AuthAsync(_ClientAsync): + """Auth methods.""" - def create_user(self, request: UserCreate) -> UserRead: - return self._post( - self.api_base, - "/admin/backend/v1/users", - body=request, + async def register_password(self, body: UserCreate, **kwargs) -> UserRead: + return await self._post( + "/v2/auth/register/password", + body=body, response_model=UserRead, + **kwargs, ) - def update_user(self, request: UserUpdate) -> UserRead: - return self._patch( - self.api_base, - "/admin/backend/v1/users", - body=request, + async def login_password(self, body: PasswordLoginRequest, **kwargs) -> UserRead: + return await self._post( + "/v2/auth/login/password", + body=body, response_model=UserRead, + **kwargs, ) - def list_users( + async def change_password(self, body: PasswordChangeRequest, **kwargs) -> UserRead: + return await self._patch( + "/v2/auth/login/password", + body=body, + response_model=UserRead, + **kwargs, + ) + + +class _PricesAsync(_ClientAsync): + """Prices methods.""" + + async def create_price_plan(self, body: PricePlanCreate, **kwargs) -> PricePlanRead: + return await self._post( + "/v2/prices/plans", + body=body, + response_model=PricePlanRead, + **kwargs, + ) + + async def list_price_plans( self, + *, offset: int = 0, limit: int = 100, - order_by: str = AdminOrderBy.UPDATED_AT, - order_descending: bool = True, - ) -> Page[UserRead]: - """ - List users. - - Args: - offset (int, optional): Item offset. Defaults to 0. - limit (int, optional): Number of users to return (min 1, max 100). Defaults to 100. - order_by (str, optional): Sort users by this attribute. Defaults to "updated_at". - order_descending (bool, optional): Whether to sort by descending order. Defaults to True. - - Returns: - response (Page[UserRead]): The paginated user metadata response. - """ - return self._get( - self.api_base, - "/admin/backend/v1/users", + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[PricePlanRead]: + return await self._get( + "/v2/prices/plans/list", params=dict( offset=offset, limit=limit, order_by=order_by, - order_descending=order_descending, + order_ascending=order_ascending, ), - response_model=Page[UserRead], - ) - - def get_user(self, user_id: str) -> UserRead: - return self._get( - self.api_base, - f"/admin/backend/v1/users/{quote(user_id)}", - params=None, - response_model=UserRead, + response_model=Page[PricePlanRead], + **kwargs, ) - def delete_user( + async def get_price_plan( self, - user_id: str, - *, - missing_ok: bool = True, - ) -> OkResponse: - response = self._delete( - self.api_base, - f"/admin/backend/v1/users/{quote(user_id)}", - params=None, - response_model=None, - ignore_code=404 if missing_ok else None, - ) - if response.status_code == 404 and missing_ok: - return OkResponse() - else: - return OkResponse.model_validate_json(response.text) - - def create_pat(self, request: PATCreate) -> PATRead: - return self._post( - self.api_base, - "/admin/backend/v1/pats", - body=request, - response_model=PATRead, + plan_id: str, + **kwargs, + ) -> PricePlanRead: + return await self._get( + "/v2/prices/plans", + params=dict(price_plan_id=plan_id), + response_model=PricePlanRead, + **kwargs, ) - def get_pat(self, pat: str) -> PATRead: - return self._get( - self.api_base, - f"/admin/backend/v1/pats/{quote(pat)}", - params=None, - response_model=PATRead, + async def update_price_plan( + self, + plan_id: str, + body: PricePlanUpdate, + **kwargs, + ) -> PricePlanRead: + return await self._patch( + "/v2/prices/plans", + params=dict(price_plan_id=plan_id), + body=body, + response_model=PricePlanRead, + **kwargs, ) - def delete_pat( + async def delete_price_plan( self, - pat: str, + price_plan_id: str, *, missing_ok: bool = True, + **kwargs, ) -> OkResponse: - response = self._delete( - self.api_base, - f"/admin/backend/v1/pats/{quote(pat)}", - params=None, + response = await self._delete( + "/v2/prices/plans", + params=dict(price_plan_id=price_plan_id), response_model=None, ignore_code=404 if missing_ok else None, + **kwargs, ) if response.status_code == 404 and missing_ok: return OkResponse() else: return OkResponse.model_validate_json(response.text) - def create_organization(self, request: OrganizationCreate) -> OrganizationRead: - return self._post( - self.api_base, - "/admin/backend/v1/organizations", - body=request, - response_model=OrganizationRead, + async def list_model_prices(self, **kwargs) -> ModelPrice: + return await self._get( + "/v2/prices/models/list", + response_model=ModelPrice, + **kwargs, ) - def update_organization(self, request: OrganizationUpdate) -> OrganizationRead: - return self._patch( - self.api_base, - "/admin/backend/v1/organizations", - body=request, - response_model=OrganizationRead, + +class _UsersAsync(_ClientAsync): + """Users methods.""" + + async def create_user(self, body: UserCreate, **kwargs) -> UserRead: + return await self._post( + "/v2/users", + body=body, + response_model=UserRead, + process_body_kwargs={"exclude_unset": True}, + **kwargs, ) - def list_organizations( + async def list_users( self, + *, offset: int = 0, limit: int = 100, - order_by: str = AdminOrderBy.UPDATED_AT, - order_descending: bool = True, - ) -> Page[OrganizationRead]: - return self._get( - self.api_base, - "/admin/backend/v1/organizations", + order_by: str = "updated_at", + order_ascending: bool = True, + search_query: str = "", + search_columns: list[str] | None = None, + after: str | None = None, + **kwargs, + ) -> Page[UserRead]: + params = dict( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + search_query=search_query, + after=after, + ) + if search_columns: + params["search_columns"] = search_columns + return await self._get( + "/v2/users/list", + params=params, + response_model=Page[UserRead], + **kwargs, + ) + + async def get_user( + self, + user_id: str | None = None, + **kwargs, + ) -> UserRead: + return await self._get( + "/v2/users", + params=dict(user_id=user_id), + response_model=UserRead, + **kwargs, + ) + + async def update_user( + self, + body: UserUpdate, + **kwargs, + ) -> UserRead: + return await self._patch( + "/v2/users", + body=body, + response_model=UserRead, + **kwargs, + ) + + async def delete_user( + self, + *, + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: + response = await self._delete( + "/v2/users", + response_model=None, + ignore_code=404 if missing_ok else None, + **kwargs, + ) + if response.status_code == 404 and missing_ok: + return OkResponse() + else: + return OkResponse.model_validate_json(response.text) + + async def create_pat(self, body: ProjectKeyCreate, **kwargs) -> ProjectKeyRead: + return await self._post( + "/v2/pats", + body=body, + response_model=ProjectKeyRead, + **kwargs, + ) + + async def list_pats( + self, + *, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[ProjectKeyRead]: + return await self._get( + "/v2/pats/list", params=dict( offset=offset, limit=limit, order_by=order_by, - order_descending=order_descending, + order_ascending=order_ascending, ), - response_model=Page[OrganizationRead], + response_model=Page[ProjectKeyRead], + **kwargs, ) - def get_organization(self, organization_id: str) -> OrganizationRead: - return self._get( - self.api_base, - f"/admin/backend/v1/organizations/{quote(organization_id)}", - params=None, - response_model=OrganizationRead, + async def update_pat( + self, + pat_id: str, + body: ProjectKeyUpdate, + **kwargs, + ) -> ProjectKeyRead: + return await self._patch( + "/v2/pats", + params=dict(pat_id=pat_id), + body=body, + response_model=ProjectKeyRead, + **kwargs, ) - def delete_organization( + async def delete_pat( self, - organization_id: str, + pat_id: str, *, missing_ok: bool = True, + **kwargs, ) -> OkResponse: - response = self._delete( - self.api_base, - f"/admin/backend/v1/organizations/{quote(organization_id)}", - params=None, + response = await self._delete( + "/v2/pats", + params=dict(pat_id=pat_id), response_model=None, ignore_code=404 if missing_ok else None, + **kwargs, ) if response.status_code == 404 and missing_ok: return OkResponse() else: return OkResponse.model_validate_json(response.text) - def generate_invite_token( + async def create_email_verification_code( self, - organization_id: str, - user_email: str = "", - user_role: str = "", + *, valid_days: int = 7, - ) -> str: + **kwargs, + ) -> VerificationCodeRead: """ - Generates an invite token to join an organization. + Generates an email verification code. Args: - organization_id (str): Organization ID. - user_email (str, optional): User email. - Leave blank to disable email check and generate a public invite. Defaults to "". - user_role (str, optional): User role. - Leave blank to default to guest. Defaults to "". - valid_days (int, optional): How many days should this link be valid for. Defaults to 7. + valid_days (int, optional): Code validity in days. Defaults to 7. Returns: - token (str): _description_ + code (InviteCodeRead): Verification code. """ - response = self._get( - self.api_base, - "/admin/backend/v1/invite_tokens", - params=dict( - organization_id=organization_id, - user_email=user_email, - user_role=user_role, - valid_days=valid_days, - ), - response_model=None, - ) - return response.text - - def join_organization(self, request: OrgMemberCreate) -> OrgMemberRead: - return self._post( - self.api_base, - "/admin/backend/v1/organizations/link", - body=request, - response_model=OrgMemberRead, + return await self._post( + "/v2/users/verify/email/code", + params=dict(valid_days=valid_days), + body=None, + response_model=VerificationCodeRead, + **kwargs, ) - def leave_organization(self, user_id: str, organization_id: str) -> OkResponse: - return self._delete( - self.api_base, - f"/admin/backend/v1/organizations/link/{quote(user_id)}/{quote(organization_id)}", - params=None, - response_model=OkResponse, + async def list_email_verification_codes( + self, + *, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + search_query: str = "", + search_columns: list[str] | None = None, + after: str | None = None, + **kwargs, + ) -> Page[VerificationCodeRead]: + params = dict( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + search_query=search_query, + after=after, ) - - def create_api_key(self, request: ApiKeyCreate) -> ApiKeyRead: - warn(ORG_API_KEY_DEPRECATE, FutureWarning, stacklevel=2) - return self._post( - self.api_base, - "/admin/backend/v1/api_keys", - body=request, - response_model=ApiKeyRead, + if search_columns: + params["search_columns"] = search_columns + return await self._get( + "/v2/users/verify/email/code/list", + params=params, + response_model=Page[VerificationCodeRead], + **kwargs, ) - def get_api_key(self, api_key: str) -> ApiKeyRead: - warn(ORG_API_KEY_DEPRECATE, FutureWarning, stacklevel=2) - return self._get( - self.api_base, - f"/admin/backend/v1/api_keys/{quote(api_key)}", - params=None, - response_model=ApiKeyRead, + async def get_email_verification_code( + self, + verification_code: str, + **kwargs, + ) -> VerificationCodeRead: + return await self._get( + "/v2/users/verify/email/code", + params=dict(verification_code=verification_code), + response_model=VerificationCodeRead, + **kwargs, ) - def delete_api_key( + async def revoke_email_verification_code( self, - api_key: str, + verification_code: str, *, missing_ok: bool = True, + **kwargs, ) -> OkResponse: - warn(ORG_API_KEY_DEPRECATE, FutureWarning, stacklevel=2) - response = self._delete( - self.api_base, - f"/admin/backend/v1/api_keys/{quote(api_key)}", - params=None, + response = await self._delete( + "/v2/users/verify/email/code", + params=dict(verification_code=verification_code), response_model=None, ignore_code=404 if missing_ok else None, + **kwargs, ) if response.status_code == 404 and missing_ok: return OkResponse() else: return OkResponse.model_validate_json(response.text) - def refresh_quota( + @deprecated( + "`delete_email_verification_code` is deprecated, use `revoke_email_verification_code` instead.", + category=FutureWarning, + stacklevel=1, + ) + async def delete_email_verification_code( self, - organization_id: str, - reset_usage: bool = True, - ) -> OrganizationRead: - return self._post( - self.api_base, - f"/admin/backend/v1/quotas/refresh/{quote(organization_id)}", - body=None, - params=dict(reset_usage=reset_usage), - response_model=OrganizationRead, - ) - - def get_event(self, event_id: str) -> EventRead: - return self._get( - self.api_base, - f"/admin/backend/v1/events/{quote(event_id)}", - params=None, - response_model=EventRead, - ) - - def add_event(self, request: EventCreate) -> OkResponse: - return self._post( - self.api_base, - "/admin/backend/v1/events", - body=request, - response_model=OkResponse, - ) - - def mark_event_as_done(self, event_id: str) -> OkResponse: - return self._patch( - self.api_base, - f"/admin/backend/v1/events/done/{quote(event_id)}", - body=None, - response_model=OkResponse, - ) - - def get_internal_organization_id(self) -> StringResponse: - return self._get( - self.api_base, - "/admin/backend/v1/internal_organization_id", - params=None, - response_model=StringResponse, - ) - - def set_internal_organization_id(self, organization_id: str) -> OkResponse: - return self._patch( - self.api_base, - f"/admin/backend/v1/internal_organization_id/{quote(organization_id)}", - body=None, - response_model=OkResponse, - ) - - def get_pricing(self) -> Price: - return self._get( - self.api_base, - "/public/v1/prices/plans", - params=None, - response_model=Price, - ) - - def set_pricing(self, request: Price) -> OkResponse: - return self._patch( - self.api_base, - "/admin/backend/v1/prices/plans", - body=request, - response_model=OkResponse, - ) - - def get_model_pricing(self) -> ModelPrice: - return self._get( - self.api_base, - "/public/v1/prices/models", - params=None, - response_model=ModelPrice, - ) - - def get_model_config(self) -> ModelListConfig: - return self._get( - self.api_base, - "/admin/backend/v1/models", - params=None, - response_model=ModelListConfig, - ) - - def set_model_config(self, request: ModelListConfig) -> OkResponse: - return self._patch( - self.api_base, - "/admin/backend/v1/models", - body=request, - response_model=OkResponse, + verification_code: str, + *, + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: + return await self.revoke_email_verification_code( + verification_code=verification_code, + missing_ok=missing_ok, + **kwargs, ) - def add_template( + async def verify_email( self, - source: str | BinaryIO, - template_id_dst: str, - exist_ok: bool = False, + verification_code: str, + **kwargs, ) -> OkResponse: """ - Upload a template Parquet file to add a new template into gallery. - - Args: - source (str | BinaryIO): The path to the template Parquet file or a file-like object. - template_id_dst (str): The ID of the new template. - exist_ok (bool, optional): Whether to overwrite existing template. Defaults to False. - - Returns: - response (OkResponse): The response indicating success. - """ - kwargs = dict( - address=self.api_base, - endpoint="/admin/backend/v1/templates/import", - body=None, - response_model=OkResponse, - data={"template_id_dst": template_id_dst, "exist_ok": exist_ok}, - timeout=self.file_upload_timeout, - ) - mime_type = "application/octet-stream" - if isinstance(source, str): - filename = split(source)[-1] - # Open the file in binary mode - with open(source, "rb") as f: - return self._post(files={"file": (filename, f, mime_type)}, **kwargs) - else: - filename = "import.parquet" - return self._post(files={"file": (filename, source, mime_type)}, **kwargs) - - def populate_templates(self, timeout: float = 30.0) -> OkResponse: - """ - Re-populates the template gallery. + Verify and update user email. Args: - timeout (float, optional): Timeout in seconds, must be >= 0. Defaults to 30.0. + verification_code (str): Verification code. Returns: - response (OkResponse): The response indicating success. + ok (OkResponse): Success. """ - return self._post( - self.api_base, - "/admin/backend/v1/templates/populate", - body=None, - params=dict(timeout=timeout), - response_model=OkResponse, - ) - - -class _OrgAdminClient(_Client): - """Organization administration methods.""" - - def get_org_model_config(self, organization_id: str) -> ModelListConfig: - return self._get( - self.api_base, - f"/admin/org/v1/models/{quote(organization_id)}", - params=None, - response_model=ModelListConfig, - ) - - def set_org_model_config( - self, - organization_id: str, - config: ModelListConfig, - ) -> OkResponse: - return self._patch( - self.api_base, - f"/admin/org/v1/models/{quote(organization_id)}", - body=config, + return await self._post( + "/v2/users/verify/email", + params=dict(verification_code=verification_code), response_model=OkResponse, + **kwargs, ) - def create_project(self, request: ProjectCreate) -> ProjectRead: - return self._post( - self.api_base, - "/admin/org/v1/projects", - body=request, - response_model=ProjectRead, - ) - def update_project(self, request: ProjectUpdate) -> ProjectRead: - return self._patch( - self.api_base, - "/admin/org/v1/projects", - body=request, - response_model=ProjectRead, - ) +class _ModelsAsync(_ClientAsync): + """Models methods.""" - def set_project_updated_at( - self, - project_id: str, - updated_at: str | None = None, - ) -> OkResponse: - return self._patch( - self.api_base, - f"/admin/org/v1/projects/{quote(project_id)}", - body=None, - params=dict(updated_at=updated_at), - response_model=OkResponse, + async def create_model_config(self, body: ModelConfigCreate, **kwargs) -> ModelConfigRead: + return await self._post( + "/v2/models/configs", + body=body, + response_model=ModelConfigRead, + **kwargs, ) - def list_projects( + async def list_model_configs( self, - organization_id: str = "default", - search_query: str = "", + *, + organization_id: str | None = None, offset: int = 0, limit: int = 100, - order_by: str = AdminOrderBy.UPDATED_AT, - order_descending: bool = True, - ) -> Page[ProjectRead]: - return self._get( - self.api_base, - "/admin/org/v1/projects", + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[ModelConfigRead]: + return await self._get( + "/v2/models/configs/list", params=dict( - organization_id=organization_id, - search_query=search_query, offset=offset, limit=limit, order_by=order_by, - order_descending=order_descending, + order_ascending=order_ascending, + organization_id=organization_id, ), - response_model=Page[ProjectRead], + response_model=Page[ModelConfigRead], + **kwargs, ) - def get_project(self, project_id: str) -> ProjectRead: - return self._get( - self.api_base, - f"/admin/org/v1/projects/{quote(project_id)}", - params=None, - response_model=ProjectRead, + async def get_model_config( + self, + model_id: str, + **kwargs, + ) -> ModelConfigRead: + return await self._get( + "/v2/models/configs", + params=dict(model_id=model_id), + response_model=ModelConfigRead, + **kwargs, ) - def delete_project( + async def update_model_config( self, - project_id: str, - *, + model_id: str, + body: ModelConfigUpdate, + **kwargs, + ) -> ModelConfigRead: + return await self._patch( + "/v2/models/configs", + params=dict(model_id=model_id), + body=body, + response_model=ModelConfigRead, + **kwargs, + ) + + async def delete_model_config( + self, + model_id: str, missing_ok: bool = True, + **kwargs, ) -> OkResponse: - response = self._delete( - self.api_base, - f"/admin/org/v1/projects/{quote(project_id)}", - params=None, + response = await self._delete( + "/v2/models/configs", + params=dict(model_id=model_id), response_model=None, ignore_code=404 if missing_ok else None, + **kwargs, ) if response.status_code == 404 and missing_ok: return OkResponse() else: return OkResponse.model_validate_json(response.text) - def import_project( + async def create_deployment( self, - source: str | BinaryIO, - organization_id: str, - project_id_dst: str = "", - ) -> ProjectRead: - """ - Imports a project. - - Args: - source (str | BinaryIO): The parquet file path or file-like object. - It can be a Project or Template file. - organization_id (str): Organization ID "org_xxx". - project_id_dst (str, optional): ID of the project to import tables into. - Defaults to creating new project. - - Returns: - response (ProjectRead): The imported project. - """ - kwargs = dict( - address=self.api_base, - endpoint=f"/admin/org/v1/projects/import/{quote(organization_id)}", - body=None, - response_model=ProjectRead, - data={"project_id_dst": project_id_dst}, - timeout=self.file_upload_timeout, + body: DeploymentCreate, + timeout: float | None = 300.0, + **kwargs, + ) -> DeploymentRead: + return await self._post( + "/v2/models/deployments/cloud", + body=body, + response_model=DeploymentRead, + timeout=self.timeout if timeout is None else timeout, + **kwargs, ) - mime_type = "application/octet-stream" - if isinstance(source, str): - filename = split(source)[-1] - # Open the file in binary mode - with open(source, "rb") as f: - return self._post(files={"file": (filename, f, mime_type)}, **kwargs) - else: - filename = "import.parquet" - return self._post(files={"file": (filename, source, mime_type)}, **kwargs) - def export_project( + async def list_deployments( self, - project_id: str, - compression: Literal["NONE", "ZSTD", "LZ4", "SNAPPY"] = "ZSTD", - ) -> bytes: - """ - Exports a project as a Project Parquet file. - - Args: - project_id (str): Project ID "proj_xxx". - compression (str, optional): Parquet compression codec. Defaults to "ZSTD". - - Returns: - response (bytes): The Parquet file. - """ - response = self._get( - self.api_base, - f"/admin/org/v1/projects/{quote(project_id)}/export", - params=dict(compression=compression), - response_model=None, + *, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[DeploymentRead]: + return await self._get( + "/v2/models/deployments/list", + params=dict( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + ), + response_model=Page[DeploymentRead], + **kwargs, ) - return response.content - def import_project_from_template( + async def get_deployment( self, - organization_id: str, - template_id: str, - project_id_dst: str = "", - ) -> ProjectRead: - """ - Imports a project from a template. - - Args: - organization_id (str): Organization ID "org_xxx". - template_id (str): ID of the template to import from. - project_id_dst (str, optional): ID of the project to import tables into. - Defaults to creating new project. - - Returns: - response (ProjectRead): The imported project. - """ - return self._post( - self.api_base, - f"/admin/org/v1/projects/import/{quote(organization_id)}/templates/{quote(template_id)}", - body=None, - params=dict(project_id_dst=project_id_dst), - response_model=ProjectRead, + deployment_id: str, + **kwargs, + ) -> DeploymentRead: + return await self._get( + "/v2/models/deployments", + params=dict(deployment_id=deployment_id), + response_model=DeploymentRead, + **kwargs, ) - def export_project_as_template( + async def update_deployment( self, - project_id: str, - *, - name: str, - tags: list[str], - description: str, - compression: Literal["NONE", "ZSTD", "LZ4", "SNAPPY"] = "ZSTD", - ) -> bytes: - """ - Exports a project as a template Parquet file. - - Args: - project_id (str): Project ID "proj_xxx". - name (str): Template name. - tags (list[str]): Template tags. - description (str): Template description. - compression (str, optional): Parquet compression codec. Defaults to "ZSTD". - - Returns: - response (bytes): The template Parquet file. - """ - response = self._get( - self.api_base, - f"/admin/org/v1/projects/{quote(project_id)}/export/template", - params=dict( - name=name, - tags=tags, - description=description, - compression=compression, - ), - response_model=None, + deployment_id: str, + body: DeploymentUpdate, + **kwargs, + ) -> DeploymentRead: + return await self._patch( + "/v2/models/deployments", + params=dict(deployment_id=deployment_id), + body=body, + response_model=DeploymentRead, + **kwargs, ) - return response.content - - -class _AdminClient(_Client): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.backend = _BackendAdminClient(*args, **kwargs) - self.organization = _OrgAdminClient(*args, **kwargs) - - -class _TemplateClient(_Client): - """Template methods.""" - def list_templates(self, search_query: str = "") -> Page[Template]: - """ - List all templates. - - Args: - search_query (str, optional): A string to search for within template names. - - Returns: - templates (Page[Template]): A page of templates. - """ - return self._get( - self.api_base, - "/public/v1/templates", - params=dict(search_query=search_query), - response_model=Page[Template], + async def delete_deployment(self, deployment_id: str, **kwargs) -> OkResponse: + return await self._delete( + "/v2/models/deployments", + params=dict(deployment_id=deployment_id), + response_model=OkResponse, + **kwargs, ) - def get_template(self, template_id: str) -> Template: - """ - Get a template by its ID. - Args: - template_id (str): Template ID. +class _OrganizationsAsync(_ClientAsync): + """Organization methods.""" - Returns: - template (Template): The template. - """ - return self._get( - self.api_base, - f"/public/v1/templates/{quote(template_id)}", - params=None, - response_model=Template, + async def create_organization( + self, + body: OrganizationCreate, + **kwargs, + ) -> OrganizationRead: + return await self._post( + "/v2/organizations", + body=body, + response_model=OrganizationRead, + **kwargs, ) - def list_tables( + async def list_organizations( self, - template_id: str, - table_type: str, *, offset: int = 0, limit: int = 100, - search_query: str = "", - order_by: str = GenTableOrderBy.UPDATED_AT, - order_descending: bool = True, - ) -> Page[TableMetaResponse]: - """ - List all tables in a template. - - Args: - template_id (str): Template ID. - table_type (str): Table type. - offset (int, optional): Item offset. Defaults to 0. - limit (int, optional): Number of tables to return (min 1, max 100). Defaults to 100. - search_query (str, optional): A string to search for within table IDs as a filter. - Defaults to "" (no filter). - order_by (str, optional): Sort tables by this attribute. Defaults to "updated_at". - order_descending (bool, optional): Whether to sort by descending order. Defaults to True. - - Returns: - tables (Page[TableMetaResponse]): A page of tables. - """ - return self._get( - self.api_base, - f"/public/v1/templates/{quote(template_id)}/gen_tables/{quote(table_type)}", + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[OrganizationRead]: + return await self._get( + "/v2/organizations/list", params=dict( offset=offset, limit=limit, - search_query=search_query, order_by=order_by, - order_descending=order_descending, + order_ascending=order_ascending, ), - response_model=Page[TableMetaResponse], + response_model=Page[OrganizationRead], + **kwargs, ) - def get_table(self, template_id: str, table_type: str, table_id: str) -> TableMetaResponse: - """ - Get a table in a template. + async def get_organization( + self, + organization_id: str, + **kwargs, + ) -> OrganizationRead: + return await self._get( + "/v2/organizations", + params=dict(organization_id=organization_id), + response_model=OrganizationRead, + **kwargs, + ) - Args: - template_id (str): Template ID. - table_type (str): Table type. - table_id (str): Table ID. + async def update_organization( + self, + organization_id: str, + body: OrganizationUpdate, + **kwargs, + ) -> OrganizationRead: + return await self._patch( + "/v2/organizations", + body=body, + params=dict(organization_id=organization_id), + response_model=OrganizationRead, + **kwargs, + ) - Returns: - table (TableMetaResponse): The table. - """ - return self._get( - self.api_base, - f"/public/v1/templates/{quote(template_id)}/gen_tables/{quote(table_type)}/{quote(table_id)}", - params=None, - response_model=TableMetaResponse, + async def delete_organization( + self, + organization_id: str, + *, + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: + response = await self._delete( + "/v2/organizations", + params=dict(organization_id=organization_id), + response_model=None, + ignore_code=404 if missing_ok else None, + **kwargs, ) + if response.status_code == 404 and missing_ok: + return OkResponse() + else: + return OkResponse.model_validate_json(response.text) - def list_table_rows( + async def join_organization( self, - template_id: str, - table_type: str, - table_id: str, + user_id: str, + *, + invite_code: str | None = None, + organization_id: str | None = None, + role: str | None = None, + **kwargs, + ) -> OrgMemberRead: + return await self._post( + "/v2/organizations/members", + params=dict( + user_id=user_id, + organization_id=organization_id, + role=role, + invite_code=invite_code, + ), + body=None, + response_model=OrgMemberRead, + **kwargs, + ) + + async def list_members( + self, + organization_id: str, *, - starting_after: str | None = None, offset: int = 0, limit: int = 100, - order_by: str = "Updated at", - order_descending: bool = True, - float_decimals: int = 0, - vec_decimals: int = 0, - ) -> Page[dict[str, Any]]: - """ - List rows in a template table. - - Args: - template_id (str): Template ID. - table_type (str): Table type. - table_id (str): Table ID. - starting_after (str | None, optional): A cursor for use in pagination. - Only rows with ID > `starting_after` will be returned. - For instance, if your call receives 100 rows ending with ID "x", - your subsequent call can include `starting_after="x"` in order to fetch the next page of the list. - Defaults to None. - offset (int, optional): Item offset. Defaults to 0. - limit (int, optional): Number of rows to return (min 1, max 100). Defaults to 100. - order_by (str, optional): Sort rows by this column. Defaults to "Updated at". - order_descending (bool, optional): Whether to sort by descending order. Defaults to True. - float_decimals (int, optional): Number of decimals for float values. - Defaults to 0 (no rounding). - vec_decimals (int, optional): Number of decimals for vectors. - If its negative, exclude vector columns. Defaults to 0 (no rounding). - - Returns: - rows (Page[dict[str, Any]]): The rows. - """ - return self._get( - self.api_base, - f"/public/v1/templates/{quote(template_id)}/gen_tables/{quote(table_type)}/{quote(table_id)}/rows", + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[OrgMemberRead]: + return await self._get( + "/v2/organizations/members/list", params=dict( - starting_after=starting_after, offset=offset, limit=limit, order_by=order_by, - order_descending=order_descending, - float_decimals=float_decimals, - vec_decimals=vec_decimals, + order_ascending=order_ascending, + organization_id=organization_id, ), - response_model=Page[dict[str, Any]], + response_model=Page[OrgMemberRead], + **kwargs, ) + async def get_member( + self, + *, + user_id: str, + organization_id: str, + **kwargs, + ) -> OrgMemberRead: + return await self._get( + "/v2/organizations/members", + params=dict(user_id=user_id, organization_id=organization_id), + response_model=OrgMemberRead, + **kwargs, + ) -class _FileClient(_Client): - """File methods.""" - - def upload_file(self, file_path: str) -> FileUploadResponse: - """ - Uploads a file to the server. + async def update_member_role( + self, + *, + user_id: str, + organization_id: str, + role: Role, + **kwargs, + ) -> OrgMemberRead: + return await self._patch( + "/v2/organizations/members/role", + params=dict(user_id=user_id, organization_id=organization_id, role=role), + response_model=OrgMemberRead, + **kwargs, + ) - Args: - file_path (str): Path to the file to be uploaded. - - Returns: - response (FileUploadResponse): The response containing the file URI. - """ - filename = split(file_path)[-1] - mime_type = filetype.guess(file_path).mime - if mime_type is None: - mime_type = "application/octet-stream" # Default MIME type - - with open(file_path, "rb") as f: - return self._post( - self.api_base, - "/v1/files/upload", - body=None, - response_model=FileUploadResponse, - files={ - "file": (filename, f, mime_type), - }, - timeout=self.file_upload_timeout, - ) - - def get_raw_urls(self, uris: list[str]) -> GetURLResponse: - """ - Get download URLs for raw files. - - Args: - uris (List[str]): List of file URIs to download. - - Returns: - response (GetURLResponse): The response containing download information for the files. - """ - return self._post( - self.api_base, - "/v1/files/url/raw", - body=GetURLRequest(uris=uris), - response_model=GetURLResponse, - ) - - def get_thumbnail_urls(self, uris: list[str]) -> GetURLResponse: - """ - Get download URLs for file thumbnails. - - Args: - uris (List[str]): List of file URIs to get thumbnails for. - - Returns: - response (GetURLResponse): The response containing download information for the thumbnails. - """ - return self._post( - self.api_base, - "/v1/files/url/thumb", - body=GetURLRequest(uris=uris), - response_model=GetURLResponse, - ) - - -class _GenTableClient(_Client): - """Generative Table methods.""" - - def create_action_table(self, request: ActionTableSchemaCreate) -> TableMetaResponse: - """ - Create an Action Table. - - Args: - request (ActionTableSchemaCreate): The action table schema. - - Returns: - response (TableMetaResponse): The table metadata response. - """ - return self._post( - self.api_base, - "/v1/gen_tables/action", - body=request, - response_model=TableMetaResponse, - ) - - def create_knowledge_table(self, request: KnowledgeTableSchemaCreate) -> TableMetaResponse: - """ - Create a Knowledge Table. - - Args: - request (KnowledgeTableSchemaCreate): The knowledge table schema. - - Returns: - response (TableMetaResponse): The table metadata response. - """ - return self._post( - self.api_base, - "/v1/gen_tables/knowledge", - body=request, - response_model=TableMetaResponse, + async def leave_organization( + self, + user_id: str, + organization_id: str, + **kwargs, + ) -> OkResponse: + return await self._delete( + "/v2/organizations/members", + params=dict( + user_id=user_id, + organization_id=organization_id, + ), + response_model=OkResponse, + **kwargs, ) - def create_chat_table(self, request: ChatTableSchemaCreate) -> TableMetaResponse: - """ - Create a Chat Table. - - Args: - request (ChatTableSchemaCreate): The chat table schema. - - Returns: - response (TableMetaResponse): The table metadata response. - """ - return self._post( - self.api_base, - "/v1/gen_tables/chat", - body=request, - response_model=TableMetaResponse, + async def model_catalogue( + self, + *, + organization_id: str, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[ModelConfigRead]: + return await self._get( + "/v2/organizations/models/catalogue", + params=dict( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + organization_id=organization_id, + ), + response_model=Page[ModelConfigRead], + **kwargs, ) - def get_table( + async def create_invite( self, - table_type: str | TableType, - table_id: str, - ) -> TableMetaResponse: + *, + user_email: str, + organization_id: str, + role: str, + valid_days: int = 7, + **kwargs, + ) -> VerificationCodeRead: """ - Get metadata for a specific Generative Table. + Generates an invite token to join an organization. Args: - table_type (str | TableType): The type of the table. - table_id (str): The ID of the table. + user_email (str): User email. + organization_id (str): Organization ID. + role (str): Organization role. + valid_days (int, optional): Code validity in days. Defaults to 7. Returns: - response (TableMetaResponse): The table metadata response. + code (InviteCodeRead): Invite code. """ - return self._get( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}", - params=None, - response_model=TableMetaResponse, + return await self._post( + "/v2/organizations/invites", + params=dict( + user_email=user_email, + organization_id=organization_id, + role=role, + valid_days=valid_days, + ), + body=None, + response_model=VerificationCodeRead, + **kwargs, ) - def list_tables( + async def generate_invite_token(self, *_, **__): + raise NotImplementedError("This method is deprecated, use `create_invite` instead.") + + async def list_invites( self, - table_type: str | TableType, + organization_id: str, *, offset: int = 0, limit: int = 100, - parent_id: str | None = None, - search_query: str = "", - order_by: str = GenTableOrderBy.UPDATED_AT, - order_descending: bool = True, - count_rows: bool = False, - ) -> Page[TableMetaResponse]: - """ - List Generative Tables of a specific type. - - Args: - table_type (str | TableType): The type of the table. - offset (int, optional): Item offset. Defaults to 0. - limit (int, optional): Number of tables to return (min 1, max 100). Defaults to 100. - parent_id (str | None, optional): Parent ID of tables to return. - Additionally for Chat Table, you can list: - (1) all chat agents by passing in "_agent_"; or - (2) all chats by passing in "_chat_". - Defaults to None (return all tables). - search_query (str, optional): A string to search for within table IDs as a filter. - Defaults to "" (no filter). - order_by (str, optional): Sort tables by this attribute. Defaults to "updated_at". - order_descending (bool, optional): Whether to sort by descending order. Defaults to True. - count_rows (bool, optional): Whether to count the rows of the tables. Defaults to False. - - Returns: - response (Page[TableMetaResponse]): The paginated table metadata response. - """ - return self._get( - self.api_base, - f"/v1/gen_tables/{table_type}", + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[VerificationCodeRead]: + return await self._get( + "/v2/organizations/invites/list", params=dict( offset=offset, limit=limit, - parent_id=parent_id, - search_query=search_query, order_by=order_by, - order_descending=order_descending, - count_rows=count_rows, + order_ascending=order_ascending, + organization_id=organization_id, ), - response_model=Page[TableMetaResponse], + response_model=Page[VerificationCodeRead], + **kwargs, ) - def delete_table( + async def revoke_invite( self, - table_type: str | TableType, - table_id: str, + invite_id: str, *, missing_ok: bool = True, + **kwargs, ) -> OkResponse: - """ - Delete a specific table. - - Args: - table_type (str | TableType): The type of the table. - table_id (str): The ID of the table. - missing_ok (bool, optional): Ignore resource not found error. - - Returns: - response (OkResponse): The response indicating success. - """ - response = self._delete( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}", - params=None, + response = await self._delete( + "/v2/organizations/invites", + params=dict(invite_id=invite_id), response_model=None, ignore_code=404 if missing_ok else None, + **kwargs, ) if response.status_code == 404 and missing_ok: return OkResponse() else: return OkResponse.model_validate_json(response.text) - def duplicate_table( + @deprecated( + "`delete_invite` is deprecated, use `revoke_invite` instead.", + category=FutureWarning, + stacklevel=1, + ) + async def delete_invite( self, - table_type: str | TableType, - table_id_src: str, - table_id_dst: str | None = None, + invite_id: str, *, - include_data: bool = True, - create_as_child: bool = False, + missing_ok: bool = True, **kwargs, - ) -> TableMetaResponse: - """ - Duplicate a table. - - Args: - table_type (str | TableType): The type of the table. - table_id_src (str): The source table ID. - table_id_dst (str | None, optional): The destination / new table ID. - Defaults to None (create a new table ID automatically). - include_data (bool, optional): Whether to include data in the duplicated table. Defaults to True. - create_as_child (bool, optional): Whether the new table is a child table. - If this is True, then `include_data` will be set to True. Defaults to False. + ) -> OkResponse: + return await self.revoke_invite(invite_id=invite_id, missing_ok=missing_ok, **kwargs) - Returns: - response (TableMetaResponse): The table metadata response. - """ - if "deploy" in kwargs: - warn( - 'The "deploy" argument is deprecated, use "create_as_child" instead.', - FutureWarning, - stacklevel=2, - ) - create_as_child = create_as_child or kwargs.pop("deploy") - return self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/duplicate/{quote(table_id_src)}", + async def subscribe_plan( + self, + organization_id: str, + price_plan_id: str, + **kwargs, + ) -> StripePaymentInfo: + return await self._patch( + "/v2/organizations/plan", + params=dict(organization_id=organization_id, price_plan_id=price_plan_id), body=None, - params=dict( - table_id_dst=table_id_dst, - include_data=include_data, - create_as_child=create_as_child, - ), - response_model=TableMetaResponse, + response_model=StripePaymentInfo, + **kwargs, ) - def rename_table( + async def refresh_quota( self, - table_type: str | TableType, - table_id_src: str, - table_id_dst: str, - ) -> TableMetaResponse: - """ - Rename a table. - - Args: - table_type (str | TableType): The type of the table. - table_id_src (str): The source table ID. - table_id_dst (str): The destination / new table ID. - - Returns: - response (TableMetaResponse): The table metadata response. - """ - return self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/rename/{quote(table_id_src)}/{quote(table_id_dst)}", + organization_id: str, + **kwargs, + ) -> OrganizationRead: + return await self._post( + "/v2/organizations/plan/refresh", + params=dict(organization_id=organization_id), body=None, - response_model=TableMetaResponse, + response_model=OrganizationRead, + **kwargs, ) - def update_gen_config( + async def purchase_credits( self, - table_type: str | TableType, - request: GenConfigUpdateRequest, - ) -> TableMetaResponse: - """ - Update the generation configuration for a table. - - Args: - table_type (str | TableType): The type of the table. - request (GenConfigUpdateRequest): The generation configuration update request. - - Returns: - response (TableMetaResponse): The table metadata response. - """ - return self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/gen_config/update", - body=request, - response_model=TableMetaResponse, + organization_id: str, + amount: float, + *, + confirm: bool = False, + off_session: bool = False, + **kwargs, + ) -> StripePaymentInfo: + return await self._post( + "/v2/organizations/credits", + params=dict( + organization_id=organization_id, + amount=amount, + confirm=confirm, + off_session=off_session, + ), + body=None, + response_model=StripePaymentInfo, + **kwargs, ) - def add_action_columns(self, request: AddActionColumnSchema) -> TableMetaResponse: - """ - Add columns to an Action Table. - - Args: - request (AddActionColumnSchema): The action column schema. - - Returns: - response (TableMetaResponse): The table metadata response. - """ - return self._post( - self.api_base, - "/v1/gen_tables/action/columns/add", - body=request, - response_model=TableMetaResponse, + async def set_credit_grant( + self, + organization_id: str, + amount: float, + **kwargs, + ) -> OkResponse: + return await self._put( + "/v2/organizations/credit_grant", + params=dict(organization_id=organization_id, amount=amount), + body=None, + response_model=OkResponse, + **kwargs, ) - def add_knowledge_columns(self, request: AddKnowledgeColumnSchema) -> TableMetaResponse: - """ - Add columns to a Knowledge Table. - - Args: - request (AddKnowledgeColumnSchema): The knowledge column schema. - - Returns: - response (TableMetaResponse): The table metadata response. - """ - return self._post( - self.api_base, - "/v1/gen_tables/knowledge/columns/add", - body=request, - response_model=TableMetaResponse, + async def add_credit_grant( + self, + organization_id: str, + amount: float, + **kwargs, + ) -> OkResponse: + return await self._patch( + "/v2/organizations/credit_grant", + params=dict(organization_id=organization_id, amount=amount), + body=None, + response_model=OkResponse, + **kwargs, ) - def add_chat_columns(self, request: AddChatColumnSchema) -> TableMetaResponse: - """ - Add columns to a Chat Table. - - Args: - request (AddChatColumnSchema): The chat column schema. - - Returns: - response (TableMetaResponse): The table metadata response. - """ - return self._post( - self.api_base, - "/v1/gen_tables/chat/columns/add", - body=request, - response_model=TableMetaResponse, + async def get_organization_metrics( + self, + metric_id: str, + from_: datetime, + org_id: str, + window_size: str | None = None, + proj_ids: list[str] | None = None, + to: datetime | None = None, + group_by: list[str] | None = None, + data_source: Literal["clickhouse", "victoriametrics"] = "clickhouse", + **kwargs, + ) -> UsageResponse: + params = { + "metricId": metric_id, + "from": from_.isoformat(), # Use string key to avoid keyword conflict + "orgId": org_id, + "windowSize": window_size, + "projIds": proj_ids, + "to": to.isoformat() if to else None, + "groupBy": group_by, + "dataSource": data_source, + } + return await self._get( + "/v2/organizations/meters/query", + params=params, + response_model=UsageResponse, + ) + + # async def get_billing_metrics( + # self, + # from_: datetime, + # window_size: str, + # org_id: str, + # proj_ids: list[str] | None = None, + # to: datetime | None = None, + # group_by: list[str] | None = None, + # **kwargs, + # ) -> dict: + # params = { + # "from": from_.isoformat(), + # "window_size": window_size, + # "org_id": org_id, + # "proj_ids": proj_ids, + # "to": to, + # "group_by": group_by, + # } + # return await self._get( + # "/v2/organizations/meters/billings", + # params=params, + # **kwargs, + # ) + + +class _ProjectsAsync(_ClientAsync): + """Project methods.""" + + async def create_project(self, body: ProjectCreate, **kwargs) -> ProjectRead: + return await self._post( + "/v2/projects", + body=body, + response_model=ProjectRead, + **kwargs, ) - def drop_columns( + async def list_projects( self, - table_type: str | TableType, - request: ColumnDropRequest, - ) -> TableMetaResponse: - """ - Drop columns from a table. - - Args: - table_type (str | TableType): The type of the table. - request (ColumnDropRequest): The column drop request. - - Returns: - response (TableMetaResponse): The table metadata response. - """ - return self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/columns/drop", - body=request, - response_model=TableMetaResponse, + organization_id: str, + *, + offset: int = 0, + limit: int = 100, + search_query: str = "", + order_by: str = "updated_at", + order_ascending: bool = True, + list_chat_agents: bool = False, + **kwargs, + ) -> Page[ProjectRead]: + return await self._get( + "/v2/projects/list", + params=dict( + offset=offset, + limit=limit, + search_query=search_query, + order_by=order_by, + order_ascending=order_ascending, + organization_id=organization_id, + list_chat_agents=list_chat_agents, + ), + response_model=Page[ProjectRead], + **kwargs, ) - def rename_columns( + async def get_project( self, - table_type: str | TableType, - request: ColumnRenameRequest, - ) -> TableMetaResponse: - """ - Rename columns in a table. + project_id: str, + **kwargs, + ) -> ProjectRead: + return await self._get( + "/v2/projects", + params=dict(project_id=project_id), + response_model=ProjectRead, + **kwargs, + ) - Args: - table_type (str | TableType): The type of the table. - request (ColumnRenameRequest): The column rename request. + async def update_project( + self, + project_id: str, + body: ProjectUpdate, + **kwargs, + ) -> ProjectRead: + return await self._patch( + "/v2/projects", + body=body, + params=dict(project_id=project_id), + response_model=ProjectRead, + **kwargs, + ) - Returns: - response (TableMetaResponse): The table metadata response. - """ - return self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/columns/rename", - body=request, - response_model=TableMetaResponse, + async def delete_project( + self, + project_id: str, + *, + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: + response = await self._delete( + "/v2/projects", + params=dict(project_id=project_id), + response_model=None, + ignore_code=404 if missing_ok else None, + **kwargs, ) + if response.status_code == 404 and missing_ok: + return OkResponse() + else: + return OkResponse.model_validate_json(response.text) - def reorder_columns( + async def create_invite( self, - table_type: str | TableType, - request: ColumnReorderRequest, - ) -> TableMetaResponse: + *, + user_email: str, + project_id: str, + role: str, + valid_days: int = 7, + **kwargs, + ) -> VerificationCodeRead: """ - Reorder columns in a table. + Generates an invite token to join a project. Args: - table_type (str | TableType): The type of the table. - request (ColumnReorderRequest): The column reorder request. + user_email (str): User email. + project_id (str): Project ID. + role (str): Project role. + valid_days (int, optional): Code validity in days. Defaults to 7. Returns: - response (TableMetaResponse): The table metadata response. + code (InviteCodeRead): Invite code. """ - return self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/columns/reorder", - body=request, - response_model=TableMetaResponse, + return await self._post( + "/v2/projects/invites", + params=dict( + user_email=user_email, + project_id=project_id, + role=role, + valid_days=valid_days, + ), + body=None, + response_model=VerificationCodeRead, + **kwargs, ) - def list_table_rows( + async def list_invites( self, - table_type: str | TableType, - table_id: str, + project_id: str, *, offset: int = 0, limit: int = 100, - search_query: str = "", - columns: list[str] | None = None, - float_decimals: int = 0, - vec_decimals: int = 0, - order_descending: bool = True, - ) -> Page[dict[str, Any]]: - """ - List rows in a table. - - Args: - table_type (str | TableType): The type of the table. - table_id (str): The ID of the table. - offset (int, optional): Item offset. Defaults to 0. - limit (int, optional): Number of rows to return (min 1, max 100). Defaults to 100. - search_query (str, optional): A string to search for within the rows as a filter. - Defaults to "" (no filter). - columns (list[str] | None, optional): List of column names to include in the response. - Defaults to None (all columns). - float_decimals (int, optional): Number of decimals for float values. - Defaults to 0 (no rounding). - vec_decimals (int, optional): Number of decimals for vectors. - If its negative, exclude vector columns. Defaults to 0 (no rounding). - order_descending (bool, optional): Whether to sort by descending order. Defaults to True. - - Returns: - response (Page[dict[str, Any]]): The paginated rows response. - """ - return self._get( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}/rows", + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[VerificationCodeRead]: + return await self._get( + "/v2/projects/invites/list", params=dict( offset=offset, limit=limit, - search_query=search_query, - columns=columns, - float_decimals=float_decimals, - vec_decimals=vec_decimals, - order_descending=order_descending, + order_by=order_by, + order_ascending=order_ascending, + project_id=project_id, ), - response_model=Page[dict[str, Any]], + response_model=Page[VerificationCodeRead], + **kwargs, ) - def get_table_row( + async def revoke_invite( self, - table_type: str | TableType, - table_id: str, - row_id: str, + invite_id: str, *, - columns: list[str] | None = None, - float_decimals: int = 0, - vec_decimals: int = 0, - ) -> dict[str, Any]: - """ - Get a specific row in a table. - - Args: - table_type (str | TableType): The type of the table. - table_id (str): The ID of the table. - row_id (str): The ID of the row. - columns (list[str] | None, optional): List of column names to include in the response. - Defaults to None (all columns). - float_decimals (int, optional): Number of decimals for float values. - Defaults to 0 (no rounding). - vec_decimals (int, optional): Number of decimals for vectors. - If its negative, exclude vector columns. Defaults to 0 (no rounding). - - Returns: - response (dict[str, Any]): The row data. - """ - response = self._get( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}/rows/{quote(row_id)}", - params=dict( - columns=columns, - float_decimals=float_decimals, - vec_decimals=vec_decimals, - ), + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: + response = await self._delete( + "/v2/projects/invites", + params=dict(invite_id=invite_id), response_model=None, + ignore_code=404 if missing_ok else None, + **kwargs, ) - return json_loads(response.text) + if response.status_code == 404 and missing_ok: + return OkResponse() + else: + return OkResponse.model_validate_json(response.text) - def add_table_rows( + @deprecated( + "`delete_invite` is deprecated, use `revoke_invite` instead.", + category=FutureWarning, + stacklevel=1, + ) + async def delete_invite( self, - table_type: str | TableType, - request: RowAddRequest, - ) -> GenTableChatResponseType: - """ - Add rows to a table. - - Args: - table_type (str | TableType): The type of the table. - request (RowAddRequest): The row add request. - - Returns: - response (GenTableChatResponseType): The row completion. - In streaming mode, it is a generator that yields a `GenTableStreamReferences` object - followed by zero or more `GenTableStreamChatCompletionChunk` objects. - In non-streaming mode, it is a `GenTableRowsChatCompletionChunks` object. - """ - if request.stream: - - def gen(): - for chunk in self._stream( - self.api_base, - f"/v1/gen_tables/{table_type}/rows/add", - body=request, - ): - chunk = json_loads(chunk[5:]) - if chunk["object"] == "gen_table.references": - yield GenTableStreamReferences.model_validate(chunk) - elif chunk["object"] == "gen_table.completion.chunk": - yield GenTableStreamChatCompletionChunk.model_validate(chunk) - else: - raise RuntimeError(f"Unexpected SSE chunk: {chunk}") - - return gen() - else: - return self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/rows/add", - body=request, - response_model=GenTableRowsChatCompletionChunks, - ) + invite_id: str, + *, + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: + return await self.revoke_invite(invite_id, missing_ok=missing_ok, **kwargs) - def regen_table_rows( + async def join_project( self, - table_type: str | TableType, - request: RowRegenRequest, - ) -> GenTableChatResponseType: - """ - Regenerate rows in a table. - - Args: - table_type (str | TableType): The type of the table. - request (RowRegenRequest): The row regenerate request. + user_id: str, + *, + invite_code: str | None = None, + project_id: str | None = None, + role: str | None = None, + **kwargs, + ) -> ProjectMemberRead: + return await self._post( + "/v2/projects/members", + params=dict( + user_id=user_id, + project_id=project_id, + role=role, + invite_code=invite_code, + ), + body=None, + response_model=ProjectMemberRead, + **kwargs, + ) - Returns: - response (GenTableChatResponseType): The row completion. - In streaming mode, it is a generator that yields a `GenTableStreamReferences` object - followed by zero or more `GenTableStreamChatCompletionChunk` objects. - In non-streaming mode, it is a `GenTableRowsChatCompletionChunks` object. - """ - if request.stream: + async def list_members( + self, + project_id: str, + *, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[ProjectMemberRead]: + return await self._get( + "/v2/projects/members/list", + params=dict( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + project_id=project_id, + ), + response_model=Page[ProjectMemberRead], + **kwargs, + ) - def gen(): - for chunk in self._stream( - self.api_base, - f"/v1/gen_tables/{table_type}/rows/regen", - body=request, - ): - chunk = json_loads(chunk[5:]) - if chunk["object"] == "gen_table.references": - yield GenTableStreamReferences.model_validate(chunk) - elif chunk["object"] == "gen_table.completion.chunk": - yield GenTableStreamChatCompletionChunk.model_validate(chunk) - else: - raise RuntimeError(f"Unexpected SSE chunk: {chunk}") + async def get_member( + self, + *, + user_id: str, + project_id: str, + **kwargs, + ) -> ProjectMemberRead: + return await self._get( + "/v2/projects/members", + params=dict(user_id=user_id, project_id=project_id), + response_model=ProjectMemberRead, + **kwargs, + ) - return gen() - else: - return self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/rows/regen", - body=request, - response_model=GenTableRowsChatCompletionChunks, - ) + async def update_member_role( + self, + *, + user_id: str, + project_id: str, + role: Role, + **kwargs, + ) -> ProjectMemberRead: + return await self._patch( + "/v2/projects/members/role", + params=dict(user_id=user_id, project_id=project_id, role=role), + response_model=ProjectMemberRead, + **kwargs, + ) - def update_table_row( + async def leave_project( self, - table_type: str | TableType, - request: RowUpdateRequest, + user_id: str, + project_id: str, + **kwargs, ) -> OkResponse: - """ - Update a specific row in a table. - - Args: - table_type (str | TableType): The type of the table. - request (RowUpdateRequest): The row update request. - - Returns: - response (OkResponse): The response indicating success. - """ - return self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/rows/update", - body=request, + return await self._delete( + "/v2/projects/members", + params=dict( + user_id=user_id, + project_id=project_id, + ), response_model=OkResponse, + **kwargs, ) - def delete_table_rows( + async def import_project( self, - table_type: str | TableType, - request: RowDeleteRequest, - ) -> OkResponse: + source: str | BinaryIO, + *, + project_id: str = "", + organization_id: str = "", + **kwargs, + ) -> ProjectRead: """ - Delete rows from a table. + Import a project. Args: - table_type (str | TableType): The type of the table. - request (RowDeleteRequest): The row delete request. + source (str | BinaryIO): The parquet file path or file-like object. + It can be a Project or Template file. + project_id (str, optional): If given, import tables into this project. + Defaults to "" (create new project). + organization_id (str): Organization ID of the new project. + Only required if creating a new project. Returns: - response (OkResponse): The response indicating success. + response (ProjectRead): The imported project. """ - return self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/rows/delete", - body=request, - response_model=OkResponse, + migrate = kwargs.pop("migrate", False) # Temporary, may be removed anytime + timeout = None if migrate else (kwargs.pop("timeout", None) or self.file_upload_timeout) + kw = dict( + endpoint=f"/v2/projects/import/parquet{'/migration' if migrate else ''}", + body=None, + response_model=ProjectRead, + data=dict(project_id=project_id, organization_id=organization_id), + timeout=timeout, + **kwargs, ) + mime_type = "application/octet-stream" + if isinstance(source, str): + filename = split(source)[-1] + # Open the file in binary mode + with open(source, "rb") as f: + return await self._post( + files={"file": (filename, f, mime_type)}, + **kw, + ) + else: + filename = "import.parquet" + return await self._post( + files={"file": (filename, source, mime_type)}, + **kw, + ) - def delete_table_row( + async def export_project( self, - table_type: str | TableType, - table_id: str, - row_id: str, - ) -> OkResponse: + project_id: str, + **kwargs, + ) -> bytes: """ - Delete a specific row from a table. + Export a project as a Project Parquet file. Args: - table_type (str | TableType): The type of the table. - table_id (str): The ID of the table. - row_id (str): The ID of the row. + project_id (str): Project ID "proj_xxx". Returns: - response (OkResponse): The response indicating success. + response (bytes): The Parquet file. """ - return self._delete( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}/rows/{quote(row_id)}", - params=None, - response_model=OkResponse, + response = await self._get( + "/v2/projects/export", + params=dict(project_id=project_id), + response_model=None, + **kwargs, ) + return response.content - def get_conversation_thread( + async def import_template( self, - table_type: str | TableType, - table_id: str, - column_id: str, + template_id: str, *, - row_id: str = "", - include: bool = True, - ) -> ChatThread: + project_id: str = "", + organization_id: str = "", + **kwargs, + ) -> ProjectRead: """ - Get the conversation thread for a chat table. + Import a Template. Args: - table_type (str | TableType): The type of the table. - table_id (str): ID / name of the chat table. - column_id (str): ID / name of the column to fetch. - row_id (str, optional): ID / name of the last row in the thread. - Defaults to "" (export all rows). - include (bool, optional): Whether to include the row specified by `row_id`. - Defaults to True. + template_id (str): Template ID "proj_xxx". + project_id (str, optional): If given, import tables into this project. + Defaults to "" (create new project). + organization_id (str): Organization ID of the new project. + Only required if creating a new project. Returns: - response (ChatThread): The conversation thread. + response (ProjectRead): The imported project. """ - return self._get( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}/thread", - params=dict(column_id=column_id, row_id=row_id, include=include), - response_model=ChatThread, + return await self._post( + "/v2/projects/import/template", + body=None, + params=dict( + template_id=template_id, + project_id=project_id, + organization_id=organization_id, + ), + response_model=ProjectRead, + **kwargs, ) - def hybrid_search( - self, - table_type: str | TableType, - request: SearchRequest, - ) -> list[dict[str, Any]]: - """ - Perform a hybrid search on a table. + # async def get_usage_metrics( + # self, + # type: str, + # from_: datetime, + # window_size: str, + # proj_id: str, + # to: datetime | None = None, + # group_by: list[str] | None = None, + # **kwargs, + # ) -> dict: + # params = { + # "type": type, + # "from": from_.isoformat(), + # "window_size": window_size, + # "proj_id": proj_id, + # "to": to, + # "group_by": group_by, + # } + # return await self._get( + # "/v2/projects/meters/usages", + # params=params, + # **kwargs, + # ) + + # async def get_billing_metrics( + # self, + # from_: datetime, + # window_size: str, + # proj_id: str, + # to: datetime | None = None, + # group_by: list[str] | None = None, + # **kwargs, + # ) -> dict: + # params = { + # "from": from_.isoformat(), + # "window_size": window_size, + # "proj_id": proj_id, + # "to": to, + # "group_by": group_by, + # } + # return await self._get( + # "/v2/projects/meters/billings", + # params=params, + # **kwargs, + # ) + + +class _TemplatesAsync(_ClientAsync): + """Template methods.""" - Args: - table_type (str | TableType): The type of the table. - request (SearchRequest): The search request. + async def list_templates( + self, + *, + offset: int = 0, + limit: int = 100, + search_query: str = "", + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[ProjectRead]: + return await self._get( + "/v2/templates/list", + params=dict( + offset=offset, + limit=limit, + search_query=search_query, + order_by=order_by, + order_ascending=order_ascending, + ), + response_model=Page[ProjectRead], + **kwargs, + ) - Returns: - response (list[dict[str, Any]]): The search results. - """ - response = self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/hybrid_search", - body=request, - response_model=None, + async def get_template(self, template_id: str, **kwargs) -> ProjectRead: + return await self._get( + "/v2/templates", + params=dict(template_id=template_id), + response_model=ProjectRead, + **kwargs, ) - return json_loads(response.text) - def embed_file_options(self) -> httpx.Response: - """ - Get options for embedding a file to a Knowledge Table. + async def list_tables( + self, + template_id: str, + table_type: str, + *, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + search_query: str = "", + parent_id: str | None = None, + count_rows: bool = False, + **kwargs, + ) -> Page[TableMetaResponse]: + return await self._get( + f"/v2/templates/gen_tables/{table_type}/list", + params=dict( + template_id=template_id, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + search_query=search_query, + parent_id=parent_id, + count_rows=count_rows, + ), + response_model=Page[TableMetaResponse], + **kwargs, + ) - Returns: - response (httpx.Response): The response containing options information. - """ - response = self._options( - self.api_base, - "/v1/gen_tables/knowledge/embed_file", + async def get_table( + self, + template_id: str, + table_type: str, + table_id: str, + **kwargs, + ) -> TableMetaResponse: + return await self._get( + f"/v2/templates/gen_tables/{table_type}", + params=dict( + template_id=template_id, + table_id=table_id, + ), + response_model=TableMetaResponse, + **kwargs, ) - return response - def embed_file( + async def list_table_rows( self, - file_path: str, + template_id: str, + table_type: str, table_id: str, *, - chunk_size: int = 1000, - chunk_overlap: int = 200, - ) -> OkResponse: - """ - Embed a file into a Knowledge Table. - - Args: - file_path (str): File path of the document to be embedded. - table_id (str): Knowledge Table ID / name. - chunk_size (int, optional): Maximum chunk size (number of characters). Must be > 0. - Defaults to 1000. - chunk_overlap (int, optional): Overlap in characters between chunks. Must be >= 0. - Defaults to 200. - - Returns: - response (OkResponse): The response indicating success. - """ - # Guess the MIME type of the file based on its extension - mime_type, _ = guess_type(file_path) - if mime_type is None: - mime_type = ( - "application/jsonl" if file_path.endswith(".jsonl") else "application/octet-stream" - ) # Default MIME type - # Extract the filename from the file path - filename = split(file_path)[-1] - # Open the file in binary mode - with open(file_path, "rb") as f: - response = self._post( - self.api_base, - "/v1/gen_tables/knowledge/embed_file", - body=None, - response_model=OkResponse, - files={ - "file": (filename, f, mime_type), - }, - data={ - "table_id": table_id, - "chunk_size": chunk_size, - "chunk_overlap": chunk_overlap, - # "overwrite": request.overwrite, - }, - timeout=self.file_upload_timeout, - ) - return response - - def import_table_data( - self, - table_type: str | TableType, - request: TableDataImportRequest, - ) -> GenTableChatResponseType: + offset: int = 0, + limit: int = 100, + order_by: str = "ID", + order_ascending: bool = True, + columns: list[str] | None = None, + search_query: str = "", + search_columns: list[str] | None = None, + float_decimals: int = 0, + vec_decimals: int = 0, + **kwargs, + ) -> Page[dict[str, Any]]: """ - Imports CSV or TSV data into a table. + List rows in a table. Args: - file_path (str): CSV or TSV file path. - table_type (str | TableType): Table type. - request (TableDataImportRequest): Data import request. - - Returns: - response (OkResponse): The response indicating success. + template_id (str): The ID of the template. + table_type (str): The type of the table. + table_id (str): The ID of the table. + offset (int, optional): Item offset. Defaults to 0. + limit (int, optional): Number of rows to return (min 1, max 100). Defaults to 100. + order_by (str, optional): Column name to order by. Defaults to "ID". + order_ascending (bool, optional): Whether to sort by ascending order. Defaults to True. + columns (list[str] | None, optional): List of column names to include in the response. + Defaults to None (all columns). + search_query (str, optional): A string to search for within the rows as a filter. + Defaults to "" (no filter). + search_columns (list[str] | None, optional): A list of column names to search for `search_query`. + Defaults to None (search all columns). + float_decimals (int, optional): Number of decimals for float values. + Defaults to 0 (no rounding). + vec_decimals (int, optional): Number of decimals for vectors. + If its negative, exclude vector columns. Defaults to 0 (no rounding). """ - # Guess the MIME type of the file based on its extension - mime_type, _ = guess_type(request.file_path) - if mime_type is None: - mime_type = "application/octet-stream" # Default MIME type - # Extract the filename from the file path - filename = split(request.file_path)[-1] - data = { - "table_id": request.table_id, - "stream": request.stream, - # "column_names": request.column_names, - # "columns": request.columns, - "delimiter": request.delimiter, - } - if request.stream: - - def gen(): - # Open the file in binary mode - with open(request.file_path, "rb") as f: - for chunk in self._stream( - self.api_base, - f"/v1/gen_tables/{table_type}/import_data", - body=None, - files={ - "file": (filename, f, mime_type), - }, - data=data, - timeout=self.file_upload_timeout, - ): - chunk = json_loads(chunk[5:]) - if chunk["object"] == "gen_table.references": - yield GenTableStreamReferences.model_validate(chunk) - elif chunk["object"] == "gen_table.completion.chunk": - yield GenTableStreamChatCompletionChunk.model_validate(chunk) - else: - raise RuntimeError(f"Unexpected SSE chunk: {chunk}") - - return gen() - else: - # Open the file in binary mode - with open(request.file_path, "rb") as f: - return self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/import_data", - body=None, - response_model=GenTableRowsChatCompletionChunks, - files={ - "file": (filename, f, mime_type), - }, - data=data, - timeout=self.file_upload_timeout, - ) + if columns is not None and not isinstance(columns, list): + raise TypeError("`columns` must be None or a list.") + return await self._get( + f"/v2/templates/gen_tables/{table_type}/rows/list", + params=dict( + template_id=template_id, + table_id=table_id, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + columns=columns, + search_query=search_query, + search_columns=search_columns, + float_decimals=float_decimals, + vec_decimals=vec_decimals, + ), + response_model=Page[dict[str, Any]], + **kwargs, + ) - def export_table_data( + async def get_table_row( self, - table_type: str | TableType, + template_id: str, + table_type: str, table_id: str, + row_id: str, *, columns: list[str] | None = None, - delimiter: Literal[",", "\t"] = ",", - ) -> bytes: + float_decimals: int = 0, + vec_decimals: int = 0, + **kwargs, + ) -> dict[str, Any]: """ - Exports the row data of a table as a CSV or TSV file. + Get a specific row in a table. Args: - table_type (str | TableType): Table type. - table_id (str): ID or name of the table to be exported. - delimiter (str, optional): The delimiter of the file: can be "," or "\\t". Defaults to ",". - columns (list[str], optional): A list of columns to be exported. Defaults to None (export all columns). + template_id (str): The ID of the template. + table_type (str): The type of the table. + table_id (str): The ID of the table. + row_id (str): The ID of the row. + columns (list[str] | None, optional): List of column names to include in the response. + Defaults to None (all columns). + float_decimals (int, optional): Number of decimals for float values. + Defaults to 0 (no rounding). + vec_decimals (int, optional): Number of decimals for vectors. + If its negative, exclude vector columns. Defaults to 0 (no rounding). Returns: - response (list[dict[str, Any]]): The search results. + response (dict[str, Any]): The row data. """ - response = self._get( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}/export_data", - params=dict(delimiter=delimiter, columns=columns), + if columns is not None and not isinstance(columns, list): + raise TypeError("`columns` must be None or a list.") + response = await self._get( + f"/v2/templates/gen_tables/{table_type}/rows", + params=dict( + template_id=template_id, + table_id=table_id, + row_id=row_id, + columns=columns, + float_decimals=float_decimals, + vec_decimals=vec_decimals, + ), response_model=None, + **kwargs, ) - return response.content + return json_loads(response.text) - def import_table( - self, - table_type: str | TableType, - request: TableImportRequest, - ) -> TableMetaResponse: + +class _FileClientAsync(_ClientAsync): + """File methods.""" + + async def upload_file(self, file_path: str, **kwargs) -> FileUploadResponse: """ - Imports a table (data and schema) from a parquet file. + Uploads a file to the server. Args: - file_path (str): The parquet file path. - table_type (str | TableType): Table type. - request (TableImportRequest): Table import request. + file_path (str): Path to the file to be uploaded. Returns: - response (TableMetaResponse): The table metadata response. + response (FileUploadResponse): The response containing the file URI. """ - mime_type = "application/octet-stream" - filename = split(request.file_path)[-1] - data = {"table_id_dst": request.table_id_dst} - # Open the file in binary mode - with open(request.file_path, "rb") as f: - return self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/import", + with open(file_path, "rb") as f: + return await self._post( + "/v2/files/upload", body=None, - response_model=TableMetaResponse, + response_model=FileUploadResponse, files={ - "file": (filename, f, mime_type), + "file": (basename(file_path), f, guess_mime(file_path)), }, - data=data, timeout=self.file_upload_timeout, + **kwargs, ) - def export_table( - self, - table_type: str | TableType, - table_id: str, - ) -> bytes: + async def get_raw_urls(self, uris: list[str], **kwargs) -> GetURLResponse: """ - Exports a table (data and schema) as a parquet file. + Get download URLs for raw files. Args: - table_type (str | TableType): Table type. - table_id (str): ID or name of the table to be exported. + uris (List[str]): List of file URIs to download. Returns: - response (list[dict[str, Any]]): The search results. + response (GetURLResponse): The response containing download information for the files. """ - response = self._get( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}/export", - params=None, - response_model=None, + return await self._post( + "/v2/files/url/raw", + body=GetURLRequest(uris=uris), + response_model=GetURLResponse, + **kwargs, ) - return response.content - -class JamAI(_Client): - def __init__( - self, - project_id: str = ENV_CONFIG.jamai_project_id, - token: str = ENV_CONFIG.jamai_token_plain, - api_base: str = ENV_CONFIG.jamai_api_base, - headers: dict | None = None, - timeout: float | None = ENV_CONFIG.jamai_timeout_sec, - file_upload_timeout: float | None = ENV_CONFIG.jamai_file_upload_timeout_sec, - *, - api_key: str = "", - ) -> None: + async def get_thumbnail_urls(self, uris: list[str], **kwargs) -> GetURLResponse: """ - Initialize the JamAI client. + Get download URLs for file thumbnails. Args: - project_id (str, optional): The project ID. - Defaults to "default", but can be overridden via - `JAMAI_PROJECT_ID` var in environment or `.env` file. - token (str, optional): Your Personal Access Token or organization API key (deprecated) for authentication. - Defaults to "", but can be overridden via - `JAMAI_TOKEN` var in environment or `.env` file. - api_base (str, optional): The base URL for the API. - Defaults to "https://api.jamaibase.com/api", but can be overridden via - `JAMAI_API_BASE` var in environment or `.env` file. - headers (dict | None, optional): Additional headers to include in requests. - Defaults to None. - timeout (float | None, optional): The timeout to use when sending requests. - Defaults to 15 minutes, but can be overridden via - `JAMAI_TIMEOUT_SEC` var in environment or `.env` file. - file_upload_timeout (float | None, optional): The timeout to use when sending file upload requests. - Defaults to 60 minutes, but can be overridden via - `JAMAI_FILE_UPLOAD_TIMEOUT_SEC` var in environment or `.env` file. - api_key (str, optional): (Deprecated) Organization API key for authentication. - """ - if api_key: - warn(ORG_API_KEY_DEPRECATE, FutureWarning, stacklevel=2) - http_client = httpx.Client( - timeout=timeout, - transport=httpx.HTTPTransport(retries=3), - ) - kwargs = dict( - project_id=project_id, - token=token or api_key, - api_base=api_base, - headers=headers, - http_client=http_client, - file_upload_timeout=file_upload_timeout, - ) - super().__init__(**kwargs) - self.admin = _AdminClient(**kwargs) - self.template = _TemplateClient(**kwargs) - self.file = _FileClient(**kwargs) - self.table = _GenTableClient(**kwargs) - - def health(self) -> dict[str, Any]: - """ - Get health status. + uris (List[str]): List of file URIs to get thumbnails for. Returns: - response (dict[str, Any]): Health status. + response (GetURLResponse): The response containing download information for the thumbnails. """ - response = self._get(self.api_base, "/health", response_model=None) - return json_loads(response.text) + return await self._post( + "/v2/files/url/thumb", + body=GetURLRequest(uris=uris), + response_model=GetURLResponse, + **kwargs, + ) - # --- Models and chat --- # - def model_info( +class _GenTableClientAsync(_ClientAsync): + """Generative Table methods.""" + + # Table CRUD + async def create_action_table( self, - name: str = "", - capabilities: list[ - Literal["completion", "chat", "image", "audio", "tool", "embed", "rerank"] - ] - | None = None, - ) -> ModelInfoResponse: + request: ActionTableSchemaCreate, + **kwargs, + ) -> TableMetaResponse: """ - Get information about available models. + Create an Action Table. Args: - name (str, optional): The model name. Defaults to "". - capabilities (list[Literal["completion", "chat", "image", "audio", "tool", "embed", "rerank"]] | None, optional): - List of model capabilities to filter by. Defaults to None. + request (ActionTableSchemaCreate): The action table schema. Returns: - response (ModelInfoResponse): The model information response. + response (TableMetaResponse): The table metadata response. """ - params = {"model": name, "capabilities": capabilities} - return self._get( - self.api_base, - "/v1/models", - params=params, - response_model=ModelInfoResponse, + v = "v1" if kwargs.pop("v1", False) else "v2" + return await self._post( + f"/{v}/gen_tables/action", + body=request, + response_model=TableMetaResponse, + **kwargs, ) - def model_names( + async def create_knowledge_table( self, - prefer: str = "", - capabilities: list[ - Literal["completion", "chat", "image", "audio", "tool", "embed", "rerank"] - ] - | None = None, - ) -> list[str]: - """ - Get the names of available models. - - Args: - prefer (str, optional): Preferred model name. Defaults to "". - capabilities (list[Literal["completion", "chat", "image", "audio", "tool", "embed", "rerank"]] | None, optional): - List of model capabilities to filter by. Defaults to None. - - Returns: - response (list[str]): List of model names. - """ - params = {"prefer": prefer, "capabilities": capabilities} - response = self._get( - self.api_base, - "/v1/model_names", - params=params, - response_model=None, - ) - return json_loads(response.text) - - def generate_chat_completions( - self, request: ChatRequest - ) -> ChatCompletionChunk | Generator[References | ChatCompletionChunk, None, None]: - """ - Generates chat completions. - - Args: - request (ChatRequest): The request. - - Returns: - completion (ChatCompletionChunk | Generator): The chat completion. - In streaming mode, it is a generator that yields a `References` object - followed by zero or more `ChatCompletionChunk` objects. - In non-streaming mode, it is a `ChatCompletionChunk` object. - """ - if request.stream: - - def gen(): - for chunk in self._stream( - self.api_base, - "/v1/chat/completions", - body=request, - ): - chunk = json_loads(chunk[5:]) - if chunk["object"] == "chat.references": - yield References.model_validate(chunk) - elif chunk["object"] == "chat.completion.chunk": - yield ChatCompletionChunk.model_validate(chunk) - else: - raise RuntimeError(f"Unexpected SSE chunk: {chunk}") - - return gen() - else: - return self._post( - self.api_base, - "/v1/chat/completions", - body=request, - response_model=ChatCompletionChunk, - ) - - def generate_embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + request: KnowledgeTableSchemaCreate, + **kwargs, + ) -> TableMetaResponse: """ - Generate embeddings for the given input. + Create a Knowledge Table. Args: - request (EmbeddingRequest): The embedding request. + request (KnowledgeTableSchemaCreate): The knowledge table schema. Returns: - response (EmbeddingResponse): The embedding response. + response (TableMetaResponse): The table metadata response. """ - return self._post( - self.api_base, - "/v1/embeddings", + v = "v1" if kwargs.pop("v1", False) else "v2" + return await self._post( + f"/{v}/gen_tables/knowledge", body=request, - response_model=EmbeddingResponse, + response_model=TableMetaResponse, + **kwargs, ) - # --- Gen Table --- # - - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def create_action_table(self, request: ActionTableSchemaCreate) -> TableMetaResponse: - """ - Create an Action Table. - - Args: - request (ActionTableSchemaCreate): The action table schema. - - Returns: - response (TableMetaResponse): The table metadata response. - """ - return self.table.create_action_table(request) - - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def create_knowledge_table(self, request: KnowledgeTableSchemaCreate) -> TableMetaResponse: + async def create_chat_table( + self, + request: ChatTableSchemaCreate, + **kwargs, + ) -> TableMetaResponse: """ - Create a Knowledge Table. + Create a Chat Table. Args: - request (KnowledgeTableSchemaCreate): The knowledge table schema. + request (ChatTableSchemaCreate): The chat table schema. Returns: response (TableMetaResponse): The table metadata response. """ - return self.table.create_knowledge_table(request) + v = "v1" if kwargs.pop("v1", False) else "v2" + return await self._post( + f"/{v}/gen_tables/chat", + body=request, + response_model=TableMetaResponse, + **kwargs, + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def create_chat_table(self, request: ChatTableSchemaCreate) -> TableMetaResponse: + async def duplicate_table( + self, + table_type: str, + table_id_src: str, + table_id_dst: str | None = None, + *, + include_data: bool = True, + create_as_child: bool = False, + **kwargs, + ) -> TableMetaResponse: """ - Create a Chat Table. + Duplicate a table. Args: - request (ChatTableSchemaCreate): The chat table schema. + table_type (str): The type of the table. + table_id_src (str): The source table ID. + table_id_dst (str | None, optional): The destination / new table ID. + Defaults to None (create a new table ID automatically). + include_data (bool, optional): Whether to include data in the duplicated table. Defaults to True. + create_as_child (bool, optional): Whether the new table is a child table. + If this is True, then `include_data` will be set to True. Defaults to False. Returns: response (TableMetaResponse): The table metadata response. """ - return self.table.create_chat_table(request) + if (deploy := kwargs.pop("deploy", None)) is not None: + warn( + 'The "deploy" argument is deprecated, use "create_as_child" instead.', + FutureWarning, + stacklevel=2, + ) + create_as_child = create_as_child or deploy + return await self._post( + f"/v1/gen_tables/{table_type}/duplicate/{quote(table_id_src)}" + if kwargs.pop("v1", False) + else f"/v2/gen_tables/{table_type}/duplicate", + body=None, + params=dict( + table_id_src=table_id_src, + table_id_dst=table_id_dst, + include_data=include_data, + create_as_child=create_as_child, + ), + response_model=TableMetaResponse, + **kwargs, + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def get_table( + async def get_table( self, - table_type: str | TableType, + table_type: str, table_id: str, + **kwargs, ) -> TableMetaResponse: """ Get metadata for a specific Generative Table. Args: - table_type (str | TableType): The type of the table. + table_type (str): The type of the table. table_id (str): The ID of the table. Returns: response (TableMetaResponse): The table metadata response. """ - return self.table.get_table(table_type, table_id) + return await self._get( + f"/v1/gen_tables/{table_type}/{quote(table_id)}" + if kwargs.pop("v1", False) + else f"/v2/gen_tables/{table_type}", + params=dict(table_id=table_id), + response_model=TableMetaResponse, + **kwargs, + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def list_tables( + async def list_tables( self, - table_type: str | TableType, + table_type: str, + *, offset: int = 0, limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + created_by: str | None = None, parent_id: str | None = None, search_query: str = "", - order_by: str = GenTableOrderBy.UPDATED_AT, - order_descending: bool = True, count_rows: bool = False, + **kwargs, ) -> Page[TableMetaResponse]: """ List Generative Tables of a specific type. Args: - table_type (str | TableType): The type of the table. + table_type (str): The type of the table. offset (int, optional): Item offset. Defaults to 0. limit (int, optional): Number of tables to return (min 1, max 100). Defaults to 100. + order_by (str, optional): Sort tables by this attribute. Defaults to "updated_at". + order_ascending (bool, optional): Whether to sort by ascending order. Defaults to True. + created_by (str | None, optional): Return tables created by this user. + Defaults to None (return all tables). parent_id (str | None, optional): Parent ID of tables to return. Additionally for Chat Table, you can list: (1) all chat agents by passing in "_agent_"; or @@ -2409,120 +2466,106 @@ def list_tables( Defaults to None (return all tables). search_query (str, optional): A string to search for within table IDs as a filter. Defaults to "" (no filter). - order_by (str, optional): Sort tables by this attribute. Defaults to "updated_at". - order_descending (bool, optional): Whether to sort by descending order. Defaults to True. count_rows (bool, optional): Whether to count the rows of the tables. Defaults to False. Returns: response (Page[TableMetaResponse]): The paginated table metadata response. """ - return self.table.list_tables( - table_type, - offset=offset, - limit=limit, - parent_id=parent_id, - search_query=search_query, - order_by=order_by, - order_descending=order_descending, - count_rows=count_rows, + if (order_descending := kwargs.pop("order_descending", None)) is not None: + warn( + 'The "order_descending" argument is deprecated, use "order_ascending" instead.', + FutureWarning, + stacklevel=2, + ) + order_ascending = not order_descending + return await self._get( + f"/v1/gen_tables/{table_type}" + if kwargs.pop("v1", False) + else f"/v2/gen_tables/{table_type}/list", + params=dict( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + created_by=created_by, + parent_id=parent_id, + search_query=search_query, + count_rows=count_rows, + ), + response_model=Page[TableMetaResponse], + **kwargs, ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def delete_table( - self, - table_type: str | TableType, - table_id: str, - *, - missing_ok: bool = True, - ) -> OkResponse: - """ - Delete a specific table. - - Args: - table_type (str | TableType): The type of the table. - table_id (str): The ID of the table. - missing_ok (bool, optional): Ignore resource not found error. - - Returns: - response (OkResponse): The response indicating success. - """ - return self.table.delete_table(table_type, table_id, missing_ok=missing_ok) - - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def duplicate_table( + async def rename_table( self, - table_type: str | TableType, + table_type: str, table_id_src: str, - table_id_dst: str | None = None, - *, - include_data: bool = True, - create_as_child: bool = False, + table_id_dst: str, **kwargs, ) -> TableMetaResponse: """ - Duplicate a table. + Rename a table. Args: - table_type (str | TableType): The type of the table. + table_type (str): The type of the table. table_id_src (str): The source table ID. - table_id_dst (str | None, optional): The destination / new table ID. - Defaults to None (create a new table ID automatically). - include_data (bool, optional): Whether to include data in the duplicated table. Defaults to True. - create_as_child (bool, optional): Whether the new table is a child table. - If this is True, then `include_data` will be set to True. Defaults to False. + table_id_dst (str): The destination / new table ID. Returns: response (TableMetaResponse): The table metadata response. """ - return self.table.duplicate_table( - table_type, - table_id_src, - table_id_dst, - include_data=include_data, - create_as_child=create_as_child, + return await self._post( + f"/v1/gen_tables/{table_type}/rename/{quote(table_id_src)}/{quote(table_id_dst)}" + if kwargs.pop("v1", False) + else f"/v2/gen_tables/{table_type}/rename", + params=dict( + table_id_src=table_id_src, + table_id_dst=table_id_dst, + ), + body=None, + response_model=TableMetaResponse, **kwargs, ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def rename_table( + async def delete_table( self, - table_type: str | TableType, - table_id_src: str, - table_id_dst: str, - ) -> TableMetaResponse: + table_type: str, + table_id: str, + *, + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: """ - Rename a table. + Delete a specific table. Args: - table_type (str | TableType): The type of the table. - table_id_src (str): The source table ID. - table_id_dst (str): The destination / new table ID. + table_type (str): The type of the table. + table_id (str): The ID of the table. + missing_ok (bool, optional): Ignore resource not found error. Returns: - response (TableMetaResponse): The table metadata response. + response (OkResponse): The response indicating success. """ - return self.table.rename_table(table_type, table_id_src, table_id_dst) + response = await self._delete( + f"/v1/gen_tables/{table_type}/{quote(table_id)}" + if kwargs.pop("v1", False) + else f"/v2/gen_tables/{table_type}", + params=dict(table_id=table_id), + response_model=None, + ignore_code=404 if missing_ok else None, + **kwargs, + ) + if response.status_code == 404 and missing_ok: + return OkResponse() + else: + return OkResponse.model_validate_json(response.text) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def update_gen_config( + # Column CRUD + async def add_action_columns( self, - table_type: str | TableType, - request: GenConfigUpdateRequest, + request: AddActionColumnSchema, + **kwargs, ) -> TableMetaResponse: - """ - Update the generation configuration for a table. - - Args: - table_type (str | TableType): The type of the table. - request (GenConfigUpdateRequest): The generation configuration update request. - - Returns: - response (TableMetaResponse): The table metadata response. - """ - return self.table.update_gen_config(table_type, request) - - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def add_action_columns(self, request: AddActionColumnSchema) -> TableMetaResponse: """ Add columns to an Action Table. @@ -2532,10 +2575,19 @@ def add_action_columns(self, request: AddActionColumnSchema) -> TableMetaRespons Returns: response (TableMetaResponse): The table metadata response. """ - return self.table.add_action_columns(request) + v = "v1" if kwargs.pop("v1", False) else "v2" + return await self._post( + f"/{v}/gen_tables/action/columns/add", + body=request, + response_model=TableMetaResponse, + **kwargs, + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def add_knowledge_columns(self, request: AddKnowledgeColumnSchema) -> TableMetaResponse: + async def add_knowledge_columns( + self, + request: AddKnowledgeColumnSchema, + **kwargs, + ) -> TableMetaResponse: """ Add columns to a Knowledge Table. @@ -2545,10 +2597,19 @@ def add_knowledge_columns(self, request: AddKnowledgeColumnSchema) -> TableMetaR Returns: response (TableMetaResponse): The table metadata response. """ - return self.table.add_knowledge_columns(request) + v = "v1" if kwargs.pop("v1", False) else "v2" + return await self._post( + f"/{v}/gen_tables/knowledge/columns/add", + body=request, + response_model=TableMetaResponse, + **kwargs, + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def add_chat_columns(self, request: AddChatColumnSchema) -> TableMetaResponse: + async def add_chat_columns( + self, + request: AddChatColumnSchema, + **kwargs, + ) -> TableMetaResponse: """ Add columns to a Chat Table. @@ -2558,122 +2619,246 @@ def add_chat_columns(self, request: AddChatColumnSchema) -> TableMetaResponse: Returns: response (TableMetaResponse): The table metadata response. """ - return self.table.add_chat_columns(request) + v = "v1" if kwargs.pop("v1", False) else "v2" + return await self._post( + f"/{v}/gen_tables/chat/columns/add", + body=request, + response_model=TableMetaResponse, + **kwargs, + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def drop_columns( + async def rename_columns( self, - table_type: str | TableType, - request: ColumnDropRequest, + table_type: str, + request: ColumnRenameRequest, + **kwargs, ) -> TableMetaResponse: """ - Drop columns from a table. + Rename columns in a table. Args: - table_type (str | TableType): The type of the table. - request (ColumnDropRequest): The column drop request. + table_type (str): The type of the table. + request (ColumnRenameRequest): The column rename request. Returns: response (TableMetaResponse): The table metadata response. """ - return self.table.drop_columns(table_type, request) + v = "v1" if kwargs.pop("v1", False) else "v2" + return await self._post( + f"/{v}/gen_tables/{table_type}/columns/rename", + body=request, + response_model=TableMetaResponse, + **kwargs, + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def rename_columns( + async def update_gen_config( self, - table_type: str | TableType, - request: ColumnRenameRequest, + table_type: str, + request: GenConfigUpdateRequest, + **kwargs, ) -> TableMetaResponse: """ - Rename columns in a table. + Update the generation configuration for a table. Args: - table_type (str | TableType): The type of the table. - request (ColumnRenameRequest): The column rename request. + table_type (str): The type of the table. + request (GenConfigUpdateRequest): The generation configuration update request. Returns: response (TableMetaResponse): The table metadata response. """ - return self.table.rename_columns(table_type, request) + if kwargs.pop("v1", False): + return await self._post( + f"/v1/gen_tables/{table_type}/gen_config/update", + body=request, + response_model=TableMetaResponse, + process_body_kwargs={"exclude_unset": True}, + **kwargs, + ) + return await self._patch( + f"/v2/gen_tables/{table_type}/gen_config", + body=request, + response_model=TableMetaResponse, + **kwargs, + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def reorder_columns( + async def reorder_columns( self, - table_type: str | TableType, + table_type: str, request: ColumnReorderRequest, + **kwargs, ) -> TableMetaResponse: """ Reorder columns in a table. Args: - table_type (str | TableType): The type of the table. + table_type (str): The type of the table. request (ColumnReorderRequest): The column reorder request. Returns: response (TableMetaResponse): The table metadata response. """ - return self.table.reorder_columns(table_type, request) + v = "v1" if kwargs.pop("v1", False) else "v2" + return await self._post( + f"/{v}/gen_tables/{table_type}/columns/reorder", + body=request, + response_model=TableMetaResponse, + **kwargs, + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def list_table_rows( + async def drop_columns( + self, + table_type: str, + request: ColumnDropRequest, + **kwargs, + ) -> TableMetaResponse: + """ + Drop columns from a table. + + Args: + table_type (str): The type of the table. + request (ColumnDropRequest): The column drop request. + + Returns: + response (TableMetaResponse): The table metadata response. + """ + v = "v1" if kwargs.pop("v1", False) else "v2" + return await self._post( + f"/{v}/gen_tables/{table_type}/columns/drop", + body=request, + response_model=TableMetaResponse, + **kwargs, + ) + + # Row CRUD + async def add_table_rows( + self, + table_type: str, + request: MultiRowAddRequest, + **kwargs, + ) -> ( + MultiRowCompletionResponse + | AsyncGenerator[CellReferencesResponse | CellCompletionResponse, None] + ): + """ + Add rows to a table. + + Args: + table_type (str): The type of the table. + request (MultiRowAddRequest): The row add request. + + Returns: + response (MultiRowCompletionResponse | AsyncGenerator): The row completion. + In streaming mode, it is an async generator that yields a `CellReferencesResponse` object + followed by zero or more `CellCompletionResponse` objects. + In non-streaming mode, it is a `MultiRowCompletionResponse` object. + """ + v = "v1" if kwargs.pop("v1", False) else "v2" + if request.stream: + agen = self._stream( + f"/{v}/gen_tables/{table_type}/rows/add", + body=request, + **kwargs, + ) + return await self._return_async_iterator( + agen, [CellCompletionResponse, CellReferencesResponse] + ) + else: + return await self._post( + f"/{v}/gen_tables/{table_type}/rows/add", + body=request, + response_model=MultiRowCompletionResponse, + **kwargs, + ) + + async def list_table_rows( self, - table_type: str | TableType, + table_type: str, table_id: str, *, offset: int = 0, limit: int = 100, - search_query: str = "", + order_by: str = "ID", + order_ascending: bool = True, columns: list[str] | None = None, + where: str = "", + search_query: str = "", + search_columns: list[str] | None = None, float_decimals: int = 0, vec_decimals: int = 0, - order_descending: bool = True, + **kwargs, ) -> Page[dict[str, Any]]: """ List rows in a table. Args: - table_type (str | TableType): The type of the table. + table_type (str): The type of the table. table_id (str): The ID of the table. offset (int, optional): Item offset. Defaults to 0. limit (int, optional): Number of rows to return (min 1, max 100). Defaults to 100. - search_query (str, optional): A string to search for within the rows as a filter. - Defaults to "" (no filter). + order_by (str, optional): Column name to order by. Defaults to "ID". + order_ascending (bool, optional): Whether to sort by ascending order. Defaults to True. columns (list[str] | None, optional): List of column names to include in the response. Defaults to None (all columns). + where (str, optional): SQL where clause. Can be nested ie `x = '1' AND ("y (1)" = 2 OR z = '3')`. + It will be combined other filters using `AND`. Defaults to "" (no filter). + search_query (str, optional): A string to search for within the rows as a filter. + Defaults to "" (no filter). + search_columns (list[str] | None, optional): A list of column names to search for `search_query`. + Defaults to None (search all columns). float_decimals (int, optional): Number of decimals for float values. Defaults to 0 (no rounding). vec_decimals (int, optional): Number of decimals for vectors. If its negative, exclude vector columns. Defaults to 0 (no rounding). - order_descending (bool, optional): Whether to sort by descending order. Defaults to True. """ - return self.table.list_table_rows( - table_type, - table_id, - offset=offset, - limit=limit, - search_query=search_query, - columns=columns, - float_decimals=float_decimals, - vec_decimals=vec_decimals, - order_descending=order_descending, + if (order_descending := kwargs.pop("order_descending", None)) is not None: + warn( + 'The "order_descending" argument is deprecated, use "order_ascending" instead.', + FutureWarning, + stacklevel=2, + ) + order_ascending = not order_descending + if columns is not None and not isinstance(columns, list): + raise TypeError("`columns` must be None or a list.") + if search_columns is not None and not isinstance(search_columns, list): + raise TypeError("`search_columns` must be None or a list.") + return await self._get( + f"/v1/gen_tables/{table_type}/{quote(table_id)}/rows" + if kwargs.pop("v1", False) + else f"/v2/gen_tables/{table_type}/rows/list", + params=dict( + table_id=table_id, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + columns=columns, + where=where, + search_query=search_query, + search_columns=search_columns, + float_decimals=float_decimals, + vec_decimals=vec_decimals, + ), + response_model=Page[dict[str, Any]], + **kwargs, ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def get_table_row( + async def get_table_row( self, - table_type: str | TableType, + table_type: str, table_id: str, row_id: str, *, columns: list[str] | None = None, float_decimals: int = 0, vec_decimals: int = 0, + **kwargs, ) -> dict[str, Any]: """ Get a specific row in a table. Args: - table_type (str | TableType): The type of the table. + table_type (str): The type of the table. table_id (str): The ID of the table. row_id (str): The ID of the row. columns (list[str] | None, optional): List of column names to include in the response. @@ -2686,228 +2871,401 @@ def get_table_row( Returns: response (dict[str, Any]): The row data. """ - return self.table.get_table_row( - table_type, - table_id, - row_id, - columns=columns, - float_decimals=float_decimals, - vec_decimals=vec_decimals, + if columns is not None and not isinstance(columns, list): + raise TypeError("`columns` must be None or a list.") + response = await self._get( + f"/v1/gen_tables/{table_type}/{quote(table_id)}/rows/{quote(row_id)}" + if kwargs.pop("v1", False) + else f"/v2/gen_tables/{table_type}/rows", + params=dict( + table_id=table_id, + row_id=row_id, + columns=columns, + float_decimals=float_decimals, + vec_decimals=vec_decimals, + ), + response_model=None, + **kwargs, ) + return json_loads(response.text) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def add_table_rows( + @deprecated( + "This method is deprecated, use `get_conversation_threads` instead.", + category=FutureWarning, + stacklevel=1, + ) + async def get_conversation_thread( self, - table_type: str | TableType, - request: RowAddRequest, - ) -> ( - GenTableRowsChatCompletionChunks - | AsyncGenerator[GenTableStreamReferences | GenTableStreamChatCompletionChunk, None] - ): + table_type: str, + table_id: str, + column_id: str, + *, + row_id: str = "", + include: bool = True, + **kwargs, + ) -> ChatThreadResponse: """ - Add rows to a table. + Get the conversation thread for a column in a table. Args: - table_type (str | TableType): The type of the table. - request (RowAddRequest): The row add request. + table_type (str): The type of the table. + table_id (str): ID / name of the chat table. + column_id (str): ID / name of the column to fetch. + row_id (str, optional): ID / name of the last row in the thread. + Defaults to "" (export all rows). + include (bool, optional): Whether to include the row specified by `row_id`. + Defaults to True. Returns: - response (GenTableRowsChatCompletionChunks | AsyncGenerator): The row completion. - In streaming mode, it is a generator that yields a `GenTableStreamReferences` object - followed by zero or more `GenTableStreamChatCompletionChunk` objects. - In non-streaming mode, it is a `GenTableRowsChatCompletionChunks` object. + response (ChatThreadResponse): The conversation thread. """ - return self.table.add_table_rows(table_type, request) + return await self._get( + f"/v1/gen_tables/{table_type}/{quote(table_id)}/thread", + params=dict( + table_id=table_id, + column_id=column_id, + row_id=row_id, + include=include, + ), + response_model=ChatThreadResponse, + **kwargs, + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def regen_table_rows( + async def get_conversation_threads( self, - table_type: str | TableType, - request: RowRegenRequest, - ) -> ( - GenTableRowsChatCompletionChunks - | AsyncGenerator[GenTableStreamReferences | GenTableStreamChatCompletionChunk, None] - ): + table_type: str, + table_id: str, + column_ids: list[str] | None = None, + *, + row_id: str = "", + include_row: bool = True, + **kwargs, + ) -> ChatThreadsResponse: """ - Regenerate rows in a table. + Get all multi-turn / conversation threads from a table. Args: - table_type (str | TableType): The type of the table. - request (RowRegenRequest): The row regenerate request. + table_type (str): The type of the table. + table_id (str): ID / name of the chat table. + column_ids (list[str] | None): Columns to fetch as conversation threads. + row_id (str, optional): ID / name of the last row in the thread. + Defaults to "" (export all rows). + include_row (bool, optional): Whether to include the row specified by `row_id`. + Defaults to True. Returns: - response (GenTableRowsChatCompletionChunks | AsyncGenerator): The row completion. - In streaming mode, it is a generator that yields a `GenTableStreamReferences` object - followed by zero or more `GenTableStreamChatCompletionChunk` objects. - In non-streaming mode, it is a `GenTableRowsChatCompletionChunks` object. + response (ChatThreadsResponse): The conversation threads. """ - return self.table.regen_table_rows(table_type, request) + return await self._get( + f"/v2/gen_tables/{table_type}/threads", + params=dict( + table_id=table_id, + column_ids=column_ids, + row_id=row_id, + include_row=include_row, + ), + response_model=ChatThreadsResponse, + **kwargs, + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def update_table_row( + async def hybrid_search( self, - table_type: str | TableType, - request: RowUpdateRequest, - ) -> OkResponse: + table_type: str, + request: SearchRequest, + **kwargs, + ) -> list[dict[str, Any]]: """ - Update a specific row in a table. + Perform a hybrid search on a table. Args: - table_type (str | TableType): The type of the table. - request (RowUpdateRequest): The row update request. + table_type (str): The type of the table. + request (SearchRequest): The search request. Returns: - response (OkResponse): The response indicating success. + response (list[dict[str, Any]]): The search results. """ - return self.table.update_table_row(table_type, request) + v = "v1" if kwargs.pop("v1", False) else "v2" + response = await self._post( + f"/{v}/gen_tables/{table_type}/hybrid_search", + body=request, + response_model=None, + **kwargs, + ) + return json_loads(response.text) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def delete_table_rows( + async def regen_table_rows( self, - table_type: str | TableType, - request: RowDeleteRequest, - ) -> OkResponse: + table_type: str, + request: MultiRowRegenRequest, + **kwargs, + ) -> ( + MultiRowCompletionResponse + | AsyncGenerator[CellReferencesResponse | CellCompletionResponse, None] + ): """ - Delete rows from a table. + Regenerate rows in a table. Args: - table_type (str | TableType): The type of the table. - request (RowDeleteRequest): The row delete request. + table_type (str): The type of the table. + request (MultiRowRegenRequest): The row regenerate request. Returns: - response (OkResponse): The response indicating success. + response (MultiRowCompletionResponse | AsyncGenerator): The row completion. + In streaming mode, it is an async generator that yields a `CellReferencesResponse` object + followed by zero or more `CellCompletionResponse` objects. + In non-streaming mode, it is a `MultiRowCompletionResponse` object. """ - return self.table.delete_table_rows(table_type, request) + v = "v1" if kwargs.pop("v1", False) else "v2" + if request.stream: + agen = self._stream( + f"/{v}/gen_tables/{table_type}/rows/regen", + body=request, + **kwargs, + ) + return await self._return_async_iterator( + agen, [CellCompletionResponse, CellReferencesResponse] + ) + else: + return await self._post( + f"/{v}/gen_tables/{table_type}/rows/regen", + body=request, + response_model=MultiRowCompletionResponse, + **kwargs, + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def delete_table_row( + async def update_table_rows( self, - table_type: str | TableType, - table_id: str, - row_id: str, + table_type: str, + request: MultiRowUpdateRequest, + **kwargs, ) -> OkResponse: """ - Delete a specific row from a table. + Update rows in a table. Args: - table_type (str | TableType): The type of the table. - table_id (str): The ID of the table. - row_id (str): The ID of the row. + table_type (str): The type of the table. + request (MultiRowUpdateRequest): The row update request. Returns: response (OkResponse): The response indicating success. """ - return self.table.delete_table_row(table_type, table_id, row_id) + return await self._patch( + f"/v2/gen_tables/{table_type}/rows", + body=request, + response_model=OkResponse, + **kwargs, + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def get_conversation_thread( + @deprecated( + "This method is deprecated, use `update_table_rows` instead.", + category=FutureWarning, + stacklevel=1, + ) + async def update_table_row( self, - table_type: str | TableType, - table_id: str, - column_id: str, - row_id: str = "", - include: bool = True, - ) -> ChatThread: + table_type: str, + request: RowUpdateRequest, + **kwargs, + ) -> OkResponse: """ - Get the conversation thread for a chat table. + Update a specific row in a table. Args: - table_type (str | TableType): The type of the table. - table_id (str): ID / name of the chat table. - column_id (str): ID / name of the column to fetch. - row_id (str, optional): ID / name of the last row in the thread. - Defaults to "" (export all rows). - include (bool, optional): Whether to include the row specified by `row_id`. - Defaults to True. + table_type (str): The type of the table. + request (RowUpdateRequest): The row update request. Returns: - response (ChatThread): The conversation thread. + response (OkResponse): The response indicating success. """ - return self.table.get_conversation_thread( - table_type, table_id, column_id, row_id=row_id, include=include + return await self._post( + f"/v1/gen_tables/{table_type}/rows/update", + body=request, + response_model=OkResponse, + **kwargs, ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def hybrid_search( + async def delete_table_rows( self, - table_type: str | TableType, - request: SearchRequest, - ) -> list[dict[str, Any]]: + table_type: str, + request: MultiRowDeleteRequest, + **kwargs, + ) -> OkResponse: """ - Perform a hybrid search on a table. + Delete rows from a table. Args: - table_type (str | TableType): The type of the table. - request (SearchRequest): The search request. + table_type (str): The type of the table. + request (MultiRowDeleteRequest): The row delete request. Returns: - response (list[dict[str, Any]]): The search results. + response (OkResponse): The response indicating success. """ - return self.table.hybrid_search(table_type, request) + v = "v1" if kwargs.pop("v1", False) else "v2" + return await self._post( + f"/{v}/gen_tables/{table_type}/rows/delete", + body=request, + response_model=OkResponse, + **kwargs, + ) @deprecated( - "This method is deprecated, use `client.table.embed_file_options` instead.", + "This method is deprecated, use `delete_table_rows` instead.", category=FutureWarning, stacklevel=1, ) - def upload_file_options(self) -> httpx.Response: - """ - Get options for uploading a file to a Knowledge Table. + async def delete_table_row( + self, + table_type: str, + table_id: str, + row_id: str, + **kwargs, + ) -> OkResponse: + """ + Delete a specific row from a table. + + Args: + table_type (str): The type of the table. + table_id (str): The ID of the table. + row_id (str): The ID of the row. + + Returns: + response (OkResponse): The response indicating success. + """ + return await self._delete( + f"/v1/gen_tables/{table_type}/{quote(table_id)}/rows/{quote(row_id)}", + params=None, + response_model=OkResponse, + **kwargs, + ) + + async def embed_file_options(self, **kwargs) -> httpx.Response: + """ + Get CORS preflight options for file embedding endpoint. Returns: response (httpx.Response): The response containing options information. """ - return self.table.embed_file_options() + v = "v1" if kwargs.pop("v1", False) else "v2" + response = await self._options( + f"/{v}/gen_tables/knowledge/embed_file", + **kwargs, + ) + return response - @deprecated( - "This method is deprecated, use `client.table.embed_file` instead.", - category=FutureWarning, - stacklevel=1, - ) - def upload_file(self, request: FileUploadRequest) -> OkResponse: + async def embed_file( + self, + file_path: str, + table_id: str, + *, + chunk_size: int = 1000, + chunk_overlap: int = 200, + **kwargs, + ) -> OkResponse: """ - Upload a file to a Knowledge Table. + Embed a file into a Knowledge Table. Args: - request (FileUploadRequest): The file upload request. + file_path (str): File path of the document to be embedded. + table_id (str): Knowledge Table ID / name. + chunk_size (int, optional): Maximum chunk size (number of characters). Must be > 0. + Defaults to 1000. + chunk_overlap (int, optional): Overlap in characters between chunks. Must be >= 0. + Defaults to 200. Returns: response (OkResponse): The response indicating success. """ - return self.table.embed_file(request) + v = "v1" if kwargs.pop("v1", False) else "v2" + # Open the file in binary mode + with open(file_path, "rb") as f: + response = await self._post( + f"/{v}/gen_tables/knowledge/embed_file", + body=None, + response_model=OkResponse, + files={ + "file": (basename(file_path), f, guess_mime(file_path)), + }, + data={ + "table_id": table_id, + "chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + # "overwrite": request.overwrite, + }, + timeout=self.file_upload_timeout, + **kwargs, + ) + return response - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def import_table_data( + # Import export + async def import_table_data( self, - table_type: str | TableType, + table_type: str, request: TableDataImportRequest, + **kwargs, ) -> GenTableChatResponseType: """ Imports CSV or TSV data into a table. Args: file_path (str): CSV or TSV file path. - table_type (str | TableType): Table type. + table_type (str): Table type. request (TableDataImportRequest): Data import request. Returns: response (OkResponse): The response indicating success. """ - return self.table.import_table_data(table_type, request) + v = "v1" if kwargs.pop("v1", False) else "v2" + data = { + "table_id": request.table_id, + "stream": request.stream, + # "column_names": request.column_names, + # "columns": request.columns, + "delimiter": request.delimiter, + } + file_path = request.file_path + if request.stream: + # Open the file in binary mode + with open(file_path, "rb") as f: + agen = self._stream( + f"/{v}/gen_tables/{table_type}/import_data", + body=None, + files={"file": (basename(file_path), f, guess_mime(file_path))}, + data=data, + timeout=self.file_upload_timeout, + **kwargs, + ) + return await self._return_async_iterator( + agen, [CellCompletionResponse, CellReferencesResponse] + ) + else: + # Open the file in binary mode + with open(request.file_path, "rb") as f: + return await self._post( + f"/{v}/gen_tables/{table_type}/import_data", + body=None, + response_model=MultiRowCompletionResponse, + files={ + "file": (basename(file_path), f, guess_mime(file_path)), + }, + data=data, + timeout=self.file_upload_timeout, + **kwargs, + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def export_table_data( + async def export_table_data( self, - table_type: str | TableType, + table_type: str, table_id: str, + *, columns: list[str] | None = None, delimiter: Literal[",", "\t"] = ",", + **kwargs, ) -> bytes: """ Exports the row data of a table as a CSV or TSV file. Args: - table_type (str | TableType): Table type. + table_type (str): Table type. table_id (str): ID or name of the table to be exported. delimiter (str, optional): The delimiter of the file: can be "," or "\\t". Defaults to ",". columns (list[str], optional): A list of columns to be exported. Defaults to None (export all columns). @@ -2915,2824 +3273,3154 @@ def export_table_data( Returns: response (list[dict[str, Any]]): The search results. """ - return self.table.export_table_data( - table_type, table_id, columns=columns, delimiter=delimiter + if columns is not None and not isinstance(columns, list): + raise TypeError("`columns` must be None or a list.") + response = await self._get( + f"/v1/gen_tables/{table_type}/{quote(table_id)}/export_data" + if kwargs.pop("v1", False) + else f"/v2/gen_tables/{table_type}/export_data", + params=dict(table_id=table_id, delimiter=delimiter, columns=columns), + response_model=None, + **kwargs, ) + return response.content - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def import_table( + async def import_table( self, - table_type: str | TableType, + table_type: str, request: TableImportRequest, - ) -> TableMetaResponse: + **kwargs, + ) -> TableMetaResponse | OkResponse: """ Imports a table (data and schema) from a parquet file. Args: file_path (str): The parquet file path. - table_type (str | TableType): Table type. + table_type (str): Table type. request (TableImportRequest): Table import request. Returns: - response (TableMetaResponse): The table metadata response. + response (TableMetaResponse | OkResponse): The table metadata response if blocking is True, + otherwise OkResponse. """ - return self.table.import_table(table_type, request) + migrate = kwargs.pop("migrate", False) # Temporary, may be removed anytime + timeout = None if migrate else (kwargs.pop("timeout", None) or self.file_upload_timeout) + v = "v1" if kwargs.pop("v1", False) else "v2" + mime_type = "application/octet-stream" + filename = split(request.file_path)[-1] + # Open the file in binary mode + with open(request.file_path, "rb") as f: + return await self._post( + f"/{v}/gen_tables/{table_type}/import", + body=None, + response_model=TableMetaResponse if request.blocking else OkResponse, + files={ + "file": (filename, f, mime_type), + }, + data=dict(**self._process_body(request), migrate=migrate), + timeout=timeout, + **kwargs, + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - def export_table( + async def export_table( self, - table_type: str | TableType, + table_type: str, table_id: str, + **kwargs, ) -> bytes: """ Exports a table (data and schema) as a parquet file. Args: - table_type (str | TableType): Table type. + table_type (str): Table type. table_id (str): ID or name of the table to be exported. Returns: response (list[dict[str, Any]]): The search results. """ - return self.table.export_table(table_type, table_id) - - -class _ClientAsync(_Client): - async def close(self) -> None: - """ - Close the HTTP async client. - """ - await self.http_client.aclose() + response = await self._get( + f"/v1/gen_tables/{table_type}/{quote(table_id)}/export" + if kwargs.pop("v1", False) + else f"/v2/gen_tables/{table_type}/export", + params=dict(table_id=table_id), + response_model=None, + **kwargs, + ) + return response.content - @staticmethod - async def raise_exception( - response: httpx.Response, - *, - ignore_code: int | None = None, - ) -> httpx.Response: - """ - Raise an exception if the response status code is not 200. - Args: - response (httpx.Response): The HTTP response. - ignore_code (int | None, optional): HTTP code to ignore. +class _MeterClientAsync(_ClientAsync): + """Meter methods.""" + + async def get_usage_metrics( + self, + type: Literal["llm", "embedding", "reranking"], + from_: datetime, + window_size: str, + org_ids: list[str] | None = None, + proj_ids: list[str] | None = None, + to: datetime | None = None, + group_by: list[str] | None = None, + data_source: Literal["clickhouse", "victoriametrics"] = "clickhouse", + ) -> UsageResponse: + params = { + "type": type, + "from": from_.isoformat(), # Use string key to avoid keyword conflict + "orgIds": org_ids, + "windowSize": window_size, + "projIds": proj_ids, + "to": to.isoformat() if to else None, + "groupBy": group_by, + "dataSource": data_source, + } + return await self._get( + "/v2/meters/usages", + params=params, + response_model=UsageResponse, + ) + + async def get_billing_metrics( + self, + from_: datetime, + window_size: str, + org_ids: list[str] | None = None, + proj_ids: list[str] | None = None, + to: datetime | None = None, + group_by: list[str] | None = None, + data_source: Literal["clickhouse", "victoriametrics"] = "clickhouse", + ) -> UsageResponse: + params = { + "from": from_.isoformat(), # Use string key to avoid keyword conflict + "orgIds": org_ids, + "windowSize": window_size, + "projIds": proj_ids, + "to": to.isoformat() if to else None, + "groupBy": group_by, + "dataSource": data_source, + } + return await self._get( + "/v2/meters/billings", + params=params, + response_model=UsageResponse, + ) + + async def get_bandwidth_metrics( + self, + from_: datetime, + window_size: str, + org_ids: list[str] | None = None, + proj_ids: list[str] | None = None, + to: datetime | None = None, + group_by: list[str] | None = None, + data_source: Literal["clickhouse", "victoriametrics"] = "clickhouse", + ) -> UsageResponse: + params = { + "from": from_.isoformat(), # Use string key to avoid keyword conflict + "orgIds": org_ids, + "windowSize": window_size, + "projIds": proj_ids, + "to": to.isoformat() if to else None, + "groupBy": group_by, + "dataSource": data_source, + } + return await self._get( + "/v2/meters/bandwidths", + params=params, + response_model=UsageResponse, + ) + + async def get_storage_metrics( + self, + from_: datetime, + window_size: str, + org_ids: list[str] | None = None, + proj_ids: list[str] | None = None, + to: datetime | None = None, + group_by: list[str] | None = None, + ) -> UsageResponse: + params = { + "from": from_.isoformat(), # Use string key to avoid keyword conflict + "orgIds": org_ids, + "windowSize": window_size, + "projIds": proj_ids, + "to": to.isoformat() if to else None, + "groupBy": group_by, + } + return await self._get( + "/v2/meters/storages", + params=params, + response_model=UsageResponse, + ) - Raises: - RuntimeError: If the response status code is not 200 and is not ignored by `ignore_code`. - Returns: - response (httpx.Response): The HTTP response. - """ - if "warning" in response.headers: - warn(response.headers["warning"], stacklevel=2) - code = response.status_code - if (200 <= code < 300) or code == ignore_code: - return response - try: - error = response.text - except httpx.ResponseNotRead: - error = (await response.aread()).decode() - error = json_loads(error) - err_mssg = error.get("message", error.get("detail", str(error))) - if code == 404: - exc = ResourceNotFoundError - else: - exc = RuntimeError - raise exc(err_mssg) +class _TaskClientAsync(_ClientAsync): + """Task methods.""" - async def _get( + async def get_progress( self, - address: str, - endpoint: str, - *, - params: dict[str, Any] | None = None, - response_model: Type[BaseModel] | None = None, + key: str, **kwargs, - ) -> httpx.Response | BaseModel: - """ - Make an asynchronous GET request to the specified endpoint. - - Args: - address (str): The base address of the API. - endpoint (str): The API endpoint. - params (dict[str, Any] | None, optional): Query parameters. - response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. - **kwargs (Any): Keyword arguments for `httpx.get`. - - Returns: - response (httpx.Response | BaseModel): The response text or Pydantic response object. - """ - response = await self.http_client.get( - f"{address}{endpoint}", - params=self._filter_params(params), - headers=self.headers, + ) -> dict[str, Any]: + response = await self._get( + "/v2/progress", + params=dict(key=key), + response_model=None, **kwargs, ) - response = await self.raise_exception(response) - if response_model is None: - return response - else: - return response_model.model_validate_json(response.text) + return json_loads(response.text) - async def _post( + async def poll_progress( self, - address: str, - endpoint: str, + key: str, *, - body: BaseModel | None, - response_model: Type[BaseModel] | None = None, - params: dict[str, Any] | None = None, + initial_wait: float = 0.5, + max_wait: float = 30 * 60.0, + verbose: bool = False, **kwargs, - ) -> httpx.Response | BaseModel: - """ - Make an asynchronous POST request to the specified endpoint. + ) -> dict[str, Any] | None: + from asyncio import sleep + + i = 1 + t0 = perf_counter() + while (perf_counter() - t0) < max_wait: + await sleep(min(initial_wait * i, 5.0)) + prog = await self.get_progress(key, **kwargs) + state = prog.get("state", None) + error = prog.get("error", None) + if verbose: + logger.info( + f"{self.__class__.__name__}: Progress: key={key} state={state}" + + (f" error={error}" if error else "") + ) + if state == ProgressState.COMPLETED: + return prog + elif state == ProgressState.FAILED: + raise JamaiException(prog.get("error", "Unknown error")) + i += 1 + return None - Args: - address (str): The base address of the API. - endpoint (str): The API endpoint. - body (BaseModel | None): The request body. - response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. Defaults to None. - params (dict[str, Any] | None, optional): Query parameters. Defaults to None. - **kwargs (Any): Keyword arguments for `httpx.post`. - Returns: - response (httpx.Response | BaseModel): The response text or Pydantic response object. - """ - if body is not None: - body = body.model_dump() - response = await self.http_client.post( - f"{address}{endpoint}", - json=body, - headers=self.headers, - params=self._filter_params(params), - **kwargs, - ) - response = await self.raise_exception(response) - if response_model is None: - return response - else: - return response_model.model_validate_json(response.text) +class _ConversationClientAsync(_ClientAsync): + """Conversation methods.""" - async def _options( + async def create_conversation( self, - address: str, - endpoint: str, - *, - params: dict[str, Any] | None = None, - response_model: Type[BaseModel] | None = None, + request: ConversationCreateRequest, **kwargs, - ) -> httpx.Response | BaseModel: + ) -> AsyncGenerator[ + ConversationMetaResponse | CellReferencesResponse | CellCompletionResponse, None + ]: """ - Make an asynchronous OPTIONS request to the specified endpoint. + Creates a new conversation and sends the first message. + Yields metadata first, then the message stream. + """ + agen = self._stream("/v2/conversations", body=request, **kwargs) + current_event = None + # Get the first chunk outside of the loop so that errors can be raised immediately + try: + chunk = await anext(agen) + except StopAsyncIteration: + # Return empty async generator + return self._empty_async_generator() + + def _process( + _chunk: str, + ) -> ConversationMetaResponse | CellCompletionResponse | CellReferencesResponse | None: + nonlocal current_event + if _chunk.startswith("event:"): + current_event = _chunk[6:].strip() + return None + + if _chunk.startswith("data:"): + data_obj = json_loads(_chunk[5:]) + + if current_event == "metadata": + # This is the special metadata event + current_event = None # Reset for next events + return ConversationMetaResponse.model_validate(data_obj) + else: + # This is a standard gen_table chunk + if data_obj.get("object") == "gen_table.completion.chunk": + return CellCompletionResponse.model_validate(data_obj) + elif data_obj.get("object") == "gen_table.references": + return CellReferencesResponse.model_validate(data_obj) + else: + pass - Args: - address (str): The base address of the API. - endpoint (str): The API endpoint. - params (dict[str, Any] | None, optional): Query parameters. Defaults to None. - response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. - **kwargs (Any): Keyword arguments for `httpx.options`. + async def gen(): + nonlocal chunk + res = _process(chunk) + if res is not None: + yield res + async for chunk in agen: + res = _process(chunk) + if res is not None: + yield res - Returns: - response (httpx.Response | BaseModel): The response or Pydantic response object. - """ - response = await self.http_client.options( - f"{address}{endpoint}", - params=await self._filter_params(params), - headers=self.headers, - **kwargs, - ) - response = await self.raise_exception(response) - if response_model is None: - return response - else: - return response_model.model_validate_json(response.text) + return gen() - async def _patch( + async def list_conversations( self, - address: str, - endpoint: str, - *, - body: BaseModel | None, - response_model: Type[BaseModel] | None = None, - params: dict[str, Any] | None = None, - **kwargs, - ) -> httpx.Response | BaseModel: - """ - Make an asynchronous PATCH request to the specified endpoint. - - Args: - address (str): The base address of the API. - endpoint (str): The API endpoint. - body (BaseModel | None): The request body. - response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. Defaults to None. - params (dict[str, Any] | None, optional): Query parameters. Defaults to None. - **kwargs (Any): Keyword arguments for `httpx.patch`. - - Returns: - response (httpx.Response | BaseModel): The response text or Pydantic response object. - """ - if body is not None: - body = body.model_dump() - response = await self.http_client.patch( - f"{address}{endpoint}", - json=body, - headers=self.headers, - params=self._filter_params(params), - **kwargs, - ) - response = await self.raise_exception(response) - if response_model is None: - return response - else: - return response_model.model_validate_json(response.text) - - async def _stream( - self, - address: str, - endpoint: str, - *, - body: BaseModel | None, - params: dict[str, Any] | None = None, - **kwargs, - ) -> AsyncGenerator[str, None]: - """ - Make an asynchronous streaming POST request to the specified endpoint. - - Args: - address (str): The base address of the API. - endpoint (str): The API endpoint. - body (BaseModel | None): The request body. - params (dict[str, Any] | None, optional): Query parameters. Defaults to None. - **kwargs (Any): Keyword arguments for `httpx.stream`. - - Yields: - str: The response chunks. - """ - if body is not None: - body = body.model_dump() - async with self.http_client.stream( - "POST", - f"{address}{endpoint}", - json=body, - headers=self.headers, - params=self._filter_params(params), - **kwargs, - ) as response: - response = await self.raise_exception(response) - async for chunk in response.aiter_lines(): - chunk = chunk.strip() - if chunk == "" or chunk == "data: [DONE]": - continue - yield chunk - - async def _delete( - self, - address: str, - endpoint: str, - *, - params: dict[str, Any] | None = None, - response_model: Type[BaseModel] | None = None, - ignore_code: int | None = None, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + search_query: str = "", **kwargs, - ) -> httpx.Response | BaseModel: - """ - Make a DELETE request to the specified endpoint. - - Args: - address (str): The base address of the API. - endpoint (str): The API endpoint. - params (dict[str, Any] | None, optional): Query parameters. Defaults to None. - response_model (Type[pydantic.BaseModel] | None, optional): The response model to return. - ignore_code (int | None, optional): HTTP code to ignore. - **kwargs (Any): Keyword arguments for `httpx.delete`. - - Returns: - response (httpx.Response | BaseModel): The response text or Pydantic response object. - """ - response = await self.http_client.delete( - f"{address}{endpoint}", - params=self._filter_params(params), - headers=self.headers, + ) -> Page[ConversationMetaResponse]: + """Lists all conversations for the authenticated user.""" + return await self._get( + "/v2/conversations/list", + params=dict( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + search_query=search_query, + ), + response_model=Page[ConversationMetaResponse], **kwargs, ) - response = await self.raise_exception(response, ignore_code=ignore_code) - if response_model is None: - return response - else: - return response_model.model_validate_json(response.text) - - -class _BackendAdminClientAsync(_ClientAsync): - """Backend administration methods.""" - - async def create_user(self, request: UserCreate) -> UserRead: - return await self._post( - self.api_base, - "/admin/backend/v1/users", - body=request, - response_model=UserRead, - ) - - async def update_user(self, request: UserUpdate) -> UserRead: - return await self._patch( - self.api_base, - "/admin/backend/v1/users", - body=request, - response_model=UserRead, - ) - async def list_users( + async def list_agents( self, offset: int = 0, limit: int = 100, - order_by: str = AdminOrderBy.UPDATED_AT, - order_descending: bool = True, - ) -> Page[UserRead]: + order_by: str = "updated_at", + order_ascending: bool = True, + search_query: str = "", + **kwargs, + ) -> Page[ConversationMetaResponse]: + """Lists all available agents for the authenticated user.""" return await self._get( - self.api_base, - "/admin/backend/v1/users", + "/v2/conversations/agents/list", params=dict( offset=offset, limit=limit, order_by=order_by, - order_descending=order_descending, + order_ascending=order_ascending, + search_query=search_query, ), - response_model=Page[UserRead], + response_model=Page[ConversationMetaResponse], + **kwargs, ) - async def get_user(self, user_id: str) -> UserRead: + async def get_conversation(self, conversation_id: str, **kwargs) -> ConversationMetaResponse: + """Fetches metadata for a single conversation.""" return await self._get( - self.api_base, - f"/admin/backend/v1/users/{quote(user_id)}", - params=None, - response_model=UserRead, + "/v2/conversations", + params={"conversation_id": conversation_id}, + response_model=ConversationMetaResponse, + **kwargs, ) - async def delete_user( - self, - user_id: str, - *, - missing_ok: bool = True, - ) -> OkResponse: - response = await self._delete( - self.api_base, - f"/admin/backend/v1/users/{quote(user_id)}", - params=None, - response_model=None, - ignore_code=404 if missing_ok else None, + async def get_agent(self, agent_id: str, **kwargs) -> AgentMetaResponse: + """Fetches metadata for a single agent.""" + return await self._get( + "/v2/conversations/agents", + params={"agent_id": agent_id}, + response_model=AgentMetaResponse, + **kwargs, ) - if response.status_code == 404 and missing_ok: - return OkResponse() - else: - return OkResponse.model_validate_json(response.text) - async def create_pat(self, request: PATCreate) -> PATRead: + async def generate_title( + self, + conversation_id: str, + **kwargs, + ) -> ConversationMetaResponse: + """Generates a title for a conversation.""" return await self._post( - self.api_base, - "/admin/backend/v1/pats", - body=request, - response_model=PATRead, + "/v2/conversations/title", + params=dict(conversation_id=conversation_id), + body=None, + response_model=ConversationMetaResponse, + **kwargs, ) - async def get_pat(self, pat: str) -> PATRead: - return await self._get( - self.api_base, - f"/admin/backend/v1/pats/{quote(pat)}", - params=None, - response_model=PATRead, + async def rename_conversation_title( + self, + conversation_id: str, + title: str, + **kwargs, + ) -> ConversationMetaResponse: + """Renames conversation title.""" + return await self._patch( + "/v2/conversations/title", + params=dict(conversation_id=conversation_id, title=title), + body=None, + response_model=ConversationMetaResponse, + **kwargs, ) - async def delete_pat( + async def delete_conversation( self, - pat: str, + conversation_id: str, *, missing_ok: bool = True, + **kwargs, ) -> OkResponse: + """Deletes a conversation permanently.""" response = await self._delete( - self.api_base, - f"/admin/backend/v1/pats/{quote(pat)}", - params=None, + "/v2/conversations", + params={"conversation_id": conversation_id}, response_model=None, ignore_code=404 if missing_ok else None, + **kwargs, ) if response.status_code == 404 and missing_ok: return OkResponse() else: return OkResponse.model_validate_json(response.text) - async def create_organization(self, request: OrganizationCreate) -> OrganizationRead: - return await self._post( - self.api_base, - "/admin/backend/v1/organizations", + async def send_message( + self, + request: MessageAddRequest, + **kwargs, + ) -> AsyncGenerator[CellReferencesResponse | CellCompletionResponse, None]: + """ + Sends a message to a conversation and streams back the response. + Note: This endpoint currently only supports streaming responses from the server. + """ + agen = self._stream( + "/v2/conversations/messages", body=request, - response_model=OrganizationRead, + **kwargs, ) - - async def update_organization(self, request: OrganizationUpdate) -> OrganizationRead: - return await self._patch( - self.api_base, - "/admin/backend/v1/organizations", - body=request, - response_model=OrganizationRead, + return await self._return_async_iterator( + agen, [CellCompletionResponse, CellReferencesResponse] ) - async def list_organizations( + async def list_messages( self, + conversation_id: str, offset: int = 0, limit: int = 100, - order_by: str = AdminOrderBy.UPDATED_AT, - order_descending: bool = True, - ) -> Page[OrganizationRead]: + order_by: str = "ID", + order_ascending: bool = True, + columns: list[str] | None = None, + search_query: str = "", + search_columns: list[str] | None = None, + float_decimals: int = 0, + vec_decimals: int = 0, + **kwargs, + ) -> Page[dict[str, Any]]: + """Fetches all messages in a conversation.""" return await self._get( - self.api_base, - "/admin/backend/v1/organizations", + "/v2/conversations/messages/list", params=dict( + conversation_id=conversation_id, offset=offset, limit=limit, order_by=order_by, - order_descending=order_descending, + order_ascending=order_ascending, + columns=columns, + search_query=search_query, + search_columns=search_columns, + float_decimals=float_decimals, + vec_decimals=vec_decimals, ), - response_model=Page[OrganizationRead], + response_model=Page[dict[str, Any]], + **kwargs, ) - async def get_organization(self, organization_id: str) -> OrganizationRead: - return await self._get( - self.api_base, - f"/admin/backend/v1/organizations/{quote(organization_id)}", - params=None, - response_model=OrganizationRead, + async def regen_message( + self, + request: MessagesRegenRequest, + **kwargs, + ) -> AsyncGenerator[CellReferencesResponse | CellCompletionResponse, None]: + """ + Regenerates a message in a conversation and streams back the response. + """ + agen = self._stream( + "/v2/conversations/messages/regen", + body=request, + **kwargs, + ) + return await self._return_async_iterator( + agen, [CellCompletionResponse, CellReferencesResponse] ) - async def delete_organization( + async def update_message( self, - organization_id: str, - *, - missing_ok: bool = True, + request: MessageUpdateRequest, + **kwargs, ) -> OkResponse: - response = await self._delete( - self.api_base, - f"/admin/backend/v1/organizations/{quote(organization_id)}", - params=None, - response_model=None, - ignore_code=404 if missing_ok else None, + """Updates a specific message within a conversation.""" + return await self._patch( + "/v2/conversations/messages", + body=request, + response_model=OkResponse, + **kwargs, ) - if response.status_code == 404 and missing_ok: - return OkResponse() - else: - return OkResponse.model_validate_json(response.text) - async def generate_invite_token( + async def get_threads( self, - organization_id: str, - user_email: str = "", - valid_days: int = 7, - ) -> str: + conversation_id: str, + column_ids: list[str] | None = None, + **kwargs, + ) -> ConversationThreadsResponse: """ - Generates an invite token to join an organization. + Get all threads from a conversation. Args: - organization_id (str): Organization ID. - user_email (str, optional): User email. - Leave blank to disable email check and generate a public invite. Defaults to "". - valid_days (int, optional): How many days should this link be valid for. Defaults to 7. + conversation_id (str): Conversation ID. + column_ids (list[str] | None): Columns to fetch as conversation threads. Returns: - token (str): _description_ + response (ConversationThreadsResponse): The conversation threads. """ - response = await self._get( - self.api_base, - "/admin/backend/v1/invite_tokens", + return await self._get( + "/v2/conversations/threads", params=dict( - organization_id=organization_id, user_email=user_email, valid_days=valid_days + conversation_id=conversation_id, + column_ids=column_ids, ), - response_model=None, - ) - return response.text - - async def join_organization(self, request: OrgMemberCreate) -> OrgMemberRead: - return await self._post( - self.api_base, - "/admin/backend/v1/organizations/link", - body=request, - response_model=OrgMemberRead, - ) - - async def leave_organization(self, user_id: str, organization_id: str) -> OkResponse: - return await self._delete( - self.api_base, - f"/admin/backend/v1/organizations/link/{quote(user_id)}/{quote(organization_id)}", - params=None, - response_model=OkResponse, - ) - - async def create_api_key(self, request: ApiKeyCreate) -> ApiKeyRead: - warn(ORG_API_KEY_DEPRECATE, FutureWarning, stacklevel=2) - return await self._post( - self.api_base, - "/admin/backend/v1/api_keys", - body=request, - response_model=ApiKeyRead, + response_model=ConversationThreadsResponse, + **kwargs, ) - async def get_api_key(self, api_key: str) -> ApiKeyRead: - warn(ORG_API_KEY_DEPRECATE, FutureWarning, stacklevel=2) - return await self._get( - self.api_base, - f"/admin/backend/v1/api_keys/{quote(api_key)}", - params=None, - response_model=ApiKeyRead, - ) - async def delete_api_key( +class JamAIAsync(_ClientAsync): + def __init__( self, - api_key: str, + project_id: str = ENV_CONFIG.project_id, + token: str = ENV_CONFIG.token_plain, + api_base: str = ENV_CONFIG.api_base, + headers: dict | None = None, + timeout: float | None = ENV_CONFIG.timeout_sec, + file_upload_timeout: float | None = ENV_CONFIG.file_upload_timeout_sec, *, - missing_ok: bool = True, - ) -> OkResponse: - warn(ORG_API_KEY_DEPRECATE, FutureWarning, stacklevel=2) - response = await self._delete( - self.api_base, - f"/admin/backend/v1/api_keys/{quote(api_key)}", - params=None, - response_model=None, - ignore_code=404 if missing_ok else None, - ) - if response.status_code == 404 and missing_ok: - return OkResponse() - else: - return OkResponse.model_validate_json(response.text) + user_id: str = "", + ) -> None: + """ + Initialize the JamAI async client. - async def refresh_quota( - self, - organization_id: str, - reset_usage: bool = True, - ) -> OrganizationRead: - return await self._post( - self.api_base, - f"/admin/backend/v1/quotas/refresh/{quote(organization_id)}", - body=None, - params=dict(reset_usage=reset_usage), - response_model=OrganizationRead, + Args: + project_id (str, optional): The project ID. + Defaults to "default", but can be overridden via + `JAMAI_PROJECT_ID` var in environment or `.env` file. + token (str, optional): Your Personal Access Token or organization API key (deprecated) for authentication. + Defaults to "", but can be overridden via + `JAMAI_TOKEN` var in environment or `.env` file. + api_base (str, optional): The base URL for the API. + Defaults to "https://api.jamaibase.com/api", but can be overridden via + `JAMAI_API_BASE` var in environment or `.env` file. + headers (dict | None, optional): Additional headers to include in requests. + Defaults to None. + timeout (float | None, optional): The timeout to use when sending requests. + Defaults to 15 minutes, but can be overridden via + `JAMAI_TIMEOUT_SEC` var in environment or `.env` file. + file_upload_timeout (float | None, optional): The timeout to use when sending file upload requests. + Defaults to 60 minutes, but can be overridden via + `JAMAI_FILE_UPLOAD_TIMEOUT_SEC` var in environment or `.env` file. + user_id (str, optional): User ID. For development purposes. + Defaults to "". + """ + if not isinstance(project_id, str): + raise TypeError("`project_id` must be a string.") + if not isinstance(token, str): + raise TypeError("`token` must be a string.") + if not isinstance(api_base, str): + raise TypeError("`api_base` must be a string.") + if not (isinstance(headers, dict) or headers is None): + raise TypeError("`headers` must be a dict or None.") + if not (isinstance(timeout, (float, int)) or timeout is None): + raise TypeError("`timeout` must be a float, int or None.") + if not (isinstance(file_upload_timeout, (float, int)) or file_upload_timeout is None): + raise TypeError("`file_upload_timeout` must be a float, int or None.") + if not isinstance(user_id, str): + raise TypeError("`user_id` must be a string.") + http_client = httpx.AsyncClient( + timeout=timeout, + transport=httpx.AsyncHTTPTransport(retries=3), ) - - async def get_event(self, event_id: str) -> EventRead: - return await self._get( - self.api_base, - f"/admin/backend/v1/events/{quote(event_id)}", - params=None, - response_model=EventRead, + kwargs = dict( + user_id=user_id, + project_id=project_id, + token=token, + api_base=api_base, + headers=headers, + http_client=http_client, + timeout=timeout, + file_upload_timeout=file_upload_timeout, ) + super().__init__(**kwargs) + self.auth = _AuthAsync(**kwargs) + self.prices = _PricesAsync(**kwargs) + self.users = _UsersAsync(**kwargs) + self.models = _ModelsAsync(**kwargs) + self.organizations = _OrganizationsAsync(**kwargs) + self.projects = _ProjectsAsync(**kwargs) + self.templates = _TemplatesAsync(**kwargs) + self.file = _FileClientAsync(**kwargs) + self.table = _GenTableClientAsync(**kwargs) + self.meters = _MeterClientAsync(**kwargs) + self.tasks = _TaskClientAsync(**kwargs) + self.conversations = _ConversationClientAsync(**kwargs) - async def add_event(self, request: EventCreate) -> OkResponse: - return await self._post( - self.api_base, - "/admin/backend/v1/events", - body=request, - response_model=OkResponse, - ) + async def health(self) -> dict[str, Any]: + """ + Get health status. - async def mark_event_as_done(self, event_id: str) -> OkResponse: - return await self._patch( - self.api_base, - f"/admin/backend/v1/events/done/{quote(event_id)}", - body=None, - response_model=OkResponse, - ) + Returns: + response (dict[str, Any]): Health status. + """ + response = await self._get("/health", response_model=None) + return json_loads(response.text) - async def get_internal_organization_id(self) -> StringResponse: - return await self._get( - self.api_base, - "/admin/backend/v1/internal_organization_id", - params=None, - response_model=StringResponse, - ) + # --- Models and chat --- # - async def set_internal_organization_id(self, organization_id: str) -> OkResponse: - return await self._patch( - self.api_base, - f"/admin/backend/v1/internal_organization_id/{quote(organization_id)}", - body=None, - response_model=OkResponse, - ) + async def model_info( + self, + model: str = "", + capabilities: list[ + Literal["completion", "chat", "image", "audio", "document", "tool", "embed", "rerank"] + ] + | None = None, + **kwargs, + ) -> ModelInfoListResponse: + """ + Get information about available models. + + Args: + name (str, optional): The model name. Defaults to "". + capabilities (list[Literal["completion", "chat", "image", "audio", "document", "tool", "embed", "rerank"]] | None, optional): + List of model capabilities to filter by. Defaults to None. - async def get_pricing(self) -> Price: + Returns: + response (ModelInfoListResponse): The model information response. + """ + if (name := kwargs.pop("name", None)) is not None: + warnings.warn( + "'name' parameter is deprecated, use 'model' instead.", + DeprecationWarning, + stacklevel=2, + ) + model = name return await self._get( - self.api_base, - "/public/v1/prices/plans", - params=None, - response_model=Price, + "/v1/models", + params=dict(model=model, capabilities=capabilities), + response_model=ModelInfoListResponse, + **kwargs, ) - async def set_pricing(self, request: Price) -> OkResponse: - return await self._patch( - self.api_base, - "/admin/backend/v1/prices/plans", - body=request, - response_model=OkResponse, - ) + async def model_ids( + self, + prefer: str = "", + capabilities: list[ + Literal["completion", "chat", "image", "audio", "document", "tool", "embed", "rerank"] + ] + | None = None, + **kwargs, + ) -> list[str]: + """ + Get the IDs of available models. - async def get_model_pricing(self) -> ModelPrice: - return await self._get( - self.api_base, - "/public/v1/prices/models", - params=None, - response_model=ModelPrice, - ) + Args: + prefer (str, optional): Preferred model ID. Defaults to "". + capabilities (list[Literal["completion", "chat", "image", "audio", "document", "tool", "embed", "rerank"]] | None, optional): + List of model capabilities to filter by. Defaults to None. - async def get_model_config(self) -> ModelListConfig: - return await self._get( - self.api_base, - "/admin/backend/v1/models", - params=None, - response_model=ModelListConfig, + Returns: + response (list[str]): List of model IDs. + """ + params = {"prefer": prefer, "capabilities": capabilities} + response = await self._get( + "/v1/models/ids", + params=params, + response_model=None, + **kwargs, ) + return json_loads(response.text) - async def set_model_config(self, request: ModelListConfig) -> OkResponse: - return await self._patch( - self.api_base, - "/admin/backend/v1/models", - body=request, - response_model=OkResponse, - ) + @deprecated( + "This method is deprecated, use `model_ids` instead.", category=FutureWarning, stacklevel=1 + ) + async def model_names( + self, + prefer: str = "", + capabilities: list[ + Literal["completion", "chat", "image", "audio", "document", "tool", "embed", "rerank"] + ] + | None = None, + **kwargs, + ) -> list[str]: + return await self.model_ids(prefer=prefer, capabilities=capabilities, **kwargs) - async def add_template( + async def generate_chat_completions( self, - source: str | BinaryIO, - template_id_dst: str, - exist_ok: bool = False, - ) -> OkResponse: + request: ChatRequest, + **kwargs, + ) -> ChatCompletionResponse | AsyncGenerator[References | ChatCompletionChunkResponse, None]: """ - Upload a template Parquet file to add a new template into gallery. + Generates chat completions. Args: - source (str | BinaryIO): The path to the template Parquet file or a file-like object. - template_id_dst (str): The ID of the new template. - exist_ok (bool, optional): Whether to overwrite existing template. Defaults to False. + request (ChatRequest): The request. Returns: - response (OkResponse): The response indicating success. + completion (ChatCompletionChunkResponse | AsyncGenerator): The chat completion. + In streaming mode, it is an async generator that yields a `References` object + followed by zero or more `ChatCompletionChunkResponse` objects. + In non-streaming mode, it is a `ChatCompletionChunkResponse` object. """ - kwargs = dict( - address=self.api_base, - endpoint="/admin/backend/v1/templates/import", - body=None, - response_model=OkResponse, - data={"template_id_dst": template_id_dst, "exist_ok": exist_ok}, - timeout=self.file_upload_timeout, - ) - mime_type = "application/octet-stream" - if isinstance(source, str): - filename = split(source)[-1] - # Open the file in binary mode - with open(source, "rb") as f: - return await self._post(files={"file": (filename, f, mime_type)}, **kwargs) + body = self._process_body(request) + if request.stream: + agen = self._stream("/v1/chat/completions", body=body, **kwargs) + return await self._return_async_iterator( + agen, [ChatCompletionChunkResponse, References] + ) else: - filename = "import.parquet" - return await self._post(files={"file": (filename, source, mime_type)}, **kwargs) + return await self._post( + "/v1/chat/completions", + body=body, + response_model=ChatCompletionResponse, + **kwargs, + ) - async def populate_templates(self, timeout: float = 30.0) -> OkResponse: + async def generate_embeddings( + self, + request: EmbeddingRequest, + **kwargs, + ) -> EmbeddingResponse: """ - Re-populates the template gallery. + Generate embeddings for the given input. Args: - timeout (float, optional): Timeout in seconds, must be >= 0. Defaults to 30.0. + request (EmbeddingRequest): The embedding request. Returns: - response (OkResponse): The response indicating success. + response (EmbeddingResponse): The embedding response. """ return await self._post( - self.api_base, - "/admin/backend/v1/templates/populate", - body=None, - params=dict(timeout=timeout), - response_model=OkResponse, + "/v1/embeddings", + body=request, + response_model=EmbeddingResponse, + **kwargs, ) + async def rerank(self, request: RerankingRequest, **kwargs) -> RerankingResponse: + """ + Generate similarity rankings for the given query and documents. -class _OrgAdminClientAsync(_ClientAsync): - """Organization administration methods.""" - - async def get_org_model_config(self, organization_id: str) -> ModelListConfig: - return await self._get( - self.api_base, - f"/admin/org/v1/models/{quote(organization_id)}", - params=None, - response_model=ModelListConfig, - ) - - async def set_org_model_config( - self, - organization_id: str, - config: ModelListConfig, - ) -> OkResponse: - return await self._patch( - self.api_base, - f"/admin/org/v1/models/{quote(organization_id)}", - body=config, - response_model=OkResponse, - ) + Args: + request (RerankingRequest): The reranking request body. - async def create_project(self, request: ProjectCreate) -> ProjectRead: + Returns: + RerankingResponse: The reranking response. + """ return await self._post( - self.api_base, - "/admin/org/v1/projects", + "/v1/rerank", body=request, - response_model=ProjectRead, + response_model=RerankingResponse, + **kwargs, ) - async def update_project(self, request: ProjectUpdate) -> ProjectRead: - return await self._patch( - self.api_base, - "/admin/org/v1/projects", - body=request, - response_model=ProjectRead, - ) - async def set_project_updated_at( +class _Auth(_AuthAsync): + """Auth methods.""" + + def register_password(self, body: UserCreate, **kwargs) -> UserRead: + return LOOP.run(super().register_password(body, **kwargs)) + + def login_password(self, body: PasswordLoginRequest, **kwargs) -> UserRead: + return LOOP.run(super().login_password(body, **kwargs)) + + def change_password(self, body: PasswordChangeRequest, **kwargs) -> UserRead: + return LOOP.run(super().change_password(body, **kwargs)) + + +class _Prices(_PricesAsync): + """Prices methods.""" + + def create_price_plan( self, - project_id: str, - updated_at: str | None = None, - ) -> OkResponse: - return await self._patch( - self.api_base, - f"/admin/org/v1/projects/{quote(project_id)}", - body=None, - params=dict(updated_at=updated_at), - response_model=OkResponse, - ) + body: PricePlanCreate, + **kwargs, + ) -> PricePlanRead: + return LOOP.run(super().create_price_plan(body, **kwargs)) - async def list_projects( + def list_price_plans( self, - organization_id: str = "default", - search_query: str = "", + *, offset: int = 0, limit: int = 100, - order_by: str = AdminOrderBy.UPDATED_AT, - order_descending: bool = True, - ) -> Page[ProjectRead]: - return await self._get( - self.api_base, - "/admin/org/v1/projects", - params=dict( - organization_id=organization_id, - search_query=search_query, + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[PricePlanRead]: + return LOOP.run( + super().list_price_plans( offset=offset, limit=limit, order_by=order_by, - order_descending=order_descending, - ), - response_model=Page[ProjectRead], + order_ascending=order_ascending, + **kwargs, + ) ) - async def get_project(self, project_id: str) -> ProjectRead: - return await self._get( - self.api_base, - f"/admin/org/v1/projects/{quote(project_id)}", - params=None, - response_model=ProjectRead, - ) + def get_price_plan( + self, + plan_id: str, + **kwargs, + ) -> PricePlanRead: + return LOOP.run(super().get_price_plan(plan_id=plan_id, **kwargs)) - async def delete_project( + def update_price_plan( self, - project_id: str, + plan_id: str, + body: PricePlanUpdate, + **kwargs, + ) -> PricePlanRead: + return LOOP.run(super().update_price_plan(plan_id, body, **kwargs)) + + def delete_price_plan( + self, + price_plan_id: str, *, missing_ok: bool = True, - ) -> OkResponse: - response = await self._delete( - self.api_base, - f"/admin/org/v1/projects/{quote(project_id)}", - params=None, - response_model=None, - ignore_code=404 if missing_ok else None, - ) - if response.status_code == 404 and missing_ok: - return OkResponse() - else: - return OkResponse.model_validate_json(response.text) + **kwargs, + ) -> None: + return LOOP.run(super().delete_price_plan(price_plan_id, missing_ok=missing_ok, **kwargs)) - async def import_project( - self, - source: str | BinaryIO, - organization_id: str, - project_id_dst: str = "", - ) -> ProjectRead: - """ - Imports a project. + def list_model_prices(self, **kwargs) -> ModelPrice: + return LOOP.run(super().list_model_prices(**kwargs)) - Args: - source (str | BinaryIO): The parquet file path or file-like object. - It can be a Project or Template file. - organization_id (str): Organization ID "org_xxx". - project_id_dst (str, optional): ID of the project to import tables into. - Defaults to creating new project. - Returns: - response (ProjectRead): The imported project. - """ - kwargs = dict( - address=self.api_base, - endpoint=f"/admin/org/v1/projects/import/{quote(organization_id)}", - body=None, - response_model=ProjectRead, - data={"project_id_dst": project_id_dst}, - timeout=self.file_upload_timeout, - ) - mime_type = "application/octet-stream" - if isinstance(source, str): - filename = split(source)[-1] - # Open the file in binary mode - with open(source, "rb") as f: - return await self._post(files={"file": (filename, f, mime_type)}, **kwargs) - else: - filename = "import.parquet" - return await self._post(files={"file": (filename, source, mime_type)}, **kwargs) +class _Users(_UsersAsync): + """Users methods.""" - async def export_project( + def create_user(self, body: UserCreate, **kwargs) -> UserRead: + return LOOP.run(super().create_user(body, **kwargs)) + + def list_users( self, - project_id: str, - compression: Literal["NONE", "ZSTD", "LZ4", "SNAPPY"] = "ZSTD", - ) -> bytes: - """ - Exports a project as a Project Parquet file. + *, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + search_query: str = "", + search_columns: list[str] | None = None, + after: str | None = None, + **kwargs, + ) -> Page[UserRead]: + return LOOP.run( + super().list_users( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + search_query=search_query, + search_columns=search_columns, + after=after, + **kwargs, + ) + ) - Args: - project_id (str): Project ID "proj_xxx". - compression (str, optional): Parquet compression codec. Defaults to "ZSTD". + def get_user( + self, + user_id: str | None = None, + **kwargs, + ) -> UserRead: + return LOOP.run(super().get_user(user_id, **kwargs)) - Returns: - response (bytes): The Parquet file. - """ - response = await self._get( - self.api_base, - f"/admin/org/v1/projects/{quote(project_id)}/export", - params=dict(compression=compression), - response_model=None, - ) - return response.content + def update_user( + self, + body: UserUpdate, + **kwargs, + ) -> UserRead: + return LOOP.run(super().update_user(body, **kwargs)) - async def import_project_from_template( + def delete_user( self, - organization_id: str, - template_id: str, - project_id_dst: str = "", - ) -> ProjectRead: - """ - Imports a project from a template. + *, + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: + return LOOP.run(super().delete_user(missing_ok=missing_ok, **kwargs)) - Args: - organization_id (str): Organization ID "org_xxx". - template_id (str): ID of the template to import from. - project_id_dst (str, optional): ID of the project to import tables into. - Defaults to creating new project. + def create_pat(self, body: ProjectKeyCreate, **kwargs) -> ProjectKeyRead: + return LOOP.run(super().create_pat(body, **kwargs)) - Returns: - response (ProjectRead): The imported project. - """ - return await self._post( - self.api_base, - f"/admin/org/v1/projects/import/{quote(organization_id)}/templates/{quote(template_id)}", - body=None, - params=dict(project_id_dst=project_id_dst), - response_model=ProjectRead, + def list_pats( + self, + *, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[ProjectKeyRead]: + return LOOP.run( + super().list_pats( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + **kwargs, + ) ) - async def export_project_as_template( + def update_pat( self, - project_id: str, + pat_id: str, + body: ProjectKeyUpdate, + **kwargs, + ) -> ProjectKeyRead: + return LOOP.run(super().update_pat(pat_id, body, **kwargs)) + + def delete_pat( + self, + pat_id: str, *, - name: str, - tags: list[str], - description: str, - compression: Literal["NONE", "ZSTD", "LZ4", "SNAPPY"] = "ZSTD", - ) -> bytes: + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: + return LOOP.run(super().delete_pat(pat_id, missing_ok=missing_ok, **kwargs)) + + def create_email_verification_code( + self, + *, + valid_days: int = 7, + **kwargs, + ) -> VerificationCodeRead: """ - Exports a project as a template Parquet file. + Generates an email verification code. Args: - project_id (str): Project ID "proj_xxx". - name (str): Template name. - tags (list[str]): Template tags. - description (str): Template description. - compression (str, optional): Parquet compression codec. Defaults to "ZSTD". + valid_days (int, optional): Code validity in days. Defaults to 7. Returns: - response (bytes): The template Parquet file. + code (InviteCodeRead): Verification code. """ - response = await self._get( - self.api_base, - f"/admin/org/v1/projects/{quote(project_id)}/export/template", - params=dict( - name=name, - tags=tags, - description=description, - compression=compression, - ), - response_model=None, - ) - return response.content + return LOOP.run(super().create_email_verification_code(valid_days=valid_days, **kwargs)) + def list_email_verification_codes( + self, + *, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + search_query: str = "", + search_columns: list[str] | None = None, + after: str | None = None, + **kwargs, + ) -> Page[VerificationCodeRead]: + return LOOP.run( + super().list_email_verification_codes( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + search_query=search_query, + search_columns=search_columns, + after=after, + **kwargs, + ) + ) -class _AdminClientAsync(_ClientAsync): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.backend = _BackendAdminClientAsync(*args, **kwargs) - self.organization = _OrgAdminClientAsync(*args, **kwargs) + def get_email_verification_code( + self, + verification_code: str, + **kwargs, + ) -> VerificationCodeRead: + return LOOP.run( + super().get_email_verification_code( + verification_code=verification_code, + **kwargs, + ) + ) + def revoke_email_verification_code( + self, + verification_code: str, + *, + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: + return LOOP.run( + super().revoke_email_verification_code( + verification_code=verification_code, + missing_ok=missing_ok, + **kwargs, + ) + ) -class _TemplateClientAsync(_ClientAsync): - """Template methods.""" + @deprecated( + "`delete_email_verification_code` is deprecated, use `revoke_email_verification_code` instead.", + category=FutureWarning, + stacklevel=1, + ) + def delete_email_verification_code( + self, + verification_code: str, + *, + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: + return LOOP.run( + super().delete_email_verification_code( + verification_code=verification_code, + missing_ok=missing_ok, + **kwargs, + ) + ) - async def list_templates(self, search_query: str = "") -> Page[Template]: + def verify_email( + self, + verification_code: str, + **kwargs, + ) -> OkResponse: """ - List all templates. + Verify and update user email. Args: - search_query (str, optional): A string to search for within template names. + verification_code (str): Verification code. Returns: - templates (Page[Template]): A page of templates. + ok (OkResponse): Success. """ - return await self._get( - self.api_base, - "/public/v1/templates", - params=dict(search_query=search_query), - response_model=Page[Template], - ) + return LOOP.run(super().verify_email(verification_code=verification_code, **kwargs)) - async def get_template(self, template_id: str) -> Template: - """ - Get a template by its ID. - Args: - template_id (str): Template ID. +class _Models(_ModelsAsync): + """Models methods.""" - Returns: - template (Template): The template. - """ - return await self._get( - self.api_base, - f"/public/v1/templates/{quote(template_id)}", - params=None, - response_model=Template, - ) + def create_model_config(self, body: ModelConfigCreate, **kwargs) -> ModelConfigRead: + return LOOP.run(super().create_model_config(body, **kwargs)) - async def list_tables( + def list_model_configs( self, - template_id: str, - table_type: str, *, + organization_id: str | None = None, offset: int = 0, limit: int = 100, - search_query: str = "", - order_by: str = GenTableOrderBy.UPDATED_AT, - order_descending: bool = True, - ) -> Page[TableMetaResponse]: - """ - List all tables in a template. - - Args: - template_id (str): Template ID. - table_type (str): Table type. - offset (int, optional): Item offset. Defaults to 0. - limit (int, optional): Number of tables to return (min 1, max 100). Defaults to 100. - search_query (str, optional): A string to search for within table IDs as a filter. - Defaults to "" (no filter). - order_by (str, optional): Sort tables by this attribute. Defaults to "updated_at". - order_descending (bool, optional): Whether to sort by descending order. Defaults to True. - - Returns: - tables (Page[TableMetaResponse]): A page of tables. - """ - return await self._get( - self.api_base, - f"/public/v1/templates/{quote(template_id)}/gen_tables/{quote(table_type)}", - params=dict( + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[ModelConfigRead]: + return LOOP.run( + super().list_model_configs( + organization_id=organization_id, offset=offset, limit=limit, - search_query=search_query, order_by=order_by, - order_descending=order_descending, - ), - response_model=Page[TableMetaResponse], + order_ascending=order_ascending, + **kwargs, + ) ) - async def get_table( - self, template_id: str, table_type: str, table_id: str - ) -> TableMetaResponse: - """ - Get a table in a template. + def get_model_config( + self, + model_id: str, + **kwargs, + ) -> ModelConfigRead: + return LOOP.run(super().get_model_config(model_id, **kwargs)) - Args: - template_id (str): Template ID. - table_type (str): Table type. - table_id (str): Table ID. + def update_model_config( + self, + model_id: str, + body: ModelConfigUpdate, + **kwargs, + ) -> ModelConfigRead: + return LOOP.run(super().update_model_config(model_id, body, **kwargs)) - Returns: - table (TableMetaResponse): The table. - """ - return await self._get( - self.api_base, - f"/public/v1/templates/{quote(template_id)}/gen_tables/{quote(table_type)}/{quote(table_id)}", - params=None, - response_model=TableMetaResponse, - ) + def delete_model_config( + self, + model_id: str, + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: + return LOOP.run(super().delete_model_config(model_id, missing_ok=missing_ok, **kwargs)) - async def list_table_rows( + def create_deployment( + self, + body: DeploymentCreate, + timeout: float | None = 300.0, + **kwargs, + ) -> DeploymentRead: + return LOOP.run(super().create_deployment(body, timeout=timeout, **kwargs)) + + def list_deployments( self, - template_id: str, - table_type: str, - table_id: str, *, - starting_after: str | None = None, offset: int = 0, limit: int = 100, - order_by: str = "Updated at", - order_descending: bool = True, - float_decimals: int = 0, - vec_decimals: int = 0, - ) -> Page[dict[str, Any]]: - """ - List rows in a template table. - - Args: - template_id (str): Template ID. - table_type (str): Table type. - table_id (str): Table ID. - starting_after (str | None, optional): A cursor for use in pagination. - Only rows with ID > `starting_after` will be returned. - For instance, if your call receives 100 rows ending with ID "x", - your subsequent call can include `starting_after="x"` in order to fetch the next page of the list. - Defaults to None. - offset (int, optional): Item offset. Defaults to 0. - limit (int, optional): Number of rows to return (min 1, max 100). Defaults to 100. - order_by (str, optional): Sort rows by this column. Defaults to "Updated at". - order_descending (bool, optional): Whether to sort by descending order. Defaults to True. - float_decimals (int, optional): Number of decimals for float values. - Defaults to 0 (no rounding). - vec_decimals (int, optional): Number of decimals for vectors. - If its negative, exclude vector columns. Defaults to 0 (no rounding). - - Returns: - rows (Page[dict[str, Any]]): The rows. - """ - return await self._get( - self.api_base, - f"/public/v1/templates/{quote(template_id)}/gen_tables/{quote(table_type)}/{quote(table_id)}/rows", - params=dict( - starting_after=starting_after, + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[DeploymentRead]: + return LOOP.run( + super().list_deployments( offset=offset, limit=limit, order_by=order_by, - order_descending=order_descending, - float_decimals=float_decimals, - vec_decimals=vec_decimals, - ), - response_model=Page[dict[str, Any]], + order_ascending=order_ascending, + **kwargs, + ) ) + def get_deployment( + self, + deployment_id: str, + **kwargs, + ) -> DeploymentRead: + return LOOP.run(super().get_deployment(deployment_id, **kwargs)) -class _FileClientAsync(_ClientAsync): - """File methods.""" + def update_deployment( + self, + deployment_id: str, + body: DeploymentUpdate, + **kwargs, + ) -> DeploymentRead: + return LOOP.run(super().update_deployment(deployment_id, body, **kwargs)) - async def upload_file(self, file_path: str) -> FileUploadResponse: - """ - Uploads a file to the server. + def delete_deployment(self, deployment_id: str, **kwargs) -> OkResponse: + return LOOP.run(super().delete_deployment(deployment_id, **kwargs)) - Args: - file_path (str): Path to the file to be uploaded. - Returns: - response (FileUploadResponse): The response containing the file URI. - """ - filename = split(file_path)[-1] - mime_type = filetype.guess(file_path).mime - if mime_type is None: - mime_type = "application/octet-stream" # Default MIME type +class _Organizations(_OrganizationsAsync): + """Organization methods.""" - with open(file_path, "rb") as f: - return await self._post( - self.api_base, - "/v1/files/upload", - body=None, - response_model=FileUploadResponse, - files={ - "file": (filename, f, mime_type), - }, - timeout=self.file_upload_timeout, + def create_organization( + self, + body: OrganizationCreate, + **kwargs, + ) -> OrganizationRead: + return LOOP.run(super().create_organization(body, **kwargs)) + + def list_organizations( + self, + *, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[OrganizationRead]: + return LOOP.run( + super().list_organizations( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + **kwargs, ) + ) - async def get_raw_urls(self, uris: list[str]) -> GetURLResponse: - """ - Get download URLs for raw files. + def get_organization( + self, + organization_id: str, + **kwargs, + ) -> OrganizationRead: + return LOOP.run(super().get_organization(organization_id, **kwargs)) - Args: - uris (List[str]): List of file URIs to download. + def update_organization( + self, + organization_id: str, + body: OrganizationUpdate, + **kwargs, + ) -> OrganizationRead: + return LOOP.run(super().update_organization(organization_id, body, **kwargs)) - Returns: - response (GetURLResponse): The response containing download information for the files. - """ - return await self._post( - self.api_base, - "/v1/files/url/raw", - body=GetURLRequest(uris=uris), - response_model=GetURLResponse, + def delete_organization( + self, + organization_id: str, + *, + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: + return LOOP.run( + super().delete_organization(organization_id, missing_ok=missing_ok, **kwargs) ) - async def get_thumbnail_urls(self, uris: list[str]) -> GetURLResponse: - """ - Get download URLs for file thumbnails. - - Args: - uris (List[str]): List of file URIs to get thumbnails for. + def join_organization( + self, + user_id: str, + *, + invite_code: str | None = None, + organization_id: str | None = None, + role: str | None = None, + **kwargs, + ) -> OrgMemberRead: + return LOOP.run( + super().join_organization( + user_id=user_id, + invite_code=invite_code, + organization_id=organization_id, + role=role, + **kwargs, + ) + ) - Returns: - response (GetURLResponse): The response containing download information for the thumbnails. - """ - return await self._post( - self.api_base, - "/v1/files/url/thumb", - body=GetURLRequest(uris=uris), - response_model=GetURLResponse, + def list_members( + self, + organization_id: str, + *, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[OrgMemberRead]: + return LOOP.run( + super().list_members( + organization_id=organization_id, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + **kwargs, + ) ) + def get_member( + self, + *, + user_id: str, + organization_id: str, + **kwargs, + ) -> OrgMemberRead: + return LOOP.run( + super().get_member( + user_id=user_id, + organization_id=organization_id, + **kwargs, + ) + ) -class _GenTableClientAsync(_ClientAsync): - """Generative Table methods.""" + def update_member_role( + self, + *, + user_id: str, + organization_id: str, + role: Role, + **kwargs, + ) -> OrgMemberRead: + return LOOP.run( + super().update_member_role( + user_id=user_id, + organization_id=organization_id, + role=role, + **kwargs, + ) + ) - async def create_action_table(self, request: ActionTableSchemaCreate) -> TableMetaResponse: - """ - Create an Action Table. + def leave_organization( + self, + user_id: str, + organization_id: str, + **kwargs, + ) -> OkResponse: + return LOOP.run( + super().leave_organization( + user_id=user_id, + organization_id=organization_id, + **kwargs, + ) + ) - Args: - request (ActionTableSchemaCreate): The action table schema. + def model_catalogue( + self, + *, + organization_id: str, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[ModelConfigRead]: + return LOOP.run( + super().model_catalogue( + organization_id=organization_id, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + **kwargs, + ) + ) - Returns: - response (TableMetaResponse): The table metadata response. - """ - return await self._post( - self.api_base, - "/v1/gen_tables/action", - body=request, - response_model=TableMetaResponse, + def create_invite( + self, + *, + user_email: str, + organization_id: str, + role: str, + valid_days: int = 7, + **kwargs, + ) -> VerificationCodeRead: + return LOOP.run( + super().create_invite( + user_email=user_email, + organization_id=organization_id, + role=role, + valid_days=valid_days, + **kwargs, + ) ) - async def create_knowledge_table( - self, request: KnowledgeTableSchemaCreate - ) -> TableMetaResponse: - """ - Create a Knowledge Table. + def generate_invite_token(self, *_, **__): + raise NotImplementedError("This method is deprecated, use `create_invite` instead.") - Args: - request (KnowledgeTableSchemaCreate): The knowledge table schema. + def list_invites( + self, + organization_id: str, + *, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[VerificationCodeRead]: + return LOOP.run( + super().list_invites( + organization_id=organization_id, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + **kwargs, + ) + ) - Returns: - response (TableMetaResponse): The table metadata response. - """ - return await self._post( - self.api_base, - "/v1/gen_tables/knowledge", - body=request, - response_model=TableMetaResponse, + def revoke_invite( + self, + invite_id: str, + *, + missing_ok: bool = True, + **kwargs, + ): + return LOOP.run( + super().revoke_invite(invite_id=invite_id, missing_ok=missing_ok, **kwargs) ) - async def create_chat_table(self, request: ChatTableSchemaCreate) -> TableMetaResponse: - """ - Create a Chat Table. + @deprecated( + "`delete_invite` is deprecated, use `revoke_invite` instead.", + category=FutureWarning, + stacklevel=1, + ) + def delete_invite( + self, + invite_id: str, + *, + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: + return LOOP.run( + super().delete_invite( + invite_id=invite_id, + missing_ok=missing_ok, + **kwargs, + ) + ) - Args: - request (ChatTableSchemaCreate): The chat table schema. + def subscribe_plan( + self, + organization_id: str, + price_plan_id: str, + **kwargs, + ) -> StripePaymentInfo: + return LOOP.run( + super().subscribe_plan( + organization_id=organization_id, + price_plan_id=price_plan_id, + **kwargs, + ) + ) - Returns: - response (TableMetaResponse): The table metadata response. - """ - return await self._post( - self.api_base, - "/v1/gen_tables/chat", - body=request, - response_model=TableMetaResponse, + def refresh_quota( + self, + organization_id: str, + **kwargs, + ) -> OrganizationRead: + return LOOP.run( + super().refresh_quota( + organization_id=organization_id, + **kwargs, + ) ) - async def get_table( + def purchase_credits( self, - table_type: str | TableType, - table_id: str, - ) -> TableMetaResponse: - """ - Get metadata for a specific Generative Table. + organization_id: str, + amount: float, + *, + confirm: bool = False, + off_session: bool = False, + **kwargs, + ) -> StripePaymentInfo: + return LOOP.run( + super().purchase_credits( + organization_id=organization_id, + amount=amount, + confirm=confirm, + off_session=off_session, + **kwargs, + ) + ) - Args: - table_type (str | TableType): The type of the table. - table_id (str): The ID of the table. + def set_credit_grant( + self, + organization_id: str, + amount: float, + **kwargs, + ) -> OkResponse: + return LOOP.run( + super().set_credit_grant( + organization_id=organization_id, + amount=amount, + **kwargs, + ) + ) - Returns: - response (TableMetaResponse): The table metadata response. - """ - return await self._get( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}", - params=None, - response_model=TableMetaResponse, + def add_credit_grant( + self, + organization_id: str, + amount: float, + **kwargs, + ) -> OkResponse: + return LOOP.run( + super().add_credit_grant( + organization_id=organization_id, + amount=amount, + **kwargs, + ) ) - async def list_tables( + def get_organization_metrics( + self, + metric_id: str, + from_: datetime, + org_id: str, + window_size: str | None = None, + proj_ids: list[str] | None = None, + to: datetime | None = None, + group_by: list[str] | None = None, + data_source: Literal["clickhouse", "victoriametrics"] = "clickhouse", + **kwargs, + ) -> UsageResponse: + return LOOP.run( + super().get_organization_metrics( + metric_id=metric_id, + from_=from_, + org_id=org_id, + window_size=window_size, + proj_ids=proj_ids, + to=to, + group_by=group_by, + data_source=data_source, + **kwargs, + ) + ) + + # def get_billing_metrics( + # self, + # from_: datetime, + # window_size: str, + # org_id: str, + # proj_ids: list[str] | None = None, + # to: datetime | None = None, + # group_by: list[str] | None = None, + # **kwargs, + # ) -> dict: + # return LOOP.run( + # super().get_billing_metrics( + # from_=from_, + # window_size=window_size, + # org_id=org_id, + # proj_ids=proj_ids, + # to=to, + # group_by=group_by, + # **kwargs, + # ) + # ) + + +class _Projects(_ProjectsAsync): + """Project methods.""" + + def create_project(self, body: ProjectCreate, **kwargs) -> ProjectRead: + return LOOP.run(super().create_project(body, **kwargs)) + + def list_projects( self, - table_type: str | TableType, + organization_id: str, *, offset: int = 0, limit: int = 100, - parent_id: str | None = None, search_query: str = "", - order_by: str = GenTableOrderBy.UPDATED_AT, - order_descending: bool = True, - count_rows: bool = False, - ) -> Page[TableMetaResponse]: - """ - List Generative Tables of a specific type. - - Args: - table_type (str | TableType): The type of the table. - offset (int, optional): Item offset. Defaults to 0. - limit (int, optional): Number of tables to return (min 1, max 100). Defaults to 100. - parent_id (str | None, optional): Parent ID of tables to return. - Additionally for Chat Table, you can list: - (1) all chat agents by passing in "_agent_"; or - (2) all chats by passing in "_chat_". - Defaults to None (return all tables). - search_query (str, optional): A string to search for within table IDs as a filter. - Defaults to "" (no filter). - order_by (str, optional): Sort tables by this attribute. Defaults to "updated_at". - order_descending (bool, optional): Whether to sort by descending order. Defaults to True. - count_rows (bool, optional): Whether to count the rows of the tables. Defaults to False. - - Returns: - response (Page[TableMetaResponse]): The paginated table metadata response. - """ - return await self._get( - self.api_base, - f"/v1/gen_tables/{table_type}", - params=dict( + order_by: str = "updated_at", + order_ascending: bool = True, + list_chat_agents: bool = False, + **kwargs, + ) -> Page[ProjectRead]: + return LOOP.run( + super().list_projects( + organization_id=organization_id, offset=offset, limit=limit, - parent_id=parent_id, search_query=search_query, order_by=order_by, - order_descending=order_descending, - count_rows=count_rows, - ), - response_model=Page[TableMetaResponse], + order_ascending=order_ascending, + list_chat_agents=list_chat_agents, + **kwargs, + ) ) - async def delete_table( + def get_project( self, - table_type: str | TableType, - table_id: str, + project_id: str, + **kwargs, + ) -> ProjectRead: + return LOOP.run(super().get_project(project_id, **kwargs)) + + def update_project( + self, + project_id: str, + body: ProjectUpdate, + **kwargs, + ) -> ProjectRead: + return LOOP.run(super().update_project(project_id, body, **kwargs)) + + def delete_project( + self, + project_id: str, *, missing_ok: bool = True, + **kwargs, ) -> OkResponse: - """ - Delete a specific table. - - Args: - table_type (str | TableType): The type of the table. - table_id (str): The ID of the table. - missing_ok (bool, optional): Ignore resource not found error. + return LOOP.run(super().delete_project(project_id, missing_ok=missing_ok, **kwargs)) - Returns: - response (OkResponse): The response indicating success. - """ - response = await self._delete( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}", - params=None, - response_model=None, - ignore_code=404 if missing_ok else None, + def join_project( + self, + user_id: str, + *, + invite_code: str | None = None, + project_id: str | None = None, + role: str | None = None, + **kwargs, + ) -> ProjectMemberRead: + return LOOP.run( + super().join_project( + user_id=user_id, + invite_code=invite_code, + project_id=project_id, + role=role, + **kwargs, + ) ) - if response.status_code == 404 and missing_ok: - return OkResponse() - else: - return OkResponse.model_validate_json(response.text) - async def duplicate_table( + def list_members( self, - table_type: str | TableType, - table_id_src: str, - table_id_dst: str | None = None, + project_id: str, *, - include_data: bool = True, - create_as_child: bool = False, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, **kwargs, - ) -> TableMetaResponse: - """ - Duplicate a table. - - Args: - table_type (str | TableType): The type of the table. - table_id_src (str): The source table ID. - table_id_dst (str | None, optional): The destination / new table ID. - Defaults to None (create a new table ID automatically). - include_data (bool, optional): Whether to include data in the duplicated table. Defaults to True. - create_as_child (bool, optional): Whether the new table is a child table. - If this is True, then `include_data` will be set to True. Defaults to False. - - Returns: - response (TableMetaResponse): The table metadata response. - """ - if "deploy" in kwargs: - warn( - 'The "deploy" argument is deprecated, use "create_as_child" instead.', - FutureWarning, - stacklevel=2, + ) -> Page[ProjectMemberRead]: + return LOOP.run( + super().list_members( + project_id=project_id, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + **kwargs, ) - create_as_child = create_as_child or kwargs.pop("deploy") - return await self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/duplicate/{quote(table_id_src)}", - body=None, - params=dict( - table_id_dst=table_id_dst, - include_data=include_data, - create_as_child=create_as_child, - ), - response_model=TableMetaResponse, ) - async def rename_table( + def get_member( self, - table_type: str | TableType, - table_id_src: str, - table_id_dst: str, - ) -> TableMetaResponse: - """ - Rename a table. - - Args: - table_type (str | TableType): The type of the table. - table_id_src (str): The source table ID. - table_id_dst (str): The destination / new table ID. - - Returns: - response (TableMetaResponse): The table metadata response. - """ - return await self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/rename/{quote(table_id_src)}/{quote(table_id_dst)}", - body=None, - response_model=TableMetaResponse, + *, + user_id: str, + project_id: str, + **kwargs, + ) -> ProjectMemberRead: + return LOOP.run( + super().get_member( + user_id=user_id, + project_id=project_id, + **kwargs, + ) ) - async def update_gen_config( + def update_member_role( self, - table_type: str | TableType, - request: GenConfigUpdateRequest, - ) -> TableMetaResponse: - """ - Update the generation configuration for a table. - - Args: - table_type (str | TableType): The type of the table. - request (GenConfigUpdateRequest): The generation configuration update request. - - Returns: - response (TableMetaResponse): The table metadata response. - """ - return await self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/gen_config/update", - body=request, - response_model=TableMetaResponse, + *, + user_id: str, + project_id: str, + role: Role, + **kwargs, + ) -> ProjectMemberRead: + return LOOP.run( + super().update_member_role( + user_id=user_id, + project_id=project_id, + role=role, + **kwargs, + ) ) - async def add_action_columns(self, request: AddActionColumnSchema) -> TableMetaResponse: - """ - Add columns to an Action Table. - - Args: - request (AddActionColumnSchema): The action column schema. + def leave_project( + self, + user_id: str, + project_id: str, + **kwargs, + ) -> OkResponse: + return LOOP.run( + super().leave_project( + user_id=user_id, + project_id=project_id, + **kwargs, + ) + ) - Returns: - response (TableMetaResponse): The table metadata response. - """ - return await self._post( - self.api_base, - "/v1/gen_tables/action/columns/add", - body=request, - response_model=TableMetaResponse, + def import_project( + self, + source: str | BinaryIO, + *, + project_id: str = "", + organization_id: str = "", + **kwargs, + ) -> ProjectRead: + return LOOP.run( + super().import_project( + source=source, + project_id=project_id, + organization_id=organization_id, + **kwargs, + ) ) - async def add_knowledge_columns(self, request: AddKnowledgeColumnSchema) -> TableMetaResponse: - """ - Add columns to a Knowledge Table. + def export_project( + self, + project_id: str, + **kwargs, + ) -> bytes: + return LOOP.run( + super().export_project( + project_id=project_id, + **kwargs, + ) + ) - Args: - request (AddKnowledgeColumnSchema): The knowledge column schema. + def import_template( + self, + template_id: str, + *, + project_id: str = "", + organization_id: str = "", + **kwargs, + ) -> ProjectRead: + return LOOP.run( + super().import_template( + template_id=template_id, + project_id=project_id, + organization_id=organization_id, + **kwargs, + ) + ) - Returns: - response (TableMetaResponse): The table metadata response. - """ - return await self._post( - self.api_base, - "/v1/gen_tables/knowledge/columns/add", - body=request, - response_model=TableMetaResponse, + def create_invite( + self, + *, + user_email: str, + project_id: str, + role: str, + valid_days: int = 7, + **kwargs, + ) -> VerificationCodeRead: + return LOOP.run( + super().create_invite( + user_email=user_email, + project_id=project_id, + role=role, + valid_days=valid_days, + **kwargs, + ) ) - async def add_chat_columns(self, request: AddChatColumnSchema) -> TableMetaResponse: - """ - Add columns to a Chat Table. - - Args: - request (AddChatColumnSchema): The chat column schema. + def list_invites( + self, + project_id: str, + *, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[VerificationCodeRead]: + return LOOP.run( + super().list_invites( + project_id=project_id, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + **kwargs, + ) + ) - Returns: - response (TableMetaResponse): The table metadata response. - """ - return await self._post( - self.api_base, - "/v1/gen_tables/chat/columns/add", - body=request, - response_model=TableMetaResponse, + def revoke_invite( + self, + invite_id: str, + *, + missing_ok: bool = True, + **kwargs, + ): + return LOOP.run( + super().revoke_invite(invite_id=invite_id, missing_ok=missing_ok, **kwargs) ) - async def drop_columns( + @deprecated( + "`delete_invite` is deprecated, use `revoke_invite` instead.", + category=FutureWarning, + stacklevel=1, + ) + def delete_invite( self, - table_type: str | TableType, - request: ColumnDropRequest, - ) -> TableMetaResponse: - """ - Drop columns from a table. + invite_id: str, + *, + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: + return LOOP.run( + super().delete_invite( + invite_id=invite_id, + missing_ok=missing_ok, + **kwargs, + ) + ) - Args: - table_type (str | TableType): The type of the table. - request (ColumnDropRequest): The column drop request. - Returns: - response (TableMetaResponse): The table metadata response. - """ - return await self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/columns/drop", - body=request, - response_model=TableMetaResponse, - ) +class _Templates(_TemplatesAsync): + """Template methods.""" - async def rename_columns( + def list_templates( self, - table_type: str | TableType, - request: ColumnRenameRequest, - ) -> TableMetaResponse: - """ - Rename columns in a table. + *, + offset: int = 0, + limit: int = 100, + search_query: str = "", + order_by: str = "updated_at", + order_ascending: bool = True, + **kwargs, + ) -> Page[ProjectRead]: + return LOOP.run( + super().list_templates( + offset=offset, + limit=limit, + search_query=search_query, + order_by=order_by, + order_ascending=order_ascending, + **kwargs, + ) + ) - Args: - table_type (str | TableType): The type of the table. - request (ColumnRenameRequest): The column rename request. + def get_template(self, template_id: str, **kwargs) -> ProjectRead: + return LOOP.run(super().get_template(template_id, **kwargs)) - Returns: - response (TableMetaResponse): The table metadata response. - """ - return await self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/columns/rename", - body=request, - response_model=TableMetaResponse, + def list_tables( + self, + template_id: str, + table_type: str, + *, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + search_query: str = "", + parent_id: str | None = None, + count_rows: bool = False, + **kwargs, + ) -> Page[TableMetaResponse]: + return LOOP.run( + super().list_tables( + template_id=template_id, + table_type=table_type, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + search_query=search_query, + parent_id=parent_id, + count_rows=count_rows, + **kwargs, + ) ) - async def reorder_columns( + def get_table( self, - table_type: str | TableType, - request: ColumnReorderRequest, + template_id: str, + table_type: str, + table_id: str, + **kwargs, ) -> TableMetaResponse: - """ - Reorder columns in a table. - - Args: - table_type (str | TableType): The type of the table. - request (ColumnReorderRequest): The column reorder request. - - Returns: - response (TableMetaResponse): The table metadata response. - """ - return await self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/columns/reorder", - body=request, - response_model=TableMetaResponse, + return LOOP.run( + super().get_table( + template_id=template_id, + table_type=table_type, + table_id=table_id, + **kwargs, + ) ) - async def list_table_rows( + def list_table_rows( self, - table_type: str | TableType, + template_id: str, + table_type: str, table_id: str, *, offset: int = 0, limit: int = 100, - search_query: str = "", + order_by: str = "ID", + order_ascending: bool = True, columns: list[str] | None = None, + search_query: str = "", + search_columns: list[str] | None = None, float_decimals: int = 0, vec_decimals: int = 0, - order_descending: bool = True, + **kwargs, ) -> Page[dict[str, Any]]: """ List rows in a table. Args: - table_type (str | TableType): The type of the table. + template_id (str): The ID of the template. + table_type (str): The type of the table. table_id (str): The ID of the table. offset (int, optional): Item offset. Defaults to 0. limit (int, optional): Number of rows to return (min 1, max 100). Defaults to 100. - search_query (str, optional): A string to search for within the rows as a filter. - Defaults to "" (no filter). + order_by (str, optional): Column name to order by. Defaults to "ID". + order_ascending (bool, optional): Whether to sort by ascending order. Defaults to True. columns (list[str] | None, optional): List of column names to include in the response. Defaults to None (all columns). + search_query (str, optional): A string to search for within the rows as a filter. + Defaults to "" (no filter). + search_columns (list[str] | None, optional): A list of column names to search for `search_query`. + Defaults to None (search all columns). float_decimals (int, optional): Number of decimals for float values. Defaults to 0 (no rounding). vec_decimals (int, optional): Number of decimals for vectors. If its negative, exclude vector columns. Defaults to 0 (no rounding). - order_descending (bool, optional): Whether to sort by descending order. Defaults to True. """ - return await self._get( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}/rows", - params=dict( + return LOOP.run( + super().list_table_rows( + template_id=template_id, + table_type=table_type, + table_id=table_id, offset=offset, limit=limit, - search_query=search_query, - columns=columns, - float_decimals=float_decimals, - vec_decimals=vec_decimals, - order_descending=order_descending, - ), - response_model=Page[dict[str, Any]], - ) - - async def get_table_row( - self, - table_type: str | TableType, - table_id: str, - row_id: str, - columns: list[str] | None = None, - float_decimals: int = 0, - vec_decimals: int = 0, - ) -> dict[str, Any]: - """ - Get a specific row in a table. - - Args: - table_type (str | TableType): The type of the table. - table_id (str): The ID of the table. - row_id (str): The ID of the row. - columns (list[str] | None, optional): List of column names to include in the response. - Defaults to None (all columns). - float_decimals (int, optional): Number of decimals for float values. - Defaults to 0 (no rounding). - vec_decimals (int, optional): Number of decimals for vectors. - If its negative, exclude vector columns. Defaults to 0 (no rounding). - - Returns: - response (dict[str, Any]): The row data. - """ - response = await self._get( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}/rows/{quote(row_id)}", - params=dict( + order_by=order_by, + order_ascending=order_ascending, columns=columns, - float_decimals=float_decimals, - vec_decimals=vec_decimals, - ), - response_model=None, - ) - return json_loads(response.text) - - async def add_table_rows( - self, - table_type: str | TableType, - request: RowAddRequest, - ) -> ( - GenTableRowsChatCompletionChunks - | AsyncGenerator[GenTableStreamReferences | GenTableStreamChatCompletionChunk, None] - ): - """ - Add rows to a table. - - Args: - table_type (str | TableType): The type of the table. - request (RowAddRequest): The row add request. - - Returns: - response (GenTableRowsChatCompletionChunks | AsyncGenerator): The row completion. - In streaming mode, it is an async generator that yields a `GenTableStreamReferences` object - followed by zero or more `GenTableStreamChatCompletionChunk` objects. - In non-streaming mode, it is a `GenTableRowsChatCompletionChunks` object. - """ - if request.stream: - - async def gen(): - async for chunk in self._stream( - self.api_base, - f"/v1/gen_tables/{table_type}/rows/add", - body=request, - ): - chunk = json_loads(chunk[5:]) - if chunk["object"] == "gen_table.references": - yield GenTableStreamReferences.model_validate(chunk) - elif chunk["object"] == "gen_table.completion.chunk": - yield GenTableStreamChatCompletionChunk.model_validate(chunk) - - return gen() - else: - return await self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/rows/add", - body=request, - response_model=GenTableRowsChatCompletionChunks, - ) - - async def regen_table_rows( - self, - table_type: str | TableType, - request: RowRegenRequest, - ) -> ( - GenTableRowsChatCompletionChunks - | AsyncGenerator[GenTableStreamReferences | GenTableStreamChatCompletionChunk, None] - ): - """ - Regenerate rows in a table. - - Args: - table_type (str | TableType): The type of the table. - request (RowRegenRequest): The row regenerate request. - - Returns: - response (GenTableRowsChatCompletionChunks | AsyncGenerator): The row completion. - In streaming mode, it is an async generator that yields a `GenTableStreamReferences` object - followed by zero or more `GenTableStreamChatCompletionChunk` objects. - In non-streaming mode, it is a `GenTableRowsChatCompletionChunks` object. - """ - if request.stream: - - async def gen(): - async for chunk in self._stream( - self.api_base, - f"/v1/gen_tables/{table_type}/rows/regen", - body=request, - ): - chunk = json_loads(chunk[5:]) - if chunk["object"] == "gen_table.references": - yield GenTableStreamReferences.model_validate(chunk) - elif chunk["object"] == "gen_table.completion.chunk": - yield GenTableStreamChatCompletionChunk.model_validate(chunk) - - return gen() - else: - return await self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/rows/regen", - body=request, - response_model=GenTableRowsChatCompletionChunks, - ) - - async def update_table_row( - self, - table_type: str | TableType, - request: RowUpdateRequest, - ) -> OkResponse: - """ - Update a specific row in a table. - - Args: - table_type (str | TableType): The type of the table. - request (RowUpdateRequest): The row update request. - - Returns: - response (OkResponse): The response indicating success. - """ - return await self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/rows/update", - body=request, - response_model=OkResponse, - ) - - async def delete_table_rows( - self, - table_type: str | TableType, - request: RowDeleteRequest, - ) -> OkResponse: - """ - Delete rows from a table. - - Args: - table_type (str | TableType): The type of the table. - request (RowDeleteRequest): The row delete request. - - Returns: - response (OkResponse): The response indicating success. - """ - return await self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/rows/delete", - body=request, - response_model=OkResponse, + search_query=search_query, + search_columns=search_columns, + float_decimals=float_decimals, + vec_decimals=vec_decimals, + **kwargs, + ) ) - async def delete_table_row( + def get_table_row( self, - table_type: str | TableType, + template_id: str, + table_type: str, table_id: str, row_id: str, - ) -> OkResponse: + *, + columns: list[str] | None = None, + float_decimals: int = 0, + vec_decimals: int = 0, + **kwargs, + ) -> dict[str, Any]: """ - Delete a specific row from a table. + Get a specific row in a table. Args: - table_type (str | TableType): The type of the table. + template_id (str): The ID of the template. + table_type (str): The type of the table. table_id (str): The ID of the table. row_id (str): The ID of the row. + columns (list[str] | None, optional): List of column names to include in the response. + Defaults to None (all columns). + float_decimals (int, optional): Number of decimals for float values. + Defaults to 0 (no rounding). + vec_decimals (int, optional): Number of decimals for vectors. + If its negative, exclude vector columns. Defaults to 0 (no rounding). Returns: - response (OkResponse): The response indicating success. + response (dict[str, Any]): The row data. """ - return await self._delete( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}/rows/{quote(row_id)}", - params=None, - response_model=OkResponse, + return LOOP.run( + super().get_table_row( + template_id=template_id, + table_type=table_type, + table_id=table_id, + row_id=row_id, + columns=columns, + float_decimals=float_decimals, + vec_decimals=vec_decimals, + **kwargs, + ) ) - async def get_conversation_thread( - self, - table_type: str | TableType, - table_id: str, - column_id: str, - *, - row_id: str = "", - include: bool = True, - ) -> ChatThread: - """ - Get the conversation thread for a chat table. - Args: - table_type (str | TableType): The type of the table. - table_id (str): ID / name of the chat table. - column_id (str): ID / name of the column to fetch. - row_id (str, optional): ID / name of the last row in the thread. - Defaults to "" (export all rows). - include (bool, optional): Whether to include the row specified by `row_id`. - Defaults to True. +class _FileClient(_FileClientAsync): + """File methods (synchronous version).""" - Returns: - response (ChatThread): The conversation thread. - """ - return await self._get( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}/thread", - params=dict(column_id=column_id, row_id=row_id, include=include), - response_model=ChatThread, - ) + def upload_file(self, file_path: str, **kwargs) -> FileUploadResponse: + return LOOP.run(super().upload_file(file_path, **kwargs)) - async def hybrid_search( + def get_raw_urls(self, uris: list[str], **kwargs) -> GetURLResponse: + return LOOP.run(super().get_raw_urls(uris, **kwargs)) + + def get_thumbnail_urls(self, uris: list[str], **kwargs) -> GetURLResponse: + return LOOP.run(super().get_thumbnail_urls(uris, **kwargs)) + + +class _GenTableClient(_GenTableClientAsync): + """Generative Table methods (synchronous version).""" + + # Table CRUD + def create_action_table( self, - table_type: str | TableType, - request: SearchRequest, - ) -> list[dict[str, Any]]: + request: ActionTableSchemaCreate, + **kwargs, + ) -> TableMetaResponse: """ - Perform a hybrid search on a table. + Create an Action Table. Args: - table_type (str | TableType): The type of the table. - request (SearchRequest): The search request. + request (ActionTableSchemaCreate): The action table schema. Returns: - response (list[dict[str, Any]]): The search results. + response (TableMetaResponse): The table metadata response. """ - response = await self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/hybrid_search", - body=request, - response_model=None, - ) - return json_loads(response.text) + return LOOP.run(super().create_action_table(request, **kwargs)) - async def embed_file_options(self) -> httpx.Response: + def create_knowledge_table( + self, + request: KnowledgeTableSchemaCreate, + **kwargs, + ) -> TableMetaResponse: """ - Get options for embedding a file to a Knowledge Table. + Create a Knowledge Table. + + Args: + request (KnowledgeTableSchemaCreate): The knowledge table schema. Returns: - response (httpx.Response): The response containing options information. + response (TableMetaResponse): The table metadata response. """ - response = await self._options( - self.api_base, - "/v1/gen_tables/knowledge/embed_file", - ) - return response + return LOOP.run(super().create_knowledge_table(request, **kwargs)) - async def embed_file( + def create_chat_table( self, - file_path: str, - table_id: str, - *, - chunk_size: int = 1000, - chunk_overlap: int = 200, - ) -> OkResponse: + request: ChatTableSchemaCreate, + **kwargs, + ) -> TableMetaResponse: """ - Embed a file into a Knowledge Table. + Create a Chat Table. Args: - file_path (str): File path of the document to be embedded. - table_id (str): Knowledge Table ID / name. - chunk_size (int, optional): Maximum chunk size (number of characters). Must be > 0. - Defaults to 1000. - chunk_overlap (int, optional): Overlap in characters between chunks. Must be >= 0. - Defaults to 200. + request (ChatTableSchemaCreate): The chat table schema. Returns: - response (OkResponse): The response indicating success. + response (TableMetaResponse): The table metadata response. """ - # Guess the MIME type of the file based on its extension - mime_type, _ = guess_type(file_path) - if mime_type is None: - mime_type = ( - "application/jsonl" if file_path.endswith(".jsonl") else "application/octet-stream" - ) # Default MIME type - # Extract the filename from the file path - filename = split(file_path)[-1] - # Open the file in binary mode - with open(file_path, "rb") as f: - response = await self._post( - self.api_base, - "/v1/gen_tables/knowledge/embed_file", - body=None, - response_model=OkResponse, - files={ - "file": (filename, f, mime_type), - }, - data={ - "table_id": table_id, - "chunk_size": chunk_size, - "chunk_overlap": chunk_overlap, - # "overwrite": request.overwrite, - }, - timeout=self.file_upload_timeout, - ) - return response + return LOOP.run(super().create_chat_table(request, **kwargs)) - async def import_table_data( + def duplicate_table( self, - table_type: str | TableType, - request: TableDataImportRequest, - ) -> GenTableChatResponseType: + table_type: str, + table_id_src: str, + table_id_dst: str | None = None, + *, + include_data: bool = True, + create_as_child: bool = False, + **kwargs, + ) -> TableMetaResponse: """ - Imports CSV or TSV data into a table. + Duplicate a table. Args: - file_path (str): CSV or TSV file path. - table_type (str | TableType): Table type. - request (TableDataImportRequest): Data import request. + table_type (str): The type of the table. + table_id_src (str): The source table ID. + table_id_dst (str | None, optional): The destination / new table ID. + Defaults to None (create a new table ID automatically). + include_data (bool, optional): Whether to include data in the duplicated table. Defaults to True. + create_as_child (bool, optional): Whether the new table is a child table. + If this is True, then `include_data` will be set to True. Defaults to False. Returns: - response (OkResponse): The response indicating success. + response (TableMetaResponse): The table metadata response. """ - # Guess the MIME type of the file based on its extension - mime_type, _ = guess_type(request.file_path) - if mime_type is None: - mime_type = "application/octet-stream" # Default MIME type - # Extract the filename from the file path - filename = split(request.file_path)[-1] - data = { - "table_id": request.table_id, - "stream": request.stream, - # "column_names": request.column_names, - # "columns": request.columns, - "delimiter": request.delimiter, - } - if request.stream: - - async def gen(): - # Open the file in binary mode - with open(request.file_path, "rb") as f: - async for chunk in self._stream( - self.api_base, - f"/v1/gen_tables/{table_type}/import_data", - body=None, - files={ - "file": (filename, f, mime_type), - }, - data=data, - timeout=self.file_upload_timeout, - ): - chunk = json_loads(chunk[5:]) - if chunk["object"] == "gen_table.references": - yield GenTableStreamReferences.model_validate(chunk) - elif chunk["object"] == "gen_table.completion.chunk": - yield GenTableStreamChatCompletionChunk.model_validate(chunk) - else: - raise RuntimeError(f"Unexpected SSE chunk: {chunk}") - - return gen() - else: - # Open the file in binary mode - with open(request.file_path, "rb") as f: - return await self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/import_data", - body=None, - response_model=GenTableRowsChatCompletionChunks, - files={ - "file": (filename, f, mime_type), - }, - data=data, - timeout=self.file_upload_timeout, - ) + return LOOP.run( + super().duplicate_table( + table_type, + table_id_src, + table_id_dst=table_id_dst, + include_data=include_data, + create_as_child=create_as_child, + **kwargs, + ) + ) - async def export_table_data( + def get_table( self, - table_type: str | TableType, + table_type: str, table_id: str, - *, - columns: list[str] | None = None, - delimiter: Literal[",", "\t"] = ",", - ) -> bytes: + **kwargs, + ) -> TableMetaResponse: """ - Exports the row data of a table as a CSV or TSV file. + Get metadata for a specific Generative Table. Args: - table_type (str | TableType): Table type. - table_id (str): ID or name of the table to be exported. - delimiter (str, optional): The delimiter of the file: can be "," or "\\t". Defaults to ",". - columns (list[str], optional): A list of columns to be exported. Defaults to None (export all columns). + table_type (str): The type of the table. + table_id (str): The ID of the table. Returns: - response (list[dict[str, Any]]): The search results. + response (TableMetaResponse): The table metadata response. """ - response = await self._get( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}/export_data", - params=dict(delimiter=delimiter, columns=columns), - response_model=None, - ) - return response.content + return LOOP.run(super().get_table(table_type, table_id, **kwargs)) - async def import_table( + def list_tables( self, - table_type: str | TableType, - request: TableImportRequest, - ) -> TableMetaResponse: + table_type: str, + *, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + created_by: str | None = None, + parent_id: str | None = None, + search_query: str = "", + count_rows: bool = False, + **kwargs, + ) -> Page[TableMetaResponse]: """ - Imports a table (data and schema) from a parquet file. + List Generative Tables of a specific type. Args: - file_path (str): The parquet file path. - table_type (str | TableType): Table type. - request (TableImportRequest): Table import request. + table_type (str): The type of the table. + offset (int, optional): Item offset. Defaults to 0. + limit (int, optional): Number of tables to return (min 1, max 100). Defaults to 100. + order_by (str, optional): Sort tables by this attribute. Defaults to "updated_at". + order_ascending (bool, optional): Whether to sort by ascending order. Defaults to True. + created_by (str | None, optional): Return tables created by this user. + Defaults to None (return all tables). + parent_id (str | None, optional): Parent ID of tables to return. + Additionally for Chat Table, you can list: + (1) all chat agents by passing in "_agent_"; or + (2) all chats by passing in "_chat_". + Defaults to None (return all tables). + search_query (str, optional): A string to search for within table IDs as a filter. + Defaults to "" (no filter). + count_rows (bool, optional): Whether to count the rows of the tables. Defaults to False. Returns: - response (TableMetaResponse): The table metadata response. + response (Page[TableMetaResponse]): The paginated table metadata response. """ - mime_type = "application/octet-stream" - filename = split(request.file_path)[-1] - data = {"table_id_dst": request.table_id_dst} - # Open the file in binary mode - with open(request.file_path, "rb") as f: - return await self._post( - self.api_base, - f"/v1/gen_tables/{table_type}/import", - body=None, - response_model=TableMetaResponse, - files={ - "file": (filename, f, mime_type), - }, - data=data, - timeout=self.file_upload_timeout, + return LOOP.run( + super().list_tables( + table_type, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + created_by=created_by, + parent_id=parent_id, + search_query=search_query, + count_rows=count_rows, + **kwargs, ) + ) - async def export_table( + def rename_table( self, - table_type: str | TableType, - table_id: str, - ) -> bytes: + table_type: str, + table_id_src: str, + table_id_dst: str, + **kwargs, + ) -> TableMetaResponse: """ - Exports a table (data and schema) as a parquet file. + Rename a table. Args: - table_type (str | TableType): Table type. - table_id (str): ID or name of the table to be exported. + table_type (str): The type of the table. + table_id_src (str): The source table ID. + table_id_dst (str): The destination / new table ID. Returns: - response (list[dict[str, Any]]): The search results. + response (TableMetaResponse): The table metadata response. """ - response = await self._get( - self.api_base, - f"/v1/gen_tables/{table_type}/{quote(table_id)}/export", - params=None, - response_model=None, - ) - return response.content + return LOOP.run(super().rename_table(table_type, table_id_src, table_id_dst, **kwargs)) - -class JamAIAsync(_ClientAsync): - def __init__( + def delete_table( self, - project_id: str = ENV_CONFIG.jamai_project_id, - token: str = ENV_CONFIG.jamai_token_plain, - api_base: str = ENV_CONFIG.jamai_api_base, - headers: dict | None = None, - timeout: float | None = ENV_CONFIG.jamai_timeout_sec, - file_upload_timeout: float | None = ENV_CONFIG.jamai_file_upload_timeout_sec, + table_type: str, + table_id: str, *, - api_key: str = "", - ) -> None: + missing_ok: bool = True, + **kwargs, + ) -> OkResponse: """ - Initialize the JamAI client. + Delete a specific table. Args: - project_id (str, optional): The project ID. - Defaults to "default", but can be overridden via - `JAMAI_PROJECT_ID` var in environment or `.env` file. - token (str, optional): Your Personal Access Token or organization API key (deprecated) for authentication. - Defaults to "", but can be overridden via - `JAMAI_TOKEN` var in environment or `.env` file. - api_base (str, optional): The base URL for the API. - Defaults to "https://api.jamaibase.com/api", but can be overridden via - `JAMAI_API_BASE` var in environment or `.env` file. - headers (dict | None, optional): Additional headers to include in requests. - Defaults to None. - timeout (float | None, optional): The timeout to use when sending requests. - Defaults to 15 minutes, but can be overridden via - `JAMAI_TIMEOUT_SEC` var in environment or `.env` file. - file_upload_timeout (float | None, optional): The timeout to use when sending file upload requests. - Defaults to 60 minutes, but can be overridden via - `JAMAI_FILE_UPLOAD_TIMEOUT_SEC` var in environment or `.env` file. - api_key (str, optional): (Deprecated) Organization API key for authentication. - """ - if api_key: - warn(ORG_API_KEY_DEPRECATE, FutureWarning, stacklevel=2) - http_client = httpx.AsyncClient( - timeout=timeout, - transport=httpx.AsyncHTTPTransport(retries=3), - ) - kwargs = dict( - project_id=project_id, - token=token or api_key, - api_base=api_base, - headers=headers, - http_client=http_client, - file_upload_timeout=file_upload_timeout, - ) - super().__init__(**kwargs) - self.admin = _AdminClientAsync(**kwargs) - self.template = _TemplateClientAsync(**kwargs) - self.file = _FileClientAsync(**kwargs) - self.table = _GenTableClientAsync(**kwargs) - - async def health(self) -> dict[str, Any]: - """ - Get health status. + table_type (str): The type of the table. + table_id (str): The ID of the table. + missing_ok (bool, optional): Ignore resource not found error. Returns: - response (dict[str, Any]): Health status. + response (OkResponse): The response indicating success. """ - response = await self._get(self.api_base, "/health", response_model=None) - return json_loads(response.text) - - # --- Models and chat --- # + return LOOP.run( + super().delete_table( + table_type, + table_id, + missing_ok=missing_ok, + **kwargs, + ) + ) - async def model_info( + # Column CRUD + def add_action_columns( self, - name: str = "", - capabilities: list[ - Literal["completion", "chat", "image", "audio", "tool", "embed", "rerank"] - ] - | None = None, - ) -> ModelInfoResponse: + request: AddActionColumnSchema, + **kwargs, + ) -> TableMetaResponse: """ - Get information about available models. + Add columns to an Action Table. Args: - name (str, optional): The model name. Defaults to "". - capabilities (list[Literal["completion", "chat", "image", "audio", "tool", "embed", "rerank"]] | None, optional): - List of model capabilities to filter by. Defaults to None. + request (AddActionColumnSchema): The action column schema. Returns: - response (ModelInfoResponse): The model information response. + response (TableMetaResponse): The table metadata response. """ - params = {"model": name, "capabilities": capabilities} - return await self._get( - self.api_base, - "/v1/models", - params=params, - response_model=ModelInfoResponse, - ) + return LOOP.run(super().add_action_columns(request, **kwargs)) - async def model_names( + def add_knowledge_columns( self, - prefer: str = "", - capabilities: list[ - Literal["completion", "chat", "image", "audio", "tool", "embed", "rerank"] - ] - | None = None, - ) -> list[str]: + request: AddKnowledgeColumnSchema, + **kwargs, + ) -> TableMetaResponse: """ - Get the names of available models. + Add columns to a Knowledge Table. Args: - prefer (str, optional): Preferred model name. Defaults to "". - capabilities (list[Literal["completion", "chat", "image", "audio", "tool", "embed", "rerank"]] | None, optional): - List of model capabilities to filter by. Defaults to None. + request (AddKnowledgeColumnSchema): The knowledge column schema. Returns: - response (list[str]): List of model names. + response (TableMetaResponse): The table metadata response. """ - params = {"prefer": prefer, "capabilities": capabilities} - response = await self._get( - self.api_base, - "/v1/model_names", - params=params, - response_model=None, - ) - return json_loads(response.text) + return LOOP.run(super().add_knowledge_columns(request, **kwargs)) - async def generate_chat_completions( - self, request: ChatRequest - ) -> ChatCompletionChunk | AsyncGenerator[References | ChatCompletionChunk, None]: + def add_chat_columns( + self, + request: AddChatColumnSchema, + **kwargs, + ) -> TableMetaResponse: """ - Generates chat completions. + Add columns to a Chat Table. Args: - request (ChatRequest): The request. + request (AddChatColumnSchema): The chat column schema. Returns: - completion (ChatCompletionChunk | AsyncGenerator): The chat completion. - In streaming mode, it is an async generator that yields a `References` object - followed by zero or more `ChatCompletionChunk` objects. - In non-streaming mode, it is a `ChatCompletionChunk` object. + response (TableMetaResponse): The table metadata response. """ - if request.stream: - - async def gen(): - async for chunk in self._stream( - self.api_base, "/v1/chat/completions", body=request - ): - chunk = json_loads(chunk[5:]) - if chunk["object"] == "chat.references": - yield References.model_validate(chunk) - elif chunk["object"] == "chat.completion.chunk": - yield ChatCompletionChunk.model_validate(chunk) + return LOOP.run(super().add_chat_columns(request, **kwargs)) - return gen() - else: - return await self._post( - self.api_base, - "/v1/chat/completions", - body=request, - response_model=ChatCompletionChunk, - ) - - async def generate_embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + def rename_columns( + self, + table_type: str, + request: ColumnRenameRequest, + **kwargs, + ) -> TableMetaResponse: """ - Generate embeddings for the given input. + Rename columns in a table. Args: - request (EmbeddingRequest): The embedding request. + table_type (str): The type of the table. + request (ColumnRenameRequest): The column rename request. Returns: - response (EmbeddingResponse): The embedding response. + response (TableMetaResponse): The table metadata response. """ - return await self._post( - self.api_base, - "/v1/embeddings", - body=request, - response_model=EmbeddingResponse, - ) + return LOOP.run(super().rename_columns(table_type, request, **kwargs)) - # --- Gen Table --- # - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def create_action_table(self, request: ActionTableSchemaCreate) -> TableMetaResponse: + def update_gen_config( + self, + table_type: str, + request: GenConfigUpdateRequest, + **kwargs, + ) -> TableMetaResponse: """ - Create an Action Table. + Update the generation configuration for a table. Args: - request (ActionTableSchemaCreate): The action table schema. + table_type (str): The type of the table. + request (GenConfigUpdateRequest): The generation configuration update request. Returns: response (TableMetaResponse): The table metadata response. """ - return await self.table.create_action_table(request) + return LOOP.run(super().update_gen_config(table_type, request, **kwargs)) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def create_knowledge_table( - self, request: KnowledgeTableSchemaCreate + def reorder_columns( + self, + table_type: str, + request: ColumnReorderRequest, + **kwargs, ) -> TableMetaResponse: """ - Create a Knowledge Table. + Reorder columns in a table. Args: - request (KnowledgeTableSchemaCreate): The knowledge table schema. + table_type (str): The type of the table. + request (ColumnReorderRequest): The column reorder request. Returns: response (TableMetaResponse): The table metadata response. """ - return await self.table.create_knowledge_table(request) + return LOOP.run(super().reorder_columns(table_type, request, **kwargs)) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def create_chat_table(self, request: ChatTableSchemaCreate) -> TableMetaResponse: + def drop_columns( + self, + table_type: str, + request: ColumnDropRequest, + **kwargs, + ) -> TableMetaResponse: """ - Create a Chat Table. + Drop columns from a table. Args: - request (ChatTableSchemaCreate): The chat table schema. + table_type (str): The type of the table. + request (ColumnDropRequest): The column drop request. Returns: response (TableMetaResponse): The table metadata response. """ - return await self.table.create_chat_table(request) + return LOOP.run(super().drop_columns(table_type, request, **kwargs)) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def get_table( + # Row CRUD + def add_table_rows( self, - table_type: str | TableType, - table_id: str, - ) -> TableMetaResponse: + table_type: str, + request: MultiRowAddRequest, + **kwargs, + ) -> ( + MultiRowCompletionResponse + | Generator[CellReferencesResponse | CellCompletionResponse, None, None] + ): """ - Get metadata for a specific Generative Table. + Add rows to a table. Args: - table_type (str | TableType): The type of the table. - table_id (str): The ID of the table. + table_type (str): The type of the table. + request (MultiRowAddRequest): The row add request. Returns: - response (TableMetaResponse): The table metadata response. + response (MultiRowCompletionResponse | AsyncGenerator): The row completion. + In streaming mode, it is an async generator that yields a `CellReferencesResponse` object + followed by zero or more `CellCompletionResponse` objects. + In non-streaming mode, it is a `MultiRowCompletionResponse` object. """ - return await self.table.get_table(table_type, table_id) + agen = LOOP.run(super().add_table_rows(table_type, request, **kwargs)) + return self._return_iterator(agen, request.stream) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def list_tables( + def list_table_rows( self, - table_type: str | TableType, + table_type: str, + table_id: str, + *, offset: int = 0, limit: int = 100, - parent_id: str | None = None, + order_by: str = "ID", + order_ascending: bool = True, + columns: list[str] | None = None, + where: str = "", search_query: str = "", - order_by: str = GenTableOrderBy.UPDATED_AT, - order_descending: bool = True, - count_rows: bool = False, - ) -> Page[TableMetaResponse]: + search_columns: list[str] | None = None, + float_decimals: int = 0, + vec_decimals: int = 0, + **kwargs, + ) -> Page[dict[str, Any]]: """ - List Generative Tables of a specific type. + List rows in a table. Args: - table_type (str | TableType): The type of the table. + table_type (str): The type of the table. + table_id (str): The ID of the table. offset (int, optional): Item offset. Defaults to 0. - limit (int, optional): Number of tables to return (min 1, max 100). Defaults to 100. - parent_id (str | None, optional): Parent ID of tables to return. - Additionally for Chat Table, you can list: - (1) all chat agents by passing in "_agent_"; or - (2) all chats by passing in "_chat_". - Defaults to None (return all tables). - search_query (str, optional): A string to search for within table IDs as a filter. + limit (int, optional): Number of rows to return (min 1, max 100). Defaults to 100. + order_by (str, optional): Column name to order by. Defaults to "ID". + order_ascending (bool, optional): Whether to sort by ascending order. Defaults to True. + columns (list[str] | None, optional): List of column names to include in the response. + Defaults to None (all columns). + where (str, optional): SQL where clause. Can be nested ie `x = '1' AND ("y (1)" = 2 OR z = '3')`. + It will be combined other filters using `AND`. Defaults to "" (no filter). + search_query (str, optional): A string to search for within the rows as a filter. Defaults to "" (no filter). - order_by (str, optional): Sort tables by this attribute. Defaults to "updated_at". - order_descending (bool, optional): Whether to sort by descending order. Defaults to True. - count_rows (bool, optional): Whether to count the rows of the tables. Defaults to False. + search_columns (list[str] | None, optional): A list of column names to search for `search_query`. + Defaults to None (search all columns). + float_decimals (int, optional): Number of decimals for float values. + Defaults to 0 (no rounding). + vec_decimals (int, optional): Number of decimals for vectors. + If its negative, exclude vector columns. Defaults to 0 (no rounding). + """ + return LOOP.run( + super().list_table_rows( + table_type, + table_id, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + columns=columns, + where=where, + search_query=search_query, + search_columns=search_columns, + float_decimals=float_decimals, + vec_decimals=vec_decimals, + **kwargs, + ) + ) + + def get_table_row( + self, + table_type: str, + table_id: str, + row_id: str, + *, + columns: list[str] | None = None, + float_decimals: int = 0, + vec_decimals: int = 0, + **kwargs, + ) -> dict[str, Any]: + """ + Get a specific row in a table. + + Args: + table_type (str): The type of the table. + table_id (str): The ID of the table. + row_id (str): The ID of the row. + columns (list[str] | None, optional): List of column names to include in the response. + Defaults to None (all columns). + float_decimals (int, optional): Number of decimals for float values. + Defaults to 0 (no rounding). + vec_decimals (int, optional): Number of decimals for vectors. + If its negative, exclude vector columns. Defaults to 0 (no rounding). Returns: - response (Page[TableMetaResponse]): The paginated table metadata response. + response (dict[str, Any]): The row data. """ - return await self.table.list_tables( - table_type, - offset=offset, - limit=limit, - parent_id=parent_id, - search_query=search_query, - order_by=order_by, - order_descending=order_descending, - count_rows=count_rows, + return LOOP.run( + super().get_table_row( + table_type, + table_id, + row_id, + columns=columns, + float_decimals=float_decimals, + vec_decimals=vec_decimals, + **kwargs, + ) ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def delete_table( + @deprecated( + "This method is deprecated, use `get_conversation_threads` instead.", + category=FutureWarning, + stacklevel=1, + ) + def get_conversation_thread( self, - table_type: str | TableType, + table_type: str, table_id: str, + column_id: str, *, - missing_ok: bool = True, - ) -> OkResponse: + row_id: str = "", + include: bool = True, + **kwargs, + ) -> ChatThreadResponse: """ - Delete a specific table. + Get the conversation thread for a column in a table. Args: - table_type (str | TableType): The type of the table. - table_id (str): The ID of the table. - missing_ok (bool, optional): Ignore resource not found error. + table_type (str): The type of the table. + table_id (str): ID / name of the chat table. + column_id (str): ID / name of the column to fetch. + row_id (str, optional): ID / name of the last row in the thread. + Defaults to "" (export all rows). + include (bool, optional): Whether to include the row specified by `row_id`. + Defaults to True. Returns: - response (OkResponse): The response indicating success. + response (ChatThreadResponse): The conversation thread. """ - return await self.table.delete_table(table_type, table_id, missing_ok=missing_ok) + return LOOP.run( + super().get_conversation_thread( + table_type, + table_id, + column_id, + row_id=row_id, + include=include, + **kwargs, + ) + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def duplicate_table( + def get_conversation_threads( self, - table_type: str | TableType, - table_id_src: str, - table_id_dst: str | None = None, + table_type: str, + table_id: str, + column_ids: list[str] | None = None, *, - include_data: bool = True, - create_as_child: bool = False, + row_id: str = "", + include_row: bool = True, **kwargs, - ) -> TableMetaResponse: + ) -> ChatThreadsResponse: """ - Duplicate a table. + Get all multi-turn / conversation threads from a table. Args: - table_type (str | TableType): The type of the table. - table_id_src (str): The source table ID. - table_id_dst (str | None, optional): The destination / new table ID. - Defaults to None (create a new table ID automatically). - include_data (bool, optional): Whether to include data in the duplicated table. Defaults to True. - create_as_child (bool, optional): Whether the new table is a child table. - If this is True, then `include_data` will be set to True. Defaults to False. + table_type (str): The type of the table. + table_id (str): ID / name of the chat table. + column_ids (list[str] | None): Columns to fetch as conversation threads. + row_id (str, optional): ID / name of the last row in the thread. + Defaults to "" (export all rows). + include_row (bool, optional): Whether to include the row specified by `row_id`. + Defaults to True. Returns: - response (TableMetaResponse): The table metadata response. + response (ChatThreadsResponse): The conversation threads. """ - return await self.table.duplicate_table( - table_type, - table_id_src, - table_id_dst, - include_data=include_data, - create_as_child=create_as_child, - **kwargs, + return LOOP.run( + super().get_conversation_threads( + table_type, + table_id, + column_ids, + row_id=row_id, + include_row=include_row, + **kwargs, + ) ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def rename_table( + def hybrid_search( self, - table_type: str | TableType, - table_id_src: str, - table_id_dst: str, - ) -> TableMetaResponse: + table_type: str, + request: SearchRequest, + **kwargs, + ) -> list[dict[str, Any]]: """ - Rename a table. + Perform a hybrid search on a table. Args: - table_type (str | TableType): The type of the table. - table_id_src (str): The source table ID. - table_id_dst (str): The destination / new table ID. + table_type (str): The type of the table. + request (SearchRequest): The search request. Returns: - response (TableMetaResponse): The table metadata response. + response (list[dict[str, Any]]): The search results. """ - return await self.table.rename_table(table_type, table_id_src, table_id_dst) + return LOOP.run(super().hybrid_search(table_type, request, **kwargs)) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def update_gen_config( + def regen_table_rows( self, - table_type: str | TableType, - request: GenConfigUpdateRequest, - ) -> TableMetaResponse: + table_type: str, + request: MultiRowRegenRequest, + **kwargs, + ) -> ( + MultiRowCompletionResponse + | Generator[CellReferencesResponse | CellCompletionResponse, None, None] + ): """ - Update the generation configuration for a table. + Regenerate rows in a table. Args: - table_type (str | TableType): The type of the table. - request (GenConfigUpdateRequest): The generation configuration update request. + table_type (str): The type of the table. + request (MultiRowRegenRequest): The row regenerate request. Returns: - response (TableMetaResponse): The table metadata response. + response (MultiRowCompletionResponse | AsyncGenerator): The row completion. + In streaming mode, it is an async generator that yields a `CellReferencesResponse` object + followed by zero or more `CellCompletionResponse` objects. + In non-streaming mode, it is a `MultiRowCompletionResponse` object. """ - return await self.table.update_gen_config(table_type, request) + agen = LOOP.run(super().regen_table_rows(table_type, request, **kwargs)) + return self._return_iterator(agen, request.stream) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def add_action_columns(self, request: AddActionColumnSchema) -> TableMetaResponse: + def update_table_rows( + self, + table_type: str, + request: MultiRowUpdateRequest, + **kwargs, + ) -> OkResponse: """ - Add columns to an Action Table. + Update rows in a table. Args: - request (AddActionColumnSchema): The action column schema. + table_type (str): The type of the table. + request (MultiRowUpdateRequest): The row update request. Returns: - response (TableMetaResponse): The table metadata response. + response (OkResponse): The response indicating success. """ - return await self.table.add_action_columns(request) + return LOOP.run(super().update_table_rows(table_type, request, **kwargs)) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def add_knowledge_columns(self, request: AddKnowledgeColumnSchema) -> TableMetaResponse: + @deprecated( + "This method is deprecated, use `update_table_rows` instead.", + category=FutureWarning, + stacklevel=1, + ) + def update_table_row( + self, + table_type: str, + request: RowUpdateRequest, + **kwargs, + ) -> OkResponse: """ - Add columns to a Knowledge Table. + Update a specific row in a table. Args: - request (AddKnowledgeColumnSchema): The knowledge column schema. + table_type (str): The type of the table. + request (RowUpdateRequest): The row update request. Returns: - response (TableMetaResponse): The table metadata response. + response (OkResponse): The response indicating success. """ - return await self.table.add_knowledge_columns(request) + return LOOP.run(super().update_table_row(table_type, request, **kwargs)) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def add_chat_columns(self, request: AddChatColumnSchema) -> TableMetaResponse: + def delete_table_rows( + self, + table_type: str, + request: MultiRowDeleteRequest, + **kwargs, + ) -> OkResponse: """ - Add columns to a Chat Table. + Delete rows from a table. Args: - request (AddChatColumnSchema): The chat column schema. + table_type (str): The type of the table. + request (MultiRowDeleteRequest): The row delete request. Returns: - response (TableMetaResponse): The table metadata response. + response (OkResponse): The response indicating success. """ - return await self.table.add_chat_columns(request) + return LOOP.run(super().delete_table_rows(table_type, request, **kwargs)) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def drop_columns( + @deprecated( + "This method is deprecated, use `delete_table_rows` instead.", + category=FutureWarning, + stacklevel=1, + ) + def delete_table_row( self, - table_type: str | TableType, - request: ColumnDropRequest, - ) -> TableMetaResponse: + table_type: str, + table_id: str, + row_id: str, + **kwargs, + ) -> OkResponse: """ - Drop columns from a table. + Delete a specific row from a table. Args: - table_type (str | TableType): The type of the table. - request (ColumnDropRequest): The column drop request. + table_type (str): The type of the table. + table_id (str): The ID of the table. + row_id (str): The ID of the row. Returns: - response (TableMetaResponse): The table metadata response. + response (OkResponse): The response indicating success. """ - return await self.table.drop_columns(table_type, request) + return LOOP.run(super().delete_table_row(table_type, table_id, row_id, **kwargs)) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def rename_columns( + def embed_file_options(self, **kwargs) -> httpx.Response: + """ + Get CORS preflight options for file embedding endpoint. + + Returns: + response (httpx.Response): The response containing options information. + """ + return LOOP.run(super().embed_file_options(**kwargs)) + + def embed_file( self, - table_type: str | TableType, - request: ColumnRenameRequest, - ) -> TableMetaResponse: + file_path: str, + table_id: str, + *, + chunk_size: int = 1000, + chunk_overlap: int = 200, + **kwargs, + ) -> OkResponse: """ - Rename columns in a table. + Embed a file into a Knowledge Table. Args: - table_type (str | TableType): The type of the table. - request (ColumnRenameRequest): The column rename request. + file_path (str): File path of the document to be embedded. + table_id (str): Knowledge Table ID / name. + chunk_size (int, optional): Maximum chunk size (number of characters). Must be > 0. + Defaults to 1000. + chunk_overlap (int, optional): Overlap in characters between chunks. Must be >= 0. + Defaults to 200. Returns: - response (TableMetaResponse): The table metadata response. + response (OkResponse): The response indicating success. """ - return await self.table.rename_columns(table_type, request) + return LOOP.run( + super().embed_file( + file_path, table_id, chunk_size=chunk_size, chunk_overlap=chunk_overlap, **kwargs + ) + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def reorder_columns( + # Import export + def import_table_data( self, - table_type: str | TableType, - request: ColumnReorderRequest, - ) -> TableMetaResponse: + table_type: str, + request: TableDataImportRequest, + **kwargs, + ) -> GenTableChatResponseType: """ - Reorder columns in a table. + Imports CSV or TSV data into a table. Args: - table_type (str | TableType): The type of the table. - request (ColumnReorderRequest): The column reorder request. + file_path (str): CSV or TSV file path. + table_type (str): Table type. + request (TableDataImportRequest): Data import request. Returns: - response (TableMetaResponse): The table metadata response. + response (OkResponse): The response indicating success. """ - return await self.table.reorder_columns(table_type, request) + agen = LOOP.run(super().import_table_data(table_type, request, **kwargs)) + return self._return_iterator(agen, request.stream) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def list_table_rows( + def export_table_data( self, - table_type: str | TableType, + table_type: str, table_id: str, *, - offset: int = 0, - limit: int = 100, - search_query: str = "", columns: list[str] | None = None, - float_decimals: int = 0, - vec_decimals: int = 0, - order_descending: bool = True, - ) -> Page[dict[str, Any]]: + delimiter: Literal[",", "\t"] = ",", + **kwargs, + ) -> bytes: """ - List rows in a table. + Exports the row data of a table as a CSV or TSV file. Args: - table_type (str | TableType): The type of the table. - table_id (str): The ID of the table. - offset (int, optional): Item offset. Defaults to 0. - limit (int, optional): Number of rows to return (min 1, max 100). Defaults to 100. - search_query (str, optional): A string to search for within the rows as a filter. - Defaults to "" (no filter). - columns (list[str] | None, optional): List of column names to include in the response. - Defaults to None (all columns). - float_decimals (int, optional): Number of decimals for float values. - Defaults to 0 (no rounding). - vec_decimals (int, optional): Number of decimals for vectors. - If its negative, exclude vector columns. Defaults to 0 (no rounding). - order_descending (bool, optional): Whether to sort by descending order. Defaults to True. + table_type (str): Table type. + table_id (str): ID or name of the table to be exported. + delimiter (str, optional): The delimiter of the file: can be "," or "\\t". Defaults to ",". + columns (list[str], optional): A list of columns to be exported. Defaults to None (export all columns). + + Returns: + response (list[dict[str, Any]]): The search results. """ - return await self.table.list_table_rows( - table_type, - table_id, - offset=offset, - limit=limit, - search_query=search_query, - columns=columns, - float_decimals=float_decimals, - vec_decimals=vec_decimals, - order_descending=order_descending, + return LOOP.run( + super().export_table_data( + table_type, table_id, columns=columns, delimiter=delimiter, **kwargs + ) ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def get_table_row( + def import_table( self, - table_type: str | TableType, - table_id: str, - row_id: str, - *, - columns: list[str] | None = None, - float_decimals: int = 0, - vec_decimals: int = 0, - ) -> dict[str, Any]: + table_type: str, + request: TableImportRequest, + **kwargs, + ) -> TableMetaResponse | OkResponse: """ - Get a specific row in a table. + Imports a table (data and schema) from a parquet file. Args: - table_type (str | TableType): The type of the table. - table_id (str): The ID of the table. - row_id (str): The ID of the row. - columns (list[str] | None, optional): List of column names to include in the response. - Defaults to None (all columns). - float_decimals (int, optional): Number of decimals for float values. - Defaults to 0 (no rounding). - vec_decimals (int, optional): Number of decimals for vectors. - If its negative, exclude vector columns. Defaults to 0 (no rounding). + file_path (str): The parquet file path. + table_type (str): Table type. + request (TableImportRequest): Table import request. Returns: - response (dict[str, Any]): The row data. + response (TableMetaResponse | OkResponse): The table metadata response if blocking is True, + otherwise OkResponse. """ - return await self.table.get_table_row( - table_type, - table_id, - row_id, - columns=columns, - float_decimals=float_decimals, - vec_decimals=vec_decimals, - ) + return LOOP.run(super().import_table(table_type, request, **kwargs)) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def add_table_rows( + def export_table( self, - table_type: str | TableType, - request: RowAddRequest, - ) -> ( - GenTableRowsChatCompletionChunks - | AsyncGenerator[GenTableStreamReferences | GenTableStreamChatCompletionChunk, None] - ): + table_type: str, + table_id: str, + **kwargs, + ) -> bytes: """ - Add rows to a table. + Exports a table (data and schema) as a parquet file. Args: - table_type (str | TableType): The type of the table. - request (RowAddRequest): The row add request. + table_type (str): Table type. + table_id (str): ID or name of the table to be exported. Returns: - response (GenTableRowsChatCompletionChunks | AsyncGenerator): The row completion. - In streaming mode, it is an async generator that yields a `GenTableStreamReferences` object - followed by zero or more `GenTableStreamChatCompletionChunk` objects. - In non-streaming mode, it is a `GenTableRowsChatCompletionChunks` object. + response (list[dict[str, Any]]): The search results. """ - return await self.table.add_table_rows(table_type, request) + return LOOP.run(super().export_table(table_type, table_id, **kwargs)) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def regen_table_rows( + +class _MeterClient(_MeterClientAsync): + def get_usage_metrics( self, - table_type: str | TableType, - request: RowRegenRequest, - ) -> ( - GenTableRowsChatCompletionChunks - | AsyncGenerator[GenTableStreamReferences | GenTableStreamChatCompletionChunk, None] - ): - """ - Regenerate rows in a table. + type, + from_, + window_size, + org_ids=None, + proj_ids=None, + to=None, + group_by=None, + data_source=None, + ) -> UsageResponse: + return LOOP.run( + super().get_usage_metrics( + type, from_, window_size, org_ids, proj_ids, to, group_by, data_source + ) + ) - Args: - table_type (str | TableType): The type of the table. - request (RowRegenRequest): The row regenerate request. + def get_billing_metrics( + self, + from_, + window_size, + org_ids=None, + proj_ids=None, + to=None, + group_by=None, + data_source=None, + ) -> UsageResponse: + return LOOP.run( + super().get_billing_metrics( + from_, window_size, org_ids, proj_ids, to, group_by, data_source + ) + ) - Returns: - response (GenTableRowsChatCompletionChunks | AsyncGenerator): The row completion. - In streaming mode, it is an async generator that yields a `GenTableStreamReferences` object - followed by zero or more `GenTableStreamChatCompletionChunk` objects. - In non-streaming mode, it is a `GenTableRowsChatCompletionChunks` object. - """ - return await self.table.regen_table_rows(table_type, request) + def get_bandwidth_metrics( + self, + from_, + window_size, + org_ids=None, + proj_ids=None, + to=None, + group_by=None, + data_source=None, + ) -> UsageResponse: + return LOOP.run( + super().get_bandwidth_metrics( + from_, window_size, org_ids, proj_ids, to, group_by, data_source + ) + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def update_table_row( + def get_storage_metrics( + self, from_, window_size, org_ids=None, proj_ids=None, to=None, group_by=None + ) -> UsageResponse: + return LOOP.run( + super().get_storage_metrics(from_, window_size, org_ids, proj_ids, to, group_by) + ) + + +class _TaskClient(_TaskClientAsync): + """Task methods.""" + + def get_progress( self, - table_type: str | TableType, - request: RowUpdateRequest, - ) -> OkResponse: - """ - Update a specific row in a table. + key: str, + **kwargs, + ) -> dict[str, Any]: + return LOOP.run(super().get_progress(key, **kwargs)) - Args: - table_type (str | TableType): The type of the table. - request (RowUpdateRequest): The row update request. + def poll_progress( + self, + key: str, + *, + initial_wait: float = 0.5, + max_wait: float = 30 * 60.0, + verbose: bool = False, + **kwargs, + ) -> dict[str, Any] | None: + from time import sleep + + i = 1 + t0 = perf_counter() + while (perf_counter() - t0) < max_wait: + sleep(min(initial_wait * i, 5.0)) + prog = self.get_progress(key, **kwargs) + state = prog.get("state", None) + error = prog.get("error", None) + if verbose: + logger.info( + f"{self.__class__.__name__}: Progress: key={key} state={state}" + + (f" error={error}" if error else "") + ) + if state == ProgressState.COMPLETED: + return prog + elif state == ProgressState.FAILED: + raise JamaiException(prog.get("error", "Unknown error")) + i += 1 + return None + + # def poll_progress( + # self, + # key: str, + # *, + # initial_wait: float = 0.5, + # max_wait: float = 30 * 60.0, + # **kwargs, + # ) -> dict[str, Any] | None: + # return LOOP.run( + # super().poll_progress( + # key, + # initial_wait=initial_wait, + # max_wait=max_wait, + # **kwargs, + # ) + # ) + + +class _ConversationClient(_ConversationClientAsync): + """Conversation methods (synchronous version).""" + + def create_conversation( + self, + request: ConversationCreateRequest, + **kwargs, + ) -> Generator[ + ConversationMetaResponse | CellReferencesResponse | CellCompletionResponse, None, None + ]: + agen = LOOP.run(super().create_conversation(request, **kwargs)) + return self._return_iterator(agen, True) - Returns: - response (OkResponse): The response indicating success. - """ - return await self.table.update_table_row(table_type, request) + def list_conversations( + self, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + search_query: str = "", + **kwargs, + ) -> Page[ConversationMetaResponse]: + return LOOP.run( + super().list_conversations( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + search_query=search_query, + **kwargs, + ) + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def delete_table_rows( + def list_agents( self, - table_type: str | TableType, - request: RowDeleteRequest, - ) -> OkResponse: - """ - Delete rows from a table. + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + search_query: str = "", + **kwargs, + ) -> Page[ConversationMetaResponse]: + return LOOP.run( + super().list_agents( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + search_query=search_query, + **kwargs, + ) + ) - Args: - table_type (str | TableType): The type of the table. - request (RowDeleteRequest): The row delete request. + def get_conversation(self, conversation_id: str, **kwargs) -> ConversationMetaResponse: + return LOOP.run(super().get_conversation(conversation_id, **kwargs)) - Returns: - response (OkResponse): The response indicating success. - """ - return await self.table.delete_table_rows(table_type, request) + def get_agent(self, agent_id: str, **kwargs) -> AgentMetaResponse: + return LOOP.run(super().get_agent(agent_id, **kwargs)) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def delete_table_row( + def generate_title( self, - table_type: str | TableType, - table_id: str, - row_id: str, + conversation_id: str, + **kwargs, + ) -> ConversationMetaResponse: + """Generates a title for a conversation.""" + return LOOP.run(super().generate_title(conversation_id, **kwargs)) + + def rename_conversation_title( + self, + conversation_id: str, + title: str, + **kwargs, + ) -> ConversationMetaResponse: + return LOOP.run(super().rename_conversation_title(conversation_id, title, **kwargs)) + + def delete_conversation( + self, + conversation_id: str, + *, + missing_ok: bool = True, + **kwargs, ) -> OkResponse: - """ - Delete a specific row from a table. + return LOOP.run( + super().delete_conversation(conversation_id, missing_ok=missing_ok, **kwargs) + ) - Args: - table_type (str | TableType): The type of the table. - table_id (str): The ID of the table. - row_id (str): The ID of the row. + def send_message( + self, + request: MessageAddRequest, + **kwargs, + ) -> Generator[CellReferencesResponse | CellCompletionResponse, None, None]: + agen = LOOP.run(super().send_message(request, **kwargs)) + return self._return_iterator(agen, True) - Returns: - response (OkResponse): The response indicating success. - """ - return await self.table.delete_table_row(table_type, table_id, row_id) + def list_messages( + self, + conversation_id: str, + offset: int = 0, + limit: int = 100, + order_by: str = "ID", + order_ascending: bool = True, + columns: list[str] | None = None, + search_query: str = "", + search_columns: list[str] | None = None, + float_decimals: int = 0, + vec_decimals: int = 0, + **kwargs, + ) -> Page[dict[str, Any]]: + return LOOP.run( + super().list_messages( + conversation_id=conversation_id, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + columns=columns, + search_query=search_query, + search_columns=search_columns, + float_decimals=float_decimals, + vec_decimals=vec_decimals, + **kwargs, + ) + ) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def get_conversation_thread( + def regen_message( self, - table_type: str | TableType, - table_id: str, - column_id: str, - row_id: str = "", - include: bool = True, - ) -> ChatThread: + request: MessagesRegenRequest, + **kwargs, + ) -> Generator[CellReferencesResponse | CellCompletionResponse, None, None]: + """Regenerates a message in a conversation and streams back the response.""" + agen = LOOP.run(super().regen_message(request, **kwargs)) + return self._return_iterator(agen, True) + + def update_message( + self, + request: MessageUpdateRequest, + **kwargs, + ) -> OkResponse: + """Updates a specific message within a conversation.""" + return LOOP.run(super().update_message(request, **kwargs)) + + def get_threads( + self, + conversation_id: str, + column_ids: list[str] | None = None, + **kwargs, + ) -> ConversationThreadsResponse: """ - Get the conversation thread for a chat table. + Get all threads from a conversation. Args: - table_type (str | TableType): The type of the table. - table_id (str): ID / name of the chat table. - column_id (str): ID / name of the column to fetch. - row_id (str, optional): ID / name of the last row in the thread. - Defaults to "" (export all rows). - include (bool, optional): Whether to include the row specified by `row_id`. - Defaults to True. + conversation_id (str): Conversation ID. + column_ids (list[str] | None): Columns to fetch as conversation threads. Returns: - response (ChatThread): The conversation thread. + response (ConversationThreadsResponse): The conversation threads. """ - return await self.table.get_conversation_thread( - table_type, table_id, column_id, row_id=row_id, include=include - ) + return LOOP.run(super().get_threads(conversation_id, column_ids, **kwargs)) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def hybrid_search( + +class JamAI(JamAIAsync): + def __init__( self, - table_type: str | TableType, - request: SearchRequest, - ) -> list[dict[str, Any]]: + project_id: str = ENV_CONFIG.project_id, + token: str = ENV_CONFIG.token_plain, + api_base: str = ENV_CONFIG.api_base, + headers: dict | None = None, + timeout: float | None = ENV_CONFIG.timeout_sec, + file_upload_timeout: float | None = ENV_CONFIG.file_upload_timeout_sec, + *, + user_id: str = "", + ) -> None: """ - Perform a hybrid search on a table. + Initialize the JamAI client. Args: - table_type (str | TableType): The type of the table. - request (SearchRequest): The search request. - - Returns: - response (list[dict[str, Any]]): The search results. + project_id (str, optional): The project ID. + Defaults to "default", but can be overridden via + `JAMAI_PROJECT_ID` var in environment or `.env` file. + token (str, optional): Your Personal Access Token or organization API key (deprecated) for authentication. + Defaults to "", but can be overridden via + `JAMAI_TOKEN` var in environment or `.env` file. + api_base (str, optional): The base URL for the API. + Defaults to "https://api.jamaibase.com/api", but can be overridden via + `JAMAI_API_BASE` var in environment or `.env` file. + headers (dict | None, optional): Additional headers to include in requests. + Defaults to None. + timeout (float | None, optional): The timeout to use when sending requests. + Defaults to 15 minutes, but can be overridden via + `JAMAI_TIMEOUT_SEC` var in environment or `.env` file. + file_upload_timeout (float | None, optional): The timeout to use when sending file upload requests. + Defaults to 60 minutes, but can be overridden via + `JAMAI_FILE_UPLOAD_TIMEOUT_SEC` var in environment or `.env` file. + user_id (str, optional): User ID. For development purposes. + Defaults to "". """ - return await self.table.hybrid_search(table_type, request) + super().__init__( + project_id=project_id, + token=token, + api_base=api_base, + headers=headers, + timeout=timeout, + file_upload_timeout=file_upload_timeout, + user_id=user_id, + ) + kwargs = dict( + user_id=self.user_id, + project_id=self.project_id, + token=self.token, + api_base=self.api_base, + headers=self.headers, + http_client=self.http_client, + timeout=self.timeout, + file_upload_timeout=self.file_upload_timeout, + ) + self.auth = _Auth(**kwargs) + self.prices = _Prices(**kwargs) + self.users = _Users(**kwargs) + self.models = _Models(**kwargs) + self.organizations = _Organizations(**kwargs) + self.projects = _Projects(**kwargs) + self.templates = _Templates(**kwargs) + self.file = _FileClient(**kwargs) + self.table = _GenTableClient(**kwargs) + self.meters = _MeterClient(**kwargs) + self.tasks = _TaskClient(**kwargs) + self.conversations = _ConversationClient(**kwargs) - @deprecated( - "This method is deprecated, use `client.table.embed_file_options` instead.", - category=FutureWarning, - stacklevel=1, - ) - async def upload_file_options(self) -> httpx.Response: + def health(self) -> dict[str, Any]: """ - Get options for uploading a file to a Knowledge Table. + Get health status. Returns: - response (httpx.Response): The response containing options information. + response (dict[str, Any]): Health status. """ - return await self.table.embed_file_options() + return LOOP.run(super().health()) - @deprecated( - "This method is deprecated, use `client.table.embed_file` instead.", - category=FutureWarning, - stacklevel=1, - ) - async def upload_file(self, request: FileUploadRequest) -> OkResponse: + # --- Models and chat --- # + + def model_info( + self, + model: str = "", + capabilities: list[ + Literal["completion", "chat", "image", "audio", "document", "tool", "embed", "rerank"] + ] + | None = None, + **kwargs, + ) -> ModelInfoListResponse: """ - Upload a file to a Knowledge Table. + Get information about available models. Args: - request (FileUploadRequest): The file upload request. + name (str, optional): The model name. Defaults to "". + capabilities (list[Literal["completion", "chat", "image", "audio", "document", "tool", "embed", "rerank"]] | None, optional): + List of model capabilities to filter by. Defaults to None. Returns: - response (OkResponse): The response indicating success. + response (ModelInfoListResponse): The model information response. """ - return await self.table.embed_file(request) + return LOOP.run(super().model_info(model=model, capabilities=capabilities, **kwargs)) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def import_table_data( + def model_ids( self, - table_type: str | TableType, - request: TableDataImportRequest, - ) -> GenTableChatResponseType: + prefer: str = "", + capabilities: list[ + Literal["completion", "chat", "image", "audio", "document", "tool", "embed", "rerank"] + ] + | None = None, + **kwargs, + ) -> list[str]: """ - Imports CSV or TSV data into a table. + Get the IDs of available models. Args: - file_path (str): CSV or TSV file path. - table_type (str | TableType): Table type. - request (TableDataImportRequest): Data import request. + prefer (str, optional): Preferred model ID. Defaults to "". + capabilities (list[Literal["completion", "chat", "image", "audio", "document", "tool", "embed", "rerank"]] | None, optional): + List of model capabilities to filter by. Defaults to None. Returns: - response (OkResponse): The response indicating success. + response (list[str]): List of model IDs. """ - return await self.table.import_table_data(table_type, request) + return LOOP.run(super().model_ids(prefer=prefer, capabilities=capabilities, **kwargs)) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def export_table_data( + @deprecated( + "This method is deprecated, use `model_ids` instead.", category=FutureWarning, stacklevel=1 + ) + def model_names( self, - table_type: str | TableType, - table_id: str, - columns: list[str] | None = None, - delimiter: Literal[",", "\t"] = ",", - ) -> bytes: + prefer: str = "", + capabilities: list[ + Literal["completion", "chat", "image", "audio", "document", "tool", "embed", "rerank"] + ] + | None = None, + **kwargs, + ) -> list[str]: + return self.model_ids(prefer=prefer, capabilities=capabilities, **kwargs) + + def generate_chat_completions( + self, + request: ChatRequest, + **kwargs, + ) -> ChatCompletionResponse | Generator[References | ChatCompletionChunkResponse, None, None]: """ - Exports the row data of a table as a CSV or TSV file. + Generates chat completions. Args: - table_type (str | TableType): Table type. - table_id (str): ID or name of the table to be exported. - delimiter (str, optional): The delimiter of the file: can be "," or "\\t". Defaults to ",". - columns (list[str], optional): A list of columns to be exported. Defaults to None (export all columns). + request (ChatRequest): The request. Returns: - response (list[dict[str, Any]]): The search results. + completion (ChatCompletionChunkResponse | AsyncGenerator): The chat completion. + In streaming mode, it is an async generator that yields a `References` object + followed by zero or more `ChatCompletionChunkResponse` objects. + In non-streaming mode, it is a `ChatCompletionChunkResponse` object. """ - return await self.table.export_table_data( - table_type, table_id, columns=columns, delimiter=delimiter - ) + agen = LOOP.run(super().generate_chat_completions(request=request, **kwargs)) + return self._return_iterator(agen, request.stream) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def import_table( + def generate_embeddings( self, - table_type: str | TableType, - request: TableImportRequest, - ) -> TableMetaResponse: + request: EmbeddingRequest, + **kwargs, + ) -> EmbeddingResponse: """ - Imports a table (data and schema) from a parquet file. + Generate embeddings for the given input. Args: - file_path (str): The parquet file path. - table_type (str | TableType): Table type. - request (TableImportRequest): Table import request. + request (EmbeddingRequest): The embedding request. Returns: - response (TableMetaResponse): The table metadata response. + response (EmbeddingResponse): The embedding response. """ - return await self.table.import_table(table_type, request) + return LOOP.run(super().generate_embeddings(request=request, **kwargs)) - @deprecated(TABLE_METHOD_DEPRECATE, category=FutureWarning, stacklevel=1) - async def export_table( - self, - table_type: str | TableType, - table_id: str, - ) -> bytes: + def rerank(self, request: RerankingRequest, **kwargs) -> RerankingResponse: """ - Exports a table (data and schema) as a parquet file. + Generate similarity rankings for the given query and documents. Args: - table_type (str | TableType): Table type. - table_id (str): ID or name of the table to be exported. + request (RerankingRequest): The reranking request body. Returns: - response (list[dict[str, Any]]): The search results. + RerankingResponse: The reranking response. """ - return await self.table.export_table(table_type, table_id) + return LOOP.run(super().rerank(request=request, **kwargs)) diff --git a/clients/python/src/jamaibase/protocol.py b/clients/python/src/jamaibase/protocol.py index 1cc79ec..25f4014 100644 --- a/clients/python/src/jamaibase/protocol.py +++ b/clients/python/src/jamaibase/protocol.py @@ -1,2385 +1,9 @@ -""" -NOTES: - -- Pydantic supports setting mutable values as default. - This is in contrast to native `dataclasses` where it is not supported. - -- Pydantic supports setting default fields in any order. - This is in contrast to native `dataclasses` where fields with default values must be defined after non-default fields. -""" - -from __future__ import annotations - -import re -from datetime import datetime -from decimal import Decimal -from enum import Enum, EnumMeta -from os.path import splitext -from typing import Annotated, Any, Generic, Literal, Sequence, TypeVar, Union from warnings import warn -import numpy as np -from pydantic import ( - BaseModel, - ConfigDict, - Discriminator, - Field, - Tag, - computed_field, - field_validator, - model_validator, -) -from pydantic.functional_validators import AfterValidator -from typing_extensions import Self, deprecated - -from jamaibase.utils import datetime_now_iso -from jamaibase.version import __version__ as jamaibase_version - -PositiveInt = Annotated[int, Field(ge=0, description="Positive integer.")] -PositiveNonZeroInt = Annotated[int, Field(gt=0, description="Positive non-zero integer.")] - - -def sanitise_document_id(v: str) -> str: - if v.startswith('"') and v.endswith('"'): - v = v[1:-1] - return v - - -def sanitise_document_id_list(v: list[str]) -> list[str]: - return [sanitise_document_id(vv) for vv in v] - - -DocumentID = Annotated[str, AfterValidator(sanitise_document_id)] -DocumentIDList = Annotated[list[str], AfterValidator(sanitise_document_id_list)] - -EXAMPLE_CHAT_MODEL_IDS = ["openai/gpt-4o-mini"] -# for openai embedding models doc: https://platform.openai.com/docs/guides/embeddings -# for cohere embedding models doc: https://docs.cohere.com/reference/embed -# for jina embedding models doc: https://jina.ai/embeddings/ -# for voyage embedding models doc: https://docs.voyageai.com/docs/embeddings -# for hf embedding models doc: check the respective hf model page, name should be ellm/{org}/{model} -EXAMPLE_EMBEDDING_MODEL_IDS = [ - "openai/text-embedding-3-small-512", - "ellm/sentence-transformers/all-MiniLM-L6-v2", -] -# for cohere reranking models doc: https://docs.cohere.com/reference/rerank-1 -# for jina reranking models doc: https://jina.ai/reranker -# for colbert reranking models doc: https://docs.voyageai.com/docs/reranker -# for hf embedding models doc: check the respective hf model page, name should be ellm/{org}/{model} -EXAMPLE_RERANKING_MODEL_IDS = [ - "cohere/rerank-multilingual-v3.0", - "ellm/cross-encoder/ms-marco-TinyBERT-L-2", -] - -IMAGE_FILE_EXTENSIONS = [".jpeg", ".jpg", ".png", ".gif", ".webp"] -AUDIO_FILE_EXTENSIONS = [".mp3", ".wav"] -DOCUMENT_FILE_EXTENSIONS = [ - ".pdf", - ".txt", - ".md", - ".docx", - ".xml", - ".html", - ".json", - ".csv", - ".tsv", - ".jsonl", - ".xlsx", - ".xls", -] - - -class OkResponse(BaseModel): - ok: bool = True - - -class StringResponse(BaseModel): - object: Literal["string"] = Field( - default="string", - description='The object type, which is always "string".', - examples=["string"], - ) - data: str = Field( - description="The string data.", - examples=["text"], - ) - - -class AdminOrderBy(str, Enum): - ID = "id" - """Sort by `id` column.""" - NAME = "name" - """Sort by `name` column.""" - CREATED_AT = "created_at" - """Sort by `created_at` column.""" - UPDATED_AT = "updated_at" - """Sort by `updated_at` column.""" - - def __str__(self) -> str: - return self.value - - -class GenTableOrderBy(str, Enum): - ID = "id" - """Sort by `id` column.""" - UPDATED_AT = "updated_at" - """Sort by `updated_at` column.""" - - def __str__(self) -> str: - return self.value - - -class Tier(BaseModel): - """ - https://docs.stripe.com/api/prices/object#price_object-tiers - """ - - unit_amount_decimal: Decimal = Field( - description="Per unit price for units relevant to the tier.", - ) - up_to: float | None = Field( - description=( - "Up to and including to this quantity will be contained in the tier. " - "None means infinite quantity." - ), - ) - - -class Product(BaseModel): - name: str = Field( - min_length=1, - description="Plan name.", - ) - included: Tier = Tier(unit_amount_decimal=0, up_to=0) - tiers: list[Tier] - unit: str = Field( - description="Unit of measurement.", - ) - - -class Plan(BaseModel): - name: str - stripe_price_id_live: str - stripe_price_id_test: str - flat_amount_decimal: Decimal = Field( - description="Base price for the entire tier.", - ) - credit_grant: float = Field( - description="Credit amount included in USD.", - ) - max_users: int = Field( - description="Maximum number of users per organization.", - ) - products: dict[str, Product] = Field( - description="Mapping of price name to tier list where each element represents a pricing tier.", - ) - - -class Price(BaseModel): - plans: dict[str, Plan] = Field( - description="Mapping of price plan name to price plan.", - ) - - -class _ModelPrice(BaseModel): - id: str = Field( - description=( - 'Unique identifier in the form of "{provider}/{model_id}". ' - "Users will specify this to select a model." - ), - examples=[ - EXAMPLE_CHAT_MODEL_IDS[0], - EXAMPLE_EMBEDDING_MODEL_IDS[0], - EXAMPLE_RERANKING_MODEL_IDS[0], - ], - ) - name: str = Field( - description="Name of the model.", - examples=["OpenAI GPT-4o Mini"], - ) - - -class LLMModelPrice(_ModelPrice): - input_cost_per_mtoken: float = Field( - description="Cost in USD per million input / prompt token.", - ) - output_cost_per_mtoken: float = Field( - description="Cost in USD per million output / completion token.", - ) - - -class EmbeddingModelPrice(_ModelPrice): - cost_per_mtoken: float = Field( - description="Cost in USD per million embedding tokens.", - ) - - -class RerankingModelPrice(_ModelPrice): - cost_per_ksearch: float = Field(description="Cost in USD for a thousand (kilo) searches.") - - -class ModelPrice(BaseModel): - object: str = Field( - default="prices.models", - description="Type of API response object.", - examples=["prices.models"], - ) - llm_models: list[LLMModelPrice] = [] - embed_models: list[EmbeddingModelPrice] = [] - rerank_models: list[RerankingModelPrice] = [] - - -class _OrgMemberBase(BaseModel): - user_id: str = Field(description="User ID. Must be unique.") - organization_id: str = Field( - default="", - description="Organization ID. Must be unique.", - ) - role: Literal["admin", "member", "guest"] = "admin" - """User role.""" - - -class OrgMemberCreate(_OrgMemberBase): - invite_token: str = Field( - default="", - description="User-org link creation datetime (ISO 8601 UTC).", - ) - - -class OrgMemberRead(_OrgMemberBase): - created_at: str = Field( - description="User-org link creation datetime (ISO 8601 UTC).", - ) - updated_at: str = Field( - description="User-org link update datetime (ISO 8601 UTC).", - ) - organization_name: str = "" - """Organization name. To be populated later.""" - - -class UserUpdate(BaseModel): - id: str - """User ID. Must be unique.""" - name: str | None = None - """The user's full name or business name.""" - description: str | None = None - """An arbitrary string that you can attach to a customer object.""" - email: Annotated[str, Field(min_length=1, max_length=512)] | None = None - """User's email address. This may be up to 512 characters.""" - meta: dict | None = None - """ - Additional metadata about the user. - """ - - -class UserCreate(BaseModel): - id: str - """User ID. Must be unique.""" - name: str - """The user's full name or business name.""" - description: str = "" - """An arbitrary string that you can attach to a customer object.""" - email: Annotated[str, Field(min_length=1, max_length=512)] - """User's email address. This may be up to 512 characters.""" - meta: dict = {} - """ - Additional metadata about the user. - """ - - -class UserRead(UserCreate): - created_at: str = Field(description="User creation datetime (ISO 8601 UTC).") - updated_at: str = Field(description="User update datetime (ISO 8601 UTC).") - member_of: list[OrgMemberRead] - """List of organizations that this user is associated with and their role.""" - - -class PATCreate(BaseModel): - user_id: str = Field(description="User ID.") - expiry: str = Field( - default="", - description="PAT expiry datetime (ISO 8601 UTC). If empty, never expires.", - ) - - -class PATRead(PATCreate): - id: str = Field(description="The token.") - created_at: str = Field(description="Creation datetime (ISO 8601 UTC).") - # user: UserRead = Field(description="User that this Personal Access Token is associated with.") - - -class ProjectCreate(BaseModel): - name: str = Field( - description="Project name.", - ) - organization_id: str = Field( - description="Organization ID.", - ) - - -class ProjectUpdate(BaseModel): - id: str - """Project ID.""" - name: str | None = Field( - default=None, - description="Project name.", - ) - - -class ProjectRead(ProjectCreate): - id: str = Field( - description="Project ID.", - ) - created_at: str = Field( - description="Project creation datetime (ISO 8601 UTC).", - ) - updated_at: str = Field( - description="Project update datetime (ISO 8601 UTC).", - ) - organization: Union["OrganizationRead", None] = Field( - default=None, - description="Organization that this project is associated with.", - ) - - -class OrganizationCreate(BaseModel): - creator_user_id: str = Field( - default="", - description="User that created this organization.", - ) - name: str = Field( - description="Organization name.", - ) - external_keys: dict[str, str] = Field( - default={}, - description="Mapping of service provider to its API key.", - ) - tier: str = Field( - default="", - description="Subscribed tier.", - ) - active: bool = Field( - default=True, - description="Whether the organization's quota is active (paid).", - ) - timezone: str | None = Field( - default=None, - description="Timezone specifier.", - ) - credit: float = Field( - default=0.0, - description="Credit paid by the customer. Unused credit will be carried forward to the next billing cycle.", - ) - credit_grant: float = Field( - default=0.0, - description="Credit granted to the customer. Unused credit will NOT be carried forward.", - ) - llm_tokens_quota_mtok: float = Field( - default=0.0, - description="LLM token quota in millions of tokens.", - ) - llm_tokens_usage_mtok: float = Field( - default=0.0, - description="LLM token usage in millions of tokens.", - ) - embedding_tokens_quota_mtok: float = Field( - default=0.0, - description="Embedding token quota in millions of tokens", - ) - embedding_tokens_usage_mtok: float = Field( - default=0.0, - description="Embedding token quota in millions of tokens", - ) - reranker_quota_ksearch: float = Field( - default=0.0, - description="Reranker quota for every thousand searches", - ) - reranker_usage_ksearch: float = Field( - default=0.0, - description="Reranker usage for every thousand searches", - ) - db_quota_gib: float = Field( - default=0.0, - description="DB storage quota in GiB.", - ) - db_usage_gib: float = Field( - default=0.0, - description="DB storage usage in GiB.", - ) - file_quota_gib: float = Field( - default=0.0, - description="File storage quota in GiB.", - ) - file_usage_gib: float = Field( - default=0.0, - description="File storage usage in GiB.", - ) - egress_quota_gib: float = Field( - default=0.0, - description="Egress quota in GiB.", - ) - egress_usage_gib: float = Field( - default=0.0, - description="Egress usage in GiB.", - ) - models: dict[str, Any] = Field( - default={}, - description="The organization's custom model list, in addition to the provided default list.", - ) - - -class OrganizationRead(OrganizationCreate): - id: str = Field( - description="Organization ID.", - ) - quota_reset_at: str = Field( - default="", - description="Previous quota reset date. Could be used as event key.", - ) - stripe_id: str | None = Field( - default=None, - description="Organization Stripe ID.", - ) - openmeter_id: str | None = Field( - default=None, - description="Organization OpenMeter ID.", - ) - created_at: str = Field( - description="Organization creation datetime (ISO 8601 UTC).", - ) - updated_at: str = Field( - description="Organization update datetime (ISO 8601 UTC).", - ) - members: list[OrgMemberRead] | None = Field( - default=None, - description="List of organization members and roles.", - ) - api_keys: list["ApiKeyRead"] | None = Field( - default=None, - description="List of API keys.", - ) - projects: list[ProjectRead] | None = Field( - default=None, - description="List of projects.", - ) - quotas: dict[str, dict[str, float]] = Field( - default=None, - description="Entitlements.", - ) - - -class OrganizationUpdate(BaseModel): - id: str - """Organization ID.""" - name: str | None = None - """Organization name.""" - external_keys: dict[str, str] | None = Field( - default=None, - description="Mapping of service provider to its API key.", - ) - credit: float | None = Field( - default=None, - description="Credit paid by the customer. Unused credit will be carried forward to the next billing cycle.", - ) - credit_grant: float | None = Field( - default=None, - description="Credit granted to the customer. Unused credit will NOT be carried forward.", - ) - llm_tokens_quota_mtok: float | None = Field( - default=None, - description="LLM token quota in millions of tokens.", - ) - llm_tokens_usage_mtok: float | None = Field( - default=None, - description="LLM token usage in millions of tokens.", - ) - embedding_tokens_quota_mtok: float | None = Field( - default=None, - description="Embedding token quota in millions of tokens", - ) - embedding_tokens_usage_mtok: float | None = Field( - default=None, - description="Embedding token quota in millions of tokens", - ) - reranker_quota_ksearch: float | None = Field( - default=None, - description="Reranker quota for every thousand searches", - ) - reranker_usage_ksearch: float | None = Field( - default=None, - description="Reranker usage for every thousand searches", - ) - db_quota_gib: float | None = Field( - default=None, - description="DB storage quota in GiB.", - ) - db_usage_gib: float | None = Field( - default=None, - description="DB storage usage in GiB.", - ) - file_quota_gib: float | None = Field( - default=None, - description="File storage quota in GiB.", - ) - file_usage_gib: float | None = Field( - default=None, - description="File storage usage in GiB.", - ) - egress_quota_gib: float | None = Field( - default=None, - description="Egress quota in GiB.", - ) - egress_usage_gib: float | None = Field( - default=None, - description="Egress usage in GiB.", - ) - tier: str | None = Field( - default=None, - description="Subscribed tier.", - ) - active: bool | None = Field( - default=None, - description="Whether the organization's quota is active (paid).", - ) - timezone: str | None = Field(default=None) - """ - Timezone specifier. - """ - stripe_id: str | None = Field(default=None) - """Organization Stripe ID.""" - openmeter_id: str | None = Field(default=None) - """Organization OpenMeter ID.""" - - -class ApiKeyCreate(BaseModel): - organization_id: str = Field(description="Organization ID.") - - -class ApiKeyRead(ApiKeyCreate): - id: str = Field(description="The key.") - created_at: str = Field(description="Creation datetime (ISO 8601 UTC).") - - -class EventCreate(BaseModel): - id: str = Field( - min_length=1, - description="Event ID for idempotency. Must be unique.", - ) - organization_id: str = Field( - description="Organization ID.", - ) - deltas: dict[str, float | int] = Field( - default={}, - description="Delta changes to the values.", - ) - values: dict[str, float | int] = Field( - default={}, - description="New values (in-place update). Note that this will override any delta changes.", - ) - pending: bool = Field( - default=False, - description="Whether the event is pending (in-progress)", - ) - meta: dict[str, Any] = Field( - default={}, - description="Metadata.", - ) - - -class EventRead(EventCreate): - created_at: str = Field( - description="Event creation datetime (ISO 8601 UTC).", - ) - - -class TemplateTag(BaseModel): - id: str = Field(description="Tag ID.") - - -class Template(BaseModel): - id: str = Field(description="Template ID.") - name: str = Field(description="Template name.") - created_at: str = Field(description="Template creation datetime (ISO 8601 UTC).") - tags: list[TemplateTag] = Field(description="List of template tags") - - -class Chunk(BaseModel): - """Class for storing a piece of text and associated metadata.""" - - text: str = Field(description="Chunk text.") - title: str = Field(default="", description='Document title. Defaults to "".') - page: int | None = Field(default=None, description="Document page the chunk text from.") - file_name: str = Field(default="", description="File name.") - file_path: str = Field(default="", description="File path.") - document_id: str = Field(default="", description="Document ID.") - chunk_id: str = Field(default="", description="Chunk ID.") - metadata: dict = Field( - default_factory=dict, - description="Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.).", - ) - - -class SplitChunksParams(BaseModel): - method: str = Field( - default="RecursiveCharacterTextSplitter", - description="Name of the splitter.", - examples=["RecursiveCharacterTextSplitter"], - ) - chunk_size: PositiveNonZeroInt = Field( - default=1000, - description="Maximum chunk size (number of characters). Must be > 0.", - examples=[1000], - ) - chunk_overlap: PositiveInt = Field( - default=200, - description="Overlap in characters between chunks. Must be >= 0.", - examples=[200], - ) - - -class SplitChunksRequest(BaseModel): - id: str = Field( - default="", - description="Request ID for logging purposes.", - examples=["018ed5f1-6399-71f7-86af-fc18d4a3e3f5"], - ) - chunks: list[Chunk] = Field( - description="List of `Chunk` where each will be further split into chunks.", - examples=[ - [ - Chunk( - text="The Name of the Title is Hope\n\n...", - title="The Name of the Title is Hope", - page=0, - file_name="sample_tables.pdf", - file_path="amagpt/sample_tables.pdf", - metadata={ - "total_pages": 3, - "Author": "Ben Trovato", - "CreationDate": "D:20231031072817Z", - "Creator": "LaTeX with acmart 2023/10/14 v1.92 Typesetting articles for the Association for Computing Machinery and hyperref 2023-07-08 v7.01b Hypertext links for LaTeX", - "Keywords": "Image Captioning, Deep Learning", - "ModDate": "D:20231031073146Z", - "PTEX.Fullbanner": "This is pdfTeX, Version 3.141592653-2.6-1.40.25 (TeX Live 2023) kpathsea version 6.3.5", - "Producer": "3-Heights(TM) PDF Security Shell 4.8.25.2 (http://www.pdf-tools.com) / pdcat (www.pdf-tools.com)", - "Trapped": "False", - }, - ) - ] - ], - ) - params: SplitChunksParams = Field( - default=SplitChunksParams(), - description="How to split each document. Defaults to `RecursiveCharacterTextSplitter` with chunk_size = 1000 and chunk_overlap = 200.", - examples=[SplitChunksParams()], - ) - - def str_trunc(self) -> str: - return f"id={self.id} len(chunks)={len(self.chunks)} params={self.params}" - - -class RAGParams(BaseModel): - table_id: str = Field(description="Knowledge Table ID", examples=["my-dataset"], min_length=2) - reranking_model: str | None = Field( - default=None, - description="Reranking model to use for hybrid search.", - examples=[EXAMPLE_RERANKING_MODEL_IDS[0], None], - ) - search_query: str = Field( - default="", - description="Query used to retrieve items from the KB database. If not provided (default), it will be generated using LLM.", - ) - k: Annotated[int, Field(gt=0, le=1024)] = Field( - default=3, - gt=0, - le=1024, - description="Top-k closest text in terms of embedding distance. Must be in [1, 1024]. Defaults to 3.", - examples=[3], - ) - rerank: bool = Field( - default=True, - description="Flag to perform rerank on the retrieved results. Defaults to True.", - examples=[True, False], - ) - concat_reranker_input: bool = Field( - default=False, - description="Flag to concat title and content as reranker input. Defaults to False.", - examples=[True, False], - ) - - -class VectorSearchRequest(RAGParams): - id: str = Field( - default="", - description="Request ID for logging purposes.", - examples=["018ed5f1-6399-71f7-86af-fc18d4a3e3f5"], - ) - search_query: str = Field(description="Query used to retrieve items from the KB database.") - - -class VectorSearchResponse(BaseModel): - object: str = Field( - default="kb.search_response", - description="Type of API response object.", - examples=["kb.search_response"], - ) - chunks: list[Chunk] = Field( - default=[], - description="A list of `Chunk`.", - examples=[ - [ - Chunk( - text="The Name of the Title is Hope\n\n...", - title="The Name of the Title is Hope", - page=0, - file_name="sample_tables.pdf", - file_path="amagpt/sample_tables.pdf", - metadata={ - "total_pages": 3, - "Author": "Ben Trovato", - "CreationDate": "D:20231031072817Z", - "Creator": "LaTeX with acmart 2023/10/14 v1.92 Typesetting articles for the Association for Computing Machinery and hyperref 2023-07-08 v7.01b Hypertext links for LaTeX", - "Keywords": "Image Captioning, Deep Learning", - "ModDate": "D:20231031073146Z", - "PTEX.Fullbanner": "This is pdfTeX, Version 3.141592653-2.6-1.40.25 (TeX Live 2023) kpathsea version 6.3.5", - "Producer": "3-Heights(TM) PDF Security Shell 4.8.25.2 (http://www.pdf-tools.com) / pdcat (www.pdf-tools.com)", - "Trapped": "False", - }, - ) - ] - ], - ) - - -class ModelInfo(BaseModel): - id: str = Field( - description=( - 'Unique identifier in the form of "{provider}/{model_id}". ' - "Users will specify this to select a model." - ), - examples=EXAMPLE_CHAT_MODEL_IDS, - ) - object: str = Field( - default="model", - description="Type of API response object.", - examples=["model"], - ) - name: str = Field( - description="Name of the model.", - examples=["OpenAI GPT-4o Mini"], - ) - context_length: int = Field( - description="Context length of model.", - examples=[16384], - ) - languages: list[str] = Field( - description="List of languages which the model is well-versed in.", - examples=[["en"]], - ) - owned_by: str = Field( - description="The organization that owns the model.", - examples=["openai"], - ) - capabilities: list[ - Literal["completion", "chat", "image", "audio", "tool", "embed", "rerank"] - ] = Field( - description="List of capabilities of model.", - examples=[["chat"]], - ) - - -class ModelDeploymentConfig(BaseModel): - litellm_id: str = Field( - default="", - description=( - "LiteLLM routing / mapping ID. " - 'For example, you can map "openai/gpt-4o" calls to "openai/gpt-4o-2024-08-06". ' - 'For vLLM with OpenAI compatible server, use "openai/".' - ), - examples=EXAMPLE_CHAT_MODEL_IDS, - ) - api_base: str = Field( - default="", - description="Hosting url for the model.", - ) - provider: str = Field( - default="", - description="Provider of the model.", - ) - - -class ModelConfig(ModelInfo): - priority: int = Field( - default=0, - ge=0, - description="Priority when assigning default model. Larger number means higher priority.", - ) - deployments: list[ModelDeploymentConfig] = Field( - description="List of model deployment configs.", - min_length=1, - ) - - -class LLMModelConfig(ModelConfig): - input_cost_per_mtoken: float = Field( - default=-1.0, - description="Cost in USD per million (mega) input / prompt token.", - ) - output_cost_per_mtoken: float = Field( - default=-1.0, - description="Cost in USD per million (mega) output / completion token.", - ) - - @model_validator(mode="after") - def check_cost_per_mtoken(self) -> Self: - # GPT-4o-mini pricing (2024-08-10) - if self.input_cost_per_mtoken <= 0: - self.input_cost_per_mtoken = 0.150 - if self.output_cost_per_mtoken <= 0: - self.output_cost_per_mtoken = 0.600 - return self - - -class EmbeddingModelConfig(ModelConfig): - id: str = Field( - description=( - 'Unique identifier in the form of "{provider}/{model_id}". ' - 'For self-hosted models with Infinity, use "ellm/{org}/{model}". ' - "Users will specify this to select a model." - ), - examples=EXAMPLE_EMBEDDING_MODEL_IDS, - ) - embedding_size: int = Field( - description="Embedding size of the model", - ) - # Currently only useful for openai - dimensions: int | None = Field( - default=None, - description="Dimensions, a reduced embedding size (openai specs).", - ) - # Most likely only useful for hf models - transform_query: str | None = Field( - default=None, - description="Transform query that might be needed, esp. for hf models", - ) - cost_per_mtoken: float = Field( - default=-1, description="Cost in USD per million embedding tokens." - ) - - @model_validator(mode="after") - def check_cost_per_mtoken(self) -> Self: - # OpenAI text-embedding-3-small pricing (2024-09-09) - if self.cost_per_mtoken < 0: - self.cost_per_mtoken = 0.022 - return self - - -class RerankingModelConfig(ModelConfig): - id: str = Field( - description=( - 'Unique identifier in the form of "{provider}/{model_id}". ' - 'For self-hosted models with Infinity, use "ellm/{org}/{model}". ' - "Users will specify this to select a model." - ), - examples=EXAMPLE_RERANKING_MODEL_IDS, - ) - capabilities: list[Literal["rerank"]] = Field( - default=["rerank"], - description="List of capabilities of model.", - examples=[["rerank"]], - ) - cost_per_ksearch: float = Field(default=-1, description="Cost in USD for a thousand searches.") - - @model_validator(mode="after") - def check_cost_per_ksearch(self) -> Self: - # Cohere rerank-multilingual-v3.0 pricing (2024-09-09) - if self.cost_per_ksearch < 0: - self.cost_per_ksearch = 2.0 - return self - - -class ModelListConfig(BaseModel): - llm_models: list[LLMModelConfig] = [] - embed_models: list[EmbeddingModelConfig] = [] - rerank_models: list[RerankingModelConfig] = [] - - @property - def models(self) -> list[LLMModelConfig | EmbeddingModelConfig | RerankingModelConfig]: - """A list of all the models.""" - return self.llm_models + self.embed_models + self.rerank_models - - def __add__(self, other: ModelListConfig) -> ModelListConfig: - if isinstance(other, ModelListConfig): - self_ids = set(m.id for m in self.models) - other_ids = set(m.id for m in other.models) - repeated_ids = self_ids.intersection(other_ids) - if len(repeated_ids) != 0: - raise ValueError( - f"There are repeated model IDs among the two configs: {list(repeated_ids)}" - ) - return ModelListConfig( - llm_models=self.llm_models + other.llm_models, - embed_models=self.embed_models + other.embed_models, - rerank_models=self.rerank_models + other.rerank_models, - ) - else: - raise TypeError( - f"Unsupported operand type(s) for +: 'ModelListConfig' and '{type(other)}'" - ) - - -class ModelInfoResponse(BaseModel): - object: str = Field( - default="chat.model_info", - description="Type of API response object.", - examples=["chat.model_info"], - ) - data: list[ModelInfo] = Field( - description="List of model information.", - ) - - -class ChatRole(str, Enum): - """Represents who said a chat message.""" - - SYSTEM = "system" - """The message is from the system (usually a steering prompt).""" - USER = "user" - """The message is from the user.""" - ASSISTANT = "assistant" - """The message is from the language model.""" - # FUNCTION = "function" - # """The message is the result of a function call.""" - - def __str__(self) -> str: - return self.value - - -def sanitise_name(v: str) -> str: - """Replace any non-alphanumeric and dash characters with space. - - Args: - v (str): Raw name string. - - Returns: - out (str): Sanitised name string that is safe for OpenAI. - """ - return re.sub(r"[^a-zA-Z0-9_-]", "_", v).strip() - - -MessageName = Annotated[str, AfterValidator(sanitise_name)] - - -class MessageToolCallFunction(BaseModel): - arguments: str - name: str | None - - -class MessageToolCall(BaseModel): - id: str | None - function: MessageToolCallFunction - type: str - - -class ChatEntry(BaseModel): - """Represents a message in the chat context.""" - - model_config = ConfigDict(use_enum_values=True) - - role: ChatRole - """Who said the message?""" - content: str | list[dict[str, str | dict[str, str]]] - """The content of the message.""" - name: MessageName | None = None - """The name of the user who sent the message, if set (user messages only).""" - - @classmethod - def system(cls, content: str, **kwargs): - """Create a new system message.""" - return cls(role=ChatRole.SYSTEM, content=content, **kwargs) - - @classmethod - def user(cls, content: str, **kwargs): - """Create a new user message.""" - return cls(role=ChatRole.USER, content=content, **kwargs) - - @classmethod - def assistant(cls, content: str | list[dict[str, str]] | None, **kwargs): - """Create a new assistant message.""" - return cls(role=ChatRole.ASSISTANT, content=content, **kwargs) - - @field_validator("content", mode="before") - @classmethod - def coerce_input(cls, value: Any) -> str | list[dict[str, str | dict[str, str]]]: - if isinstance(value, list): - return [cls.coerce_input(v) for v in value] - if isinstance(value, dict): - return {k: cls.coerce_input(v) for k, v in value.items()} - if isinstance(value, str): - return value - if value is None: - return "" - return str(value) - - -class ChatCompletionChoiceOutput(ChatEntry): - tool_calls: list[MessageToolCall] | None = None - """List of tool calls if the message includes tool call responses.""" - - -class ChatThread(BaseModel): - object: str = Field( - default="chat.thread", - description="Type of API response object.", - examples=["chat.thread"], - ) - thread: list[ChatEntry] = Field( - default=[], - description="List of chat messages.", - examples=[ - [ - ChatEntry.system(content="You are an assistant."), - ChatEntry.user(content="Hello."), - ] - ], - ) - - -class CompletionUsage(BaseModel): - prompt_tokens: int = Field(default=0, description="Number of tokens in the prompt.") - completion_tokens: int = Field( - default=0, description="Number of tokens in the generated completion." - ) - total_tokens: int = Field( - default=0, description="Total number of tokens used in the request (prompt + completion)." - ) - - -class ChatCompletionChoice(BaseModel): - message: ChatEntry | ChatCompletionChoiceOutput = Field( - description="A chat completion message generated by the model." - ) - index: int = Field(description="The index of the choice in the list of choices.") - finish_reason: str | None = Field( - default=None, - description=( - "The reason the model stopped generating tokens. " - "This will be stop if the model hit a natural stop point or a provided stop sequence, " - "length if the maximum number of tokens specified in the request was reached." - ), - ) - - @property - def text(self) -> str: - """The text of the most recent chat completion.""" - return self.message.content - - -class ChatCompletionChoiceDelta(ChatCompletionChoice): - @computed_field - @property - def delta(self) -> ChatEntry | ChatCompletionChoiceOutput: - return self.message - - -class References(BaseModel): - object: str = Field( - default="chat.references", - description="Type of API response object.", - examples=["chat.references"], - ) - chunks: list[Chunk] = Field( - default=[], - description="A list of `Chunk`.", - examples=[ - [ - Chunk( - text="The Name of the Title is Hope\n\n...", - title="The Name of the Title is Hope", - page=0, - file_name="sample_tables.pdf", - file_path="amagpt/sample_tables.pdf", - metadata={ - "total_pages": 3, - "Author": "Ben Trovato", - "CreationDate": "D:20231031072817Z", - "Creator": "LaTeX with acmart 2023/10/14 v1.92 Typesetting articles for the Association for Computing Machinery and hyperref 2023-07-08 v7.01b Hypertext links for LaTeX", - "Keywords": "Image Captioning, Deep Learning", - "ModDate": "D:20231031073146Z", - "PTEX.Fullbanner": "This is pdfTeX, Version 3.141592653-2.6-1.40.25 (TeX Live 2023) kpathsea version 6.3.5", - "Producer": "3-Heights(TM) PDF Security Shell 4.8.25.2 (http://www.pdf-tools.com) / pdcat (www.pdf-tools.com)", - "Trapped": "False", - }, - ) - ] - ], - ) - search_query: str = Field(description="Query used to retrieve items from the KB database.") - finish_reason: Literal["stop", "context_overflow"] | None = Field( - default=None, - description=""" -In streaming mode, reference chunk will be streamed first. -However, if the model's context length is exceeded, then there will be no further completion chunks. -In this case, "finish_reason" will be set to "context_overflow". -Otherwise, it will be None or null. -""", - ) - - def remove_contents(self): - copy = self.model_copy(deep=True) - for d in copy.documents: - d.page_content = "" - return copy +from jamaibase.types import * # noqa: F403 - -class GenTableStreamReferences(References): - object: str = Field( - default="gen_table.references", - description="Type of API response object.", - examples=["gen_table.references"], - ) - output_column_name: str - - -class GenTableChatCompletionChunks(BaseModel): - object: str = Field( - default="gen_table.completion.chunks", - description="Type of API response object.", - examples=["gen_table.completion.chunks"], - ) - columns: dict[str, ChatCompletionChunk] - row_id: str - - -class GenTableRowsChatCompletionChunks(BaseModel): - object: str = Field( - default="gen_table.completion.rows", - description="Type of API response object.", - examples=["gen_table.completion.rows"], - ) - rows: list[GenTableChatCompletionChunks] - - -class ChatCompletionChunk(BaseModel): - id: str = Field( - description="A unique identifier for the chat completion. Each chunk has the same ID." - ) - object: str = Field( - default="chat.completion.chunk", - description="Type of API response object.", - examples=["chat.completion.chunk"], - ) - created: int = Field( - description="The Unix timestamp (in seconds) of when the chat completion was created." - ) - model: str = Field(description="The model used for the chat completion.") - usage: CompletionUsage | None = Field( - description="Number of tokens consumed for the completion request.", - examples=[CompletionUsage(), None], - ) - choices: list[ChatCompletionChoice | ChatCompletionChoiceDelta] = Field( - description="A list of chat completion choices. Can be more than one if `n` is greater than 1." - ) - references: References | None = Field( - default=None, - description="Contains the references retrieved from database when performing chat completion with RAG.", - ) - - @property - def message(self) -> ChatEntry | ChatCompletionChoiceOutput | None: - return self.choices[0].message if len(self.choices) > 0 else None - - @property - def prompt_tokens(self) -> int: - return self.usage.prompt_tokens - - @property - def completion_tokens(self) -> int: - return self.usage.completion_tokens - - @property - def text(self) -> str: - """The text of the most recent chat completion.""" - return self.message.content if len(self.choices) > 0 else "" - - @property - def finish_reason(self) -> str | None: - return self.choices[0].finish_reason if len(self.choices) > 0 else None - - -class GenTableStreamChatCompletionChunk(ChatCompletionChunk): - object: str = Field( - default="gen_table.completion.chunk", - description="Type of API response object.", - examples=["gen_table.completion.chunk"], - ) - output_column_name: str - row_id: str - - -class FunctionParameter(BaseModel): - type: str = Field( - default="", description="The type of the parameter, e.g., 'string', 'number'." - ) - description: str = Field(default="", description="A description of the parameter.") - enum: list[str] = Field( - default=[], description="An optional list of allowed values for the parameter." - ) - - -class FunctionParameters(BaseModel): - type: str = Field( - default="object", description="The type of the parameters object, usually 'object'." - ) - properties: dict[str, FunctionParameter] = Field( - description="The properties of the parameters object." - ) - required: list[str] = Field(description="A list of required parameter names.") - additionalProperties: bool = Field( - default=False, description="Whether additional properties are allowed." - ) - - -class Function(BaseModel): - name: str = Field(default="", description="The name of the function.") - description: str = Field(default="", description="A description of what the function does.") - parameters: FunctionParameters = Field(description="The parameters for the function.") - - -class Tool(BaseModel): - type: str = Field(default="function", description="The type of the tool, e.g., 'function'.") - function: Function = Field(description="The function details of the tool.") - - -class ToolChoiceFunction(BaseModel): - name: str = Field(default="", description="The name of the function.") - - -class ToolChoice(BaseModel): - type: str = Field(default="function", description="The type of the tool, e.g., 'function'.") - function: ToolChoiceFunction = Field(description="Select a tool for the chat model to use.") - - -class ChatRequest(BaseModel): - id: str = Field( - default="", - description='Chat ID. Will be replaced with request ID. Defaults to "".', - ) - model: str = Field( - default="", - description="ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - ) - messages: list[ChatEntry] = Field( - description="A list of messages comprising the conversation so far.", - min_length=1, - ) - rag_params: RAGParams | None = Field( - default=None, - description="Retrieval Augmented Generation search params. Defaults to None (disabled).", - examples=[None], - ) - temperature: Annotated[float, Field(ge=0.001, le=2.0)] = Field( - default=0.2, - description=""" -What sampling temperature to use, in [0.001, 2.0]. -Higher values like 0.8 will make the output more random, -while lower values like 0.2 will make it more focused and deterministic. -""", - examples=[0.2], - ) - top_p: Annotated[float, Field(ge=0.001, le=1.0)] = Field( - default=0.6, - description=""" -An alternative to sampling with temperature, called nucleus sampling, -where the model considers the results of the tokens with top_p probability mass. -So 0.1 means only the tokens comprising the top 10% probability mass are considered. -Must be in [0.001, 1.0]. -""", - examples=[0.6], - ) - n: int = Field( - default=1, - description="How many chat completion choices to generate for each input message.", - examples=[1], - ) - stream: bool = Field( - default=True, - description=""" -If set, partial message deltas will be sent, like in ChatGPT. -Tokens will be sent as server-sent events as they become available, -with the stream terminated by a 'data: [DONE]' message. -""", - examples=[True], - ) - stop: list[str] | None = Field( - default=None, - description="Up to 4 sequences where the API will stop generating further tokens.", - examples=[None], - ) - max_tokens: PositiveNonZeroInt = Field( - default=2048, - description=""" -The maximum number of tokens to generate in the chat completion. -Must be in [1, context_length - 1). Default is 2048. -The total length of input tokens and generated tokens is limited by the model's context length. -""", - examples=[2048], - ) - presence_penalty: float = Field( - default=0.0, - description=""" -Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, -increasing the model's likelihood to talk about new topics. -""", - examples=[0.0], - ) - frequency_penalty: float = Field( - default=0.0, - description=""" -Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, -decreasing the model's likelihood to repeat the same line verbatim. -""", - examples=[0.0], - ) - logit_bias: dict = Field( - default={}, - description=""" -Modify the likelihood of specified tokens appearing in the completion. -Accepts a json object that maps tokens (specified by their token ID in the tokenizer) -to an associated bias value from -100 to 100. -Mathematically, the bias is added to the logits generated by the model prior to sampling. -The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; -values like -100 or 100 should result in a ban or exclusive selection of the relevant token. -""", - examples=[{}], - ) - user: str = Field( - default="", - description="A unique identifier representing your end-user. For monitoring and debugging purposes.", - examples=[""], - ) - - @field_validator("stop", mode="after") - @classmethod - def convert_stop(cls, v: list[str] | None) -> list[str] | None: - if isinstance(v, list) and len(v) == 0: - v = None - return v - - -class ChatRequestWithTools(ChatRequest): - tools: list[Tool] = Field( - description="A list of tools available for the chat model to use.", - min_length=1, - examples=[ - # --- [Tool Function] --- - # def get_delivery_date(order_id: str) -> datetime: - # # Connect to the database - # conn = sqlite3.connect('ecommerce.db') - # cursor = conn.cursor() - # # ... - [ - Tool( - type="function", - function=Function( - name="get_delivery_date", - description="Get the delivery date for a customer's order.", - parameters=FunctionParameters( - type="object", - properties={ - "order_id": FunctionParameter( - type="string", description="The customer's order ID." - ) - }, - required=["order_id"], - additionalProperties=False, - ), - ), - ) - ], - ], - ) - tool_choice: str | ToolChoice = Field( - default="auto", - description="Set `auto` to let chat model pick a tool or select a tool for the chat model to use.", - examples=[ - "auto", - ToolChoice(type="function", function=ToolChoiceFunction(name="get_delivery_date")), - ], - ) - - -class EmbeddingRequest(BaseModel): - input: str | list[str] = Field( - description=( - "Input text to embed, encoded as a string or array of strings " - "(to embed multiple inputs in a single request). " - "The input must not exceed the max input tokens for the model, and cannot contain empty string." - ), - examples=["What is a llama?", ["What is a llama?", "What is an alpaca?"]], - ) - model: str = Field( - description=( - "The ID of the model to use. " - "You can use the List models API to see all of your available models." - ), - examples=EXAMPLE_EMBEDDING_MODEL_IDS, - ) - type: Literal["query", "document"] = Field( - default="document", - description=( - 'Whether the input text is a "query" (used to retrieve) or a "document" (to be retrieved).' - ), - examples=["query", "document"], - ) - encoding_format: Literal["float", "base64"] = Field( - default="float", - description=( - '_Optional_. The format to return the embeddings in. Can be either "float" or "base64". ' - "`base64` string should be decoded as a `float32` array. " - "Example: `np.frombuffer(base64.b64decode(response), dtype=np.float32)`" - ), - examples=["float", "base64"], - ) - - -class EmbeddingResponseData(BaseModel): - object: str = Field( - default="embedding", - description="Type of API response object.", - examples=["embedding"], - ) - embedding: list[float] | str = Field( - description=( - "The embedding vector, which is a list of floats or a base64-encoded string. " - "The length of vector depends on the model." - ), - examples=[[0.0, 1.0, 2.0], []], - ) - index: int = Field( - default=0, - description="The index of the embedding in the list of embeddings.", - examples=[0, 1], - ) - - -class EmbeddingResponse(BaseModel): - object: str = Field( - default="list", - description="Type of API response object.", - examples=["list"], - ) - data: list[EmbeddingResponseData] = Field( - description="List of `EmbeddingResponseData`.", - examples=[[EmbeddingResponseData(embedding=[0.0, 1.0, 2.0])]], - ) - model: str = Field( - description="The ID of the model used.", - examples=["openai/text-embedding-3-small-512"], - ) - usage: CompletionUsage = Field( - default=CompletionUsage(), - description="The number of tokens consumed.", - examples=[CompletionUsage()], - ) - - -class ClipInputData(BaseModel): - """Data model for Clip input data, assume if image_filename is None then it have to be text, otherwise, the input is an image with bytes content""" - - content: str | bytes - """content of this input data, either be str of text or an """ - image_filename: str | None - """image filename of the content, None if the content is text""" - - -T = TypeVar("T") - - -class Page(BaseModel, Generic[T]): - items: Annotated[ - Sequence[T], Field(description="List of items paginated items.", examples=[[]]) - ] = [] - offset: Annotated[int, Field(description="Number of skipped items.", examples=[0])] = 0 - limit: Annotated[int, Field(description="Number of items per page.", examples=[0])] = 0 - total: Annotated[int, Field(description="Total number of items.", examples=[0])] = 0 - - -def nd_array_before_validator(x): - return np.array(x) if isinstance(x, list) else x - - -def datetime_str_before_validator(x): - return x.isoformat() if isinstance(x, datetime) else str(x) - - -ODD_SINGLE_QUOTE = r"(? "int".' +warn( + "`jamaibase.protocol` is deprecated, use `jamaibase.types` instead.", + FutureWarning, + stacklevel=2, ) - - -@deprecated(ENUM_DEPRECATE_MSSG, category=FutureWarning, stacklevel=1) -class DtypeCreateEnum(str, Enum, metaclass=MetaEnum): - int_ = "int" - float_ = "float" - bool_ = "bool" - str_ = "str" - image_ = "image" - audio_ = "audio" - - def __getattribute__(cls, *args, **kwargs): - warn(ENUM_DEPRECATE_MSSG, FutureWarning, stacklevel=1) - return super().__getattribute__(*args, **kwargs) - - def __getitem__(cls, *args, **kwargs): - warn(ENUM_DEPRECATE_MSSG, FutureWarning, stacklevel=1) - return super().__getitem__(*args, **kwargs) - - def __call__(cls, *args, **kwargs): - warn(ENUM_DEPRECATE_MSSG, FutureWarning, stacklevel=1) - return super().__call__(*args, **kwargs) - - def __str__(self) -> str: - return self.value - - -class TableType(str, Enum, metaclass=MetaEnum): - action = "action" - """Action table.""" - knowledge = "knowledge" - """Knowledge table.""" - chat = "chat" - """Chat table.""" - - def __str__(self) -> str: - return self.value - - -class LLMGenConfig(BaseModel): - object: Literal["gen_config.llm"] = Field( - default="gen_config.llm", - description='The object type, which is always "gen_config.llm".', - examples=["gen_config.llm"], - ) - model: str = Field( - default="", - description="ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - ) - system_prompt: str = Field( - default="", - description="System prompt for the LLM.", - ) - prompt: str = Field( - default="", - description="Prompt for the LLM.", - ) - multi_turn: bool = Field( - default=False, - description="Whether this column is a multi-turn chat with history along the entire column.", - ) - rag_params: RAGParams | None = Field( - default=None, - description="Retrieval Augmented Generation search params. Defaults to None (disabled).", - examples=[None], - ) - temperature: Annotated[float, Field(ge=0.001, le=2.0)] = Field( - default=0.2, - description=""" -What sampling temperature to use, in [0.001, 2.0]. -Higher values like 0.8 will make the output more random, -while lower values like 0.2 will make it more focused and deterministic. -""", - examples=[0.2], - ) - top_p: Annotated[float, Field(ge=0.001, le=1.0)] = Field( - default=0.6, - description=""" -An alternative to sampling with temperature, called nucleus sampling, -where the model considers the results of the tokens with top_p probability mass. -So 0.1 means only the tokens comprising the top 10% probability mass are considered. -Must be in [0.001, 1.0]. -""", - examples=[0.6], - ) - stop: list[str] | None = Field( - default=None, - description="Up to 4 sequences where the API will stop generating further tokens.", - examples=[None], - ) - max_tokens: PositiveNonZeroInt = Field( - default=2048, - description=""" -The maximum number of tokens to generate in the chat completion. -Must be in [1, context_length - 1). Default is 2048. -The total length of input tokens and generated tokens is limited by the model's context length. -""", - examples=[2048], - ) - presence_penalty: float = Field( - default=0.0, - description=""" -Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, -increasing the model's likelihood to talk about new topics. -""", - examples=[0.0], - ) - frequency_penalty: float = Field( - default=0.0, - description=""" -Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, -decreasing the model's likelihood to repeat the same line verbatim. -""", - examples=[0.0], - ) - logit_bias: dict = Field( - default={}, - description=""" -Modify the likelihood of specified tokens appearing in the completion. -Accepts a json object that maps tokens (specified by their token ID in the tokenizer) -to an associated bias value from -100 to 100. -Mathematically, the bias is added to the logits generated by the model prior to sampling. -The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; -values like -100 or 100 should result in a ban or exclusive selection of the relevant token. -""", - examples=[{}], - ) - - @model_validator(mode="before") - @classmethod - def compat(cls, data: Any) -> Any: - data_type = type(data).__name__ - if isinstance(data, BaseModel): - data = data.model_dump() - if not isinstance(data, dict): - raise TypeError( - f"Input to `LLMGenConfig` must be a dict or BaseModel, received: {data_type}" - ) - if data.get("system_prompt", None) or data.get("prompt", None): - return data - warn( - ( - f'Using {data_type} as input to "gen_config" is deprecated and will be disabled in v0.4, ' - f"use {cls.__name__} instead." - ), - FutureWarning, - stacklevel=3, - ) - messages: list[dict[str, Any]] = data.get("messages", []) - num_prompts = len(messages) - if num_prompts >= 2: - data["system_prompt"] = messages[0]["content"] - data["prompt"] = messages[1]["content"] - elif num_prompts == 1: - if messages[0]["role"] == "system": - data["system_prompt"] = messages[0]["content"] - data["prompt"] = "" - elif messages[0]["role"] == "user": - data["system_prompt"] = "" - data["prompt"] = messages[0]["content"] - else: - raise ValueError( - f'Attribute "messages" cannot contain only assistant messages: {messages}' - ) - data["object"] = "gen_config.llm" - return data - - @field_validator("stop", mode="after") - @classmethod - def convert_stop(cls, v: list[str] | None) -> list[str] | None: - if isinstance(v, list) and len(v) == 0: - v = None - return v - - -class EmbedGenConfig(BaseModel): - object: Literal["gen_config.embed"] = Field( - default="gen_config.embed", - description='The object type, which is always "gen_config.embed".', - examples=["gen_config.embed"], - ) - embedding_model: str = Field( - description="The embedding model to use.", - examples=EXAMPLE_EMBEDDING_MODEL_IDS, - ) - source_column: str = Field( - description="The source column for embedding.", - examples=["text_column"], - ) - - -class CodeGenConfig(BaseModel): - object: Literal["gen_config.code"] = Field( - default="gen_config.code", - description='The object type, which is always "gen_config.code".', - examples=["gen_config.code"], - ) - source_column: str = Field( - description="The source column for python code to execute.", - examples=["code_column"], - ) - - -def _gen_config_discriminator(x: Any) -> str | None: - object_attr = getattr(x, "object", None) - if object_attr: - return object_attr - if isinstance(x, BaseModel): - x = x.model_dump() - if isinstance(x, dict): - if "object" in x: - return x["object"] - if "embedding_model" in x: - return "gen_config.embed" - else: - return "gen_config.llm" - return None - - -GenConfig = LLMGenConfig | EmbedGenConfig | CodeGenConfig -DiscriminatedGenConfig = Annotated[ - Union[ - Annotated[CodeGenConfig, Tag("gen_config.code")], - Annotated[LLMGenConfig, Tag("gen_config.llm")], - Annotated[LLMGenConfig, Tag("gen_config.chat")], - Annotated[EmbedGenConfig, Tag("gen_config.embed")], - ], - Discriminator(_gen_config_discriminator), -] - - -class ColumnSchema(BaseModel): - id: str = Field(description="Column name.") - dtype: str = Field( - default="str", - description="Column data type.", - ) - vlen: PositiveInt = Field( # type: ignore - default=0, - description=( - "_Optional_. Vector length. " - "If this is larger than zero, then `dtype` must be one of the floating data types. Defaults to zero." - ), - ) - index: bool = Field( - default=True, - description=( - "_Optional_. Whether to build full-text-search (FTS) or vector index for this column. " - "Only applies to string and vector columns. Defaults to True." - ), - ) - gen_config: DiscriminatedGenConfig | None = Field( - default=None, - description=( - '_Optional_. Generation config. If provided, then this column will be an "Output Column". ' - "Table columns on its left can be referenced by `${column-name}`." - ), - ) - - -class ColumnSchemaCreate(ColumnSchema): - id: str = Field(description="Column name.") - dtype: Literal["int", "float", "bool", "str", "file", "image", "audio"] = Field( - default="str", - description=( - 'Column data type, one of ["int", "float", "bool", "str", "file", "image", "audio"]' - ". Data type 'file' is deprecated, use 'image' instead." - ), - ) - - @model_validator(mode="before") - @classmethod - def compat(cls, data: Any) -> Any: - data_type = type(data).__name__ - if isinstance(data, BaseModel): - data = data.model_dump() - if not isinstance(data, dict): - raise TypeError( - f"Input to `ColumnSchemaCreate` must be a dict or BaseModel, received: {data_type}" - ) - if isinstance(data.get("dtype", None), DtypeCreateEnum): - data["dtype"] = data["dtype"].value - return data - - -class TableBase(BaseModel): - id: str = Field(primary_key=True, description="Table name.") - version: str = Field( - default=jamaibase_version, description="Table version, following jamaibase version." - ) - meta: dict[str, Any] = Field( - default={}, - description="Additional metadata about the table.", - ) - - -class TableSchema(TableBase): - cols: list[ColumnSchema] = Field(description="List of column schema.") - - -class TableSchemaCreate(TableSchema): - id: str = Field(description="Table name.") - cols: list[ColumnSchemaCreate] = Field(description="List of column schema.") - - @model_validator(mode="after") - def check_cols(self) -> Self: - if len(set(c.id.lower() for c in self.cols)) != len(self.cols): - raise ValueError("There are repeated column names (case-insensitive) in the schema.") - if sum(c.id.lower() in ("id", "updated at") for c in self.cols) > 0: - raise ValueError("Schema cannot contain column names: 'ID' or 'Updated at'.") - if sum(c.vlen > 0 for c in self.cols) > 0: - raise ValueError("Schema cannot contain columns with `vlen` > 0.") - return self - - -class ActionTableSchemaCreate(TableSchemaCreate): - pass - - -class AddActionColumnSchema(ActionTableSchemaCreate): - # TODO: Deprecate this - pass - - -class KnowledgeTableSchemaCreate(TableSchemaCreate): - # TODO: Maybe deprecate this and use EmbedGenConfig instead ? - embedding_model: str - - -class AddKnowledgeColumnSchema(TableSchemaCreate): - # TODO: Deprecate this - pass - - -class ChatTableSchemaCreate(TableSchemaCreate): - pass - - -class AddChatColumnSchema(TableSchemaCreate): - # TODO: Deprecate this - pass - - -class TableMeta(TableBase): - cols: list[dict[str, Any]] = Field(description="List of column schema.") - parent_id: str | None = Field( - default=None, - description="The parent table ID. If None (default), it means this is a template table.", - ) - title: str = Field( - default="", - description="Chat title. Defaults to ''.", - ) - updated_at: str = Field( - default_factory=datetime_now_iso, - description="Table last update timestamp (ISO 8601 UTC).", - ) # SQLite does not support TZ - indexed_at_fts: str | None = Field( - default=None, description="Table last FTS index timestamp (ISO 8601 UTC)." - ) - indexed_at_vec: str | None = Field( - default=None, description="Table last vector index timestamp (ISO 8601 UTC)." - ) - indexed_at_sca: str | None = Field( - default=None, description="Table last scalar index timestamp (ISO 8601 UTC)." - ) - - @property - def cols_schema(self) -> list[ColumnSchema]: - return [ColumnSchema.model_validate(c) for c in self.cols] - - @property - def regular_cols(self) -> list[ColumnSchema]: - return [c for c in self.cols_schema if not c.id.endswith("_")] - - -class TableMetaResponse(TableSchema): - parent_id: str | None = Field( - description="The parent table ID. If None (default), it means this is a template table.", - ) - title: str = Field(description="Chat title. Defaults to ''.") - updated_at: str = Field( - description="Table last update timestamp (ISO 8601 UTC).", - ) # SQLite does not support TZ - indexed_at_fts: str | None = Field( - description="Table last FTS index timestamp (ISO 8601 UTC)." - ) - indexed_at_vec: str | None = Field( - description="Table last vector index timestamp (ISO 8601 UTC)." - ) - indexed_at_sca: str | None = Field( - description="Table last scalar index timestamp (ISO 8601 UTC)." - ) - num_rows: int = Field(description="Number of rows in the table.") - - @model_validator(mode="after") - def remove_state_cols(self) -> Self: - self.cols = [c for c in self.cols if not c.id.endswith("_")] - return self - - -class GenConfigUpdateRequest(BaseModel): - table_id: str = Field(description="Table name or ID.") - column_map: dict[str, DiscriminatedGenConfig | None] = Field( - description=( - "Mapping of column ID to generation config JSON in the form of `GenConfig`. " - "Table columns on its left can be referenced by `${column-name}`." - ) - ) - - @model_validator(mode="after") - def check_column_map(self) -> Self: - if sum(n.lower() in ("id", "updated at") for n in self.column_map) > 0: - raise ValueError("column_map cannot contain keys: 'ID' or 'Updated at'.") - return self - - -class ColumnRenameRequest(BaseModel): - table_id: str = Field(description="Table name or ID.") - column_map: dict[str, str] = Field( - description="Mapping of old column names to new column names." - ) - - @model_validator(mode="after") - def check_column_map(self) -> Self: - if sum(n.lower() in ("id", "updated at") for n in self.column_map) > 0: - raise ValueError("`column_map` cannot contain keys: 'ID' or 'Updated at'.") - return self - - -class ColumnReorderRequest(BaseModel): - table_id: str = Field(description="Table name or ID.") - column_names: list[str] = Field(description="List of column ID in the desired order.") - - -class ColumnDropRequest(BaseModel): - table_id: str = Field(description="Table name or ID.") - column_names: list[str] = Field(description="List of column ID to drop.") - - @model_validator(mode="after") - def check_column_names(self) -> Self: - if sum(n.lower() in ("id", "updated at") for n in self.column_names) > 0: - raise ValueError("`column_names` cannot contain keys: 'ID' or 'Updated at'.") - return self - - -class RowAddRequest(BaseModel): - table_id: str = Field( - description="Table name or ID.", - ) - data: list[dict[str, Any]] = Field( - min_length=1, - description=( - "List of mapping of column names to its value. " - "In other words, each item in the list is a row, and each item is a mapping. " - "Minimum 1 row, maximum 100 rows." - ), - ) - stream: bool = Field( - default=True, - description="Whether or not to stream the LLM generation.", - ) - reindex: bool | None = Field( - default=None, - description=( - "_Optional_. If True, reindex immediately. If False, wait until next periodic reindex. " - "If None (default), reindex immediately for smaller tables." - ), - ) - concurrent: bool = Field( - default=True, - description="_Optional_. Whether or not to concurrently generate the output rows and columns.", - ) - - def __repr__(self): - _data = [ - { - k: ( - {"type": type(v), "shape": v.shape, "dtype": v.dtype} - if isinstance(v, np.ndarray) - else v - ) - } - for k, v in self.data.items() - ] - return ( - f"{self.__class__.__name__}(" - f"table_id={self.table_id} stream={self.stream} reindex={self.reindex} " - f"concurrent={self.concurrent} data={_data}" - ")" - ) - - @model_validator(mode="after") - def check_data(self) -> Self: - for row in self.data: - for value in row.values(): - if isinstance(value, str) and ( - value.startswith("s3://") or value.startswith("file://") - ): - extension = splitext(value)[1].lower() - if extension not in IMAGE_FILE_EXTENSIONS + AUDIO_FILE_EXTENSIONS: - raise ValueError( - "Unsupported file type. Make sure the file belongs to " - "one of the following formats: \n" - f"[Image File Types]: \n{IMAGE_FILE_EXTENSIONS} \n" - f"[Audio File Types]: \n{AUDIO_FILE_EXTENSIONS}" - ) - return self - - -class RowAddRequestWithLimit(RowAddRequest): - data: list[dict[str, Any]] = Field( - min_length=1, - max_length=100, - description=( - "List of mapping of column names to its value. " - "In other words, each item in the list is a row, and each item is a mapping. " - "Minimum 1 row, maximum 100 rows." - ), - ) - - -class RowUpdateRequest(BaseModel): - table_id: str = Field( - description="Table name or ID.", - ) - row_id: str = Field( - description="ID of the row to update.", - ) - data: dict[str, Any] = Field( - description="Mapping of column names to its value.", - ) - reindex: bool | None = Field( - default=None, - description=( - "_Optional_. If True, reindex immediately. If False, wait until next periodic reindex. " - "If None (default), reindex immediately for smaller tables." - ), - ) - - @model_validator(mode="after") - def check_data(self) -> Self: - for value in self.data.values(): - if isinstance(value, str) and ( - value.startswith("s3://") or value.startswith("file://") - ): - extension = splitext(value)[1].lower() - if extension not in IMAGE_FILE_EXTENSIONS + AUDIO_FILE_EXTENSIONS: - raise ValueError( - "Unsupported file type. Make sure the file belongs to " - "one of the following formats: \n" - f"[Image File Types]: \n{IMAGE_FILE_EXTENSIONS} \n" - f"[Audio File Types]: \n{AUDIO_FILE_EXTENSIONS}" - ) - return self - - -class RegenStrategy(str, Enum): - """Strategies for selecting columns during row regeneration.""" - - RUN_ALL = "run_all" - RUN_BEFORE = "run_before" - RUN_SELECTED = "run_selected" - RUN_AFTER = "run_after" - - def __str__(self) -> str: - return self.value - - -class RowRegen(BaseModel): - table_id: str = Field( - description="Table name or ID.", - ) - row_id: str = Field( - description="ID of the row to regenerate.", - ) - regen_strategy: RegenStrategy = Field( - default=RegenStrategy.RUN_ALL, - description=( - "_Optional_. Strategy for selecting columns to regenerate." - "Choose `run_all` to regenerate all columns in the specified row; " - "Choose `run_before` to regenerate columns up to the specified column_id; " - "Choose `run_selected` to regenerate only the specified column_id; " - "Choose `run_after` to regenerate columns starting from the specified column_id; " - ), - ) - output_column_id: str | None = Field( - default=None, - description=( - "_Optional_. Output column name to indicate the starting or ending point of regen for `run_before`, " - "`run_selected` and `run_after` strategies. Required if `regen_strategy` is not 'run_all'. " - "Given columns are 'C1', 'C2', 'C3' and 'C4', if column_id is 'C3': " - "`run_before` regenerate columns 'C1', 'C2' and 'C3'; " - "`run_selected` regenerate only column 'C3'; " - "`run_after` regenerate columns 'C3' and 'C4'; " - ), - ) - stream: bool = Field( - description="Whether or not to stream the LLM generation.", - ) - reindex: bool | None = Field( - default=None, - description=( - "_Optional_. If True, reindex immediately. If False, wait until next periodic reindex. " - "If None (default), reindex immediately for smaller tables." - ), - ) - concurrent: bool = Field( - default=True, - description="_Optional_. Whether or not to concurrently generate the output columns.", - ) - - -class RowRegenRequest(BaseModel): - table_id: str = Field( - description="Table name or ID.", - ) - row_ids: list[str] = Field( - min_length=1, - max_length=100, - description="List of ID of the row to regenerate. Minimum 1 row, maximum 100 rows.", - ) - regen_strategy: RegenStrategy = Field( - default=RegenStrategy.RUN_ALL, - description=( - "_Optional_. Strategy for selecting columns to regenerate." - "Choose `run_all` to regenerate all columns in the specified row; " - "Choose `run_before` to regenerate columns up to the specified column_id; " - "Choose `run_selected` to regenerate only the specified column_id; " - "Choose `run_after` to regenerate columns starting from the specified column_id; " - ), - ) - output_column_id: str | None = Field( - default=None, - description=( - "_Optional_. Output column name to indicate the starting or ending point of regen for `run_before`, " - "`run_selected` and `run_after` strategies. Required if `regen_strategy` is not 'run_all'. " - "Given columns are 'C1', 'C2', 'C3' and 'C4', if column_id is 'C3': " - "`run_before` regenerate columns 'C1', 'C2' and 'C3'; " - "`run_selected` regenerate only column 'C3'; " - "`run_after` regenerate columns 'C3' and 'C4'; " - ), - ) - stream: bool = Field( - description="Whether or not to stream the LLM generation.", - ) - reindex: bool | None = Field( - default=None, - description=( - "_Optional_. If True, reindex immediately. If False, wait until next periodic reindex. " - "If None (default), reindex immediately for smaller tables." - ), - ) - concurrent: bool = Field( - default=True, - description="_Optional_. Whether or not to concurrently generate the output rows and columns.", - ) - - @model_validator(mode="after") - def check_output_column_id_provided(self) -> Self: - if self.regen_strategy != RegenStrategy.RUN_ALL and self.output_column_id is None: - raise ValueError( - "`output_column_id` is required for regen_strategy other than 'run_all'." - ) - return self - - @model_validator(mode="after") - def sort_row_ids(self) -> Self: - self.row_ids = sorted(self.row_ids) - return self - - -class RowDeleteRequest(BaseModel): - table_id: str = Field(description="Table name or ID.") - row_ids: list[str] | None = Field( - min_length=1, - max_length=100, - default=None, - description="List of ID of the row to delete. Minimum 1 row, maximum 100 rows.", - ) - where: str | None = Field( - default=None, - description="_Optional_. SQL where clause. If not provided, will match all rows and thus deleting all table content.", - ) - reindex: bool | None = Field( - default=None, - description=( - "_Optional_. If True, reindex immediately. If False, wait until next periodic reindex. " - "If None (default), reindex immediately for smaller tables." - ), - ) - - -class EmbedFileRequest(BaseModel): - table_id: str = Field(description="Table name or ID.") - file_id: str = Field(description="ID of the file.") - chunk_size: Annotated[ - int, Field(description="Maximum chunk size (number of characters). Must be > 0.", gt=0) - ] = 1000 - chunk_overlap: Annotated[ - int, Field(description="Overlap in characters between chunks. Must be >= 0.", ge=0) - ] = 200 - # stream: Annotated[bool, Field(description="Whether or not to stream the LLM generation.")] = ( - # True - # ) - - -class SearchRequest(BaseModel): - table_id: str = Field(description="Table name or ID.") - query: str = Field( - min_length=1, - description="Query for full-text-search (FTS) and vector search. Must not be empty.", - ) - where: str | None = Field( - default=None, - description="_Optional_. SQL where clause. If not provided, will match all rows.", - ) - limit: Annotated[int, Field(gt=0, le=1_000)] = Field( - default=100, description="_Optional_. Min 1, max 1000. Number of rows to return." - ) - metric: str = Field( - default="cosine", - description='_Optional_. Vector search similarity metric. Defaults to "cosine".', - ) - nprobes: Annotated[int, Field(gt=0, le=1000)] = Field( - default=50, - description=( - "_Optional_. Set the number of partitions to search (probe)." - "This argument is only used when the vector column has an IVF PQ index. If there is no index then this value is ignored. " - "The IVF stage of IVF PQ divides the input into partitions (clusters) of related values. " - "The partition whose centroids are closest to the query vector will be exhaustively searched to find matches. " - "This parameter controls how many partitions should be searched. " - "Increasing this value will increase the recall of your query but will also increase the latency of your query. Defaults to 50." - ), - ) - refine_factor: Annotated[int, Field(gt=0, le=1000)] = Field( - default=20, - description=( - "_Optional_. A multiplier to control how many additional rows are taken during the refine step. " - "This argument is only used when the vector column has an IVF PQ index. " - "If there is no index then this value is ignored. " - "An IVF PQ index stores compressed (quantized) values. " - "They query vector is compared against these values and, since they are compressed, the comparison is inaccurate. " - "This parameter can be used to refine the results. " - "It can improve both improve recall and correct the ordering of the nearest results. " - "To refine results LanceDb will first perform an ANN search to find the nearest limit * refine_factor results. " - "In other words, if refine_factor is 3 and limit is the default (10) then the first 30 results will be selected. " - "LanceDb then fetches the full, uncompressed, values for these 30 results. " - "The results are then reordered by the true distance and only the nearest 10 are kept. Defaults to 50." - ), - ) - float_decimals: int = Field( - default=0, - description="_Optional_. Number of decimals for float values. Defaults to 0 (no rounding).", - ) - vec_decimals: int = Field( - default=0, - description="_Optional_. Number of decimals for vectors. If its negative, exclude vector columns. Defaults to 0 (no rounding).", - ) - reranking_model: Annotated[ - str | None, Field(description="Reranking model to use for hybrid search.") - ] = None - - -class FileUploadRequest(BaseModel): - file_path: Annotated[str, Field(description="File path of the document to be uploaded.")] - table_id: Annotated[str, Field(description="Knowledge Table name / ID.")] - chunk_size: Annotated[ - int, Field(description="Maximum chunk size (number of characters). Must be > 0.", gt=0) - ] = 1000 - chunk_overlap: Annotated[ - int, Field(description="Overlap in characters between chunks. Must be >= 0.", ge=0) - ] = 200 - # overwrite: Annotated[ - # bool, - # Field( - # description="Whether to overwrite the file.", - # examples=[True, False], - # ), - # ] = False - - -class TableDataImportRequest(BaseModel): - file_path: Annotated[str, Field(description="CSV or TSV file path.")] - table_id: Annotated[ - str, Field(description="ID or name of the table that the data should be imported into.") - ] - stream: Annotated[bool, Field(description="Whether or not to stream the LLM generation.")] = ( - True - ) - # column_names: Annotated[ - # list[str] | None, - # Field( - # description="A list of columns names if the CSV does not have header row. Defaults to None (read from CSV)." - # ), - # ] = None - # columns: Annotated[ - # list[str] | None, - # Field( - # description="A list of columns to be imported. Defaults to None (import all columns except 'ID' and 'Updated at')." - # ), - # ] = None - delimiter: Annotated[ - Literal[",", "\t"], - Field(description='The delimiter of the file: can be "," or "\\t". Defaults to ",".'), - ] = "," - - -class TableImportRequest(BaseModel): - file_path: Annotated[str, Field(description="The parquet file path.")] - table_id_dst: Annotated[ - str | None, Field(description="_Optional_. The ID or name of the new table.") - ] = None - table_id_dst: Annotated[str, Field(description="The ID or name of the new table.")] - - -class FileUploadResponse(BaseModel): - object: Literal["file.upload"] = Field( - default="file.upload", - description='The object type, which is always "file.upload".', - examples=["file.upload"], - ) - uri: str = Field( - description="The URI of the uploaded file.", - examples=[ - "s3://bucket-name/raw/org_id/project_id/uuid/filename.ext", - "file:///path/to/raw/file.ext", - ], - ) - - -class GetURLRequest(BaseModel): - uris: list[str] = Field( - description=( - "A list of file URIs for which pre-signed URLs or local file paths are requested. " - "The service will return a corresponding list of pre-signed URLs or local file paths." - ), - ) - - -class GetURLResponse(BaseModel): - object: Literal["file.urls"] = Field( - default="file.urls", - description='The object type, which is always "file.urls".', - examples=["file.urls"], - ) - urls: list[str] = Field( - description="A list of pre-signed URLs or local file paths.", - examples=[ - "https://presigned-url-for-file1.ext", - "/path/to/file2.ext", - ], - ) diff --git a/clients/python/src/jamaibase/types/__init__.py b/clients/python/src/jamaibase/types/__init__.py new file mode 100644 index 0000000..68ab44e --- /dev/null +++ b/clients/python/src/jamaibase/types/__init__.py @@ -0,0 +1,508 @@ +import re +from decimal import Decimal +from typing import Annotated, Generic, Self, TypeVar + +from pydantic import BaseModel, EmailStr, Field, computed_field + +from jamaibase.types.billing import ( # noqa: F401 + DBStorageUsageData, + EgressUsageData, + EmbedUsageData, + FileStorageUsageData, + LlmUsageData, + RerankUsageData, + UsageData, +) +from jamaibase.types.common import ( # noqa: F401 + DEFAULT_MUL_LANGUAGES, + EXAMPLE_CHAT_MODEL_IDS, + EXAMPLE_EMBEDDING_MODEL_IDS, + EXAMPLE_RERANKING_MODEL_IDS, + DatetimeUTC, + EmptyIfNoneStr, + FilePath, + JSONInput, + JSONInputBin, + JSONOutput, + JSONOutputBin, + LanguageCodeList, + NullableStr, + PositiveInt, + PositiveNonZeroInt, + Progress, + ProgressStage, + ProgressState, + SanitisedMultilineStr, + SanitisedNonEmptyStr, + SanitisedStr, + TableImportProgress, + YAMLInput, + YAMLOutput, + empty_string_to_none, + none_to_empty_string, +) +from jamaibase.types.compat import ( # noqa: F401 + AdminOrderBy, + ChatCompletionChoiceDelta, + ChatCompletionChoiceOutput, + ChatCompletionChunk, + ChatRequestWithTools, + ChatThread, + CompletionUsage, + GenTableChatCompletionChunks, + GenTableOrderBy, + GenTableRowsChatCompletionChunks, + GenTableStreamChatCompletionChunk, + GenTableStreamReferences, + MessageToolCall, + MessageToolCallFunction, + ModelInfoResponse, + RowAddRequest, + RowDeleteRequest, + RowRegenRequest, + ToolFunction, +) +from jamaibase.types.conversation import ( # noqa: F401 + AgentMetaResponse, + ConversationCreateRequest, + ConversationMetaResponse, + MessageAddRequest, + MessagesRegenRequest, + MessageUpdateRequest, +) +from jamaibase.types.db import ( # noqa: F401 + CloudProvider, + Deployment_, + DeploymentCreate, + DeploymentRead, + DeploymentUpdate, + ModelCapability, + ModelConfig_, + ModelConfigCreate, + ModelConfigRead, + ModelConfigUpdate, + ModelInfo, + ModelInfoRead, + ModelProvider, + ModelType, + OnPremProvider, + Organization_, + OrganizationCreate, + OrganizationRead, + OrganizationUpdate, + OrgMember_, + OrgMemberCreate, + OrgMemberRead, + OrgMemberUpdate, + PaymentState, + PricePlan_, + PricePlanCreate, + PricePlanRead, + PricePlanUpdate, + PriceTier, + Product, + Products, + ProductType, + Project_, + ProjectCreate, + ProjectKey_, + ProjectKeyCreate, + ProjectKeyRead, + ProjectKeyUpdate, + ProjectMember_, + ProjectMemberCreate, + ProjectMemberRead, + ProjectMemberUpdate, + ProjectRead, + ProjectUpdate, + RankedRole, + Role, + User_, + UserAuth, + UserCreate, + UserRead, + UserReadObscured, + UserUpdate, + VerificationCode_, + VerificationCodeCreate, + VerificationCodeRead, + VerificationCodeUpdate, +) +from jamaibase.types.file import ( # noqa: F401 + FileUploadResponse, + GetURLRequest, + GetURLResponse, +) +from jamaibase.types.gen_table import ( # noqa: F401 + ActionTableSchemaCreate, + AddActionColumnSchema, + AddChatColumnSchema, + AddKnowledgeColumnSchema, + CellCompletionResponse, + CellReferencesResponse, + ChatTableSchemaCreate, + CodeGenConfig, + ColumnDropRequest, + ColumnRenameRequest, + ColumnReorderRequest, + ColumnSchema, + ColumnSchemaCreate, + CSVDelimiter, + DiscriminatedGenConfig, + EmbedGenConfig, + GenConfigUpdateRequest, + KnowledgeTableSchemaCreate, + LLMGenConfig, + MultiRowAddRequest, + MultiRowCompletionResponse, + MultiRowDeleteRequest, + MultiRowRegenRequest, + MultiRowUpdateRequest, + MultiRowUpdateRequestWithLimit, + PythonGenConfig, + RowCompletionResponse, + RowRegen, + RowUpdateRequest, + SearchRequest, + TableDataImportRequest, + TableImportRequest, + TableMeta, + TableMetaResponse, + TableSchemaCreate, + TableType, +) +from jamaibase.types.legacy import ( # noqa: F401 + VectorSearchRequest, + VectorSearchResponse, +) +from jamaibase.types.lm import ( # noqa: F401 + CITATION_PATTERN, + AudioContent, + AudioContentData, + AudioResponse, + ChatCompletionChoice, + ChatCompletionChunkResponse, + ChatCompletionDelta, + ChatCompletionMessage, + ChatCompletionResponse, + ChatCompletionUsage, + ChatContent, + ChatContentS3, + ChatEntry, + ChatRequest, + ChatRole, + ChatThreadEntry, + ChatThreadResponse, + ChatThreadsResponse, + Chunk, + CodeInterpreterTool, + CompletionUsageDetails, + ConversationThreadsResponse, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, + EmbeddingUsage, + Function, + FunctionCall, + FunctionParameters, + ImageContent, + ImageContentData, + LogProbs, + LogProbToken, + PromptUsageDetails, + RAGParams, + References, + RerankingApiVersion, + RerankingBilledUnits, + RerankingData, + RerankingMeta, + RerankingMetaUsage, + RerankingRequest, + RerankingResponse, + RerankingUsage, + S3Content, + SplitChunksParams, + SplitChunksRequest, + TextContent, + ToolCall, + ToolChoice, + ToolChoiceFunction, + ToolUsageDetails, + WebSearchTool, +) +from jamaibase.types.logs import LogQueryResponse # noqa: F401 +from jamaibase.types.model import ( # noqa: F401 + EmbeddingModelPrice, + LLMModelPrice, + ModelInfoListResponse, + ModelPrice, + RerankingModelPrice, +) +from jamaibase.types.telemetry import ( # noqa: F401 + Host, + Metric, + Usage, + UsageResponse, +) + + +class OkResponse(BaseModel): + ok: bool = True + progress_key: str = "" + + +T = TypeVar("T") + + +class Page(BaseModel, Generic[T]): + items: Annotated[ + list[T], Field(description="List of items paginated items.", examples=[[]]) + ] = [] + offset: Annotated[int, Field(description="Number of skipped items.", examples=[0])] = 0 + limit: Annotated[int, Field(description="Number of items per page.", examples=[0])] = 0 + total: Annotated[int, Field(description="Total number of items.", examples=[0])] = 0 + # start_cursor: Annotated[ + # str | None, + # Field( + # description=( + # "Opaque token for the first item in this page. " + # "Pass it as `before=` to request the page that precedes the current window." + # ) + # ), + # ] = None + end_cursor: Annotated[ + str | None, + Field( + description=( + "Opaque cursor token for the last item in this page. " + "Pass it as `after=` to request the page that follows the current window." + ) + ), + ] = None + + +class UserAgent(BaseModel): + is_browser: bool = Field( + True, + description="Whether the request originates from a browser or an app.", + examples=[True, False], + ) + agent: str = Field( + description="The agent, such as 'SDK', 'Chrome', 'Firefox', 'Edge', or an empty string if it cannot be determined.", + examples=["", "SDK", "Chrome", "Firefox", "Edge"], + ) + agent_version: str = Field( + "", + description="The agent version, or an empty string if it cannot be determined.", + examples=["", "5.0", "0.3.0"], + ) + os: str = Field( + "", + description="The system/OS name and release, such as 'Windows NT 10.0', 'Linux 5.15.0-113-generic', or an empty string if it cannot be determined.", + examples=["", "Windows NT 10.0", "Linux 5.15.0-113-generic"], + ) + architecture: str = Field( + "", + description="The machine type, such as 'AMD64', 'x86_64', or an empty string if it cannot be determined.", + examples=["", "AMD64", "x86_64"], + ) + language: str = Field( + "", + description="The SDK language, such as 'TypeScript', 'Python', or an empty string if it is not applicable.", + examples=["", "TypeScript", "Python"], + ) + language_version: str = Field( + "", + description="The SDK language version, such as '4.9', '3.10.14', or an empty string if it is not applicable.", + examples=["", "4.9", "3.10.14"], + ) + + @computed_field( + description="The system/OS name, such as 'Linux', 'Darwin', 'Java', 'Windows', or an empty string if it cannot be determined.", + examples=["", "Windows NT", "Linux"], + ) + @property + def system(self) -> str: + return self._split_os_string()[0] + + @computed_field( + description="The system's release, such as '2.2.0', 'NT', or an empty string if it cannot be determined.", + examples=["", "10", "5.15.0-113-generic"], + ) + @property + def system_version(self) -> str: + return self._split_os_string()[1] + + def _split_os_string(self) -> tuple[str, str]: + match = re.match(r"([^\d]+) ([\d.]+).*$", self.os) + if match: + os_name = match.group(1).strip() + os_version = match.group(2).strip() + return os_name, os_version + else: + return "", "" + + @classmethod + def from_user_agent_string(cls, ua_string: str) -> Self: + if not ua_string: + return cls(is_browser=False, agent="") + + # SDK pattern + sdk_match = re.match(r"SDK/(\S+) \((\w+)/(\S+); ([^;]+); (\w+)\)", ua_string) + if sdk_match: + return cls( + is_browser=False, + agent="SDK", + agent_version=sdk_match.group(1), + os=sdk_match.group(4), + architecture=sdk_match.group(5), + language=sdk_match.group(2), + language_version=sdk_match.group(3), + ) + + # Browser pattern + browser_match = re.match(r"Mozilla/5.0 \(([^)]+)\).*", ua_string) + if browser_match: + os_info = browser_match.group(1).split(";") + # Microsoft Edge + match = re.match(r".+(Edg/.+)$", ua_string) + if match: + return cls( + agent="Edge", + agent_version=match.group(1).split("/")[-1].strip(), + os=os_info[0].strip(), + architecture=os_info[-1].strip() if len(os_info) == 3 else "", + language="", + language_version="", + ) + # Firefox + match = re.match(r".+(Firefox/.+)$", ua_string) + if match: + return cls( + agent="Firefox", + agent_version=match.group(1).split("/")[-1].strip(), + os=os_info[0].strip(), + architecture=os_info[-1].strip() if len(os_info) == 3 else "", + language="", + language_version="", + ) + # Chrome + match = re.match(r".+(Chrome/.+)$", ua_string) + if match: + return cls( + agent="Chrome", + agent_version=match.group(1).split("/")[-1].strip(), + os=os_info[0].strip(), + architecture=os_info[-1].strip() if len(os_info) == 3 else "", + language="", + language_version="", + ) + return cls(is_browser="mozilla" in ua_string.lower(), agent="") + + +class PasswordLoginRequest(BaseModel): + email: EmailStr = Field(min_length=1, description="Email.") + password: str = Field(min_length=1, max_length=72, description="Password.") + + +class PasswordChangeRequest(BaseModel): + email: EmailStr = Field(min_length=1, description="Email.") + password: str = Field(min_length=1, max_length=72, description="Password.") + new_password: str = Field(min_length=1, max_length=72, description="New password.") + + +class StripePaymentInfo(BaseModel): + status: str = Field( + description="Stripe invoice payment status.", + ) + subscription_id: str | None = Field( + pattern=r"^sub_.+", + description="Stripe subscription ID.", + ) + payment_intent_id: str | None = Field( + pattern=r"^pi_.+", + description="Stripe payment intent ID.", + ) + client_secret: str | None = Field( + description="Stripe client secret.", + ) + amount_due: Decimal = Field( + decimal_places=2, + description="Amount due.", + ) + amount_overpaid: Decimal = Field( + decimal_places=2, + description="Amount overpaid.", + ) + amount_paid: Decimal = Field( + decimal_places=2, + description="Amount paid.", + ) + amount_remaining: Decimal = Field( + decimal_places=2, + description="Amount remaining.", + ) + currency: str = Field( + description="Currency.", + ) + + +class StripeEventData(BaseModel): + event_type: str = Field( + description="Stripe event type.", + ) + event_id: str = Field( + pattern=r"^evt_.+", + description="Stripe event ID.", + ) + invoice_id: str | None = Field( + pattern=r"^in_.+", + description="Stripe invoice ID.", + ) + subscription_id: str | None = Field( + pattern=r"^sub_.+", + description="Stripe subscription ID.", + ) + price_id: str | None = Field( + pattern=r"^price_.+", + description="Stripe price ID.", + ) + payment_method: str | None = Field( + pattern=r"^pm_.+", + description="Stripe payment method.", + ) + customer_id: str = Field( + pattern=r"^cus_.+", + description="Stripe customer ID.", + ) + organization_id: str = Field( + description="Organization ID.", + ) + collection_method: str = Field( + description="Stripe collection method.", + ) + billing_reason: str = Field( + description="Stripe billing reason.", + ) + amount_paid: Decimal = Field( + decimal_places=2, + description="Amount paid.", + ) + currency: str = Field( + description="Currency.", + ) + status: str = Field( + description="Stripe subscription status.", + ) + receipt_url: str = Field( + "", + description="Stripe receipt URL.", + ) + invoice_url: str = Field( + "", + description="Stripe invoice URL.", + ) + invoice_pdf: str = Field( + "", + description="Stripe invoice PDF URL.", + ) diff --git a/clients/python/src/jamaibase/types/billing.py b/clients/python/src/jamaibase/types/billing.py new file mode 100644 index 0000000..56b87f3 --- /dev/null +++ b/clients/python/src/jamaibase/types/billing.py @@ -0,0 +1,136 @@ +from datetime import datetime, timezone + +from pydantic import BaseModel, Field + +from jamaibase.utils import uuid7_str + + +class _BaseUsageData(BaseModel): + id: str = Field( + default_factory=uuid7_str, + description="UUID of the insert row.", + ) + org_id: str = Field( + description="Organization ID.", + ) + proj_id: str = Field( + description="Project ID.", + ) + user_id: str = Field( + description="User ID.", + ) + timestamp: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + description="UTC Timestamp (microsecond precision) of the insert.", + ) + cost: float = Field( + description="Usage cost (per_million_tokens for LLM and Embedding, per_thousand_searches for Rerank).", + ) + + def as_list(self): + """Convert the instance to a list, including all fields.""" + return list(self.model_dump().values()) + + +class LlmUsageData(_BaseUsageData): + model: str = Field( + description="Model used.", + ) + input_token: int = Field( + description="Number of input tokens used.", + ) + output_token: int = Field( + description="Number of output tokens used.", + ) + input_cost: float = Field( + description="Cost in USD per million input tokens.", + ) + output_cost: float = Field( + description="Cost in USD per million output tokens.", + ) + + +class EmbedUsageData(_BaseUsageData): + model: str = Field( + description="Model used.", + ) + token: int = Field( + description="Number of tokens used.", + ) + + +class RerankUsageData(_BaseUsageData): + model: str = Field( + description="Model used.", + ) + number_of_search: int = Field( + description="Number of searches.", + ) + + +class EgressUsageData(_BaseUsageData): + amount_gib: float = Field( + description="Amount in GiB.", + ) + + +class FileStorageUsageData(_BaseUsageData): + amount_gib: float = Field( + description="Chargeable Amount in GiB.", + ) + snapshot_gib: float = Field( + description="Snapshot of amount in GiB.", + ) + + +class DBStorageUsageData(_BaseUsageData): + amount_gib: float = Field( + description="Chargeable Amount in GiB.", + ) + snapshot_gib: float = Field( + description="Snapshot of amount in GiB.", + ) + + +class UsageData(BaseModel): + llm_usage: list[LlmUsageData] = [] + embed_usage: list[EmbedUsageData] = [] + rerank_usage: list[RerankUsageData] = [] + egress_usage: list[EgressUsageData] = [] + file_storage_usage: list[FileStorageUsageData] = [] + db_storage_usage: list[DBStorageUsageData] = [] + + # A computed field to get the per type list + def as_list_by_type(self) -> dict[str, list[list]]: + """Returns a dictionary of lists, where each key is a usage type and the value is a list of lists.""" + return { + "llm_usage": [usage.as_list() for usage in self.llm_usage], + "embed_usage": [usage.as_list() for usage in self.embed_usage], + "rerank_usage": [usage.as_list() for usage in self.rerank_usage], + "egress_usage": [usage.as_list() for usage in self.egress_usage], + "file_storage_usage": [usage.as_list() for usage in self.file_storage_usage], + "db_storage_usage": [usage.as_list() for usage in self.db_storage_usage], + } + + @property + def total_usage_events(self) -> int: + """Returns the total number of usage events across all types.""" + return ( + len(self.llm_usage) + + len(self.embed_usage) + + len(self.rerank_usage) + + len(self.egress_usage) + + len(self.file_storage_usage) + + len(self.db_storage_usage) + ) + + def __add__(self, other: "UsageData") -> "UsageData": + """Overload the + operator to combine two UsageData objects.""" + combined = UsageData() + combined.llm_usage = self.llm_usage + other.llm_usage + combined.embed_usage = self.embed_usage + other.embed_usage + combined.rerank_usage = self.rerank_usage + other.rerank_usage + combined.egress_usage = self.egress_usage + other.egress_usage + combined.file_storage_usage = self.file_storage_usage + other.file_storage_usage + combined.db_storage_usage = self.db_storage_usage + other.db_storage_usage + return combined diff --git a/clients/python/src/jamaibase/types/common.py b/clients/python/src/jamaibase/types/common.py new file mode 100644 index 0000000..bd64cf4 --- /dev/null +++ b/clients/python/src/jamaibase/types/common.py @@ -0,0 +1,263 @@ +import unicodedata +from collections import OrderedDict +from datetime import timezone +from functools import partial +from pathlib import Path +from typing import Annotated, Any, Dict, List, Tuple, Union + +from pydantic import AfterValidator, BaseModel, BeforeValidator, Field +from pydantic.types import AwareDatetime +from pydantic_extra_types.country import _index_by_alpha2 as iso_3166 +from pydantic_extra_types.language_code import _index_by_alpha2 as iso_639 + +from jamaibase.utils.types import StrEnum + +PositiveInt = Annotated[int, Field(ge=0, description="Positive integer.")] +PositiveNonZeroInt = Annotated[int, Field(gt=0, description="Positive non-zero integer.")] + + +def none_to_empty_string(v: str | None) -> str: + if v is None: + return "" + return v + + +def empty_string_to_none(v: str | None) -> str | None: + if not v: + return None + return v + + +NullableStr = Annotated[str | None, BeforeValidator(empty_string_to_none)] +EmptyIfNoneStr = Annotated[str, BeforeValidator(none_to_empty_string)] + +EXAMPLE_CHAT_MODEL_IDS = ["openai/gpt-4o-mini"] +EXAMPLE_EMBEDDING_MODEL_IDS = [ + "openai/text-embedding-3-small-512", + "ellm/sentence-transformers/all-MiniLM-L6-v2", +] +EXAMPLE_RERANKING_MODEL_IDS = [ + "cohere/rerank-multilingual-v3.0", + "ellm/cross-encoder/ms-marco-TinyBERT-L-2", +] + +# fmt: off +FilePath = Union[str, Path] +# Superficial JSON input/output types +# https://github.com/python/typing/issues/182#issuecomment-186684288 +JSONOutput = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] +JSONOutputBin = Union[bytes, str, int, float, bool, None, Dict[str, Any], List[Any]] +# For input, we also accept tuples, ordered dicts etc. +JSONInput = Union[str, int, float, bool, None, Dict[str, Any], List[Any], Tuple[Any, ...], OrderedDict] +JSONInputBin = Union[bytes, str, int, float, bool, None, Dict[str, Any], List[Any], Tuple[Any, ...], OrderedDict] +YAMLInput = JSONInput +YAMLOutput = JSONOutput +# fmt: on + + +def _to_utc(d: AwareDatetime) -> AwareDatetime: + return d.astimezone(timezone.utc) + + +DatetimeUTC = Annotated[AwareDatetime, AfterValidator(_to_utc)] + +### --- String Validator --- ### + + +def _is_bad_char(char: str, *, allow_newline: bool) -> bool: + """ + Checks if a character is disallowed. + """ + # 1. Handle newlines based on the flag + if char == "\n": + return not allow_newline # Bad if newlines are NOT allowed + + # 2. Check for other non-printable characters (like tabs, control codes) + # str.isprintable() is False for all non-printing chars except space. + if not char.isprintable(): + return True + + # 3. Check for specific disallowed Unicode categories and blocks + category = unicodedata.category(char) + # Combining marks (e.g., for Zalgo text) + if category.startswith("M"): + return True + # Box drawing + if "\u2500" <= char <= "\u257f": + return True + # Block elements + if "\u2580" <= char <= "\u259f": + return True + # Braille patterns + if "\u2800" <= char <= "\u28ff": + return True + + return False + + +def _str_pre_validator( + value: Any, *, disallow_empty_string: bool = False, allow_newline: bool = False +) -> str: + if not isinstance(value, str): + value = str(value) + value = value.strip() + if disallow_empty_string and len(value) == 0: + raise ValueError("Text is empty.") + + # --- Simplified and Consolidated Character Validation --- + # The generator expression is efficient as `any()` will short-circuit + # on the first bad character found. + value = "".join(char for char in value if not unicodedata.category(char).startswith("M")) + if any(_is_bad_char(char, allow_newline=allow_newline) for char in value): + raise ValueError("Text contains disallowed or non-printable characters.") + + return value + + +SanitisedStr = Annotated[ + str, + BeforeValidator(_str_pre_validator), + # Cannot use Field here due to conflict with SQLModel +] +SanitisedMultilineStr = Annotated[ + str, + BeforeValidator(partial(_str_pre_validator, disallow_empty_string=False, allow_newline=True)), + # Cannot use Field here due to conflict with SQLModel +] +SanitisedNonEmptyStr = Annotated[ + str, + BeforeValidator(partial(_str_pre_validator, disallow_empty_string=True)), + # Cannot use Field here due to conflict with SQLModel +] +SanitisedNonEmptyMultilineStr = Annotated[ + str, + BeforeValidator(partial(_str_pre_validator, disallow_empty_string=True, allow_newline=True)), + # Cannot use Field here due to conflict with SQLModel +] + +### --- Language Code Validator --- ### + + +WILDCARD_LANG_CODES = {"*", "mul"} +DEFAULT_MUL_LANGUAGES = [ + # ChatGPT supported languages + # "sq", # Albanian + # "am", # Amharic + # "ar", # Arabic + # "hy", # Armenian + # "bn", # Bengali + # "bs", # Bosnian + # "bg", # Bulgarian + # "my", # Burmese + # "ca", # Catalan + "zh", # Chinese + # "hr", # Croatian + # "cs", # Czech + # "da", # Danish + # "nl", # Dutch + "en", # English + # "et", # Estonian + # "fi", # Finnish + "fr", # French + # "ka", # Georgian + # "de", # German + # "el", # Greek + # "gu", # Gujarati + # "hi", # Hindi + # "hu", # Hungarian + # "is", # Icelandic + # "id", # Indonesian + "it", # Italian + "ja", # Japanese + # "kn", # Kannada + # "kk", # Kazakh + "ko", # Korean + # "lv", # Latvian + # "lt", # Lithuanian + # "mk", # Macedonian + # "ms", # Malay + # "ml", # Malayalam + # "mr", # Marathi + # "mn", # Mongolian + # "no", # Norwegian + # "fa", # Persian + # "pl", # Polish + # "pt", # Portuguese + # "pa", # Punjabi + # "ro", # Romanian + # "ru", # Russian + # "sr", # Serbian + # "sk", # Slovak + # "sl", # Slovenian + # "so", # Somali + "es", # Spanish + # "sw", # Swahili + # "sv", # Swedish + # "tl", # Tagalog + # "ta", # Tamil + # "te", # Telugu + # "th", # Thai + # "tr", # Turkish + # "uk", # Ukrainian + # "ur", # Urdu + # "vi", # Vietnamese +] + + +def _validate_lang(s: str) -> str: + try: + code = s.split("-") + lang = code[0] + lang = lang.lower().strip() + if lang not in iso_639(): + raise ValueError + if len(code) == 2: + country = code[1] + country = country.upper().strip() + if country not in iso_3166(): + raise ValueError + return f"{lang}-{country}" + elif len(code) == 1: + return lang + else: + raise ValueError + except Exception as e: + raise ValueError( + f'Language code "{s}" is not ISO 639-1 alpha-2 or BCP-47 ([ISO 639-1 alpha-2]-[ISO 3166-1 alpha-2]).' + ) from e + + +def _validate_lang_list(s: list[str]) -> list[str]: + s = {lang.strip() for lang in s} + if len(s & WILDCARD_LANG_CODES) > 0: + s = list((s - WILDCARD_LANG_CODES) | set(DEFAULT_MUL_LANGUAGES)) + return [_validate_lang(lang) for lang in s] + + +LanguageCodeList = Annotated[list[str], AfterValidator(_validate_lang_list)] + + +class ProgressState(StrEnum): + STARTED = "STARTED" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +class Progress(BaseModel): + key: str + data: dict[str, Any] = {} + state: ProgressState = ProgressState.STARTED + error: str | None = None + + +class ProgressStage(BaseModel): + name: str + progress: int = 0 + + +class TableImportProgress(Progress): + load_data: ProgressStage = ProgressStage(name="Load data") + parse_data: ProgressStage = ProgressStage(name="Parse data") + upload_files: ProgressStage = ProgressStage(name="Upload files") + add_rows: ProgressStage = ProgressStage(name="Add rows") + index: ProgressStage = ProgressStage(name="Indexing") diff --git a/clients/python/src/jamaibase/types/compat.py b/clients/python/src/jamaibase/types/compat.py new file mode 100644 index 0000000..f352394 --- /dev/null +++ b/clients/python/src/jamaibase/types/compat.py @@ -0,0 +1,210 @@ +from pydantic import Field +from typing_extensions import deprecated + +from jamaibase.types.gen_table import ( + CellCompletionResponse, + CellReferencesResponse, + MultiRowAddRequest, + MultiRowCompletionResponse, + MultiRowDeleteRequest, + MultiRowRegenRequest, + RowCompletionResponse, +) +from jamaibase.types.lm import ( + ChatCompletionChoice, + ChatCompletionChunkResponse, + ChatCompletionMessage, + ChatCompletionUsage, + ChatRequest, + ChatThreadResponse, + Function, + ToolCall, + ToolCallFunction, +) +from jamaibase.types.model import ModelInfoListResponse +from jamaibase.utils.types import StrEnum + + +@deprecated( + "AdminOrderBy is deprecated, use string instead.", + category=FutureWarning, + stacklevel=1, +) +class AdminOrderBy(StrEnum): + ID = "id" + """Sort by `id` column.""" + NAME = "name" + """Sort by `name` column.""" + CREATED_AT = "created_at" + """Sort by `created_at` column.""" + UPDATED_AT = "updated_at" + """Sort by `updated_at` column.""" + + +@deprecated( + "GenTableOrderBy is deprecated, use string instead.", + category=FutureWarning, + stacklevel=1, +) +class GenTableOrderBy(StrEnum): + ID = "id" + """Sort by `id` column.""" + UPDATED_AT = "updated_at" + """Sort by `updated_at` column.""" + + +@deprecated( + "ModelInfoResponse is deprecated, use ModelInfoListResponse instead.", + category=FutureWarning, + stacklevel=1, +) +class ModelInfoResponse(ModelInfoListResponse): + object: str = Field( + default="chat.model_info", + description="Type of API response object.", + examples=["chat.model_info"], + ) + + +@deprecated( + "MessageToolCallFunction is deprecated, use ToolCallFunction instead.", + category=FutureWarning, + stacklevel=1, +) +class MessageToolCallFunction(ToolCallFunction): + pass + + +@deprecated( + "MessageToolCall is deprecated, use ToolCall instead.", + category=FutureWarning, + stacklevel=1, +) +class MessageToolCall(ToolCall): + pass + + +@deprecated( + "ChatCompletionChoiceDelta is deprecated, use ChatCompletionChoice instead.", + category=FutureWarning, + stacklevel=1, +) +class ChatCompletionChoiceDelta(ChatCompletionChoice): + pass + + +@deprecated( + "CompletionUsage is deprecated, use ChatCompletionUsage instead.", + category=FutureWarning, + stacklevel=1, +) +class CompletionUsage(ChatCompletionUsage): + pass + + +@deprecated( + "ChatCompletionChunk is deprecated, use ChatCompletionChunkResponse instead.", + category=FutureWarning, + stacklevel=1, +) +class ChatCompletionChunk(ChatCompletionChunkResponse): + pass + + +@deprecated( + "ChatCompletionChoiceOutput is deprecated, use ChatCompletionMessage instead.", + category=FutureWarning, + stacklevel=1, +) +class ChatCompletionChoiceOutput(ChatCompletionMessage): + pass + + +@deprecated( + "ChatThread is deprecated, use ChatThreadResponse instead.", + category=FutureWarning, + stacklevel=1, +) +class ChatThread(ChatThreadResponse): + pass + + +@deprecated( + "ToolFunction is deprecated, use Function instead.", + category=FutureWarning, + stacklevel=1, +) +class ToolFunction(Function): + pass + + +@deprecated( + "ChatRequestWithTools is deprecated, use ChatRequest instead.", + category=FutureWarning, + stacklevel=1, +) +class ChatRequestWithTools(ChatRequest): + pass + + +@deprecated( + "GenTableStreamReferences is deprecated, use CellReferencesResponse instead.", + category=FutureWarning, + stacklevel=1, +) +class GenTableStreamReferences(CellReferencesResponse): + pass + + +@deprecated( + "GenTableStreamChatCompletionChunk is deprecated, use CellCompletionResponse instead.", + category=FutureWarning, + stacklevel=1, +) +class GenTableStreamChatCompletionChunk(CellCompletionResponse): + pass + + +@deprecated( + "GenTableChatCompletionChunks is deprecated, use RowCompletionResponse instead.", + category=FutureWarning, + stacklevel=1, +) +class GenTableChatCompletionChunks(RowCompletionResponse): + pass + + +@deprecated( + "GenTableRowsChatCompletionChunks is deprecated, use MultiRowCompletionResponse instead.", + category=FutureWarning, + stacklevel=1, +) +class GenTableRowsChatCompletionChunks(MultiRowCompletionResponse): + pass + + +@deprecated( + "RowAddRequest is deprecated, use MultiRowAddRequest instead.", + category=FutureWarning, + stacklevel=1, +) +class RowAddRequest(MultiRowAddRequest): + pass + + +@deprecated( + "RowRegenRequest is deprecated, use MultiRowRegenRequest instead.", + category=FutureWarning, + stacklevel=1, +) +class RowRegenRequest(MultiRowRegenRequest): + pass + + +@deprecated( + "RowDeleteRequest is deprecated, use MultiRowDeleteRequest instead.", + category=FutureWarning, + stacklevel=1, +) +class RowDeleteRequest(MultiRowDeleteRequest): + pass diff --git a/clients/python/src/jamaibase/types/conversation.py b/clients/python/src/jamaibase/types/conversation.py new file mode 100644 index 0000000..cba63f5 --- /dev/null +++ b/clients/python/src/jamaibase/types/conversation.py @@ -0,0 +1,105 @@ +from typing import Any, Self + +from pydantic import BaseModel, Field, model_validator + +from jamaibase.types.common import DatetimeUTC, SanitisedNonEmptyStr, SanitisedStr +from jamaibase.types.gen_table import ColumnSchema, TableMetaResponse + + +class _MetaResponse(BaseModel): + meta: dict[SanitisedNonEmptyStr, Any] | None = Field( + None, + description="Additional metadata about the table.", + ) + cols: list[ColumnSchema] = Field( + description="List of column schema.", + ) + title: SanitisedStr = Field( + description="Conversation title.", + ) + created_by: SanitisedNonEmptyStr = Field( + description="ID of the user that created this table.", + ) + updated_at: DatetimeUTC = Field( + description="Table last update datetime (UTC).", + ) + num_rows: int = Field( + -1, + description="Number of rows in the table. Defaults to -1 (not counted).", + ) + version: str = Field( + description="Version.", + ) + + @model_validator(mode="after") + def remove_state_cols(self) -> Self: + self.cols = [c for c in self.cols if not c.id.endswith("_")] + return self + + +class AgentMetaResponse(_MetaResponse): + agent_id: SanitisedNonEmptyStr = Field( + description="Agent ID.", + ) + + @classmethod + def from_table_meta(cls, meta: TableMetaResponse) -> Self: + """Returns an instance from TableMetaResponse.""" + return cls(agent_id=meta.id, **meta.model_dump(exclude={"id"})) + + +class ConversationMetaResponse(_MetaResponse): + conversation_id: SanitisedNonEmptyStr = Field( + description="Conversation ID.", + ) + parent_id: SanitisedNonEmptyStr | None = Field( + description="The parent table ID. If None, it means this is a parent table.", + ) + + @classmethod + def from_table_meta(cls, meta: TableMetaResponse) -> Self: + """Returns an instance from TableMetaResponse.""" + return cls(conversation_id=meta.id, **meta.model_dump(exclude={"id"})) + + +class _MessageBase(BaseModel): + data: dict[str, Any] = Field( + description="Mapping of column names to its value.", + ) + + +class ConversationCreateRequest(_MessageBase): + """Request to create a new conversation.""" + + agent_id: SanitisedNonEmptyStr = Field( + description="Agent ID (parent Chat Table ID).", + ) + title: SanitisedStr | None = Field( + None, + min_length=1, + description="The title of the conversation.", + ) + + +class MessageAddRequest(_MessageBase): + conversation_id: SanitisedNonEmptyStr = Field( + description="Conversation ID.", + ) + + +class MessageUpdateRequest(BaseModel): + """Request to update a single message in a conversation.""" + + conversation_id: str = Field(description="Unique ID of the conversation (table_id).") + row_id: str = Field(description="The ID of the message (row) to update.") + data: dict[str, Any] = Field( + description="The new data for the message, e.g. `{'User': 'new content'}`.", + min_length=1, + ) + + +class MessagesRegenRequest(BaseModel): + """Request to regenerate the current message (and the rest of the messages) in a conversation.""" + + conversation_id: str = Field(description="Unique ID of the conversation (table_id).") + row_id: str = Field(description="Message IDs (rows) to regenerate.") diff --git a/clients/python/src/jamaibase/types/db.py b/clients/python/src/jamaibase/types/db.py new file mode 100644 index 0000000..95f3273 --- /dev/null +++ b/clients/python/src/jamaibase/types/db.py @@ -0,0 +1,1284 @@ +from enum import IntEnum +from typing import Annotated, Any + +from fastapi.exceptions import RequestValidationError +from pydantic import ( + AnyUrl, + BaseModel, + BeforeValidator, + EmailStr, + Field, + ValidationError, + computed_field, + field_validator, + model_validator, +) +from pydantic_extra_types.currency_code import ISO4217 +from pydantic_extra_types.timezone_name import TimeZoneName +from typing_extensions import Self + +from jamaibase.types.common import ( + DEFAULT_MUL_LANGUAGES, + DatetimeUTC, + LanguageCodeList, + PositiveNonZeroInt, + SanitisedMultilineStr, + SanitisedNonEmptyStr, + SanitisedStr, +) +from jamaibase.types.gen_table import TableMetaResponse +from jamaibase.utils import uuid7_str +from jamaibase.utils.dates import now +from jamaibase.utils.exceptions import BadInputError +from jamaibase.utils.types import StrEnum, get_enum_validator + + +class _BaseModel(BaseModel, from_attributes=True, str_strip_whitespace=True): + meta: dict[str, Any] = Field( + {}, + description="Metadata.", + ) + + @classmethod + def validate_updates( + cls, + base: Self, + updates: dict[str, Any], + *, + raise_request_error: bool = True, + ) -> Self: + try: + updates = {k: v for k, v in updates.items() if k in cls.model_fields} + new = cls.model_validate(base.model_dump() | updates) + except ValidationError as e: + if raise_request_error: + raise RequestValidationError(errors=e.errors()) from e + else: + raise + return new + + +class _TableBase(BaseModel): + created_at: DatetimeUTC = Field( + description="Creation datetime (UTC).", + ) + updated_at: DatetimeUTC = Field( + description="Update datetime (UTC).", + ) + + def allowed( + self, + filter_id: str, + *, + allow_list_attr: str = "allowed_orgs", + block_list_attr: str = "blocked_orgs", + ) -> bool: + allow_list: list[str] = getattr(self, allow_list_attr) + block_list: list[str] | None = getattr(self, block_list_attr, None) + # Allow list + allowed = len(allow_list) == 0 or filter_id in allow_list + if block_list is None: + # No block list, just allow list + return allowed + else: + # Block list + return allowed and filter_id not in block_list + + +# TODO: Perhaps need to implement OveragePolicy + + +class PriceTier(BaseModel): + """ + https://docs.stripe.com/api/prices/object#price_object-tiers + """ + + unit_cost: float = Field( + description="Per unit price for units relevant to the tier.", + ) + up_to: float | None = Field( + description=( + "Up to and including to this quantity will be contained in the tier. " + "`None` means infinite quantity." + ), + ) + + @classmethod + def null(cls): + return cls( + unit_cost=0.0, + up_to=0.0, + ) + + @classmethod + def unlimited(cls, unit_cost: float = 0.0): + return cls( + unit_cost=unit_cost, + up_to=None, + ) + + +class Product(BaseModel): + name: SanitisedNonEmptyStr = Field( + max_length=255, + description="Product name.", + ) + included: PriceTier = Field( + description="Free tier. The `unit_cost` of this tier will always be `0.0`.", + ) + tiers: list[PriceTier] = Field( + description=( + "Additional tiers so that we may charge a different price for the first usage band versus the next. " + "For example, `included=PriceTier(unit_cost=0.0, up_to=0.5), " + "tiers=[PriceTier(unit_cost=1.0, up_to=1.0), PriceTier(unit_cost=2.0, up_to=None)]` " + "would be free for the first `0.5` units, `$1.0` per unit for the next `1.0` units, and `$2.0` per unit for the rest. " + "In this case, a usage of `2.0` units would cost `$2.0`." + ), + ) + unit: SanitisedNonEmptyStr = Field( + description="Unit of measurement for reference.", + ) + + @model_validator(mode="after") + def check_included_cost(self) -> Self: + # Included tier should be free + self.included.unit_cost = 0.0 + return self + + @classmethod + def null(cls, name: str, unit: str): + return cls( + name=name, + included=PriceTier.null(), + tiers=[], + unit=unit, + ) + + @classmethod + def unlimited(cls, name: str, unit: str, unit_cost: float = 0.0): + return cls( + name=name, + included=PriceTier.unlimited(unit_cost=unit_cost), + tiers=[], + unit=unit, + ) + + +class Products(BaseModel): + llm_tokens: Product = Field( + description="LLM token quota to this plan or tier.", + ) + embedding_tokens: Product = Field( + description="Embedding token quota to this plan or tier.", + ) + reranker_searches: Product = Field( + description="Reranker search quota to this plan or tier.", + ) + db_storage: Product = Field( + description="Database storage quota to this plan or tier.", + ) + file_storage: Product = Field( + description="File storage quota to this plan or tier.", + ) + egress: Product = Field( + description="Egress bandwidth quota to this plan or tier.", + ) + + @classmethod + def null(cls): + return cls( + llm_tokens=Product.null("ELLM tokens", "Million Tokens"), + embedding_tokens=Product.null("Embedding tokens", "Million Tokens"), + reranker_searches=Product.null("Reranker searches", "Thousand Searches"), + db_storage=Product.null("Database storage", "GiB"), + file_storage=Product.null("File storage", "GiB"), + egress=Product.null("Egress bandwidth", "GiB"), + ) + + @classmethod + def unlimited(cls, unit_cost: float = 0.0): + return cls( + llm_tokens=Product.unlimited("ELLM tokens", "Million Tokens", unit_cost), + embedding_tokens=Product.unlimited("Embedding tokens", "Million Tokens", unit_cost), + reranker_searches=Product.unlimited( + "Reranker searches", "Thousand Searches", unit_cost + ), + db_storage=Product.unlimited("Database storage", "GiB", unit_cost), + file_storage=Product.unlimited("File storage", "GiB", unit_cost), + egress=Product.unlimited("Egress bandwidth", "GiB", unit_cost), + ) + + +_product2column = dict( + credit=("credit",), + credit_grant=("credit_grant",), + llm_tokens=("llm_tokens_quota_mtok", "llm_tokens_usage_mtok"), + embedding_tokens=( + "embedding_tokens_quota_mtok", + "embedding_tokens_usage_mtok", + ), + reranker_searches=("reranker_quota_ksearch", "reranker_usage_ksearch"), + db_storage=("db_quota_gib", "db_usage_gib"), + file_storage=("file_quota_gib", "file_usage_gib"), + egress=("egress_quota_gib", "egress_usage_gib"), +) + + +class ProductType(StrEnum): + CREDIT = "credit" + CREDIT_GRANT = "credit_grant" + LLM_TOKENS = "llm_tokens" + EMBEDDING_TOKENS = "embedding_tokens" + RERANKER_SEARCHES = "reranker_searches" + DB_STORAGE = "db_storage" + FILE_STORAGE = "file_storage" + EGRESS = "egress" + + @property + def quota_column(self) -> str: + return _product2column[self.value][0] + + @property + def usage_column(self) -> str: + return _product2column[self.value][-1] + + @classmethod + def exclude_credits(cls) -> list["ProductType"]: + return [p for p in cls if not p.value.startswith("credit")] + + +class PricePlanUpdate(_BaseModel): + stripe_price_id_live: SanitisedNonEmptyStr = Field( + "", + description="Stripe price ID (live mode).", + ) + stripe_price_id_test: SanitisedNonEmptyStr = Field( + "", + description="Stripe price ID (test mode).", + ) + name: SanitisedNonEmptyStr = Field( + "", + max_length=255, + description="Price plan name.", + ) + flat_cost: float = Field( + 0.0, + ge=0.0, + description="Base price for the entire tier.", + ) + credit_grant: float = Field( + 0.0, + ge=0.0, + description="Credit amount included in USD.", + ) + max_users: int | None = Field( + 0, + ge=1, + description="Maximum number of users per organization. `None` means no limit.", + ) + products: Products = Field( + Products.null(), + description="Mapping of product ID to product.", + ) + allowed_orgs: list[str] = Field( + [], + description=( + "List of IDs of organizations allowed to use this price plan. " + "If empty, all orgs are allowed." + ), + ) + + @classmethod + def free( + cls, + stripe_price_id_live: str = "price_123", + stripe_price_id_test: str = "price_1RT2CqCcpbd72IcYEvy6U3GR", + ): + return cls( + name="Free plan", + stripe_price_id_live=stripe_price_id_live, + stripe_price_id_test=stripe_price_id_test, + flat_cost=0.0, + credit_grant=0.0, + max_users=2, # For ease of testing + products=Products( + llm_tokens=Product( + name="ELLM tokens", + included=PriceTier(unit_cost=0.5, up_to=0.75), + tiers=[], + unit="Million Tokens", + ), + embedding_tokens=Product( + name="Embedding tokens", + included=PriceTier(unit_cost=0.5, up_to=0.75), + tiers=[], + unit="Million Tokens", + ), + reranker_searches=Product( + name="Reranker searches", + included=PriceTier(unit_cost=0.5, up_to=0.75), + tiers=[], + unit="Thousand Searches", + ), + db_storage=Product( + name="Database storage", + included=PriceTier(unit_cost=0.5, up_to=0.75), + tiers=[], + unit="GiB", + ), + file_storage=Product( + name="File storage", + included=PriceTier(unit_cost=0.5, up_to=0.75), + tiers=[], + unit="GiB", + ), + egress=Product( + name="Egress bandwidth", + included=PriceTier(unit_cost=0.5, up_to=0.75), + tiers=[], + unit="GiB", + ), + ), + ) + + +class PricePlanCreate(PricePlanUpdate): + id: str = Field( + "", + description="Price plan ID.", + ) + name: SanitisedNonEmptyStr = Field( + max_length=255, + description="Price plan name.", + ) + stripe_price_id_live: SanitisedNonEmptyStr = Field( + description="Stripe price ID (live mode).", + ) + stripe_price_id_test: SanitisedNonEmptyStr = Field( + description="Stripe price ID (test mode).", + ) + flat_cost: float = Field( + ge=0.0, + description="Base price for the entire tier.", + ) + credit_grant: float = Field( + ge=0.0, + description="Credit amount included in USD.", + ) + max_users: int | None = Field( + ge=1, + description="Maximum number of users per organization. `None` means no limit.", + ) + products: Products = Field( + description="Mapping of product ID to product.", + ) + + +class PricePlan_(PricePlanCreate, _TableBase): + # Computed fields + is_private: bool = Field( + description="Whether this is a private price plan visible only to select organizations.", + ) + stripe_price_id: str = Field( + description="Stripe Price ID (either live or test based on API key).", + ) + + +class PricePlanRead(PricePlan_): + pass + + +class OnPremProvider(StrEnum): + VLLM = "vllm" + VLLM_AMD = "vllm_amd" + OLLAMA = "ollama" + INFINITY = "infinity" + INFINITY_CPU = "infinity_cpu" + + @classmethod + def list_(cls) -> list[str]: + return list(map(str, cls)) + + +class CloudProvider(StrEnum): + ANTHROPIC = "anthropic" + AZURE = "azure" + AZURE_AI = "azure_ai" + BEDROCK = "bedrock" + CEREBRAS = "cerebras" + COHERE = "cohere" + DEEPSEEK = "deepseek" + ELLM = "ellm" + FIREWORKS_AI = "fireworks_ai" + GEMINI = "gemini" + GROQ = "groq" + HYPERBOLIC = "hyperbolic" + INFINITY_CLOUD = "infinity_cloud" + JINA_AI = "jina_ai" + OPENAI = "openai" + OPENROUTER = "openrouter" + SAGEMAKER = "sagemaker" + SAMBANOVA = "sambanova" + TOGETHER_AI = "together_ai" + # VERTEX_AI = "vertex_ai" + VLLM_CLOUD = "vllm_cloud" + VOYAGE = "voyage" + + @classmethod + def list_(cls) -> list[str]: + return list(map(str, cls)) + + +class ModelProvider(StrEnum): + ANTHROPIC = "anthropic" + COHERE = "cohere" + DEEPSEEK = "deepseek" + GEMINI = "gemini" + JINA_AI = "jina_ai" + OPENAI = "openai" + + @classmethod + def list_(cls) -> list[str]: + return list(map(str, cls)) + + +class DeploymentStatus(StrEnum): + ACTIVE = "active" + + +class DeploymentUpdate(_BaseModel): + name: SanitisedNonEmptyStr = Field( + "", + max_length=255, + description="Name for the deployment.", + ) + routing_id: SanitisedNonEmptyStr = Field( + "", + description=( + "Model ID that the inference provider expects (whereas `model_id` is what the users will see). " + "OpenAI example: `model_id` CAN be `openai/gpt-5` but `routing_id` SHOULD be `gpt-5`." + ), + ) + api_base: str = Field( + "", + description=( + "(Optional) Hosting url. " + "Required for creating external cloud deployment using custom providers. " + "Example: `http://vllm-endpoint.xyz/v1`." + ), + ) + provider: SanitisedNonEmptyStr = Field( + "", + description=( + f"Inference provider of the model. " + f"Standard cloud providers are {CloudProvider.list_()}." + ), + ) + weight: int = Field( + 1, + description="Routing weight. Must be >= 0. A deployment is selected according to its relative weight.", + ) + cooldown_until: DatetimeUTC = Field( + default_factory=now, + description="Cooldown until datetime (UTC).", + ) + + +class DeploymentCreate(DeploymentUpdate): + model_id: SanitisedNonEmptyStr = Field( + description="Model ID.", + ) + name: SanitisedNonEmptyStr = Field( + max_length=255, + description="Name for the deployment.", + ) + + +class Deployment_(DeploymentCreate, _TableBase): + id: str = Field( + description="Deployment ID.", + ) + + +class DeploymentRead(Deployment_): + model: "ModelConfig_" = Field( + description="Model config.", + ) + + @computed_field(description='Status of the deployment. Will always be "ACTIVE".') + @property + def status(self) -> str: + return DeploymentStatus.ACTIVE + + +class ModelType(StrEnum): + COMPLETION = "completion" + LLM = "llm" + EMBED = "embed" + RERANK = "rerank" + + +# This is needed because DB stores Enums as keys but Pydantic loads via values +_ModelType = Annotated[ModelType, BeforeValidator(get_enum_validator(ModelType))] + + +class ModelCapability(StrEnum): + COMPLETION = "completion" + CHAT = "chat" + TOOL = "tool" + IMAGE = "image" # TODO: Maybe change to "image_in" & "image_out" + AUDIO = "audio" + EMBED = "embed" + RERANK = "rerank" + REASONING = "reasoning" + + +_ModelCapability = Annotated[ModelCapability, BeforeValidator(get_enum_validator(ModelCapability))] + + +class ModelInfo(_BaseModel): + id: SanitisedNonEmptyStr = Field( + description=( + "Unique identifier. " + "Users will specify this to select a model. " + "Must follow the following format: `{provider}/{model_id}`. " + "Examples=['openai/gpt-4o-mini', 'Qwen/Qwen2.5-0.5B']" + ), + examples=["openai/gpt-4o-mini", "Qwen/Qwen2.5-0.5B"], + ) + type: _ModelType = Field( + "", + description="Model type. Can be completion, llm, embed, or rerank.", + examples=[ModelType.LLM], + ) + name: SanitisedNonEmptyStr = Field( + "", + max_length=255, + description="Model name that is more user friendly.", + examples=["OpenAI GPT-4o Mini"], + ) + owned_by: SanitisedStr = Field( + "", + description="Model provider (usually organization that trained the model).", + ) + capabilities: list[_ModelCapability] = Field( + [], + min_length=1, + description="List of capabilities of model.", + examples=[[ModelCapability.CHAT], [ModelCapability.CHAT, ModelCapability.AUDIO]], + ) + context_length: int = Field( + 4096, + gt=0, + description="Context length of model.", + examples=[4096], + ) + languages: LanguageCodeList = Field( + ["en"], + description=f'List of languages which the model is well-versed in. "*" and "mul" resolves to {DEFAULT_MUL_LANGUAGES}.', + examples=[["en"], ["en", "zh-CN"]], + ) + max_output_tokens: int | None = Field( + None, + gt=0, + description="Maximum number of output tokens, if not specified, will be based on context length.", + # examples=[8192], + ) + + @field_validator("id", mode="after") + @classmethod + def validate_id(cls, v: str) -> str: + if len(v.split("/")) < 2: + raise ValueError( + "Model `id` must follow the following format: `{provider}/{model_id}`." + ) + return v + + @property + def capabilities_set(self) -> set[str]: + return set(map(str, self.capabilities)) + + +class ModelInfoRead(ModelInfo, _TableBase): + pass + + +class ModelConfigUpdate(ModelInfo): + # --- All models --- # + id: SanitisedNonEmptyStr = Field( + "", + description=( + "Unique identifier. " + "Users will specify this to select a model. " + "Must follow the following format: `{provider}/{model_id}`. " + "Examples=['openai/gpt-4o-mini', 'Qwen/Qwen2.5-0.5B']" + ), + ) + timeout: float = Field( + 30 * 60 * 60, + description="Timeout in seconds. Must be greater than 0. Defaults to 30 minutes.", + ) + priority: int = Field( + 0, + description="Priority for fallback model selection. The larger the number, the higher the priority.", + ) + allowed_orgs: list[str] = Field( + [], + description=( + "List of IDs of organizations allowed to use this model. " + "If empty, all orgs are allowed. Allow list is applied first, followed by block list." + ), + ) + blocked_orgs: list[str] = Field( + [], + description=( + "List of IDs of organizations NOT allowed to use this model. " + "If empty, no org is blocked. Allow list is applied first, followed by block list." + ), + ) + # --- LLM models --- # + llm_input_cost_per_mtoken: float = Field( + -1.0, + description="Cost in USD per million (mega) input / prompt token.", + ) + llm_output_cost_per_mtoken: float = Field( + -1.0, + description="Cost in USD per million (mega) output / completion token.", + ) + # --- Embedding models --- # + embedding_size: PositiveNonZeroInt | None = Field( + None, + description=( + "The default embedding size of the model. " + "For example: `openai/text-embedding-3-large` has `embedding_size` of 3072 " + "but can be shortened to `embedding_dimensions` of 256; " + "`cohere/embed-v4.0` has `embedding_size` of 1536 " + "but can be shortened to `embedding_dimensions` of 256." + ), + ) + # Matryoshka embedding dimension + embedding_dimensions: PositiveNonZeroInt | None = Field( + None, + description=( + "The number of dimensions the resulting output embeddings should have. " + "Can be overridden by `dimensions` for each request. " + "Defaults to None (no reduction). " + "Note that this parameter will only be used when using models that support Matryoshka embeddings. " + "For example: `openai/text-embedding-3-large` has `embedding_size` of 3072 " + "but can be shortened to `embedding_dimensions` of 256; " + "`cohere/embed-v4.0` has `embedding_size` of 1536 " + "but can be shortened to `embedding_dimensions` of 256." + ), + ) + # Most likely only useful for HuggingFace models + embedding_transform_query: SanitisedNonEmptyStr | None = Field( + None, + description="Transform query that might be needed, especially for HuggingFace models.", + ) + embedding_cost_per_mtoken: float = Field( + -1.0, + description="Cost in USD per million embedding tokens.", + ) + # --- Reranking models --- # + reranking_cost_per_ksearch: float = Field( + -1.0, + description="Cost in USD for a thousand searches.", + ) + + @property + def final_embedding_size(self) -> int: + embed_size = self.embedding_dimensions or self.embedding_size + if embed_size is None: + raise BadInputError( + f'Both `embedding_dimensions` and `embedding_size` are None for embedding model "{self.id}".' + ) + return embed_size + + @model_validator(mode="after") + def check_chat_cost_per_mtoken(self) -> Self: + # GPT-4o-mini pricing (2024-08-10) + if self.llm_input_cost_per_mtoken < 0: + self.llm_input_cost_per_mtoken = 0.150 + if self.llm_output_cost_per_mtoken < 0: + self.llm_output_cost_per_mtoken = 0.600 + return self + + @model_validator(mode="after") + def check_embed_cost_per_mtoken(self) -> Self: + # OpenAI text-embedding-3-small pricing (2024-09-09) + if self.embedding_cost_per_mtoken < 0: + self.embedding_cost_per_mtoken = 0.022 + return self + + @model_validator(mode="after") + def check_rerank_cost_per_ksearch(self) -> Self: + # Cohere rerank-multilingual-v3.0 pricing (2024-09-09) + if self.reranking_cost_per_ksearch < 0: + self.reranking_cost_per_ksearch = 2.0 + return self + + +class ModelConfigCreate(ModelConfigUpdate): + # Overrides to make these field required in ModelConfigCreate. + type: _ModelType = Field( + description="Model type. Can be completion, chat, embed, or rerank.", + ) + name: SanitisedNonEmptyStr = Field( + max_length=255, + description="Model name that is more user friendly.", + ) + context_length: int = Field( + gt=0, + description="Context length of model. Examples=[4096]", + ) + capabilities: list[_ModelCapability] = Field( + description="List of capabilities of model.", + ) + owned_by: SanitisedStr = Field( + "", + description="Model provider (usually organization that trained the model).", + ) + + @model_validator(mode="after") + def validate_owned_by_ellm_id_match(self) -> Self: + ellm_owned = self.owned_by == "ellm" + ellm_id = self.id.startswith("ellm/") + if (ellm_owned and not ellm_id) or (ellm_id and not ellm_owned): + raise ValueError('ELLM models must have `owned_by="ellm"` and `id="ellm/..."`.') + return self + + +class ModelConfig_(ModelConfigCreate, _TableBase): + # Computed fields + is_private: bool = Field( + False, + description="Whether this is a private model visible only to select organizations.", + ) + + @model_validator(mode="after") + def validate_owned_by_ellm_id_match(self) -> Self: + # Don't validate when reading from DB + return self + + +class ModelConfigRead(ModelConfig_): + deployments: list[Deployment_] = Field( + description="List of model deployment configs.", + ) + # Computed fields + # Since this depends on Deployment, we put here to avoid circular dependency + is_active: bool = Field( + description="Whether this model is active and ready for inference.", + ) + + +class Role(StrEnum): + ADMIN = "ADMIN" + MEMBER = "MEMBER" + GUEST = "GUEST" + + @property + def rank(self) -> "RankedRole": + return RankedRole[self.value] + + +class RankedRole(IntEnum): + GUEST = 0 + MEMBER = 1 + ADMIN = 2 + + @classmethod + def get(cls, role: str) -> int: + try: + return int(RankedRole[role]) + except KeyError: + return -1 + + +_Role = Annotated[Role, BeforeValidator(get_enum_validator(Role))] + + +class OrgMemberUpdate(_BaseModel): + role: _Role = Field( + description="Organization role.", + ) + + +class OrgMemberCreate(OrgMemberUpdate): + user_id: SanitisedNonEmptyStr = Field( + description="User ID.", + ) + organization_id: SanitisedNonEmptyStr = Field( + description="Organization ID.", + ) + + +class OrgMember_(OrgMemberCreate, _TableBase): + pass + + +class OrgMemberRead(OrgMember_): + user: "User_" = Field(description="User.") + organization: "Organization_" = Field(description="Organization.") + + +class ProjectMemberUpdate(_BaseModel): + role: _Role = Field( + description="Project role.", + ) + + +class ProjectMemberCreate(ProjectMemberUpdate): + user_id: SanitisedNonEmptyStr = Field( + description="User ID.", + ) + project_id: SanitisedNonEmptyStr = Field( + description="Project ID.", + ) + + +class ProjectMember_(ProjectMemberCreate, _TableBase): + pass + + +class ProjectMemberRead(ProjectMember_): + user: "User_" = Field(description="User.") + project: "Project_" = Field(description="Project.") + + +class _UserBase(_BaseModel): + name: SanitisedNonEmptyStr = Field( + "", + max_length=255, + description="User's preferred name.", + ) + email: EmailStr = Field( + "", + description="User's email.", + ) + picture_url: AnyUrl | None = Field( + None, + description="User picture URL.", + ) + google_id: SanitisedNonEmptyStr | None = Field( + None, + description="Google user ID.", + ) + google_name: SanitisedNonEmptyStr | None = Field( + None, + description="Google user's preferred name.", + ) + google_username: SanitisedNonEmptyStr | None = Field( + None, + description="Google username.", + ) + google_email: EmailStr | None = Field( + None, + description="Google email.", + ) + google_picture_url: SanitisedNonEmptyStr | None = Field( + None, + description="Google user picture URL.", + ) + google_updated_at: DatetimeUTC | None = Field( + None, + description="Google user info update datetime (UTC).", + ) + github_id: SanitisedNonEmptyStr | None = Field( + None, + description="GitHub user ID.", + ) + github_name: SanitisedNonEmptyStr | None = Field( + None, + description="GitHub user's preferred name.", + ) + github_username: SanitisedNonEmptyStr | None = Field( + None, + description="GitHub username.", + ) + github_email: EmailStr | None = Field( + None, + description="GitHub email.", + ) + github_picture_url: SanitisedNonEmptyStr | None = Field( + None, + description="GitHub user picture URL.", + ) + github_updated_at: DatetimeUTC | None = Field( + None, + description="GitHub user info update datetime (UTC).", + ) + + +class UserUpdate(_UserBase): + password: SanitisedNonEmptyStr = Field( + "", + max_length=72, + description="Password in plain text.", + ) + + +class UserCreate(UserUpdate): + id: SanitisedNonEmptyStr = Field( + default_factory=uuid7_str, + description="User ID.", + ) + name: SanitisedNonEmptyStr = Field( + max_length=255, + description="User's preferred name.", + ) + email: EmailStr = Field( + description="User's email.", + ) + + +def _obscure_password_hash(value: Any) -> Any: + if value is not None: + return "***" + else: + return value + + +class User_(_UserBase, _TableBase): + id: SanitisedNonEmptyStr = Field( + default_factory=uuid7_str, + description="User ID.", + ) + email_verified: bool = Field( + description="Whether the email address is verified.", + ) + password_hash: Annotated[str | None, BeforeValidator(_obscure_password_hash)] = Field( + description="Password hash.", + ) + refresh_counter: int = Field( + 0, + description="Counter used as refresh token version for invalidation.", + ) + # Computed fields + preferred_name: str = Field( + "", + description="Name for display.", + ) + preferred_email: str = Field( + "", + description="Email for display.", + ) + preferred_picture_url: str | None = Field( + None, + description="Picture URL for display.", + ) + preferred_username: str | None = Field( + None, + description="Username for display.", + ) + + +class UserAuth(User_): + org_memberships: list[OrgMember_] = Field( + description="List of organization memberships.", + ) + proj_memberships: list[ProjectMember_] = Field( + description="List of project memberships.", + ) + + +class UserRead(UserAuth): + organizations: list["Organization_"] = Field( + description="List of organizations that this user is a member of.", + ) + projects: list["Project_"] = Field( + description="List of projects that this user is a member of.", + ) + + +class UserReadObscured(UserRead): + password_hash: Annotated[str | None, BeforeValidator(_obscure_password_hash)] = Field( + description="Password hash.", + ) + + +class PaymentState(StrEnum): + NONE = "NONE" # When an organization is created + SUCCESS = "SUCCESS" # Payment is completed + PROCESSING = "PROCESSING" # Payment is initiated but yet to complete + FAILED = "FAILED" # Payment failed + + +class OrganizationUpdate(_BaseModel): + name: SanitisedNonEmptyStr = Field( + "", + max_length=255, + description="Organization name.", + ) + currency: ISO4217 = Field( + "USD", + description="Currency of the organization.", + ) + timezone: TimeZoneName | None = Field( + None, + description="Timezone specifier.", + ) + external_keys: dict[SanitisedNonEmptyStr, str] = Field( + {}, + description="Mapping of external service provider to its API key.", + ) + + @field_validator("external_keys", mode="before") + @classmethod + def validate_external_keys(cls, v: dict[str, str]) -> dict[str, str]: + # Remove empty API keys, and ensure provider is lowercase + v = {k.strip().lower(): v.strip() for k, v in v.items() if v.strip()} + return v + + @field_validator("currency", mode="after") + @classmethod + def validate_currency(cls, v: str) -> str: + if v != "USD": + raise ValueError("Currently only USD is supported.") + return v + + +class OrganizationCreate(OrganizationUpdate): + name: SanitisedNonEmptyStr = Field( + max_length=255, + description="Organization name.", + ) + + +class Organization_(OrganizationCreate, _TableBase): + id: str = Field( + description="Organization ID.", + ) + created_by: str = Field( + description="ID of the user that created this organization.", + ) + owner: str = Field( + description="ID of the user that owns this organization.", + ) + stripe_id: SanitisedNonEmptyStr | None = Field( + description="Stripe Customer ID.", + ) + # stripe_subscription_id: SanitisedNonEmptyStr = Field( + # "", + # description="Stripe Subscription ID.", + # ) + price_plan_id: SanitisedNonEmptyStr | None = Field( + description="Price plan ID.", + ) + payment_state: PaymentState = Field( + description="Payment state of the organization.", + ) + last_subscription_payment_at: DatetimeUTC | None = Field( + description="Datetime of the last successful subscription payment (UTC).", + ) + quota_reset_at: DatetimeUTC = Field( + description="Quota reset datetime (UTC).", + ) + credit: float = Field( + description="Credit paid by the customer. Unused credit will be carried forward to the next billing cycle.", + ) + credit_grant: float = Field( + description="Credit granted to the customer. Unused credit will NOT be carried forward.", + ) + llm_tokens_quota_mtok: float | None = Field( + description="LLM token quota in millions of tokens.", + ) + llm_tokens_usage_mtok: float = Field( + description="LLM token usage in millions of tokens.", + ) + embedding_tokens_quota_mtok: float | None = Field( + description="Embedding token quota in millions of tokens.", + ) + embedding_tokens_usage_mtok: float = Field( + description="Embedding token quota in millions of tokens.", + ) + reranker_quota_ksearch: float | None = Field( + description="Reranker quota for every thousand searches.", + ) + reranker_usage_ksearch: float = Field( + description="Reranker usage for every thousand searches.", + ) + db_quota_gib: float | None = Field( + description="DB storage quota in GiB.", + ) + db_usage_gib: float = Field( + description="DB storage usage in GiB.", + ) + db_usage_updated_at: DatetimeUTC = Field( + description="Datetime of the last successful DB usage update (UTC).", + ) + file_quota_gib: float | None = Field( + description="File storage quota in GiB.", + ) + file_usage_gib: float = Field( + description="File storage usage in GiB.", + ) + file_usage_updated_at: DatetimeUTC = Field( + description="Datetime of the last successful File usage update (UTC).", + ) + egress_quota_gib: float | None = Field( + description="Egress quota in GiB.", + ) + egress_usage_gib: float = Field( + description="Egress usage in GiB.", + ) + # Computed fields + active: bool = Field( + description="Whether the organization's quota is active (paid).", + ) + quotas: dict[str, dict[str, float | None]] = Field( + description="Quota snapshot.", + ) + + +class OrganizationRead(Organization_): + price_plan: PricePlan_ | None = Field( + description="Subscribed plan.", + ) + + +class ProjectUpdate(_BaseModel): + name: SanitisedNonEmptyStr = Field( + "", + max_length=255, + description="Project name.", + ) + description: SanitisedMultilineStr = Field( + "", + description="Project description.", + ) + tags: list[str] = Field( + [], + description="Project tags.", + ) + profile_picture_url: str | None = Field( + None, + description="URL of the profile picture.", + ) + cover_picture_url: str | None = Field( + None, + description="URL of the cover picture.", + ) + + +class ProjectCreate(ProjectUpdate): + organization_id: SanitisedNonEmptyStr = Field( + description="Organization ID.", + ) + name: SanitisedNonEmptyStr = Field( + max_length=255, + description="Project name.", + ) + + +class Project_(ProjectCreate, _TableBase): + id: str = Field( + description="Project ID.", + ) + created_by: str = Field( + description="ID of the user that created this project.", + ) + owner: str = Field( + description="ID of the user that owns this project.", + ) + + +class ProjectRead(Project_): + organization: "Organization_" = Field( + description="Organization.", + ) + chat_agents: list[TableMetaResponse] | None = Field( + None, + description=( + "List of ID of chat agents in this project. " + "Empty list means no chat agents are available in this project. " + "Note that by default, the list is not populated will be None." + ), + ) + + +class VerificationCodeUpdate(_BaseModel): + name: SanitisedStr = Field( + "", + max_length=255, + description="Code name.", + ) + role: SanitisedNonEmptyStr | None = Field( + None, + description="Organization or project role.", + ) + + +class VerificationCodeCreate(VerificationCodeUpdate): + user_email: EmailStr = Field( + description="User email.", + ) + expiry: DatetimeUTC = Field( + description="Code expiry datetime (UTC).", + ) + organization_id: SanitisedNonEmptyStr | None = Field( + None, + description="Organization ID.", + ) + project_id: SanitisedNonEmptyStr | None = Field( + None, + description="Project ID.", + ) + + +class VerificationCode_(VerificationCodeCreate, _TableBase): + id: str = Field( + description="The code.", + ) + purpose: str | None = Field( + None, + description="Code purpose.", + ) + used_at: DatetimeUTC | None = Field( + None, + description="Code usage datetime (UTC).", + ) + revoked_at: DatetimeUTC | None = Field( + None, + description="Code revocation datetime (UTC).", + ) + + +class VerificationCodeRead(VerificationCode_): + pass + + +class ProjectKeyUpdate(_BaseModel): + name: SanitisedNonEmptyStr = Field( + "", + max_length=255, + description="Name.", + ) + expiry: DatetimeUTC | None = Field( + None, + description="Expiry datetime (UTC). If None, never expires.", + ) + + +class ProjectKeyCreate(ProjectKeyUpdate): + name: SanitisedNonEmptyStr = Field( + max_length=255, + description="Name.", + ) + project_id: SanitisedNonEmptyStr | None = Field( + None, + description="Project ID.", + ) + + +class ProjectKey_(ProjectKeyCreate, _TableBase): + id: str = Field( + description="The token.", + ) + user_id: str = Field( + description="User ID.", + ) + + +class ProjectKeyRead(ProjectKey_): + pass diff --git a/clients/python/src/jamaibase/types/file.py b/clients/python/src/jamaibase/types/file.py new file mode 100644 index 0000000..51c209d --- /dev/null +++ b/clients/python/src/jamaibase/types/file.py @@ -0,0 +1,42 @@ +from typing import Literal + +from pydantic import BaseModel, Field + + +class FileUploadResponse(BaseModel): + object: Literal["file.upload"] = Field( + "file.upload", + description='The object type, which is always "file.upload".', + examples=["file.upload"], + ) + uri: str = Field( + description="The URI of the uploaded file.", + examples=[ + "s3://bucket-name/raw/org_id/project_id/uuid/filename.ext", + "file:///path/to/raw/file.ext", + ], + ) + + +class GetURLRequest(BaseModel): + uris: list[str] = Field( + description=( + "A list of file URIs for which pre-signed URLs or local file paths are requested. " + "The service will return a corresponding list of pre-signed URLs or local file paths." + ), + ) + + +class GetURLResponse(BaseModel): + object: Literal["file.urls"] = Field( + "file.urls", + description='The object type, which is always "file.urls".', + examples=["file.urls"], + ) + urls: list[str] = Field( + description="A list of pre-signed URLs or local file paths.", + examples=[ + "https://presigned-url-for-file1.ext", + "/path/to/file2.ext", + ], + ) diff --git a/clients/python/src/jamaibase/types/gen_table.py b/clients/python/src/jamaibase/types/gen_table.py new file mode 100644 index 0000000..8db753b --- /dev/null +++ b/clients/python/src/jamaibase/types/gen_table.py @@ -0,0 +1,667 @@ +from functools import cached_property +from typing import Annotated, Any, Literal, Self, Union + +import numpy as np +from pydantic import ( + BaseModel, + Discriminator, + Field, + Tag, + field_validator, + model_validator, +) + +from jamaibase.types.common import ( + EXAMPLE_EMBEDDING_MODEL_IDS, + DatetimeUTC, + EmptyIfNoneStr, + PositiveInt, + SanitisedNonEmptyMultilineStr, + SanitisedNonEmptyStr, +) +from jamaibase.types.lm import ( + ChatCompletionChunkResponse, + ChatCompletionResponse, + ChatRequestBase, + References, +) +from jamaibase.utils.types import StrEnum + + +class CSVDelimiter(StrEnum): + COMMA = "," + TAB = "\t" + + +class TableType(StrEnum): + ACTION = "action" + KNOWLEDGE = "knowledge" + CHAT = "chat" + + +class CellReferencesResponse(References): + object: Literal["gen_table.references"] = Field( + "gen_table.references", + description="Type of API response object.", + examples=["gen_table.references"], + ) + output_column_name: str + row_id: str + + +class CellCompletionResponse(ChatCompletionChunkResponse): + object: Literal["gen_table.completion.chunk"] = Field( + "gen_table.completion.chunk", + description="Type of API response object.", + examples=["gen_table.completion.chunk"], + ) + output_column_name: str + row_id: str + + +class RowCompletionResponse(BaseModel): + object: Literal["gen_table.completion.chunks"] = Field( + "gen_table.completion.chunks", + description="Type of API response object.", + examples=["gen_table.completion.chunks"], + ) + # Union just to satisfy "object" discriminator + # columns: dict[str, ChatCompletionResponse | ChatCompletionChunkResponse] + columns: dict[str, ChatCompletionResponse] + row_id: str + + +class MultiRowCompletionResponse(BaseModel): + object: Literal["gen_table.completion.rows"] = Field( + "gen_table.completion.rows", + description="Type of API response object.", + examples=["gen_table.completion.rows"], + ) + rows: list[RowCompletionResponse] + + +class LLMGenConfig(ChatRequestBase): + object: Literal["gen_config.llm"] = Field( + "gen_config.llm", + description='The object type, which is always "gen_config.llm".', + examples=["gen_config.llm"], + ) + system_prompt: str = Field( + "", + description="System prompt for the LLM.", + ) + prompt: str = Field( + "", + description="Prompt for the LLM.", + ) + multi_turn: bool = Field( + False, + description="Whether this column is a multi-turn chat with history along the entire column.", + ) + + @model_validator(mode="before") + @classmethod + def compat(cls, data: dict[str, Any] | BaseModel) -> dict[str, Any]: + if isinstance(data, BaseModel): + data = data.model_dump() + if not isinstance(data, dict): + raise TypeError( + f"Input to `LLMGenConfig` must be a dict or BaseModel, received: {type(data)}" + ) + if data.get("system_prompt", None) or data.get("prompt", None): + return data + messages: list[dict[str, Any]] = data.get("messages", []) + num_prompts = len(messages) + if num_prompts >= 2: + data["system_prompt"] = messages[0]["content"] + data["prompt"] = messages[1]["content"] + elif num_prompts == 1: + if messages[0]["role"] == "system": + data["system_prompt"] = messages[0]["content"] + data["prompt"] = "" + elif messages[0]["role"] == "user": + data["system_prompt"] = "" + data["prompt"] = messages[0]["content"] + else: + raise ValueError( + f'Attribute "messages" cannot contain only assistant messages: {messages}' + ) + data["object"] = "gen_config.llm" + return data + + +class EmbedGenConfig(BaseModel): + object: Literal["gen_config.embed"] = Field( + "gen_config.embed", + description='The object type, which is always "gen_config.embed".', + examples=["gen_config.embed"], + ) + embedding_model: SanitisedNonEmptyStr = Field( + description="The embedding model to use.", + examples=EXAMPLE_EMBEDDING_MODEL_IDS, + ) + source_column: SanitisedNonEmptyStr = Field( + description="The source column for embedding.", + examples=["text_column"], + ) + + +class CodeGenConfig(BaseModel): + object: Literal["gen_config.code"] = Field( + "gen_config.code", + description='The object type, which is always "gen_config.code".', + examples=["gen_config.code"], + ) + source_column: SanitisedNonEmptyStr = Field( + description="The source column for python code to execute.", + examples=["code_column"], + ) + + +class PythonGenConfig(BaseModel): + object: Literal["gen_config.python"] = Field( + "gen_config.python", + description='The object type, which is always "gen_config.python".', + examples=["gen_config.python"], + ) + python_code: SanitisedNonEmptyMultilineStr = Field( + description="The python code to execute.", + examples=["row['output_column']='Hello World!'"], + ) + + +def _gen_config_discriminator(x: Any) -> str | None: + object_attr = getattr(x, "object", None) + if object_attr: + return object_attr + if isinstance(x, BaseModel): + x = x.model_dump() + if isinstance(x, dict): + if "object" in x: + return x["object"] + if "embedding_model" in x: + return "gen_config.embed" + if "source_column" in x: + return "gen_config.code" + if "python_code" in x: + return "gen_config.python" + else: + return "gen_config.llm" + return None + + +DiscriminatedGenConfig = Annotated[ + Union[ + # Annotated[CodeGenConfig, Tag("gen_config.code")], + Annotated[PythonGenConfig, Tag("gen_config.python")], + Annotated[LLMGenConfig, Tag("gen_config.llm")], + Annotated[LLMGenConfig, Tag("gen_config.chat")], + Annotated[EmbedGenConfig, Tag("gen_config.embed")], + ], + Discriminator(_gen_config_discriminator), +] + + +class ColumnSchema(BaseModel): + id: str = Field(description="Column name.") + dtype: str = Field( + "str", + description="Column data type.", + ) + vlen: PositiveInt = Field( # type: ignore + 0, + description=( + "_Optional_. Vector length. " + "If this is larger than zero, then `dtype` must be one of the floating data types. Defaults to zero." + ), + ) + index: bool = Field( + True, + description=( + "_Optional_. Whether to build full-text-search (FTS) or vector index for this column. " + "Only applies to string and vector columns. Defaults to True." + ), + ) + gen_config: DiscriminatedGenConfig | None = Field( + None, + description=( + '_Optional_. Generation config. If provided, then this column will be an "Output Column". ' + "Table columns on its left can be referenced by `${column-name}`." + ), + ) + + +class ColumnSchemaCreate(ColumnSchema): + id: SanitisedNonEmptyStr = Field(description="Column name.") + dtype: Literal["int", "float", "bool", "str", "file", "image", "audio", "document"] = Field( + "str", + description=( + 'Column data type, one of ["int", "float", "bool", "str", "file", "image", "audio", "document"]' + ". Data type 'file' is deprecated, use 'image' instead." + ), + ) + + +class _TableBase(BaseModel): + id: str = Field( + description="Table name.", + ) + + +class TableSchemaCreate(_TableBase): + id: SanitisedNonEmptyStr = Field( + description="Table name.", + ) + cols: list[ColumnSchemaCreate] = Field( + description="List of column schema.", + ) + + +class ActionTableSchemaCreate(TableSchemaCreate): + pass + + +class AddActionColumnSchema(ActionTableSchemaCreate): + # TODO: Deprecate this + pass + + +class KnowledgeTableSchemaCreate(TableSchemaCreate): + # TODO: Maybe deprecate this and use EmbedGenConfig instead ? + embedding_model: str + + +class AddKnowledgeColumnSchema(TableSchemaCreate): + # TODO: Deprecate this + pass + + +class ChatTableSchemaCreate(TableSchemaCreate): + pass + + +class AddChatColumnSchema(TableSchemaCreate): + # TODO: Deprecate this + pass + + +class TableMeta(_TableBase): + meta: dict[str, Any] | None = Field( + None, + description="Additional metadata about the table.", + ) + cols: list[ColumnSchema] = Field( + description="List of column schema.", + ) + parent_id: str | None = Field( + description="The parent table ID. If None (default), it means this is a parent table.", + ) + title: str = Field( + description='Chat title. Defaults to "".', + ) + created_by: str | None = Field( + None, + description="ID of the user that created this table. Defaults to None.", + ) + updated_at: DatetimeUTC = Field( + description="Table last update datetime (UTC).", + ) + num_rows: int = Field( + -1, + description="Number of rows in the table. Defaults to -1 (not counted).", + ) + version: str = Field( + description="Version.", + ) + + @cached_property + def col_map(self) -> dict[str, ColumnSchema]: + return {c.id: c for c in self.cols} + + @cached_property + def cfg_map(self) -> dict[str, DiscriminatedGenConfig | None]: + return {c.id: c.gen_config for c in self.cols} + + +class TableMetaResponse(TableMeta): + # Legacy, for backwards compatibility + indexed_at_fts: str | None = Field( + None, + description="Table last FTS index timestamp (ISO 8601 UTC).", + ) + indexed_at_vec: str | None = Field( + None, + description="Table last vector index timestamp (ISO 8601 UTC).", + ) + indexed_at_sca: str | None = Field( + None, + description="Table last scalar index timestamp (ISO 8601 UTC).", + ) + + @model_validator(mode="after") + def remove_state_cols(self) -> Self: + self.cols = [c for c in self.cols if not c.id.endswith("_")] + return self + + +class GenConfigUpdateRequest(BaseModel): + table_id: str = Field(description="Table name or ID.") + column_map: dict[str, DiscriminatedGenConfig | None] = Field( + description=( + "Mapping of column ID to generation config JSON in the form of `GenConfig`. " + "Table columns on its left can be referenced by `${column-name}`." + ) + ) + + @model_validator(mode="after") + def check_column_map(self) -> Self: + if sum(n.lower() in ("id", "updated at") for n in self.column_map) > 0: + raise ValueError("column_map cannot contain keys: 'ID' or 'Updated at'.") + return self + + +class ColumnRenameRequest(BaseModel): + table_id: str = Field( + description="Table name or ID.", + ) + column_map: dict[str, str] = Field( + min_length=1, + description="Mapping of old column names to new column names.", + ) + + @model_validator(mode="after") + def check_column_map(self) -> Self: + if sum(n.lower() in ("id", "updated at") for n in self.column_map) > 0: + raise ValueError("`column_map` cannot contain keys: 'ID' or 'Updated at'.") + return self + + +class ColumnReorderRequest(BaseModel): + table_id: str = Field( + description="Table name or ID.", + ) + column_names: list[str] = Field( + min_length=1, + description="List of column ID in the desired order.", + ) + + @field_validator("column_names", mode="after") + @classmethod + def check_column_order(cls, values: list[str]) -> list[str]: + if values[0].lower() != "id": + values.insert(0, "ID") + if values[1].lower() != "updated at": + values.insert(1, "Updated at") + return values + + @field_validator("column_names", mode="after") + @classmethod + def check_unique_column_names(cls, values: list[str]) -> list[str]: + if len(set(n.lower() for n in values)) != len(values): + raise ValueError("Column names must be unique (case-insensitive).") + return values + + @field_validator("column_names", mode="after") + @classmethod + def check_state_column(cls, values: list[str]) -> list[str]: + if len(invalid_cols := [n for n in values if n.endswith("_")]) > 0: + raise ValueError(f"State columns cannot be reordered: {invalid_cols}") + return values + + +class ColumnDropRequest(BaseModel): + table_id: str = Field( + description="Table name or ID.", + ) + column_names: list[str] = Field( + min_length=1, + description="List of column ID to drop.", + ) + + @model_validator(mode="after") + def check_column_names(self) -> Self: + if sum(n.lower() in ("id", "updated at") for n in self.column_names) > 0: + raise ValueError("`column_names` cannot contain keys: 'ID' or 'Updated at'.") + return self + + +class MultiRowAddRequest(BaseModel): + table_id: SanitisedNonEmptyStr = Field( + description="Table name or ID.", + ) + data: list[dict[str, Any]] = Field( + min_length=1, + description=( + "List of mapping of column names to its value. " + "In other words, each item in the list is a row, and each item is a mapping. " + "Minimum 1 row, maximum 100 rows." + ), + ) + stream: bool = Field( + True, + description="Whether or not to stream the LLM generation.", + ) + concurrent: bool = Field( + True, + description="_Optional_. Whether or not to concurrently generate the output rows and columns.", + ) + + def __repr__(self): + _data = [ + { + k: ( + {"type": type(v), "shape": v.shape, "dtype": v.dtype} + if isinstance(v, np.ndarray) + else v + ) + for k, v in d.items() + } + for d in self.data + ] + return ( + f"{self.__class__.__name__}(" + f"table_id={self.table_id} stream={self.stream} " + f"concurrent={self.concurrent} data={_data}" + ")" + ) + + +class RowUpdateRequest(BaseModel): + table_id: str = Field( + description="Table name or ID.", + ) + row_id: str = Field( + description="ID of the row to update.", + ) + data: dict[str, Any] = Field( + description="Mapping of column names to its value.", + ) + + +class MultiRowUpdateRequest(BaseModel): + table_id: str = Field( + description="Table name or ID.", + ) + data: dict[str, dict[str, Any]] = Field( + min_length=1, + description="Mapping of row IDs to row data, where each row data is a mapping of column names to its value.", + ) + + +class MultiRowUpdateRequestWithLimit(MultiRowUpdateRequest): + data: dict[str, dict[str, Any]] = Field( + min_length=1, + max_length=100, + description="Mapping of row IDs to row data, where each row data is a mapping of column names to its value.", + ) + + +class RowRegen(BaseModel): + table_id: str = Field( + description="Table name or ID.", + ) + row_id: str = Field( + description="ID of the row to regenerate.", + ) + regen_strategy: str = Field( + "run_all", + description=( + "_Optional_. Strategy for selecting columns to regenerate." + "Choose `run_all` to regenerate all columns in the specified row; " + "Choose `run_before` to regenerate columns up to the specified column_id; " + "Choose `run_selected` to regenerate only the specified column_id; " + "Choose `run_after` to regenerate columns starting from the specified column_id; " + ), + ) + output_column_id: str | None = Field( + None, + min_length=1, + description=( + "_Optional_. Output column name to indicate the starting or ending point of regen for `run_before`, " + "`run_selected` and `run_after` strategies. Required if `regen_strategy` is not 'run_all'. " + "Given columns are 'C1', 'C2', 'C3' and 'C4', if column_id is 'C3': " + "`run_before` regenerate columns 'C1', 'C2' and 'C3'; " + "`run_selected` regenerate only column 'C3'; " + "`run_after` regenerate columns 'C3' and 'C4'; " + ), + ) + stream: bool = Field( + description="Whether or not to stream the LLM generation.", + ) + concurrent: bool = Field( + True, + description="_Optional_. Whether or not to concurrently generate the output columns.", + ) + + +class MultiRowRegenRequest(BaseModel): + table_id: str = Field( + description="Table name or ID.", + ) + row_ids: list[str] = Field( + min_length=1, + max_length=100, + description="List of ID of the row to regenerate. Minimum 1 row, maximum 100 rows.", + ) + regen_strategy: str = Field( + "run_all", + description=( + "_Optional_. Strategy for selecting columns to regenerate." + "Choose `run_all` to regenerate all columns in the specified row; " + "Choose `run_before` to regenerate columns up to the specified column_id; " + "Choose `run_selected` to regenerate only the specified column_id; " + "Choose `run_after` to regenerate columns starting from the specified column_id; " + ), + ) + output_column_id: str | None = Field( + None, + min_length=1, + description=( + "_Optional_. Output column name to indicate the starting or ending point of regen for `run_before`, " + "`run_selected` and `run_after` strategies. Required if `regen_strategy` is not 'run_all'. " + "Given columns are 'C1', 'C2', 'C3' and 'C4', if column_id is 'C3': " + "`run_before` regenerate columns 'C1', 'C2' and 'C3'; " + "`run_selected` regenerate only column 'C3'; " + "`run_after` regenerate columns 'C3' and 'C4'; " + ), + ) + stream: bool = Field( + True, + description="Whether or not to stream the LLM generation.", + ) + concurrent: bool = Field( + True, + description="Whether or not to concurrently generate the output rows and columns. Defaults to True.", + ) + + +class MultiRowDeleteRequest(BaseModel): + table_id: str = Field( + description="Table name or ID.", + ) + row_ids: list[str] | None = Field( + None, + min_length=1, + max_length=100, + description="List of row IDs to be deleted. Maximum 100 rows. Defaults to None (match rows using `where`).", + ) + where: EmptyIfNoneStr = Field( + "", + description=( + "SQL where clause. " + "Can be nested ie `x = '1' AND (\"y (1)\" = 2 OR z = '3')`. " + "It will be combined with `row_ids` using `AND`. " + 'Defaults to "" (no filter).' + ), + ) + + +class SearchRequest(BaseModel): + table_id: str = Field( + description="Table name or ID.", + ) + query: str = Field( + min_length=1, + description="Query for full-text-search (FTS) and vector search. Must not be empty.", + ) + limit: Annotated[int, Field(gt=0, le=1_000)] = Field( + 100, + description="_Optional_. Min 1, max 1000. Number of rows to return.", + ) + metric: str = Field( + "cosine", + description='_Optional_. Vector search similarity metric. Defaults to "cosine".', + ) + float_decimals: int = Field( + 0, + description="_Optional_. Number of decimals for float values. Defaults to 0 (no rounding).", + ) + vec_decimals: int = Field( + 0, + description="_Optional_. Number of decimals for vectors. If its negative, exclude vector columns. Defaults to 0 (no rounding).", + ) + reranking_model: Annotated[ + str | None, Field(description="Reranking model to use for hybrid search.") + ] = None + + +class TableDataImportRequest(BaseModel): + file_path: Annotated[str, Field(description="CSV or TSV file path.")] + table_id: Annotated[ + str, Field(description="ID or name of the table that the data should be imported into.") + ] + stream: Annotated[bool, Field(description="Whether or not to stream the LLM generation.")] = ( + True + ) + # column_names: Annotated[ + # list[str] | None, + # Field( + # description="A list of columns names if the CSV does not have header row. Defaults to None (read from CSV)." + # ), + # ] = None + # columns: Annotated[ + # list[str] | None, + # Field( + # description="A list of columns to be imported. Defaults to None (import all columns except 'ID' and 'Updated at')." + # ), + # ] = None + delimiter: Annotated[ + Literal[",", "\t"], + Field(description='The delimiter of the file: can be "," or "\\t". Defaults to ",".'), + ] = "," + + +class TableImportRequest(BaseModel): + file_path: Annotated[str, Field(description="The parquet file path.")] + table_id_dst: Annotated[ + str | None, Field(description="_Optional_. The ID or name of the new table.") + ] = None + blocking: Annotated[ + bool, + Field( + description=( + "If True, waits until import finishes. " + "If False, the task is submitted to a task queue and returns immediately." + ), + ), + ] = True diff --git a/clients/python/src/jamaibase/types/legacy.py b/clients/python/src/jamaibase/types/legacy.py new file mode 100644 index 0000000..f9b5741 --- /dev/null +++ b/clients/python/src/jamaibase/types/legacy.py @@ -0,0 +1,49 @@ +from pydantic import BaseModel, Field + +from jamaibase.types.lm import ( + Chunk, + RAGParams, +) + + +class VectorSearchRequest(RAGParams): + id: str = Field( + default="", + description="Request ID for logging purposes.", + examples=["018ed5f1-6399-71f7-86af-fc18d4a3e3f5"], + ) + search_query: str = Field(description="Query used to retrieve items from the Knowledge Table.") + + +class VectorSearchResponse(BaseModel): + object: str = Field( + default="kb.search_response", + description="Type of API response object.", + examples=["kb.search_response"], + ) + chunks: list[Chunk] = Field( + default=[], + description="A list of `Chunk`.", + examples=[ + [ + Chunk( + text="The Name of the Title is Hope\n\n...", + title="The Name of the Title is Hope", + page=0, + file_name="sample_tables.pdf", + file_path="amagpt/sample_tables.pdf", + metadata={ + "total_pages": 3, + "Author": "Ben Trovato", + "CreationDate": "D:20231031072817Z", + "Creator": "LaTeX with acmart 2023/10/14 v1.92 Typesetting articles for the Association for Computing Machinery and hyperref 2023-07-08 v7.01b Hypertext links for LaTeX", + "Keywords": "Image Captioning, Deep Learning", + "ModDate": "D:20231031073146Z", + "PTEX.Fullbanner": "This is pdfTeX, Version 3.141592653-2.6-1.40.25 (TeX Live 2023) kpathsea version 6.3.5", + "Producer": "3-Heights(TM) PDF Security Shell 4.8.25.2 (http://www.pdf-tools.com) / pdcat (www.pdf-tools.com)", + "Trapped": "False", + }, + ) + ] + ], + ) diff --git a/clients/python/src/jamaibase/types/lm.py b/clients/python/src/jamaibase/types/lm.py new file mode 100644 index 0000000..3289640 --- /dev/null +++ b/clients/python/src/jamaibase/types/lm.py @@ -0,0 +1,1433 @@ +import re +from time import time +from typing import Annotated, Any, Literal, Union + +from pydantic import ( + AfterValidator, + BaseModel, + BeforeValidator, + ConfigDict, + Field, + field_validator, + model_validator, +) + +from jamaibase.types.common import ( + EXAMPLE_EMBEDDING_MODEL_IDS, + EXAMPLE_RERANKING_MODEL_IDS, + EmptyIfNoneStr, + PositiveInt, + PositiveNonZeroInt, +) +from jamaibase.utils.types import StrEnum + +CITATION_PATTERN = r"\[(@[0-9]+)[; ]*\]" + + +class Chunk(BaseModel): + """Class for storing a piece of text and associated metadata.""" + + text: str = Field( + description="Chunk text.", + ) + title: str = Field( + "", + description='Document title. Defaults to "".', + ) + page: int | None = Field( + None, + description="Document page the chunk text is from. Defaults to None.", + ) + file_name: str = Field( + "", + description='File name. Defaults to "".', + ) + file_path: str = Field( + "", + description='File path. Defaults to "".', + ) + document_id: str = Field( + "", + description='Document ID. Defaults to "".', + ) + chunk_id: str = Field( + "", + description='Chunk ID. Defaults to "".', + ) + context: dict[str, str] = Field( + {}, + description="Additional context that should be included in the RAG prompt. Defaults to an empty dictionary.", + ) + metadata: dict = Field( + {}, + description=( + "Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). " + "Defaults to an empty dictionary." + ), + ) + + +class SplitChunksParams(BaseModel): + method: str = Field( + "RecursiveCharacterTextSplitter", + description="Name of the splitter.", + examples=["RecursiveCharacterTextSplitter"], + ) + chunk_size: PositiveNonZeroInt = Field( + 1000, + description="Maximum chunk size (number of characters). Must be > 0.", + examples=[1000], + ) + chunk_overlap: PositiveInt = Field( + 200, + description="Overlap in characters between chunks. Must be >= 0.", + examples=[200], + ) + + +class SplitChunksRequest(BaseModel): + id: str = Field( + "", + description="Request ID for logging purposes.", + examples=["018ed5f1-6399-71f7-86af-fc18d4a3e3f5"], + ) + chunks: list[Chunk] = Field( + description="List of `Chunk` where each will be further split into chunks.", + examples=[ + [ + Chunk( + text="The Name of the Title is Hope\n\n...", + title="The Name of the Title is Hope", + page=0, + file_name="sample_tables.pdf", + file_path="amagpt/sample_tables.pdf", + metadata={ + "total_pages": 3, + "Author": "Ben Trovato", + "CreationDate": "D:20231031072817Z", + "Creator": "LaTeX with acmart 2023/10/14 v1.92 Typesetting articles", + "Keywords": "Image Captioning, Deep Learning", + "ModDate": "D:20231031073146Z", + }, + ) + ] + ], + ) + params: SplitChunksParams = Field( + SplitChunksParams(), + description=( + "How to split each document. " + "Defaults to `RecursiveCharacterTextSplitter` with chunk_size = 1000 and chunk_overlap = 200." + ), + examples=[SplitChunksParams()], + ) + + def str_trunc(self) -> str: + return f"id={self.id} len(chunks)={len(self.chunks)} params={self.params}" + + +class References(BaseModel): + object: Literal["chat.references"] = Field( + "chat.references", + description="Type of API response object.", + examples=["chat.references"], + ) + chunks: list[Chunk] = Field( + [], + description="A list of `Chunk`.", + examples=[ + [ + Chunk( + text="The Name of the Title is Hope\n\n...", + title="The Name of the Title is Hope", + page=0, + file_name="sample_tables.pdf", + file_path="amagpt/sample_tables.pdf", + metadata={ + "total_pages": 3, + "Author": "Ben Trovato", + "CreationDate": "D:20231031072817Z", + "Creator": "LaTeX with acmart 2023/10/14 v1.92 Typesetting articles", + "Keywords": "Image Captioning, Deep Learning", + "ModDate": "D:20231031073146Z", + }, + ) + ] + ], + ) + search_query: str = Field(description="Query used to retrieve items from the Knowledge Table.") + finish_reason: Literal["stop", "context_overflow"] | None = Field( + None, + deprecated=True, + description=""" +In streaming mode, reference chunk will be streamed first. +However, if the model's context length is exceeded, then there will be no further completion chunks. +In this case, "finish_reason" will be set to "context_overflow". +Otherwise, it will be None or null. +""", + ) + + def remove_contents(self): + copy = self.model_copy(deep=True) + for d in copy.documents: + d.page_content = "" + return copy + + +class RAGParams(BaseModel): + table_id: str = Field( + "", + description="Knowledge Table ID", + examples=["my-dataset"], + ) + reranking_model: str | None = Field( + None, + description="Reranking model to use for hybrid search. Defaults to None (no reranking).", + examples=[EXAMPLE_RERANKING_MODEL_IDS[0], None], + ) + search_query: str = Field( + "", + description=( + "Query used to retrieve items from the Knowledge Table. " + "If not provided (default), it will be generated using LLM." + ), + ) + k: Annotated[int, Field(gt=0, le=1024)] = Field( + 3, + gt=0, + le=1024, + description="Top-k closest text in terms of embedding distance. Must be in [1, 1024]. Defaults to 3.", + examples=[3], + ) + rerank: bool = Field( + True, + deprecated=True, + description="(Deprecated) Flag to perform rerank on the retrieved results. Defaults to True.", + examples=[True, False], + ) + concat_reranker_input: bool = Field( + False, + description="Flag to concat title and content as reranker input. Defaults to False.", + examples=[True, False], + ) + inline_citations: bool = Field( + True, + description=( + "If True, the model will cite sources as it writes using Pandoc-style in the form of `[@]`. " + "The number is the index of the source in the reference list, ie `[@0; @3]` means the 1st and 4th source in `References.chunks`. " + "Defaults to True." + ), + examples=[True, False], + ) + + +class FunctionCall(BaseModel): + name: str = Field( + description="The name of the function to call.", + ) + arguments: str = Field( + description="The arguments to call the function with, as generated by the model in JSON format.", + ) + + +class ToolCallFunction(BaseModel): + arguments: str + name: str | None + + +class ToolCall(BaseModel): + id: str = Field( + description="The ID of the tool call.", + ) + type: Literal["function"] = Field( + "function", + description="The type of the tool. Currently, only `function` is supported.", + ) + function: ToolCallFunction + + +class AudioResponse(BaseModel): + id: str = Field( + description="Unique identifier for this audio response.", + ) + expires_at: int = Field( + description="The Unix timestamp (in seconds) for when this audio response will no longer be accessible.", + ) + data: str = Field( + description="Base64 encoded audio bytes generated by the model.", + ) + transcript: str = Field( + description="Transcript of the audio generated by the model.", + ) + + +class ChatCompletionDelta(BaseModel): + role: str = Field( + "assistant", + description="The role of the author of this message.", + ) + content: str | None = Field( + None, + description="The contents of the chunk message.", + ) + reasoning_content: str | None = Field( + None, + description="The reasoning contents generated by the model.", + ) + refusal: str | None = Field( + None, + description="The refusal message generated by the model.", + ) + tool_calls: list[ToolCall] | None = Field( + None, + description="The tool calls generated by the model, such as function calls.", + ) + function_call: FunctionCall | None = Field( + None, + deprecated=True, + description=( + "Deprecated and replaced by `tool_calls`. " + "The name and arguments of a function that should be called." + ), + ) + + +class ChatCompletionMessage(ChatCompletionDelta): + # content: str = Field( + # description="The contents of the message.", + # ) + audio: AudioResponse | None = Field( + None, + description="If the audio output modality is requested, this object contains data about the audio response from the model.", + ) + + +class LogProbToken(BaseModel): + token: str = Field( + description="The token.", + ) + logprob: float = Field( + description=( + "The log probability of this token, if it is within the top 20 most likely tokens. " + "Otherwise, the value `-9999.0` is used to signify that the token is very unlikely." + ), + ) + bytes: list[int] | None = Field( + description="A list of integers representing the UTF-8 bytes representation of the token.", + ) + + +class LogProbs(BaseModel): + content: list[LogProbToken] | None = Field( + None, + description="A list of message content tokens with log probability information.", + ) + refusal: list[LogProbToken] | None = Field( + None, + description="A list of message refusal tokens with log probability information.", + ) + + +class ChatCompletionChoice(BaseModel): + index: int = Field( + description="The index of the choice in the list of choices.", + ) + message: ChatCompletionMessage | None = Field( + None, + description="A chat completion message generated by the model.", + ) + delta: ChatCompletionDelta | None = Field( + None, + description="A chat completion delta generated by streamed model responses.", + ) + logprobs: LogProbs | None = Field( + None, + description="Log probability information for the choice.", + ) + finish_reason: str | None = Field( + None, + description=( + "The reason the model stopped generating tokens. " + "This will be `stop` if the model hit a natural stop point or a provided stop sequence, " + "`length` if the maximum number of tokens specified in the request was reached." + ), + ) + + @property + def text(self) -> str: + """The text of the most recent chat completion.""" + message = self.message or self.delta + return getattr(message, "content", None) or "" + + @model_validator(mode="after") + def validate_message_delta(self): + if self.delta is not None: + self.message = ChatCompletionMessage.model_validate(self.delta.model_dump()) + return self + + +def _none_to_zero(v: int | None) -> int: + if v is None: + return 0 + return v + + +ZeroIfNoneInt = Annotated[int, BeforeValidator(_none_to_zero)] + + +class PromptUsageDetails(BaseModel): + cached_tokens: ZeroIfNoneInt = Field( + 0, + description="Cached tokens present in the prompt.", + ) + audio_tokens: ZeroIfNoneInt = Field( + 0, + description="Audio input tokens present in the prompt or generated by the model.", + ) + + +class CompletionUsageDetails(BaseModel): + audio_tokens: ZeroIfNoneInt = Field( + 0, + description="Audio input tokens present in the prompt or generated by the model.", + ) + reasoning_tokens: ZeroIfNoneInt = Field( + 0, + description="Tokens generated by the model for reasoning.", + ) + accepted_prediction_tokens: ZeroIfNoneInt = Field( + 0, + description="When using Predicted Outputs, the number of tokens in the prediction that appeared in the completion.", + ) + rejected_prediction_tokens: ZeroIfNoneInt = Field( + 0, + description="When using Predicted Outputs, the number of tokens in the prediction that did not appear in the completion.", + ) + + +class ToolUsageDetails(BaseModel): + web_search_calls: ZeroIfNoneInt = Field( + 0, + description="Number of web search calls.", + ) + code_interpreter_calls: ZeroIfNoneInt = Field( + 0, + description="Number of code interpreter calls.", + ) + + +class ChatCompletionUsage(BaseModel): + prompt_tokens: ZeroIfNoneInt = Field( + 0, + description="Number of tokens in the prompt.", + ) + completion_tokens: ZeroIfNoneInt = Field( + 0, + description="Number of tokens in the generated completion.", + ) + total_tokens: ZeroIfNoneInt = Field( + 0, + description="Total number of tokens used in the request (prompt + completion).", + ) + prompt_tokens_details: PromptUsageDetails | None = Field( + None, + description="Breakdown of tokens used in the prompt.", + ) + completion_tokens_details: CompletionUsageDetails | None = Field( + None, + description="Breakdown of tokens used in a completion.", + ) + tool_usage_details: ToolUsageDetails | None = Field( + None, + description="Breakdown of tool usage details, such as web search and code interpreter calls.", + ) + + @property + def reasoning_tokens(self) -> int: + return getattr(self.completion_tokens_details, "reasoning_tokens", 0) + + +class ChatCompletionResponse(BaseModel): + id: str = Field( + description="A unique identifier for the chat completion.", + ) + object: Literal["chat.completion"] = Field( + "chat.completion", + description="The object type, which is always `chat.completion`.", + ) + created: int = Field( + default_factory=lambda: int(time()), + description="The Unix timestamp (in seconds) of when the chat completion was created.", + ) + model: str = Field( + description="The model used for the chat completion.", + ) + choices: list[ChatCompletionChoice] = Field( + description=( + "A list of chat completion choices. " + "Can contain more than one elements if `n` is greater than 1." + ), + ) + usage: ChatCompletionUsage = Field( + description="Usage statistics for the completion request.", + ) + references: References | None = Field( + None, + description="References of this Retrieval Augmented Generation (RAG) response.", + ) + service_tier: str | None = Field( + None, + description="The service tier used for processing the request.", + ) + system_fingerprint: str | None = Field( + None, + description="This fingerprint represents the backend configuration that the model runs with.", + ) + + @field_validator("choices", mode="after") + @classmethod + def validate_choices(cls, v: list[ChatCompletionChoice]) -> list[ChatCompletionChoice]: + if len(v) > 0 and v[0].message is None: + raise ValueError("`message` must be defined.") + return v + + @property + def finish_reason(self) -> str | None: + return self.choices[0].finish_reason if len(self.choices) > 0 else None + + @property + def delta(self) -> ChatCompletionMessage | None: + """The delta of the first chat completion choice.""" + return self.message + + @property + def message(self) -> ChatCompletionMessage | None: + """The message of the first chat completion choice.""" + return self.choices[0].message if len(self.choices) > 0 else None + + @property + def reasoning_content(self) -> str: + """The reasoning text of the first chat completion choice message.""" + return getattr(self.message, "reasoning_content", None) or "" + + @property + def content(self) -> str: + """The text of the first chat completion choice message.""" + return getattr(self.message, "content", None) or "" + + @property + def text(self) -> str: + """The text of the most recent chat completion.""" + return self.content + + @property + def prompt_tokens(self) -> int: + return getattr(self.usage, "prompt_tokens", 0) + + @property + def completion_tokens(self) -> int: + return getattr(self.usage, "completion_tokens", 0) + + @property + def reasoning_tokens(self) -> int: + return getattr(self.usage, "reasoning_tokens", 0) + + @property + def total_tokens(self) -> int: + return getattr(self.usage, "total_tokens", 0) + + +class ChatCompletionChunkResponse(ChatCompletionResponse): + object: Literal["chat.completion.chunk"] = Field( + "chat.completion.chunk", + description="The object type, which is always `chat.completion.chunk`.", + ) + choices: list[ChatCompletionChoice] = Field( + description=( + "A list of chat completion choices. " + "Can contain more than one elements if `n` is greater than 1. " + 'Can also be empty for the last chunk if you set stream_options: `{"include_usage": true}`.' + ), + ) + usage: ChatCompletionUsage | None = Field( + None, + description="Contains a `null` value except for the last chunk which contains the token usage statistics for the entire request.", + ) + + @field_validator("choices", mode="after") + @classmethod + def validate_choices(cls, v: list[ChatCompletionChoice]) -> list[ChatCompletionChoice]: + # Override + return v + + +class TextContent(BaseModel): + type: Literal["text"] = Field( + "text", + description="The type of content.", + ) + text: EmptyIfNoneStr = Field( + description="The text content.", + ) + + +class ImageContentData(BaseModel): + url: str = Field( + description=( + "Either a URL of the image or the base64 encoded image data " + 'in the form of `"data:;base64,{base64_image}"`.' + ), + ) + + def __repr__(self): + _url = self.url + if len(_url) > 12: + _url = f"{_url[:6]}...{_url[-6:]}" + return f"{self.__class__.__name__}(url='{_url}')" + + +class ImageContent(BaseModel): + type: Literal["image_url"] = Field( + "image_url", + description="The type of content.", + ) + image_url: ImageContentData = Field( + description="The image content.", + ) + + +class AudioContentData(BaseModel): + data: str = Field( + description="Base-64 encoded audio data.", + ) + format: Literal["mp3", "wav"] = Field( + "wav", + description="The audio format.", + ) + + def __repr__(self): + _data = self.data + if len(_data) > 12: + _data = f"{_data[:6]}...{_data[-6:]}" + return f"{self.__class__.__name__}(data='{_data}', format='{self.format}')" + + +class AudioContent(BaseModel): + type: Literal["input_audio"] = Field( + "input_audio", + description="The type of content.", + ) + input_audio: AudioContentData = Field( + description="The audio content.", + ) + + +# class AudioURLData(BaseModel): +# url: str = Field( +# description=( +# "Either a URL of the audio or the base64 encoded audio data " +# 'in the form of `"data:;base64,{base64_audio}"`.' +# ), +# ) + +# def __repr__(self): +# _url = self.url +# if len(_url) > 12: +# _url = f"{_url[:6]}...{_url[-6:]}" +# return f"{self.__class__.__name__}(url='{_url}')" + + +# class AudioURL(BaseModel): +# type: Literal["audio_url"] = Field( +# "audio_url", +# description="The type of content.", +# ) +# audio_url: AudioURLData = Field( +# description="The audio content.", +# ) + + +class S3Content(BaseModel): + type: Literal["input_s3"] = Field( + "input_s3", + description="The type of content.", + ) + uri: str = Field( + description="The S3 URI.", + ) + column_name: str = Field( + description="The column holding this data.", + ) + + +ChatContent = Annotated[ + Union[TextContent, ImageContent, AudioContent], + Field(discriminator="type"), +] +ChatContentS3 = Annotated[ + Union[TextContent, S3Content], + Field(discriminator="type"), +] + + +class ChatRole(StrEnum): + """Represents who said a chat message.""" + + SYSTEM = "system" + """The message is from the system (usually a steering prompt).""" + USER = "user" + """The message is from the user.""" + ASSISTANT = "assistant" + """The message is from the language model.""" + # FUNCTION = "function" + # """The message is the result of a function call.""" + + +def _sanitise_name(v: str) -> str: + """Replace any non-alphanumeric and dash characters with space. + + Args: + v (str): Raw name string. + + Returns: + out (str): Sanitised name string that is safe for OpenAI. + """ + return re.sub(r"[^a-zA-Z0-9_-]", "_", v).strip() + + +class ChatEntry(BaseModel): + """Represents a message in the chat context.""" + + model_config = ConfigDict(use_enum_values=True) + + role: ChatRole = Field( + description="Who said the message?", + ) + content: EmptyIfNoneStr | list[ChatContent] = Field( + description="The content of the message.", + ) + name: Annotated[str, AfterValidator(_sanitise_name)] | None = Field( + None, + description="The name of the user who sent the message, if set (user messages only).", + ) + + @property + def text_content(self) -> str: + if isinstance(self.content, str): + return self.content + text_contents = [c for c in self.content if isinstance(c, TextContent)] + if len(text_contents) > 0: + return "\n".join(c.text for c in text_contents) + return "" + + @property + def has_text_only(self) -> bool: + # Explicitly use `isinstance(self.content, str)` to help the type checker + return isinstance(self.content, str) or all( + isinstance(c, TextContent) for c in self.content + ) + + @property + def has_image(self) -> bool: + # Explicitly use `isinstance(self.content, str)` to help the type checker + return (not isinstance(self.content, str)) and any( + isinstance(c, ImageContent) for c in self.content + ) + + @property + def has_audio(self) -> bool: + # Explicitly use `isinstance(self.content, str)` to help the type checker + return (not isinstance(self.content, str)) and any( + isinstance(c, AudioContent) for c in self.content + ) + + @classmethod + def system(cls, content: str | list[ChatContent | ChatContentS3], **kwargs): + """Create a new system message.""" + return cls(role="system", content=content, **kwargs) + + @classmethod + def user(cls, content: str | list[ChatContent | ChatContentS3], **kwargs): + """Create a new user message.""" + return cls(role="user", content=content, **kwargs) + + @classmethod + def assistant(cls, content: str | None, **kwargs): + """Create a new assistant message.""" + return cls(role="assistant", content=content, **kwargs) + + +class ChatThreadEntry(ChatEntry): + """Represents a message in the chat thread response.""" + + content: EmptyIfNoneStr | list[ChatContentS3] = Field( + description="The content of the message.", + ) + user_prompt: str | None = Field( + None, + description=( + "Original prompt sent by the user without content interpolation/injection. " + 'Only applicable for Chat Table column that references the "User" column. ' + "Defaults to None (not applicable)." + ), + ) + references: References | None = Field( + None, + description=( + "References of this Retrieval Augmented Generation (RAG) response. " + "Defaults to None (not applicable)." + ), + ) + row_id: str | None = Field( + None, + description="Table row ID of this chat message. Defaults to None (not applicable).", + ) + + +class ChatThreadResponse(BaseModel): + object: Literal["chat.thread"] = Field( + "chat.thread", + description="Type of API response object.", + examples=["chat.thread"], + ) + thread: list[ChatThreadEntry] = Field( + [], + description="List of chat messages.", + examples=[ + [ + ChatThreadEntry.system("You are an assistant."), + ChatThreadEntry.user("Hello."), + ChatThreadEntry.assistant( + "Hello.", + references=References( + chunks=[Chunk(title="Title", text="Text")], + search_query="hello", + ), + ), + ] + ], + ) + column_id: str = Field( + "", + description="Table column ID of this chat thread.", + ) + + +class _ChatThreadsBase(BaseModel): + object: Literal["chat.threads"] = Field( + "chat.threads", + description="Type of API response object.", + examples=["chat.threads"], + ) + threads: dict[str, ChatThreadResponse] = Field( + [], + description="List of chat threads.", + examples=[ + dict( + AI=ChatThreadResponse( + thread=[ + ChatThreadEntry.system("You are an assistant."), + ChatThreadEntry.user("Hello."), + ChatThreadEntry.assistant( + "Hello.", + references=References( + chunks=[Chunk(title="Title", text="Text")], + search_query="hello", + ), + ), + ] + ), + ) + ], + ) + + +class ChatThreadsResponse(_ChatThreadsBase): + table_id: str = Field( + "", + description="Table ID of the chat threads.", + ) + + +class ConversationThreadsResponse(_ChatThreadsBase): + conversation_id: str = Field( + "", + description="Conversation ID of the chat threads.", + ) + + +class FunctionParameter(BaseModel): + type: str = Field( + "", + description="The type of the parameter, e.g., 'string', 'number'.", + ) + description: str = Field( + "", + description="A description of the parameter.", + ) + enum: list[str] = Field( + [], + description="An optional list of allowed values for the parameter.", + ) + + +class FunctionParameters(BaseModel): + type: str = Field( + "object", + description="The type of the parameters object, usually 'object'.", + ) + properties: dict[str, FunctionParameter] = Field( + description="The properties of the parameters object.", + ) + required: list[str] = Field( + description="A list of required parameter names.", + ) + additionalProperties: bool = Field( + False, + description="Whether additional properties are allowed.", + ) + + +class Function(BaseModel): + name: str = Field( + max_length=64, + description=( + "The name of the function to be called. " + "Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64." + ), + ) + description: str | None = Field( + None, + description="A description of what the function does, used by the model to choose when and how to call the function.", + ) + parameters: FunctionParameters | None = Field( + None, + description="The parameters the functions accepts, described as a JSON Schema object.", + ) + strict: bool = Field( + False, + description=( + "Whether to enable strict schema adherence when generating the function call. " + "If set to `true`, the model will follow the exact schema defined in the `parameters` field. " + "Only a subset of JSON Schema is supported when `strict` is `true`." + ), + ) + + +class FunctionTool(BaseModel): + type: Literal["function"] = Field( + "function", + description="The type of the tool. Currently, only `function` is supported.", + ) + function: Function + + +class WebSearchTool(BaseModel): + type: Literal["web_search"] = Field( + "web_search", + description="The type of tool.", + ) + + +class CodeInterpreterTool(BaseModel): + type: Literal["code_interpreter"] = Field( + "code_interpreter", + description="The type of tool.", + ) + container: dict[str, str] = Field( + {"type": "auto"}, + description="The code interpreter container.", + ) + + +Tool = Annotated[ + WebSearchTool | CodeInterpreterTool | FunctionTool, + Field( + discriminator="type", + description=( + "The type of tool. " + "Currently, one of `web_search`, `code_interpreter`, or `function`. " + "Note that `web_search` and `code_interpreter` are only supported with OpenAI models. " + "They will be ignored with other models." + ), + ), +] + + +class ToolChoiceFunction(BaseModel): + name: str = Field( + description="The name of the function to call.", + ) + + +class ToolChoice(BaseModel): + type: str = Field( + "function", + description="The type of the tool. Currently, only `function` is supported.", + ) + function: ToolChoiceFunction = Field( + description="The function that should be called.", + ) + + +def _empty_list_to_none(v: list[str]) -> list[str] | None: + if len(v) == 0: + v = None + return v + + +class ChatRequestBase(BaseModel): + """ + Base for chat request and LLM gen config. + """ + + model: str = Field( + "", + description='ID of the model to use. Defaults to "".', + ) + rag_params: RAGParams | None = Field( + None, + description="Retrieval Augmented Generation params. Defaults to None (disabled).", + examples=[RAGParams(table_id="papers"), None], + ) + tools: list[Tool] | None = Field( + None, + description=( + "A list of tools available for the chat model to use. " + "Note that `web_search` and `code_interpreter` are only supported with OpenAI models. " + "They will be ignored with other models." + ), + min_length=1, + examples=[ + [ + WebSearchTool(), + CodeInterpreterTool(), + FunctionTool( + type="function", + function=Function( + name="get_weather", + description="Get current temperature for a given location.", + parameters=FunctionParameters( + type="object", + properties={ + "location": FunctionParameter( + type="string", + description="City and country e.g. Bogotá, Colombia", + ) + }, + required=["location"], + additionalProperties=False, + ), + ), + ), + ], + ], + ) + tool_choice: Literal["none", "auto", "required"] | ToolChoice | None = Field( + None, + description=( + "Controls which (if any) tool is called by the model. " + '`"none"` means the model will not call any tool and instead generates a message. ' + '`"auto"` means the model can pick between generating a message or calling one or more tools. ' + '`"required"` means the model must call one or more tools. ' + 'Specifying a particular tool via `{"type": "function", "function": {"name": "my_function"}` forces the model to call that tool. ' + '`"none"` is the default when no tools are present. ' + '`"auto"` is the default if tools are present.' + ), + examples=[ + "auto", + ToolChoice(type="function", function=ToolChoiceFunction(name="get_delivery_date")), + ], + ) + temperature: float = Field( + 0.2, + ge=0, + description=( + "What sampling temperature to use. " + "Higher values like 0.8 will make the output more random, " + "while lower values like 0.2 will make it more focused and deterministic. " + "Note that this parameter will be ignored when using that do not support it, " + "such as OpenAI's reasoning models and Anthropic with extended thinking." + ), + examples=[0.2], + ) + top_p: float = Field( + 0.6, + ge=0.001, + description=( + "An alternative to sampling with temperature, called nucleus sampling, " + "where the model considers the results of the tokens with top_p probability mass. " + "So 0.1 means only the tokens comprising the top 10% probability mass are considered. " + "Note that this parameter will be ignored when using that do not support it, " + "such as OpenAI's reasoning models and Anthropic with extended thinking." + ), + examples=[0.6], + ) + stream: bool = Field( + True, + description=( + "If set, partial message deltas will be sent, like in ChatGPT. " + "Tokens will be sent as server-sent events (SSE) as they become available, " + "with the stream terminated by a `data: [DONE]` message." + ), + examples=[True, False], + ) + max_tokens: PositiveNonZeroInt = Field( + 2048, + description=( + "The maximum number of tokens to generate in the chat completion. " + "Must be in [1, context_length - 1). Default is 2048. " + "The total length of input tokens and generated tokens is limited by the model's context length." + ), + examples=[2048], + ) + stop: Annotated[list[str], AfterValidator(_empty_list_to_none)] | None = Field( + None, + min_length=1, + description=( + "A list of sequences where the API will stop generating further tokens. " + "Note that this parameter will be ignored when using that do not support it, " + "such as OpenAI's reasoning models." + ), + examples=[None], + ) + presence_penalty: float = Field( + 0.0, + description=( + "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, " + "increasing the model's likelihood to talk about new topics. " + "Note that this parameter will be ignored when using that do not support it, " + "such as OpenAI's reasoning models." + ), + examples=[0.0], + ) + frequency_penalty: float = Field( + 0.0, + description=( + "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, " + "decreasing the model's likelihood to repeat the same line verbatim. " + "Note that this parameter will be ignored when using that do not support it, " + "such as OpenAI's reasoning models." + ), + examples=[0.0], + ) + logit_bias: dict = Field( + {}, + description=( + "Modify the likelihood of specified tokens appearing in the completion. " + "Accepts a JSON object that maps tokens (specified by their token ID in the tokenizer) " + "to an associated bias value from -100 to 100. " + "Mathematically, the bias is added to the logits generated by the model prior to sampling. " + "The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; " + "values like -100 or 100 should result in a ban or exclusive selection of the relevant token. " + "Note that this parameter will be ignored when using that do not support it, " + "such as OpenAI's reasoning models." + ), + examples=[{}], + ) + reasoning_effort: Literal["disable", "minimal", "low", "medium", "high"] | None = Field( + "minimal", + description=( + "Constrains effort on reasoning for reasoning models. " + "Currently supported values are `disable`, `minimal`, `low`, `medium`, and `high`. " + "Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response. " + "For non-OpenAI models, `low` ~ 1024 tokens, `medium` ~ 2048 tokens, `high` ~ 4096 tokens. " + "Note that this parameter will be ignored when using models that do not support it, " + "such as non-reasoning models." + ), + examples=["low"], + ) + reasoning_effort: Literal["disable", "minimal", "low", "medium", "high"] | None = Field( + None, + description=( + "Constrains effort on reasoning for reasoning models. " + "Currently supported values are `disable`, `minimal`, `low`, `medium`, and `high`. " + "Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response. " + "For non-OpenAI models, `low` ~ 1024 tokens, `medium` ~ 4096 tokens, `high` ~ 8192 tokens. " + "Note that this parameter will be ignored when using models that do not support it, " + "such as non-reasoning models." + ), + examples=["low"], + ) + thinking_budget: int | None = Field( + None, + ge=0, + description=( + "Model reasoning budget in tokens. " + "Set to zero to disable reasoning if supported. " + "For OpenAI models, 1 <= budget <= 1024 is low, 1025 <= budget <= 4096 is medium, 4097 <= budget <= 8192 is high. " + "Note that this parameter will be ignored when using models that do not support it, " + "such as non-reasoning models." + ), + examples=[1024], + ) + reasoning_summary: Literal["auto", "concise", "detailed"] = Field( + "auto", + description=( + "To access the most detailed summarizer available for a model, set the value of this parameter to auto. " + "auto will be equivalent to detailed for most reasoning models today, " + "but there may be more granular settings in the future. " + "Will be ignored if the model does not support it." + ), # https://platform.openai.com/docs/guides/reasoning/advice-on-prompting#reasoning-summaries + ) + + @property + def hyperparams(self) -> dict[str, Any]: + # object key could cause issue to some LLM provider, ex: Anthropic + return self.model_dump(exclude_none=True, exclude={"object", "messages", "rag_params"}) + + +class ChatRequest(ChatRequestBase): + id: str = Field( + "", + description='Chat ID for logging. Defaults to "".', + ) + messages: list[ChatEntry] = Field( + min_length=1, + description="A list of messages comprising the conversation so far.", + ) + max_completion_tokens: PositiveNonZeroInt | None = Field( + None, + description=( + "An upper bound for the number of tokens that can be generated for a completion, " + "including visible output tokens and reasoning tokens. " + "Must be in [1, context_length - 1). Default is 2048. " + "If both `max_completion_tokens` and `max_tokens` are set, `max_completion_tokens` will be used. " + ), + examples=[2048], + ) + n: int = Field( + 1, + description=( + "How many chat completion choices to generate for each input message. " + "Note that this parameter will be ignored when using models and tools that do not support it." + ), + examples=[1], + ) + user: str = Field( + "", + description="A unique identifier representing your end-user. For monitoring and debugging purposes.", + examples=[""], + ) + stream: bool = Field( + False, + description=( + "If set, partial message deltas will be sent, like in ChatGPT. " + "Tokens will be sent as server-sent events (SSE) as they become available, " + "with the stream terminated by a 'data: [DONE]' message." + ), + examples=[True, False], + ) + + @model_validator(mode="after") + def validate_params(self): + self.max_tokens = self.max_completion_tokens or self.max_tokens + if self.thinking_budget and self.thinking_budget > self.max_tokens: + raise ValueError("`thinking_budget` cannot be higher than `max_tokens`.") + return self + + +class EmbeddingRequest(BaseModel): + input: str | list[str] = Field( + description=( + "Input text to embed, encoded as a string or array of strings " + "(to embed multiple inputs in a single request). " + "The input must not exceed the max input tokens for the model, and cannot contain empty string." + ), + examples=["What is a llama?", ["What is a llama?", "What is an alpaca?"]], + ) + model: str = Field( + description=( + "The ID of the model to use. " + "You can use the List models API to see all of your available models." + ), + examples=EXAMPLE_EMBEDDING_MODEL_IDS, + ) + type: Literal["query", "document"] = Field( + "document", + description=( + 'Whether the input text is a "query" (used to retrieve) or a "document" (to be retrieved).' + ), + examples=["query", "document"], + ) + encoding_format: Literal["float", "base64"] = Field( + "float", + description=( + '_Optional_. The format to return the embeddings in. Can be either "float" or "base64". ' + "`base64` string should be decoded as a `float32` array. " + "Example: `np.frombuffer(base64.b64decode(response), dtype=np.float32)`" + ), + examples=["float", "base64"], + ) + dimensions: PositiveNonZeroInt | None = Field( + None, + description=( + "The number of dimensions the resulting output embeddings should have. " + "Note that this parameter will only be used when using models that support Matryoshka embeddings." + ), + ) + + +class EmbeddingResponseData(BaseModel): + object: Literal["embedding"] = Field( + "embedding", + description="The object type, which is always `embedding`.", + examples=["embedding"], + ) + embedding: list[float] | str = Field( + description=( + "The embedding vector, which is a list of floats or a base64-encoded string. " + "The length of vector depends on the model." + ), + examples=[[0.0, 1.0, 2.0], []], + ) + index: int = Field( + 0, + description="The index of the embedding in the list of embeddings.", + examples=[0, 1], + ) + + +class EmbeddingUsage(BaseModel): + prompt_tokens: ZeroIfNoneInt = Field( + 0, + description="Number of tokens in the prompt.", + ) + total_tokens: ZeroIfNoneInt = Field( + 0, + description="Total number of tokens used in the request.", + ) + + +class EmbeddingResponse(BaseModel): + object: Literal["list"] = Field( + "list", + description="Type of API response object.", + examples=["list"], + ) + data: list[EmbeddingResponseData] = Field( + description="List of `EmbeddingResponseData`.", + examples=[[EmbeddingResponseData(embedding=[0.0, 1.0, 2.0])]], + ) + model: str = Field( + description="The ID of the model used.", + examples=["openai/text-embedding-3-small-512"], + ) + usage: EmbeddingUsage = Field( + EmbeddingUsage(), + description="The number of tokens consumed.", + examples=[EmbeddingUsage()], + ) + + +class RerankingRequest(BaseModel): + model: str = Field( + description=( + "The ID of the model to use. " + "You can use the List models API to see all of your available models." + ), + examples=EXAMPLE_RERANKING_MODEL_IDS, + ) + documents: list[str] + query: str + + +class RerankingData(BaseModel): + object: Literal["reranking"] = Field( + "reranking", + description="Type of API response object.", + examples=["reranking"], + ) + index: int + relevance_score: float + + +class RerankingApiVersion(BaseModel): + version: str = Field( + "", + description="API version.", + examples=["2"], + ) + is_deprecated: bool = Field( + False, + description="Whether it is deprecated.", + examples=[False], + ) + is_experimental: bool = Field( + False, + description="Whether it is experimental.", + examples=[False], + ) + + +class RerankingBilledUnits(BaseModel): + images: int | None = Field(None, description="The number of billed images.") + input_tokens: int | None = Field(None, description="The number of billed input tokens.") + output_tokens: int | None = Field(None, description="The number of billed output tokens.") + search_units: float | None = Field(None, description="The number of billed search units.") + classifications: float | None = Field( + None, description="The number of billed classifications units." + ) + + +class RerankingMetaUsage(BaseModel): + input_tokens: int | None = Field( + None, + description="The number of tokens used as input to the model.", + ) + output_tokens: int | None = Field( + None, + description="The number of tokens produced by the model.", + ) + + +class RerankingUsage(RerankingMetaUsage): + documents: ZeroIfNoneInt = Field( + description="The number of documents processed.", + ) + + +class RerankingMeta(BaseModel): + model: str = Field( + description="The ID of the model used.", + examples=["cohere/rerank-multilingual-v3.0"], + ) + api_version: RerankingApiVersion | None = Field( + None, + description="API version.", + examples=[RerankingApiVersion(), None], + ) + billed_units: RerankingBilledUnits | None = Field( + None, + description="Billed units.", + examples=[RerankingBilledUnits(), None], + ) + tokens: RerankingMetaUsage | None = Field( + None, + description="Token usage.", + examples=[RerankingMetaUsage(input_tokens=500), None], + ) + warnings: list[str] | None = Field( + None, + description="Warnings.", + examples=[["This is a warning."], None], + ) + + +class RerankingResponse(BaseModel): + object: Literal["list"] = Field( + "list", + description="Type of API response object.", + examples=["list"], + ) + results: list[RerankingData] = Field( + description="List of `RerankingResponseData`.", + examples=[[RerankingData(index=0, relevance_score=0.0032)]], + ) + usage: RerankingUsage = Field( + description="Usage.", + examples=[RerankingUsage(documents=10), None], + ) + meta: RerankingMeta = Field( + description="Reranking metadata from Cohere.", + ) diff --git a/clients/python/src/jamaibase/types/logs.py b/clients/python/src/jamaibase/types/logs.py new file mode 100644 index 0000000..748dd52 --- /dev/null +++ b/clients/python/src/jamaibase/types/logs.py @@ -0,0 +1,7 @@ +from typing import Any + +from pydantic import BaseModel + + +class LogQueryResponse(BaseModel): + logs: list[dict[str, Any]] diff --git a/clients/python/src/jamaibase/types/mcp.py b/clients/python/src/jamaibase/types/mcp.py new file mode 100644 index 0000000..827498d --- /dev/null +++ b/clients/python/src/jamaibase/types/mcp.py @@ -0,0 +1,461 @@ +from enum import IntEnum +from typing import Any, Generic, Literal, TypeVar + +from pydantic import AnyUrl, BaseModel, Field + + +# Standard JSON-RPC error codes +class JSONRPCErrorCode(IntEnum): + """Standard JSON-RPC error codes as defined by the JSON-RPC 2.0 specification.""" + + # Standard error codes + PARSE_ERROR = -32700 + INVALID_REQUEST = -32600 + METHOD_NOT_FOUND = -32601 + INVALID_PARAMS = -32602 + INTERNAL_ERROR = -32603 + + # Custom error codes + UNAUTHORIZED = -32001 + FORBIDDEN = -32003 + + +ProgressToken = str | int +Cursor = str +Role = Literal["user", "assistant"] +# AnyFunction: TypeAlias = Callable[..., Any] + +MetaT = TypeVar("ParamsT", bound=dict[str, Any]) + + +class RequestParamsMeta(BaseModel, extra="allow"): + """Metadata for request parameters.""" + + progressToken: ProgressToken | None = Field( + None, + description=( + "If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). " + "The value of this parameter is an opaque token that will be attached to any subsequent notifications. " + "The receiver is not obligated to provide these notifications." + ), + ) + + +class Params(BaseModel, Generic[MetaT], extra="allow"): + meta: MetaT | None = Field( + None, + alias="_meta", + description="This parameter name is reserved by MCP to allow clients and servers to attach additional metadata.", + ) + + +ParamsT = TypeVar("ParamsT", bound=Params) + + +class PaginatedRequestParams(Params): + cursor: str | None = Field( + None, + description=( + "An opaque token representing the current pagination position. " + "If provided, the server should return results starting after this cursor." + ), + ) + + +class JSONRPCBase(BaseModel, extra="allow"): + jsonrpc: Literal["2.0"] = "2.0" + + +class JSONRPCRequest(JSONRPCBase, Generic[ParamsT]): + """Base request interface.""" + + id: str | int = Field(description="Request ID.") + method: str + params: ParamsT | None = Field( + None, + description="Parameters for the request.", + ) + + +class PaginatedRequest(JSONRPCRequest[PaginatedRequestParams]): + pass + + +class JSONRPCNotification(JSONRPCBase, Generic[ParamsT]): + method: str + params: ParamsT | None = Field( + None, + description="This parameter name is reserved by MCP to allow clients and servers to attach additional metadata to their notifications.", + ) + + +class InitializedNotification(JSONRPCNotification[Params]): + method: Literal["notifications/initialized"] = "notifications/initialized" + + +class Result(BaseModel, extra="allow"): + """ + Base result class that allows for additional metadata and arbitrary fields. + """ + + meta: dict[str, Any] | None = Field( + None, + alias="_meta", + description="This result property is reserved by the protocol to allow clients and servers to attach additional metadata to their responses.", + ) + + +ResultT = TypeVar("ResultT", bound=Result) + + +class JSONRPCResponse(JSONRPCBase, Generic[ResultT]): + id: str | int = Field(description="Request ID that this response corresponds to.") + result: ResultT | None = None + + +class JSONRPCEmptyResponse(JSONRPCBase): + id: str | int = Field(description="Request ID that this response corresponds to.") + result: dict[str, Any] = {} + + +class ErrorData(BaseModel, extra="allow"): + """Error information for JSON-RPC error responses.""" + + code: int = Field( + description="The error code, which is a negative integer as defined by the JSON-RPC specification.", + ) + message: str = Field( + description="A short description of the error. This message should be concise and limited to a single sentence.", + ) + data: Any | None = Field( + None, + description=( + "Additional information about the error. " + "The value of this member is defined by the sender (e.g. detailed error information, nested errors etc.)." + ), + ) + + +class JSONRPCError(JSONRPCBase): + id: str | int = Field(description="Request ID that this response corresponds to.") + error: ErrorData + + +class Capability(BaseModel): + """Capabilities related to prompt templates.""" + + listChanged: bool = Field( + False, + description="Whether this server supports notifications for changes to the prompt list.", + ) + + +class ResourcesCapability(Capability): + """Capabilities related to resources.""" + + subscribe: bool | None = Field( + None, + description="Whether this server supports subscribing to resource updates.", + ) + + +class ServerCapabilities(BaseModel, extra="allow"): + """ + Capabilities that a server may support. Known capabilities are defined here, + in this schema, but this is not a closed set: any server can define its own, + additional capabilities. + """ + + experimental: dict[str, dict[str, Any]] | None = Field( + None, + description="Experimental, non-standard capabilities that the server supports.", + ) + logging: dict[str, Any] | None = Field( + None, + description="Present if the server supports sending log messages to the client.", + ) + completions: dict[str, Any] | None = Field( + None, + description="Present if the server supports argument autocompletion suggestions.", + ) + prompts: Capability | None = Field( + None, + description="Present if the server offers any prompt templates.", + ) + resources: ResourcesCapability | None = Field( + None, + description="Present if the server offers any resources to read.", + ) + tools: Capability | None = Field( + Capability(listChanged=False), + description="Present if the server offers any tools to call.", + ) + + +class Implementation(BaseModel): + name: str + version: str + + +class InitializeRequestParams(BaseModel): + protocolVersion: str + capabilities: dict[str, Any] + clientInfo: Implementation + + +class InitializeRequest(JSONRPCRequest[InitializeRequestParams]): + method: Literal["initialize"] = "initialize" + + +class InitializeResult(Result): + protocolVersion: Literal["2025-03-26"] = "2025-03-26" + capabilities: ServerCapabilities + serverInfo: Implementation + instructions: str | None = Field( + None, + description=( + "Instructions describing how to use the server and its features." + "This can be used by clients to improve the LLM's understanding of available tools, resources, etc. " + 'It can be thought of like a "hint" to the model. ' + "For example, this information MAY be added to the system prompt." + ), + ) + + +class ListToolsRequest(PaginatedRequest): + method: Literal["tools/list"] = "tools/list" + + +class ToolAnnotations(BaseModel): + """ + Additional properties describing a Tool to clients. + + NOTE: all properties in ToolAnnotations are *hints*. + They are not guaranteed to provide a faithful description of + tool behavior (including descriptive properties like `title`). + + Clients should never make tool use decisions based on ToolAnnotations + received from untrusted servers. + """ + + title: str | None = Field( + None, + description="A human-readable title for the tool.", + ) + readOnlyHint: bool | None = Field( + False, + description="If true, the tool does not modify its environment. Default: False", + ) + destructiveHint: bool | None = Field( + True, + description=( + "If true, the tool may perform destructive updates to its environment. " + "If false, the tool performs only additive updates. " + "(This property is meaningful only when `readOnlyHint == false`) Default: True" + ), + ) + idempotentHint: bool | None = Field( + False, + description=( + "If true, calling the tool repeatedly with the same arguments " + "will have no additional effect on the its environment. " + "(This property is meaningful only when `readOnlyHint == false`) Default: False" + ), + ) + openWorldHint: bool | None = Field( + True, + description=( + "If true, this tool may interact with an 'open world' of external " + "entities. If false, the tool's domain of interaction is closed. " + "For example, the world of a web search tool is open, whereas that " + "of a memory tool is not. Default: True" + ), + ) + + +class ToolInputSchema(BaseModel): + """JSON Schema object defining the expected parameters for the tool.""" + + type: Literal["object"] = "object" + properties: dict[str, dict[str, Any]] | None = Field( + None, + description="Schema properties defining the tool parameters.", + ) + required: list[str] | None = Field( + None, + description="List of required parameter names.", + ) + + +class ToolAPIInfo(BaseModel): + path: str + method: str + args_types: dict[str, Literal["header", "query", "path", "body"]] + method_info: dict[str, Any] + + +class Tool(BaseModel): + """Definition for a tool the client can call.""" + + name: str = Field( + description="The name of the tool.", + ) + description: str | None = Field( + None, + description=( + "A human-readable description of the tool. " + "This can be used by clients to improve the LLM's understanding of available tools. " + "It can be thought of like a 'hint' to the model." + ), + ) + input_schema: ToolInputSchema = Field( + alias="inputSchema", + description="A JSON Schema object defining the expected parameters for the tool.", + ) + annotations: ToolAnnotations | None = Field( + None, + description="Optional additional tool information.", + ) + + +class ToolAPI(Tool): + """Definition for a tool the client can call.""" + + api_info: ToolAPIInfo | None = Field( + None, + description="API information.", + ) + + +class ListToolsResult(Result): + tools: list[Tool] + + +class CallToolRequestParams(Params): + """Parameters specific to tool call requests.""" + + name: str = Field( + description="The name of the tool to call.", + ) + arguments: dict[str, Any] | None = Field( + None, + description="Arguments to pass to the tool.", + ) + + +class CallToolRequest(JSONRPCRequest[CallToolRequestParams]): + """Used by the client to invoke a tool provided by the server.""" + + method: Literal["tools/call"] = "tools/call" + + +class Annotations(BaseModel): + """ + Optional annotations for the client. The client can use annotations to inform how objects are used or displayed + """ + + audience: list[Role] | None = Field( + None, + description="Describes who the intended customer of this object or data is. " + "It can include multiple entries to indicate content useful for multiple audiences (e.g., ['user', 'assistant']).", + ) + priority: float | None = Field( + None, + ge=0.0, + le=1.0, + description="Describes how important this data is for operating the server. " + "A value of 1 means 'most important,' and indicates that the data is " + "effectively required, while 0 means 'least important,' and indicates that " + "the data is entirely optional.", + ) + + +class Content(BaseModel): + annotations: Annotations | None = Field( + None, + description="Optional annotations for the client.", + ) + + +class TextContent(Content): + """Text provided to or from an LLM.""" + + type: Literal["text"] = "text" + text: str = Field( + description="The text content of the message.", + ) + + +class ImageContent(Content): + """An image provided to or from an LLM.""" + + type: Literal["image"] = "image" + data: str = Field( + description="The base64-encoded image data.", + ) + mimeType: str = Field( + description="The MIME type of the image. Different providers may support different image types.", + ) + + +class AudioContent(Content): + """Audio provided to or from an LLM.""" + + type: Literal["audio"] = "audio" + data: str = Field( + description="The base64-encoded audio data.", + ) + mimeType: str = Field( + description="The MIME type of the audio. Different providers may support different audio types.", + ) + + +class ResourceContents(BaseModel): + """The contents of a specific resource or sub-resource.""" + + uri: AnyUrl = Field(description="The URI of this resource.") + mimeType: str | None = Field( + None, + description="The MIME type of this resource, if known.", + ) + + +class TextResourceContents(ResourceContents): + """Resource contents with text data.""" + + text: str = Field( + description="The text of the item. This must only be set if the item can actually be represented as text (not binary data).", + ) + + +class BlobResourceContents(ResourceContents): + """Resource contents with binary data.""" + + blob: str = Field( + description="A base64-encoded string representing the binary data of the item." + ) + + +class EmbeddedResource(Content): + """ + The contents of a resource, embedded into a prompt or tool call result. + + It is up to the client how best to render embedded resources for the benefit + of the LLM and/or the user. + """ + + type: Literal["resource"] = "resource" + resource: TextResourceContents | BlobResourceContents = Field( + description="The resource contents, either text or binary data." + ) + + +class CallToolResult(Result): + content: list[TextContent | ImageContent | AudioContent | EmbeddedResource] + isError: bool | None = Field( + False, + description=( + "Whether the tool call ended in an error. " + "If not set, this is assumed to be false (the call was successful)." + ), + ) diff --git a/clients/python/src/jamaibase/types/model.py b/clients/python/src/jamaibase/types/model.py new file mode 100644 index 0000000..cc48f2e --- /dev/null +++ b/clients/python/src/jamaibase/types/model.py @@ -0,0 +1,126 @@ +from functools import cached_property +from typing import Literal, Self, Union + +from natsort import natsorted +from pydantic import ( + BaseModel, + Field, + model_validator, +) + +from jamaibase.types.common import ( + EXAMPLE_CHAT_MODEL_IDS, + EXAMPLE_EMBEDDING_MODEL_IDS, + EXAMPLE_RERANKING_MODEL_IDS, +) +from jamaibase.types.db import ModelInfoRead + + +class ModelInfoListResponse(BaseModel): + object: Literal["models.info"] = Field( + "models.info", + description="Type of API response object.", + examples=["models.info"], + ) + data: list[ModelInfoRead] = Field( + description="List of model information.", + ) + + @model_validator(mode="after") + def sort_models(self) -> Self: + self.data = list(natsorted(self.data, key=self._sort_key)) + return self + + @staticmethod + def _sort_key(x: ModelInfoRead) -> str: + return (int(not x.id.startswith("ellm")), x.name) + + +class _ModelPrice(BaseModel): + id: str = Field( + description=( + 'Unique identifier in the form of "{provider}/{model_id}". ' + "Users will specify this to select a model." + ), + examples=[ + EXAMPLE_CHAT_MODEL_IDS[0], + EXAMPLE_EMBEDDING_MODEL_IDS[0], + EXAMPLE_RERANKING_MODEL_IDS[0], + ], + ) + name: str = Field( + description="Name of the model.", + examples=["OpenAI GPT-4o Mini"], + ) + + +class LLMModelPrice(_ModelPrice): + llm_input_cost_per_mtoken: float = Field( + description="Cost in USD per million input / prompt token.", + ) + llm_output_cost_per_mtoken: float = Field( + description="Cost in USD per million output / completion token.", + ) + + +class EmbeddingModelPrice(_ModelPrice): + embedding_cost_per_mtoken: float = Field( + description="Cost in USD per million embedding tokens.", + ) + + +class RerankingModelPrice(_ModelPrice): + reranking_cost_per_ksearch: float = Field( + description="Cost in USD for a thousand (kilo) searches." + ) + + +class ModelPrice(BaseModel): + object: Literal["prices.models"] = Field( + "prices.models", + description="Type of API response object.", + examples=["prices.models"], + ) + llm_models: list[LLMModelPrice] = [] + embed_models: list[EmbeddingModelPrice] = [] + rerank_models: list[RerankingModelPrice] = [] + + @cached_property + def model_map( + self, + ) -> dict[str, Union[LLMModelPrice, EmbeddingModelPrice, RerankingModelPrice]]: + """ + Build and cache a dictionary of models for faster lookups. + + Returns: + Dict[str, Union[LLMModelPrice, EmbeddingModelPrice, RerankingModelPrice]]: A dictionary mapping model IDs to their price information. + """ + cache = {} + for model in self.llm_models: + cache[model.id] = model + for model in self.embed_models: + cache[model.id] = model + for model in self.rerank_models: + cache[model.id] = model + return cache + + def get(self, model_id: str) -> Union[LLMModelPrice, EmbeddingModelPrice, RerankingModelPrice]: + """ + Retrieve the price information for a specific model by its ID. + + Args: + model_id (str): The ID of the model to retrieve. + + Returns: + Union[LLMModelPrice, EmbeddingModelPrice, RerankingModelPrice]: + The pricing information for the requested model. + + Raises: + ValueError: If the model ID is not found in the `model_map`. + """ + try: + return self.model_map[model_id] + except KeyError as e: + raise ValueError( + f"Invalid model ID: {model_id}. Available models: {list(self.model_map.keys())}" + ) from e diff --git a/clients/python/src/jamaibase/types/telemetry.py b/clients/python/src/jamaibase/types/telemetry.py new file mode 100644 index 0000000..22e068e --- /dev/null +++ b/clients/python/src/jamaibase/types/telemetry.py @@ -0,0 +1,162 @@ +from datetime import datetime, timedelta +from typing import Any, ClassVar + +from pydantic import BaseModel + + +class Metric(BaseModel): + name: str + device_type: str + value: float + timestamp: int + hostname: str + device_model: str + device_id: str + + SYSTEM_METRIC_NAMES: ClassVar[dict[str, str]] = { + "cpu_util": "cpu_util", + "memory_util": "memory_util", + "disk_read_bytes": "disk_read_bytes", + "disk_write_bytes": "disk_write_bytes", + "network_receive_bytes": "network_receive_bytes", + "network_transmit_bytes": "network_transmit_bytes", + } + + AMD_METRIC_NAMES: ClassVar[dict[str, str]] = { + "gpu_clock": "device_clock", + "gpu_memory_clock": "device_memory_clock", + "gpu_edge_temperature": "device_temp", + "gpu_memory_temperature": "device_memory_temp", + "gpu_power_usage": "device_power_usage", + "gpu_gfx_activity": "device_util", + "gpu_umc_activity": "device_memory_utils", + "gpu_free_vram": "device_free_memory", + "gpu_used_vram": "device_used_memory", + } + + NVIDIA_METRIC_NAMES: ClassVar[dict[str, str]] = { + "DCGM_FI_DEV_SM_CLOCK": "device_clock", + "DCGM_FI_DEV_MEM_CLOCK": "device_memory_clock", + "DCGM_FI_DEV_GPU_TEMP": "device_temp", + "DCGM_FI_DEV_MEMORY_TEMP": "device_memory_temp", + "DCGM_FI_DEV_POWER_USAGE": "device_power_usage", + "DCGM_FI_DEV_GPU_UTIL": "device_util", + "DCGM_FI_DEV_MEM_COPY_UTIL": "device_memory_utils", + "DCGM_FI_DEV_FB_FREE": "device_free_memory", + "DCGM_FI_DEV_FB_USED": "device_used_memory", + } + + SYSTEM_LABELS: ClassVar[dict[str, str]] = { + "metric_name": "__name__", + "hostname": "instance", + "device_model": "N/A", + "device_id": "N/A", + } + + AMD_LABELS: ClassVar[dict[str, str]] = { + "metric_name": "__name__", + "hostname": "hostname", + "device_model": "card_series", + "device_id": "gpu_id", + } + + NVIDIA_LABELS: ClassVar[dict[str, str]] = { + "metric_name": "__name__", + "hostname": "Hostname", + "device_model": "modelName", + "device_id": "gpu", + } + + @classmethod + def from_response(cls, response: dict[str, Any]) -> "Metric": + """Create a Metric instance from a response dictionary. + + This method extracts relevant information from the response dictionary obtained from response + and uses it to create and return a Metric object. It determines the Device type and selects the appropriate + labels and metric names for processing. + + Args: + response (dict[str, Any]): A dictionary containing the metric data from response. + + Returns: + Metric: A Metric object populated with the data from the response. + + Raises: + ValueError: If the Device type is not recognized as either 'system', 'amd' or 'nvidia'. + """ + device_type = response["metric"]["job"].split("-")[0] + if device_type.lower() not in ["amd", "nvidia", "system"]: + raise ValueError( + f"Expected device_type to be within [nvidia, amd, system] but instead got {device_type}" + ) + + if device_type == "nvidia": + device_labels = cls.NVIDIA_LABELS + metric_names = cls.NVIDIA_METRIC_NAMES + elif device_type == "amd": + device_labels = cls.AMD_LABELS + metric_names = cls.AMD_METRIC_NAMES + elif device_type == "system": + device_labels = cls.SYSTEM_LABELS + metric_names = cls.SYSTEM_METRIC_NAMES + + return cls( + name=metric_names[response["metric"][device_labels["metric_name"]]], + device_type=device_type, + value=float(response["value"][1]), + timestamp=response["value"][0], + hostname=response["metric"][device_labels["hostname"]], + device_model=response["metric"].get(device_labels["device_model"], "N/A"), + device_id=response["metric"].get(device_labels["device_id"], "N/A"), + ) + + +class Host(BaseModel): + name: str + metrics: list[Metric] + + +class Usage(BaseModel): + value: float + window_start: str + window_end: str + subject: str + groupBy: dict[str, str] + + @classmethod + def from_result( + cls, + value: list[Any], + metrics: dict[str, Any], + data_interval: timedelta, + group_by: list[str], + ) -> "Usage": + """Create a Usage instance from a result entry. + + Args: + value (list[Any]): A list containing the timestamp and value. + metrics (dict[str, Any]): A dictionary containing metric labels. + data_interval (timedelta): The data interval to adjust the window range. + group_by (list[str]): The group-by fields for the query. + + Returns: + Usage: A Usage object populated with the data from the result entry. + """ + return cls( + value=float(value[1]), + window_start=(datetime.fromtimestamp(value[0]) - data_interval).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ), + window_end=datetime.fromtimestamp(value[0]).strftime("%Y-%m-%dT%H:%M:%SZ"), + subject=metrics["org_id"], + groupBy={ + key: metrics[key] for key in group_by if key != "org_id" and key in metrics.keys() + }, + ) + + +class UsageResponse(BaseModel): + windowSize: str + data: list[Usage] + start: str + end: str diff --git a/clients/python/src/jamaibase/utils/__init__.py b/clients/python/src/jamaibase/utils/__init__.py index 0b7677e..cce9e31 100644 --- a/clients/python/src/jamaibase/utils/__init__.py +++ b/clients/python/src/jamaibase/utils/__init__.py @@ -1,7 +1,11 @@ +import time from asyncio.coroutines import iscoroutine -from datetime import datetime, timezone from typing import Any, AsyncGenerator, Awaitable, Callable, Generator, TypeVar +import numpy as np +from uuid_extensions import uuid7str as _uuid7_draft2_str +from uuid_utils import uuid7 as _uuid7 + R = TypeVar("R") @@ -18,5 +22,54 @@ async def run(fn: Callable[..., R | Awaitable[R]], *args: Any, **kwargs: Any) -> return ret -def datetime_now_iso() -> str: - return datetime.now(timezone.utc).isoformat() +def get_non_empty(mapping: dict[str, Any], key: str, default: Any): + value = mapping.get(key, None) + return value if value else default + + +def uuid7_draft2_str(prefix: str = "") -> str: + return f"{prefix}{_uuid7_draft2_str()}" + + +def uuid7_str(prefix: str = "") -> str: + return f"{prefix}{_uuid7()}" + + +def get_ttl_hash(seconds: int = 3600) -> int: + """Return the same value within `seconds` time period""" + return round(time.time() / max(1, seconds)) + + +def mask_string(x: str | None, *, include_len: bool = True) -> str | None: + if x is None or x == "": + return x + str_len = len(x) + if str_len < 4: + return f"{'*' * str_len} ({str_len=})" + visible_len = min(100, str_len // 5) + x = f"{x[:visible_len]}***{x[-visible_len:]}" + return f"{x} ({str_len=})" if include_len else x + + +def mask_content(x: str | list | dict | np.ndarray | Any) -> str | list | dict | None: + if isinstance(x, str): + return mask_string(x) + if isinstance(x, list): + return [mask_content(v) for v in x] + if isinstance(x, dict): + return {k: mask_content(v) for k, v in x.items()} + if isinstance(x, np.ndarray): + return f"array(shape={x.shape}, dtype={x.dtype})" + return None + + +def merge_dict(d: dict | Any, update: dict | Any): + if isinstance(d, dict) and isinstance(update, dict): + for k, v in update.items(): + d[k] = merge_dict(d.get(k, {}), v) + return d + return update + + +def mask_dict(value: dict[str, str | Any]) -> dict[str, str]: + return {k: "***" if v else v for k, v in value.items()} diff --git a/clients/python/src/jamaibase/utils/background_loop.py b/clients/python/src/jamaibase/utils/background_loop.py new file mode 100644 index 0000000..eca4582 --- /dev/null +++ b/clients/python/src/jamaibase/utils/background_loop.py @@ -0,0 +1,33 @@ +# Modified from LanceDB +# https://github.com/lancedb/lancedb/blob/main/python/python/lancedb/background_loop.py + +import asyncio +import threading + + +class BackgroundEventLoop: + """ + A background event loop that can run futures. + + Used to bridge sync and async code, without messing with users event loops. + """ + + def __init__(self): + self.loop = asyncio.new_event_loop() + self.thread = threading.Thread( + target=self.loop.run_forever, + name="JamAIBackgroundEventLoop", + daemon=True, + ) + self.thread.start() + + def run(self, future): + return asyncio.run_coroutine_threadsafe(future, self.loop).result() + + def cleanup(self): + self.loop.call_soon_threadsafe(self.loop.stop) + self.thread.join() + self.loop.close() + + +LOOP = BackgroundEventLoop() diff --git a/clients/python/src/jamaibase/utils/dates.py b/clients/python/src/jamaibase/utils/dates.py new file mode 100644 index 0000000..822f3ca --- /dev/null +++ b/clients/python/src/jamaibase/utils/dates.py @@ -0,0 +1,82 @@ +from datetime import date, datetime, time, timezone +from uuid import UUID +from zoneinfo import ZoneInfo, ZoneInfoNotFoundError + + +def now(tz: str = "UTC") -> datetime: + return datetime.now(ZoneInfo(tz)) + + +def now_iso(tz: str = "UTC") -> str: + return now(tz).isoformat() + + +def now_tz_naive(tz: str = "UTC") -> datetime: + return datetime.now(ZoneInfo(tz)).replace(tzinfo=None) + + +def earliest(tz: str = "UTC") -> datetime: + return datetime.min.replace(tzinfo=ZoneInfo(tz)) + + +def utc_iso_from_string(dt: str) -> str: + parsed_dt: datetime = datetime.fromisoformat(dt) # Explicitly declare type + if parsed_dt.tzinfo is None: + raise ValueError("Input datetime string is not timezone aware.") + return parsed_dt.astimezone(timezone.utc).isoformat() + + +def utc_iso_from_datetime(dt: datetime) -> str: + if dt.tzinfo is None: + raise ValueError("Input datetime object is not timezone aware.") + return dt.astimezone(timezone.utc).isoformat() + + +def utc_datetime_from_iso(dt: str) -> datetime: + parsed_dt: datetime = datetime.fromisoformat(dt) # Explicitly declare type + if parsed_dt.tzinfo is None: + raise ValueError("Input datetime string is not timezone aware.") + return parsed_dt.astimezone(timezone.utc) + + +def utc_iso_from_uuid7(uuid7_str: str) -> str: + # from uuid_utils import uuid7 + # Extract the timestamp (first 48 bits) + timestamp = UUID(uuid7_str).int >> 80 + dt = datetime.fromtimestamp(timestamp / 1000.0, tz=timezone.utc) + return dt.isoformat() + + +def utc_iso_from_uuid7_draft2(uuid7_str: str) -> str: + # from uuid_extensions import uuid7str + # https://www.ietf.org/archive/id/draft-peabody-dispatch-new-uuid-format-02.html#name-uuidv7-layout-and-bit-order + # Parse the UUID string + uuid_obj = UUID(uuid7_str) + # Extract the unix timestamp (first 36 bits) + unix_ts = uuid_obj.int >> 92 + # Extract the fractional seconds (next 24 bits) + frac_secs = (uuid_obj.int >> 68) & 0xFFFFFF + # Combine unix timestamp and fractional seconds + total_secs = unix_ts + (frac_secs / 0x1000000) + # Create a datetime object + dt = datetime.fromtimestamp(total_secs, tz=timezone.utc) + return dt.isoformat() + + +def date_to_utc(d: date, tz: str = "UTC") -> datetime: + try: + return datetime.combine(d, time.min, ZoneInfo(tz)).astimezone(timezone.utc) + except ZoneInfoNotFoundError as e: + raise ValueError(f"Invalid timezone: {tz}") from e + + +def date_to_utc_iso(d: date, tz: str = "UTC") -> str: + return date_to_utc(d, tz).isoformat() + + +def ensure_utc_timezone(value: str) -> str: + dt = datetime.fromisoformat(value) + tz = str(dt.tzinfo) + if tz != "UTC": + raise ValueError(f'Time zone must be UTC, but received "{tz}".') + return value diff --git a/clients/python/src/jamaibase/exceptions.py b/clients/python/src/jamaibase/utils/exceptions.py similarity index 51% rename from clients/python/src/jamaibase/exceptions.py rename to clients/python/src/jamaibase/utils/exceptions.py index 755ebc3..dfef23b 100644 --- a/clients/python/src/jamaibase/exceptions.py +++ b/clients/python/src/jamaibase/utils/exceptions.py @@ -1,9 +1,6 @@ -import functools +from functools import wraps from typing import Any -from pydantic import ValidationError -from pydantic_core import InitErrorDetails - def docstring_message(cls): """ @@ -13,7 +10,7 @@ def docstring_message(cls): # Must use cls_init name, not cls.__init__ itself, in closure to avoid recursion cls_init = cls.__init__ - @functools.wraps(cls.__init__) + @wraps(cls.__init__) def wrapped_init(self, msg=cls.__doc__, *args, **kwargs): cls_init(self, msg, *args, **kwargs) @@ -21,31 +18,14 @@ def wrapped_init(self, msg=cls.__doc__, *args, **kwargs): return cls -def make_validation_error( - exception: Exception, - *, - object_name: str = "", - loc: tuple = (), - input_value: Any = None, -) -> ValidationError: - return ValidationError.from_exception_data( - object_name, - line_errors=[ - InitErrorDetails( - type="value_error", - loc=loc, - input=input_value, - ctx={"error": exception}, - ) - ], - ) +@docstring_message +class JamaiException(Exception): + """Base exception class for Jamai errors.""" @docstring_message -class JamaiException(RuntimeError): - """Base exception class for JamAIBase errors.""" - - pass +class UpStreamError(JamaiException): + """One or more upstream columns errored out.""" @docstring_message @@ -65,12 +45,22 @@ class ForbiddenError(JamaiException): @docstring_message class UpgradeTierError(JamaiException): - """You have exhausted the allocations of your subscribed tier. Please upgrade.""" + """Your organization has exhausted the allocations of your subscribed plan. Please upgrade or top up credits.""" + + +@docstring_message +class NoTierError(UpgradeTierError): + """Your organization has not subscribed to any plan. Please subscribe in Organization Billing Settings.""" + + +@docstring_message +class BaseTierCountError(UpgradeTierError): + """You can have only one organization with Free Plan. Please upgrade.""" @docstring_message class InsufficientCreditsError(JamaiException): - """Please ensure that you have sufficient credits.""" + """Your organization has exhausted your credits. Please top up.""" @docstring_message @@ -87,14 +77,17 @@ class ResourceExistsError(JamaiException): class UnsupportedMediaTypeError(JamaiException): """This file type is unsupported.""" - pass - @docstring_message class BadInputError(JamaiException): """Your input is invalid.""" +@docstring_message +class ModelCapabilityError(BadInputError): + """No model has the specified capabilities.""" + + @docstring_message class TableSchemaFixedError(JamaiException): """Table schema cannot be modified.""" @@ -109,11 +102,45 @@ class ContextOverflowError(JamaiException): class UnexpectedError(JamaiException): """We ran into an unexpected error.""" - pass + +@docstring_message +class RateLimitExceedError(JamaiException): + """The rate limit is exceeded.""" + + def __init__( + self, + *args, + limit: int, + remaining: int, + reset_at: int, + used: int | None = None, + retry_after: int | None = None, + meta: dict[str, Any] | None = None, + ): + super().__init__(*args) + self.limit = limit + self.remaining = remaining + self.reset_at = reset_at + self.used = used + self.retry_after = retry_after + self.meta = meta + + +@docstring_message +class UnavailableError(JamaiException): + """The requested functionality is unavailable.""" @docstring_message class ServerBusyError(JamaiException): """The server is busy.""" - pass + +@docstring_message +class ModelOverloadError(JamaiException): + """The model is overloaded.""" + + +@docstring_message +class MethodNotAllowedError(JamaiException): + """Method is not allowed.""" diff --git a/clients/python/src/jamaibase/utils/io.py b/clients/python/src/jamaibase/utils/io.py index d790803..4a89ef6 100644 --- a/clients/python/src/jamaibase/utils/io.py +++ b/clients/python/src/jamaibase/utils/io.py @@ -1,29 +1,65 @@ -from __future__ import annotations - import csv import logging import pickle -from io import BytesIO, StringIO +from collections import OrderedDict +from io import StringIO +from mimetypes import guess_type +from os.path import splitext from typing import Any +import filetype import numpy as np import orjson import pandas as pd -import srsly import toml +import yaml from PIL import ExifTags, Image -from jamaibase.utils.types import JSONInput, JSONOutput +from jamaibase.types.common import JSONInput, JSONOutput logger = logging.getLogger(__name__) +EMBED_WHITE_LIST = { + "application/pdf": [".pdf"], + "application/xml": [".xml"], + "application/json": [".json"], + "application/jsonl": [".jsonl"], + "application/x-ndjson": [".jsonl"], + "application/json-lines": [".jsonl"], + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": [".docx"], + "application/vnd.openxmlformats-officedocument.presentationml.presentation": [".pptx"], + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": [".xlsx"], + "text/markdown": [".md"], + "text/plain": [".txt"], + "text/html": [".html"], + "text/tab-separated-values": [".tsv"], + "text/csv": [".csv"], + "text/xml": [".xml"], +} +DOC_WHITE_LIST = EMBED_WHITE_LIST +IMAGE_WHITE_LIST = { + "image/jpeg": [".jpg", ".jpeg"], + "image/png": [".png"], + "image/gif": [".gif"], + "image/webp": [".webp"], +} +AUDIO_WHITE_LIST = { + "audio/mpeg": [".mp3"], + "audio/wav": [".wav"], + "audio/x-wav": [".wav"], + "audio/x-pn-wav": [".wav"], + "audio/wave": [".wav"], + "audio/vnd.wav": [".wav"], + "audio/vnd.wave": [".wav"], +} + def load_pickle(file_path: str): with open(file_path, "rb") as f: return pickle.load(f) -def dump_pickle(out_path: str, obj: any): +def dump_pickle(out_path: str, obj: Any): with open(out_path, "wb") as f: pickle.dump(obj, f) @@ -61,8 +97,8 @@ def json_loads(data: str) -> JSONOutput: return orjson.loads(data) -def json_dumps(data: JSONInput) -> str: - return orjson.dumps(data).decode("utf-8") +def json_dumps(data: JSONInput, **kwargs) -> str: + return orjson.dumps(data, **kwargs).decode("utf-8") def read_yaml(path: str) -> JSONOutput: @@ -74,7 +110,8 @@ def read_yaml(path: str) -> JSONOutput: Returns: data (JSONOutput): The data. """ - return srsly.read_yaml(path) + with open(path, "r") as f: + return yaml.safe_load(f) def dump_yaml(data: JSONInput, path: str, **kwargs) -> str: @@ -83,12 +120,13 @@ def dump_yaml(data: JSONInput, path: str, **kwargs) -> str: Args: data (JSONInput): The data. path (str): Path to the file. - **kwargs: Other keyword arguments to pass into `srsly.write_yaml`. + **kwargs: Other keyword arguments to pass into `yaml.dump`. Returns: path (str): Path to the file. """ - srsly.write_yaml(path, data, **kwargs) + with open(path, "w") as f: + yaml.dump(data, f, **kwargs) return path @@ -116,6 +154,10 @@ def dump_toml(data: JSONInput, path: str, **kwargs) -> str: Returns: path (str): Path to the file. """ + # Convert non-dictionary data into a dictionary + if not isinstance(data, (dict, OrderedDict)): + data = {"value": data} # Wrap non-dictionary data in a dictionary + with open(path, "w") as f: toml.dump(data, f) return path @@ -126,14 +168,14 @@ def csv_to_df( column_names: list[str] | None = None, sep: str = ",", dtype: dict[str, Any] | None = None, + **kwargs, ) -> pd.DataFrame: - has_header = not (isinstance(column_names, list) and len(column_names) > 0) df = pd.read_csv( StringIO(data), - header=0 if has_header else None, names=column_names, sep=sep, dtype=dtype, + **kwargs, ) return df @@ -149,6 +191,7 @@ def df_to_csv( encoding="utf-8", lineterminator="\n", decimal=".", + header=True, index=False, quoting=csv.QUOTE_NONNUMERIC, quotechar='"', @@ -176,53 +219,32 @@ def read_image(img_path: str) -> tuple[np.ndarray, bool]: return np.asarray(image), is_rotated -def generate_image_thumbnail( - file_content: bytes, - size: tuple[float, float] = (450.0, 450.0), -) -> bytes: - try: - with Image.open(BytesIO(file_content)) as img: - # Check image mode - if img.mode not in ("RGB", "RGBA"): - img = img.convert("RGB") - # Resize and save - img.thumbnail(size=size) - with BytesIO() as f: - img.save( - f, - format="webp", - lossless=False, - quality=60, - alpha_quality=50, - method=6, - exact=False, - ) - return f.getvalue() - except Exception as e: - logger.exception(f"Failed to generate thumbnail due to {e.__class__.__name__}: {e}") - return b"" - - -def generate_audio_thumbnail(file_content: bytes, duration_ms: int = 30000) -> bytes: - """ - Generates a thumbnail audio by extracting a segment from the original audio. - - Args: - file_content (bytes): The audio file content. - duration_ms (int): Duration of the thumbnail in milliseconds. - - Returns: - bytes: The thumbnail audio segment as bytes. - """ - from pydub import AudioSegment - - # Use BytesIO to simulate a file object from the byte content - audio = AudioSegment.from_file(BytesIO(file_content)) +# Use the first MIME for each extension +MIME_WHITE_LIST = {**EMBED_WHITE_LIST, **IMAGE_WHITE_LIST, **AUDIO_WHITE_LIST} +EXT_TO_MIME = {} +for mime, exts in MIME_WHITE_LIST.items(): + for ext in exts: + EXT_TO_MIME[ext] = EXT_TO_MIME.get(ext, mime) - # Extract the first `duration_ms` milliseconds - thumbnail = audio[:duration_ms] - # Export the thumbnail to a bytes object - with BytesIO() as output: - thumbnail.export(output, format="mp3") - return output.getvalue() +def guess_mime(source: str | bytes) -> str: + if isinstance(source, str): + ext = splitext(source)[1].lower() + mime = EXT_TO_MIME.get(ext, None) + if mime is not None: + return mime + try: + # `filetype` can handle file path and content bytes + mime = filetype.guess(source) + if mime is not None: + return mime.mime + if isinstance(source, str): + # `mimetypes` can only handle file path + mime, _ = guess_type(source) + if mime is not None: + return mime + if source.endswith(".jsonl"): + return "application/jsonl" + except Exception: + logger.warning(f'Failed to sniff MIME type of file "{source}".') + return "application/octet-stream" diff --git a/clients/python/src/jamaibase/utils/types.py b/clients/python/src/jamaibase/utils/types.py index 5214177..74e8ba0 100644 --- a/clients/python/src/jamaibase/utils/types.py +++ b/clients/python/src/jamaibase/utils/types.py @@ -1,3 +1,54 @@ -from srsly.util import JSONInput, JSONOutput +import argparse +from enum import Enum +from typing import Callable, Type, TypeVar -__all__ = ["JSONInput", "JSONOutput"] +from pydantic import BaseModel + +try: + from enum import StrEnum +except ImportError: + + class StrEnum(str, Enum): + pass + + +### --- Enum Validator --- ### + +E = TypeVar("E", bound=Enum) + + +def get_enum_validator(enum_cls: Type[E]) -> Callable[[str], E]: + def _validator(v: str) -> E: + try: + return enum_cls[v] + except KeyError: + return enum_cls(v) + + return _validator + + +class CLI(BaseModel): + @classmethod + def parse_args(cls) -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + for field_name, field_info in cls.model_fields.items(): + field_type = field_info.annotation + default = field_info.default + description = field_info.description or "" + if field_type is bool: + parser.add_argument( + f"--{field_name}", + action="store_true", + help=description, + ) + else: + parser.add_argument( + f"--{field_name}", + type=field_type, + default=default, + required=default is ..., + help=description, + ) + return cls(**vars(parser.parse_args())) diff --git a/clients/python/src/jamaibase/version.py b/clients/python/src/jamaibase/version.py index 3d26edf..382021f 100644 --- a/clients/python/src/jamaibase/version.py +++ b/clients/python/src/jamaibase/version.py @@ -1 +1 @@ -__version__ = "0.4.1" +__version__ = "1.0.6" diff --git a/clients/python/tests/cloud/test_admin.py b/clients/python/tests/cloud/test_admin.py deleted file mode 100644 index a760683..0000000 --- a/clients/python/tests/cloud/test_admin.py +++ /dev/null @@ -1,1445 +0,0 @@ -from contextlib import contextmanager -from datetime import datetime, timedelta, timezone -from inspect import signature -from multiprocessing import Manager, Process -from time import sleep -from typing import Generator, Type - -import pytest -from loguru import logger -from tenacity import retry, stop_after_attempt, wait_exponential - -from jamaibase import JamAI -from jamaibase.protocol import ( - ActionTableSchemaCreate, - AdminOrderBy, - ApiKeyCreate, - ApiKeyRead, - ChatCompletionChunk, - ChatEntry, - ChatRequest, - ChatTableSchemaCreate, - ColumnSchemaCreate, - EmbeddingRequest, - EmbeddingResponse, - EventCreate, - EventRead, - GenTableRowsChatCompletionChunks, - GenTableStreamChatCompletionChunk, - KnowledgeTableSchemaCreate, - LLMGenConfig, - LLMModelConfig, - ModelDeploymentConfig, - ModelListConfig, - ModelPrice, - OkResponse, - OrganizationCreate, - OrganizationRead, - OrganizationUpdate, - OrgMemberCreate, - OrgMemberRead, - PATCreate, - PATRead, - Price, - ProjectCreate, - RowAddRequest, - TableMetaResponse, - TableType, - UserCreate, - UserRead, - UserUpdate, -) -from jamaibase.utils import datetime_now_iso -from owl.configs.manager import ENV_CONFIG, PlanName, ProductType -from owl.utils import uuid7_str - -CLIENT_CLS = [JamAI] -USER_ID_A = "duncan" -USER_ID_B = "mama" -USER_ID_C = "sus" -TABLE_TYPES = [TableType.action, TableType.knowledge, TableType.chat] - - -@contextmanager -def _create_user( - owl: JamAI, - user_id: str = USER_ID_A, - **kwargs, -) -> Generator[UserRead, None, None]: - # TODO: Can make this work with OSS too by yielding a dummy UserRead - owl.admin.backend.delete_user(user_id) - try: - user = owl.admin.backend.create_user( - UserCreate( - id=user_id, - name=kwargs.pop("name", "Duncan Idaho"), - description=kwargs.pop("description", "A Ginaz Swordmaster from House Atreides."), - email=kwargs.pop("email", "duncan.idaho@gmail.com"), - meta=kwargs.pop("meta", {}), - ) - ) - yield user - finally: - owl.admin.backend.delete_user(user_id) - - -@contextmanager -def _create_org( - owl: JamAI, - user_id: str, - active: bool = True, - **kwargs, -) -> Generator[OrganizationRead, None, None]: - org_id = None - try: - org = owl.admin.backend.create_organization( - OrganizationCreate( - creator_user_id=user_id, - name=kwargs.pop("name", "Company"), - external_keys=kwargs.pop("external_keys", {}), - tier=kwargs.pop("tier", PlanName.FREE), - active=active, - **kwargs, - ) - ) - org_id = org.id - yield org - finally: - if org_id is not None: - owl.admin.backend.delete_organization(org_id) - - -def _delete_project(owl: JamAI, project_id: str | None): - if project_id is not None: - owl.admin.organization.delete_project(project_id) - - -@contextmanager -def _create_project( - owl: JamAI, - organization_id: str, - name: str = "default", -) -> Generator[OrganizationRead, None, None]: - project_id = None - try: - project = owl.admin.organization.create_project( - ProjectCreate( - organization_id=organization_id, - name=name, - ) - ) - project_id = project.id - yield project - finally: - _delete_project(owl, project_id) - - -@contextmanager -def _set_model_config(owl: JamAI, config: ModelListConfig): - old_config = owl.admin.backend.get_model_config() - try: - response = owl.admin.backend.set_model_config(config) - assert isinstance(response, OkResponse) - yield response - finally: - owl.admin.backend.set_model_config(old_config) - - -def _chat(jamai: JamAI, model_id: str): - request = ChatRequest( - model=model_id, - messages=[ - ChatEntry.system("You are a concise assistant."), - ChatEntry.user("What is a llama?"), - ], - temperature=0.001, - top_p=0.001, - max_tokens=3, - stream=False, - ) - completion = jamai.generate_chat_completions(request) - assert isinstance(completion, ChatCompletionChunk) - assert isinstance(completion.text, str) - assert len(completion.text) > 1 - - -def _embed(jamai: JamAI, model_id: str): - request = EmbeddingRequest( - input="什么是 llama?", - model=model_id, - type="document", - encoding_format="float", - ) - response = jamai.generate_embeddings(request) - assert isinstance(response, EmbeddingResponse) - assert isinstance(response.data, list) - assert isinstance(response.data[0].embedding, list) - assert len(response.data[0].embedding) > 0 - - -@contextmanager -def _create_gen_table( - jamai: JamAI, - table_type: TableType, - table_id: str, - model_id: str = "", - cols: list[ColumnSchemaCreate] | None = None, - chat_cols: list[ColumnSchemaCreate] | None = None, - embedding_model: str = "", - delete_first: bool = True, - delete: bool = True, -): - try: - if delete_first: - jamai.table.delete_table(table_type, table_id) - if cols is None: - cols = [ - ColumnSchemaCreate(id="input", dtype="str"), - ColumnSchemaCreate( - id="output", - dtype="str", - gen_config=LLMGenConfig( - model=model_id, - prompt="${input}", - max_tokens=3, - ), - ), - ] - if chat_cols is None: - chat_cols = [ - ColumnSchemaCreate(id="User", dtype="str"), - ColumnSchemaCreate( - id="AI", - dtype="str", - gen_config=LLMGenConfig( - model=model_id, - system_prompt="You are an assistant.", - max_tokens=3, - ), - ), - ] - if table_type == TableType.action: - table = jamai.table.create_action_table( - ActionTableSchemaCreate(id=table_id, cols=cols) - ) - elif table_type == TableType.knowledge: - table = jamai.table.create_knowledge_table( - KnowledgeTableSchemaCreate(id=table_id, cols=cols, embedding_model=embedding_model) - ) - elif table_type == TableType.chat: - table = jamai.table.create_chat_table( - ChatTableSchemaCreate(id=table_id, cols=chat_cols + cols) - ) - else: - raise ValueError(f"Invalid table type: {table_type}") - assert isinstance(table, TableMetaResponse) - yield table - finally: - if delete: - jamai.table.delete_table(table_type, table_id) - - -def test_cors(): - import httpx - - def _assert_cors(_response: httpx.Response): - assert "Access-Control-Allow-Origin" in _response.headers, _response.headers - assert "Access-Control-Allow-Methods" in _response.headers, _response.headers - assert "Access-Control-Allow-Headers" in _response.headers, _response.headers - assert "Access-Control-Allow-Credentials" in _response.headers, _response.headers - assert _response.headers["Access-Control-Allow-Credentials"].lower() == "true" - - headers = { - "Origin": "http://example.com", - "Access-Control-Request-Method": "POST", - "Access-Control-Request-Headers": "Content-Type", - } - owl = JamAI() - # Preflight - response = httpx.options(owl.api_base, headers=headers) - _assert_cors(response) - - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - assert isinstance(org.id, str) - assert len(org.id) > 0 - with _create_project(owl, org.id) as p0: - assert isinstance(p0.id, str) - endpoint = f"{owl.api_base}/v1/models" - # Assert preflight no auth - response = httpx.options(endpoint, headers=headers) - _assert_cors(response) - # Assert CORS headers in methods with auth - response = httpx.get(endpoint, headers=headers) - assert response.status_code == 401 - response = httpx.get( - endpoint, - headers={ - "Authorization": f"Bearer {owl.api_key}", - "X-PROJECT-ID": p0.id, - **headers, - }, - ) - assert "Access-Control-Allow-Origin" in response.headers, response.headers - assert "Access-Control-Allow-Credentials" in response.headers, response.headers - assert response.headers["Access-Control-Allow-Credentials"].lower() == "true" - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_create_users(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as user: - assert isinstance(user, UserRead) - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_get_and_list_users(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan, _create_user(owl, USER_ID_B) as mama: - # Test fetch - user = owl.admin.backend.get_user(duncan.id) - assert isinstance(user, UserRead) - assert user.id == duncan.id - - user = owl.admin.backend.get_user(mama.id) - assert isinstance(user, UserRead) - assert user.id == mama.id - - # Test list - users = owl.admin.backend.list_users() - assert isinstance(users.items, list) - assert all(isinstance(r, UserRead) for r in users.items) - assert users.total == 2 - assert users.offset == 0 - assert users.limit == 100 - assert len(users.items) == 2 - - users = owl.admin.backend.list_users(offset=1) - assert isinstance(users.items, list) - assert all(isinstance(r, UserRead) for r in users.items) - assert users.total == 2 - assert users.offset == 1 - assert users.limit == 100 - assert len(users.items) == 1 - - users = owl.admin.backend.list_users(limit=1) - assert isinstance(users.items, list) - assert all(isinstance(r, UserRead) for r in users.items) - assert users.total == 2 - assert users.offset == 0 - assert users.limit == 1 - assert len(users.items) == 1 - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_update_user(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - updated_user_request = UserUpdate(id=duncan.id, name="Updated Duncan") - updated_user_response = owl.admin.backend.update_user(updated_user_request) - assert isinstance(updated_user_response, UserRead) - assert updated_user_response.id == duncan.id - assert updated_user_response.name == "Updated Duncan" - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_delete_users(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as user: - assert isinstance(user, UserRead) - # Assert there is a user - users = owl.admin.backend.list_users() - assert isinstance(users.items, list) - assert users.total == 1 - # Delete - response = owl.admin.backend.delete_user(user.id) - assert isinstance(response, OkResponse) - # Assert there is no user - users = owl.admin.backend.list_users() - assert isinstance(users.items, list) - assert users.total == 0 - - with pytest.raises(RuntimeError, match="User .+ is not found."): - owl.admin.backend.update_user(UserUpdate(id=user.id, name="Updated Name")) - - with pytest.raises(RuntimeError, match="User .+ is not found."): - owl.admin.backend.get_user(user.id) - - response = owl.admin.backend.delete_user(user.id) - assert isinstance(response, OkResponse) - with pytest.raises(RuntimeError, match="User .+ is not found."): - owl.admin.backend.delete_user(user.id, missing_ok=False) - - -def test_user_update_pydantic_model(): - sig = signature(UserUpdate) - for name, param in sig.parameters.items(): - if name == "id": - continue - assert ( - param.default is None - ), f'Parameter "{name}" has a default value of {param.default} instead of None.' - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_pat(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as u0, _create_user(owl, USER_ID_B) as u1: - with _create_org(owl, u0.id) as o0, _create_org(owl, u1.id): - with _create_project(owl, o0.id) as p0: - pat0 = owl.admin.backend.create_pat(PATCreate(user_id=u0.id)) - pat0_expire = owl.admin.backend.create_pat( - PATCreate( - user_id=u0.id, - expiry=(datetime.now(tz=timezone.utc) + timedelta(seconds=1)).isoformat(), - ) - ) - assert isinstance(pat0, PATRead) - pat1 = owl.admin.backend.create_pat(PATCreate(user_id=u1.id)) - assert isinstance(pat1, PATRead) - # Make some requests using the PAT - jamai = JamAI(project_id=p0.id, token=pat0.id) - models = jamai.model_names(capabilities=["chat"]) - assert isinstance(models, list) - assert len(models) > 0 - # Fetch the user - user = JamAI().admin.backend.get_user(u0.id) - assert isinstance(user, UserRead) - assert user.id == USER_ID_A - user = JamAI().admin.backend.get_user(u1.id) - assert isinstance(user, UserRead) - assert user.id == USER_ID_B - # Create gen table - with _create_gen_table(jamai, "action", "xx"): - table = jamai.table.get_table("action", "xx") - assert isinstance(table, TableMetaResponse) - ### --- Test service key auth --- ### - table = JamAI( - project_id=p0.id, - token=ENV_CONFIG.service_key_plain, - headers={"X-USER-ID": u0.id}, - ).table.get_table("action", "xx") - assert isinstance(table, TableMetaResponse) - # Try using invalid user ID - with pytest.raises(RuntimeError): - JamAI( - project_id=p0.id, - token=ENV_CONFIG.service_key_plain, - headers={"X-USER-ID": u1.id}, - ).table.get_table("action", "xx") - ### --- Test PAT --- ### - # Try using invalid PAT - with pytest.raises(RuntimeError): - JamAI(project_id=p0.id, token=pat1.id).table.get_table("action", "xx") - # Test PAT expiry - while datetime_now_iso() < pat0_expire.expiry: - sleep(1) - with pytest.raises(RuntimeError): - JamAI(project_id=p0.id, token=pat0_expire.id).table.get_table( - "action", "xx" - ) - # Test PAT fetch - pat0_read = owl.admin.backend.get_pat(pat0.id) - assert isinstance(pat0_read, PATRead) - assert pat0_read.id == pat0.id - # Test PAT deletion - response = owl.admin.backend.delete_pat(pat0.id) - assert isinstance(response, OkResponse) - with pytest.raises(RuntimeError): - owl.admin.backend.get_pat(pat0.id) - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_create_organizations(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id, external_keys=dict(openai="sk-test")) as org: - assert isinstance(org, OrganizationRead) - assert isinstance(org.id, str) - assert len(org.id) > 0 - assert "openai" in org.external_keys - assert org.external_keys["openai"] == "sk-test" - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_create_organizations_free_tier_check(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with ( - _create_org(owl, duncan.id, name="Free 0", tier=PlanName.FREE) as o0, - _create_org(owl, duncan.id, name="Free 1", tier=PlanName.FREE) as o1, - _create_org(owl, duncan.id, name="Paid 0", tier=PlanName.PRO) as o2, - ): - assert isinstance(o0, OrganizationRead) - assert isinstance(o0.id, str) - assert len(o0.id) > 0 - assert isinstance(o1, OrganizationRead) - assert isinstance(o1.id, str) - assert len(o1.id) > 0 - assert isinstance(o2, OrganizationRead) - assert isinstance(o2.id, str) - assert len(o2.id) > 0 - assert o0.active is True - assert o1.active is False - assert o2.active is True - with _create_project(owl, o0.id, "Pear"): - pass - with pytest.raises(RuntimeError, match="not activated"): - with _create_project(owl, o1.id, "Pear"): - pass - with _create_project(owl, o2.id, "Pear"): - pass - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_create_organizations_invalid_key(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with pytest.raises(RuntimeError, match="Unsupported external provider"): - with _create_org(owl, duncan.id, external_keys=dict(invalid="sk-test")): - pass - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_get_and_list_organizations(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id, name="company") as company: - with _create_org(owl, duncan.id, name="Personal"): - # Test fetch - org = owl.admin.backend.get_organization(company.id) - assert isinstance(org, OrganizationRead) - assert org.id == company.id - assert isinstance(org.members, list) - assert isinstance(org.api_keys, list) - assert isinstance(org.projects, list) - assert duncan.id in set(u.user_id for u in org.members) - assert len(org.api_keys) == 0 - assert len(org.projects) == 0 - - with ( - _create_project(owl, company.id, "bear") as p0, - _create_project(owl, company.id) as p1, - ): - org = owl.admin.backend.get_organization(company.id) - assert isinstance(org, OrganizationRead) - assert org.id == company.id - assert isinstance(org.members, list) - assert isinstance(org.api_keys, list) - assert isinstance(org.projects, list) - assert duncan.id in set(u.user_id for u in org.members) - assert len(org.api_keys) == 0 - assert len(org.projects) == 2 - assert p0.id in set(p.id for p in org.projects) - assert p1.id in set(p.id for p in org.projects) - - # Test list - orgs = owl.admin.backend.list_organizations() - assert isinstance(orgs.items, list) - assert all(isinstance(r, OrganizationRead) for r in orgs.items) - assert orgs.total == 2 - assert orgs.offset == 0 - assert orgs.limit == 100 - assert len(orgs.items) == 2 - - orgs = owl.admin.backend.list_organizations(offset=1) - assert isinstance(orgs.items, list) - assert all(isinstance(r, OrganizationRead) for r in orgs.items) - assert orgs.total == 2 - assert orgs.offset == 1 - assert orgs.limit == 100 - assert len(orgs.items) == 1 - - orgs = owl.admin.backend.list_organizations(limit=1) - assert isinstance(orgs.items, list) - assert all(isinstance(r, OrganizationRead) for r in orgs.items) - assert orgs.total == 2 - assert orgs.offset == 0 - assert orgs.limit == 1 - assert len(orgs.items) == 1 - - # Test list with order_by - orgs = owl.admin.backend.list_organizations( - order_by="created_at", order_descending=False - ) - assert isinstance(orgs.items, list) - assert all(isinstance(r, OrganizationRead) for r in orgs.items) - assert orgs.items[0].name == "company" - assert orgs.items[1].name == "Personal" - assert orgs.total == 2 - assert orgs.offset == 0 - assert orgs.limit == 100 - assert len(orgs.items) == 2 - - # Ensure ordering is case-insensitive, otherwise uppercase will come before lowercase - orgs = owl.admin.backend.list_organizations( - order_by="name", order_descending=False - ) - assert isinstance(orgs.items, list) - assert all(isinstance(r, OrganizationRead) for r in orgs.items) - assert orgs.items[0].name == "company" - assert orgs.items[1].name == "Personal" - assert orgs.total == 2 - assert orgs.offset == 0 - assert orgs.limit == 100 - assert len(orgs.items) == 2 - - for order_by in AdminOrderBy: - orgs = owl.admin.backend.list_organizations(order_by=order_by) - org_ids = [org.id for org in orgs.items] - assert len(orgs.items) == 2 - orgs_desc = owl.admin.backend.list_organizations( - order_by=order_by, order_descending=False - ) - org_ids_desc = [org.id for org in orgs_desc.items] - assert len(orgs_desc.items) == 2 - assert ( - org_ids == org_ids_desc[::-1] - ), f"Failed to order by {order_by}: {org_ids} != {org_ids_desc[::-1]}" - - # # Test starting_after - # orgs = owl.admin.backend.list_organizations( - # order_by="created_at", order_descending=False, starting_after=company.id - # ) - # assert isinstance(orgs.items, list) - # assert all(isinstance(r, OrganizationRead) for r in orgs.items) - # assert orgs.items[0].name == "Personal" - # assert orgs.total == 2 - # assert orgs.offset == 0 - # assert orgs.limit == 100 - # assert len(orgs.items) == 1 - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_update_organization(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - updated_org = owl.admin.backend.update_organization( - OrganizationUpdate( - id=org.id, - name="Company X", - active=True, - llm_tokens_usage_mtok=100.0, - ) - ) - assert isinstance(updated_org, OrganizationRead) - assert updated_org.id == org.id - assert updated_org.name == "Company X" - assert updated_org.llm_tokens_usage_mtok == 100.0 - updated_org = owl.admin.backend.update_organization( - OrganizationUpdate( - id=org.id, - embedding_tokens_quota_mtok=9.0, - ) - ) - assert isinstance(updated_org, OrganizationRead) - org = owl.admin.backend.get_organization(org.id) - assert isinstance(org, OrganizationRead) - assert updated_org.llm_tokens_usage_mtok == 100.0 - assert updated_org.embedding_tokens_quota_mtok == 9.0 - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_delete_organizations(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - assert isinstance(org, OrganizationRead) - # Assert there is an org - orgs = owl.admin.backend.list_organizations() - assert isinstance(orgs.items, list) - assert orgs.total == 1 - - # Delete the organization - response = owl.admin.backend.delete_organization(org.id) - assert isinstance(response, OkResponse) - - # Assert there is no org - orgs = owl.admin.backend.list_organizations() - assert isinstance(orgs.items, list) - assert orgs.total == 0 - - response = owl.admin.backend.delete_organization(org.id) - assert isinstance(response, OkResponse) - with pytest.raises(RuntimeError, match="Organization .+ is not found."): - owl.admin.backend.delete_organization(org.id, missing_ok=False) - - with pytest.raises(RuntimeError, match="Organization .+ is not found."): - owl.admin.backend.update_organization( - OrganizationUpdate(id=org.id, name="Updated Name") - ) - - with pytest.raises(RuntimeError, match="Organization .+ is not found."): - owl.admin.backend.get_organization(org.id) - - with pytest.raises(RuntimeError, match="Organization .+ is not found."): - owl.admin.organization.create_project( - ProjectCreate(name="New Project", organization_id=org.id) - ) - - with pytest.raises(RuntimeError, match="Organization .+ is not found."): - owl.admin.backend.join_organization( - OrgMemberCreate(user_id=duncan.id, organization_id=org.id) - ) - - with pytest.raises(RuntimeError, match="Organization .+ is not found."): - owl.admin.backend.leave_organization(user_id=duncan.id, organization_id=org.id) - - -def test_organization_update_pydantic_model(): - sig = signature(OrganizationUpdate) - for name, param in sig.parameters.items(): - if name == "id": - continue - assert ( - param.default is None - ), f'Parameter "{name}" has a default value of {param.default} instead of None.' - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_refresh_quota(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id, tier=PlanName.FREE) as org: - free_quota = org.llm_tokens_quota_mtok - assert org.llm_tokens_usage_mtok == 0.0 - # Set to another tier - org = owl.admin.backend.update_organization( - OrganizationUpdate( - id=org.id, - tier=PlanName.PRO, - llm_tokens_usage_mtok=0.2, - ) - ) - # Quota should be unchanged before refresh - assert org.llm_tokens_quota_mtok == free_quota - assert org.llm_tokens_usage_mtok == 0.2 - # Quota should increase after refresh, usage should reset - org = owl.admin.backend.refresh_quota(org.id) - assert isinstance(org, OrganizationRead) - pro_quota = org.llm_tokens_quota_mtok - assert pro_quota > free_quota - assert org.llm_tokens_usage_mtok == 0.0 - # Test refresh without resetting usage - owl.admin.backend.update_organization( - OrganizationUpdate( - id=org.id, - tier=PlanName.FREE, - llm_tokens_usage_mtok=0.2, - ) - ) - org = owl.admin.backend.refresh_quota(org.id, False) - assert org.llm_tokens_quota_mtok < pro_quota - assert org.llm_tokens_usage_mtok == 0.2 - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_create_fetch_delete_api_key(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id, tier=PlanName.PRO) as org: - # Create API key - api_key = owl.admin.backend.create_api_key(ApiKeyCreate(organization_id=org.id)) - assert isinstance(api_key, ApiKeyRead) - print(f"API key created: {api_key}\n") - - # Fetch API key info - fetched_key = owl.admin.backend.get_api_key(api_key.id) - assert isinstance(fetched_key, ApiKeyRead) - assert fetched_key.id == api_key.id - print(f"API key fetched: {fetched_key}\n") - - # Fetch company using API key - org = owl.admin.backend.get_organization(api_key.id) - assert isinstance(org, OrganizationRead) - print(f"Organization fetched: {org}\n") - - # Delete API key - response = owl.admin.backend.delete_api_key(api_key.id) - assert isinstance(response, OkResponse) - print(f"API key deleted: {api_key.id}\n") - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_fetch_specific_user(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - user = owl.admin.backend.get_user(duncan.id) - assert isinstance(user, UserRead) - print(f"User fetched: {user}\n") - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_join_and_leave_organization(client_cls: Type[JamAI]): - owl = client_cls() - with ( - _create_user(owl, USER_ID_A, email="a@gmail.com") as u0, - _create_user(owl, USER_ID_B, email="b@gmail.com") as u1, - _create_user(owl, USER_ID_C, email="c@gmail.com") as u2, - ): - # --- Join without invite link --- # - with _create_org(owl, u0.id, tier="pro") as pro_org, _create_org(owl, u0.id) as free_org: - assert u1.id not in set(m.user_id for m in pro_org.members) - member = owl.admin.backend.join_organization( - OrgMemberCreate(user_id=u1.id, organization_id=pro_org.id) - ) - assert isinstance(member, OrgMemberRead) - assert member.user_id == u1.id - assert member.organization_id == pro_org.id - assert member.role == "admin" - # Cannot join free org - with pytest.raises(RuntimeError): - owl.admin.backend.join_organization( - OrgMemberCreate(user_id=u1.id, organization_id=free_org.id) - ) - # --- Join with public invite link --- # - with _create_org(owl, u0.id, tier="pro") as pro_org: - assert u1.id not in set(m.user_id for m in pro_org.members) - invite = owl.admin.backend.generate_invite_token(pro_org.id, user_role="member") - member = owl.admin.backend.join_organization( - OrgMemberCreate( - user_id=u1.id, - organization_id=pro_org.id, - role="member", - invite_token=invite, - ) - ) - assert isinstance(member, OrgMemberRead) - assert member.user_id == u1.id - assert member.organization_id == pro_org.id - assert member.role == "member" - # --- Join with private invite link --- # - with _create_org(owl, u0.id, tier="pro") as pro_org: - assert u1.id not in set(m.user_id for m in pro_org.members) - # Invite token email validation should be case and space insensitive - invite = owl.admin.backend.generate_invite_token( - pro_org.id, f" {u1.email.upper()} ", user_role="admin" - ) - member = owl.admin.backend.join_organization( - OrgMemberCreate( - user_id=u1.id, - organization_id=pro_org.id, - role="admin", - invite_token=invite, - ) - ) - assert isinstance(member, OrgMemberRead) - assert member.user_id == u1.id - assert member.organization_id == pro_org.id - assert member.role == "admin" - # Other email should fail - with pytest.raises(RuntimeError): - owl.admin.backend.join_organization( - OrgMemberCreate( - user_id=u2.id, - organization_id=pro_org.id, - role="admin", - invite_token=invite, - ) - ) - # --- Leave organization --- # - leave_response = owl.admin.backend.leave_organization(u0.id, pro_org.id) - assert isinstance(leave_response, OkResponse) - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_add_event(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - response = owl.admin.backend.add_event( - EventCreate( - id=f"{org.id}_token", - organization_id=org.id, - deltas={ProductType.LLM_TOKENS: -0.5}, - values={}, - ) - ) - assert isinstance(response, OkResponse) - - event = owl.admin.backend.get_event(f"{org.id}_token") - assert isinstance(event, EventRead) - assert event.id == f"{org.id}_token" - assert event.deltas.get(ProductType.LLM_TOKENS) == -0.5 - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_get_event(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - owl.admin.backend.add_event( - EventCreate( - id=f"{org.id}_token", - organization_id=org.id, - deltas={ProductType.LLM_TOKENS: -0.5}, - values={}, - ) - ) - - event = owl.admin.backend.get_event(f"{org.id}_token") - assert isinstance(event, EventRead) - assert event.id == f"{org.id}_token" - assert event.deltas.get(ProductType.LLM_TOKENS) == -0.5 - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_mark_event_as_done(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - owl.admin.backend.add_event( - EventCreate( - id=f"{org.id}_token", - organization_id=org.id, - deltas={ProductType.LLM_TOKENS: -0.5}, - values={}, - ) - ) - - response = owl.admin.backend.mark_event_as_done(f"{org.id}_token") - assert isinstance(response, OkResponse) - - event = owl.admin.backend.get_event(f"{org.id}_token") - assert isinstance(event, EventRead) - assert event.id == f"{org.id}_token" - assert event.pending is False - assert event.deltas.get(ProductType.LLM_TOKENS) == -0.5 - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_get_pricing(client_cls: Type[JamAI]): - owl = client_cls() - response = owl.admin.backend.get_pricing() - assert isinstance(response, Price) - assert len(response.plans) > 0 - response = owl.admin.backend.get_model_pricing() - assert isinstance(response, ModelPrice) - assert len(response.llm_models) > 0 - assert len(response.embed_models) > 0 - assert len(response.rerank_models) > 0 - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_add_credit(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - assert isinstance(org, OrganizationRead) - assert isinstance(org.id, str) - assert len(org.id) > 0 - - assert org.credit == 0 - assert org.credit_grant == 0 - assert org.llm_tokens_usage_mtok == 0 - assert org.db_usage_gib == 0 - assert org.file_usage_gib == 0 - assert org.egress_usage_gib == 0 - # Set values - response = owl.admin.backend.add_event( - EventCreate( - id=f"{org.quota_reset_at}_credit_{uuid7_str()}", - organization_id=org.id, - values={ - ProductType.CREDIT: 20.0, - ProductType.CREDIT_GRANT: 1, - ProductType.LLM_TOKENS: 70, - ProductType.DB_STORAGE: 2.0, - ProductType.FILE_STORAGE: 3.0, - ProductType.EGRESS: 4.0, - ProductType.EMBEDDING_TOKENS: 5.0, - ProductType.RERANKER_SEARCHES: 6.0, - }, - ) - ) - assert isinstance(response, OkResponse) - org = owl.admin.backend.get_organization(org.id) - assert org.credit == 20.0 - assert org.credit_grant == 1.0 - assert org.llm_tokens_usage_mtok == 70 - assert org.db_usage_gib == 2.0 - assert org.file_usage_gib == 3.0 - assert org.egress_usage_gib == 4.0 - assert org.embedding_tokens_usage_mtok == 5.0 - assert org.reranker_usage_ksearch == 6.0 - for product in ProductType.exclude_credits(): - assert isinstance(org.quotas[product]["quota"], (int, float)) - assert isinstance(org.quotas[product]["usage"], (int, float)) - # Add deltas - response = owl.admin.backend.add_event( - EventCreate( - id=f"{org.quota_reset_at}_credit_{uuid7_str()}", - organization_id=org.id, - deltas={ - "credit": 1.0, - ProductType.CREDIT_GRANT: 1.0, - ProductType.LLM_TOKENS: 70, - ProductType.DB_STORAGE: 2.0, - ProductType.FILE_STORAGE: 3.0, - ProductType.EGRESS: 4.0, - ProductType.EMBEDDING_TOKENS: 5.0, - ProductType.RERANKER_SEARCHES: 6.0, - }, - ) - ) - assert isinstance(response, OkResponse) - org = owl.admin.backend.get_organization(org.id) - assert org.credit == 21.0 - assert org.credit_grant == 2.0 - assert org.llm_tokens_usage_mtok == 140 - assert org.db_usage_gib == 4.0 - assert org.file_usage_gib == 6.0 - assert org.egress_usage_gib == 8.0 - assert org.embedding_tokens_usage_mtok == 10.0 - assert org.reranker_usage_ksearch == 12.0 - # Ensure values cannot go to negative - response = owl.admin.backend.add_event( - EventCreate( - id=f"{org.quota_reset_at}_credit_{uuid7_str()}", - organization_id=org.id, - deltas={ - "credit": -200.0, - ProductType.CREDIT_GRANT: -200.0, - ProductType.LLM_TOKENS: -200, - ProductType.DB_STORAGE: -200.0, - ProductType.FILE_STORAGE: -200.0, - ProductType.EGRESS: -200.0, - ProductType.EMBEDDING_TOKENS: -200.0, - ProductType.RERANKER_SEARCHES: -200.0, - }, - ) - ) - assert isinstance(response, OkResponse) - org = owl.admin.backend.get_organization(org.id) - assert org.credit == 0 - assert org.credit_grant == 0 - assert org.llm_tokens_usage_mtok == 0 - assert org.db_usage_gib == 0 - assert org.file_usage_gib == 0 - assert org.egress_usage_gib == 0 - assert org.embedding_tokens_usage_mtok == 0.0 - assert org.reranker_usage_ksearch == 0.0 - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_get_set_model_config(client_cls: Type[JamAI]): - owl = client_cls() - # Initial fetch - config = owl.admin.backend.get_model_config() - assert isinstance(config, ModelListConfig) - assert len(config.llm_models) > 1 - assert len(config.embed_models) > 1 - assert len(config.rerank_models) > 1 - llm_model_ids = [m.id for m in config.llm_models] - assert "ellm/new_model" not in llm_model_ids - # Set - new_config = config.model_copy(deep=True) - new_config.llm_models.append( - LLMModelConfig( - id="ellm/new_model", - name="ELLM New Model", - context_length=8000, - deployments=[ - ModelDeploymentConfig( - provider="ellm", - ) - ], - languages=["mul"], - capabilities=["chat"], - owned_by="ellm", - ) - ) - with _set_model_config(owl, new_config) as response: - assert isinstance(response, OkResponse) - # Fetch again - new_config = owl.admin.backend.get_model_config() - assert isinstance(new_config, ModelListConfig) - assert len(new_config.llm_models) == len(config.llm_models) + 1 - assert len(new_config.embed_models) == len(config.embed_models) - assert len(new_config.rerank_models) == len(config.rerank_models) - llm_model_ids = [m.id for m in new_config.llm_models] - assert "ellm/new_model" in llm_model_ids - # Fetch model list - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - assert isinstance(org.id, str) - assert len(org.id) > 0 - with _create_project(owl, org.id) as project: - assert isinstance(project.id, str) - assert len(project.id) > 0 - jamai = JamAI(project_id=project.id) - models = jamai.model_names(capabilities=["chat"]) - assert isinstance(models, list) - assert "ellm/new_model" in models - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_credit_check_llm(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - assert isinstance(org, OrganizationRead) - assert isinstance(org.id, str) - assert len(org.id) > 0 - with _create_project(owl, org.id) as project: - assert isinstance(project.id, str) - assert len(project.id) > 0 - # Get model list - jamai = JamAI(project_id=project.id) - models = jamai.model_info(capabilities=["chat"]).data - assert isinstance(models, list) - models = {m.owned_by: m for m in models} - model = models["openai"] - - # --- No credit to use 3rd party models --- # - assert org.credit == 0 - assert len(model.id) > 0 - # Error message should show model ID when called via API - with pytest.raises( - RuntimeError, - match=f"Insufficient LLM token quota or credits for model: {model.id}", - ): - _chat(jamai, model.id) - assert len(model.name) > 0 - assert model.name != model.id - # Error message should show model name when called via browser - name = model.name.replace("(", "\\(").replace(")", "\\)") - with pytest.raises( - RuntimeError, - match=f"Insufficient LLM token quota or credits for model: {name}", - ): - _chat( - JamAI(project_id=project.id, headers={"User-Agent": "Mozilla"}), - model.id, - ) - - @retry( - wait=wait_exponential(multiplier=1, min=1, max=10), - stop=stop_after_attempt(5), - reraise=True, - ) - def _assert_usage_updated(initial_value: int | float = 0): - org_read = owl.admin.backend.get_organization(org.id) - assert isinstance(org_read, OrganizationRead) - assert org_read.llm_tokens_usage_mtok > initial_value - - @retry( - wait=wait_exponential(multiplier=1, min=1, max=10), - stop=stop_after_attempt(5), - reraise=True, - ) - def _assert_chat_fail(_model_id: str): - # No more credit left - try: - _chat(jamai, _model_id) - logger.warning( - f"Org credit grant = {owl.admin.backend.get_organization(org.id).credit_grant}" - ) - except RuntimeError as e: - if ( - f"Insufficient LLM token quota or credits for model: {_model_id}" - not in str(e) - ): - raise ValueError("Error message mismatch") from e - # We actually want this to raise RuntimeError - else: - raise ValueError("Chat attempt did not fail.") - - # --- Test credit --- # - owl.admin.backend.add_event( - EventCreate( - id=f"{org.quota_reset_at}_credit_{uuid7_str()}", - organization_id=org.id, - values={ProductType.CREDIT: 1e-12}, - ) - ) - _chat(jamai, model.id) - _assert_chat_fail(model.id) - - # --- Test credit grant --- # - owl.admin.backend.add_event( - EventCreate( - id=f"{org.quota_reset_at}_credit_{uuid7_str()}", - organization_id=org.id, - values={ - ProductType.CREDIT: 0.0, - ProductType.CREDIT_GRANT: 1e-12, - }, - ) - ) - org = owl.admin.backend.get_organization(org.id) - assert org.credit == 0 - assert org.credit_grant == 1e-12 - _chat(jamai, model.id) - _assert_chat_fail(model.id) - - # --- Test ELLM model --- # - # ELLM model ok - ellm_model_id = "ellm/llama-3.1-8B" - config = owl.admin.backend.get_model_config() - config.llm_models.append( - LLMModelConfig( - id=ellm_model_id, - name="ELLM Meta Llama 3.1 (8B)", - deployments=[ - ModelDeploymentConfig( - litellm_id="together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", - provider="together_ai", - ) - ], - context_length=8000, - languages=["mul"], - capabilities=["chat"], - owned_by="ellm", - ) - ) - with _set_model_config(owl, config): - _chat(jamai, ellm_model_id) - _assert_usage_updated() - # Exhaust the quota - owl.admin.backend.add_event( - EventCreate( - id=f"{org.quota_reset_at}_llm_tokens_{uuid7_str()}", - organization_id=org.id, - values={ - ProductType.CREDIT: 0.0, - ProductType.CREDIT_GRANT: 0.0, - ProductType.LLM_TOKENS: 100000.0, - }, - ) - ) - # No quota to use ELLM model - _assert_chat_fail(ellm_model_id) - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_credit_check_embedding(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - assert isinstance(org, OrganizationRead) - assert isinstance(org.id, str) - assert len(org.id) > 0 - with _create_project(owl, org.id) as project: - assert isinstance(project.id, str) - assert len(project.id) > 0 - # Get model list - jamai = JamAI(project_id=project.id) - models = jamai.model_info(capabilities=["embed"]).data - assert isinstance(models, list) - models = {m.owned_by: m for m in models} - model = models["openai"] - - # --- No credit to use 3rd party models --- # - assert org.credit == 0 - assert len(model.id) > 0 - # Error message should show model ID when called via API - with pytest.raises( - RuntimeError, - match=rf"Insufficient Embedding token quota or credits for model: {model.id}", - ): - _embed(jamai, model.id) - assert len(model.name) > 0 - assert model.name != model.id - # Error message should show model name when called via browser - name = model.name.replace("(", "\\(").replace(")", "\\)") - with pytest.raises( - RuntimeError, - match=f"Insufficient Embedding token quota or credits for model: {name}", - ): - _embed( - JamAI(project_id=project.id, headers={"User-Agent": "Mozilla"}), - model.id, - ) - - @retry( - wait=wait_exponential(multiplier=1, min=1, max=10), - stop=stop_after_attempt(5), - reraise=True, - ) - def _assert_usage_updated(initial_value: int | float = 0): - org_read = owl.admin.backend.get_organization(org.id) - assert isinstance(org_read, OrganizationRead) - assert org_read.embedding_tokens_usage_mtok > initial_value - - @retry( - wait=wait_exponential(multiplier=1, min=1, max=20), stop=stop_after_attempt(10) - ) - def _assert_embed_fail(_model_id: str): - # No more credit left - try: - _embed(jamai, _model_id) - logger.warning( - f"Org credit grant = {owl.admin.backend.get_organization(org.id).credit_grant}" - ) - except RuntimeError as e: - if ( - f"Insufficient Embedding token quota or credits for model: {_model_id}" - not in str(e) - ): - raise ValueError("Error message mismatch") from e - # We actually want this to raise RuntimeError - else: - raise ValueError("Embedding attempt did not fail.") - - # --- Test credit --- # - owl.admin.backend.add_event( - EventCreate( - id=f"{org.quota_reset_at}_credit_{uuid7_str()}", - organization_id=org.id, - values={ProductType.CREDIT: 1e-12}, - ) - ) - _embed(jamai, model.id) - _assert_embed_fail(model.id) - - # --- Test credit grant --- # - owl.admin.backend.add_event( - EventCreate( - id=f"{org.quota_reset_at}_credit_{uuid7_str()}", - organization_id=org.id, - values={ - ProductType.CREDIT: 0.0, - ProductType.CREDIT_GRANT: 1e-12, - }, - ) - ) - org = owl.admin.backend.get_organization(org.id) - assert org.credit == 0 - assert org.credit_grant == 1e-12 - _embed(jamai, model.id) - _assert_embed_fail(model.id) - - # --- Test ELLM model --- # - # ELLM model ok - model = models["ellm"] - _embed(jamai, model.id) - _assert_usage_updated() - # Exhaust the quota - owl.admin.backend.add_event( - EventCreate( - id=f"{org.quota_reset_at}_llm_tokens_{uuid7_str()}", - organization_id=org.id, - values={ - ProductType.CREDIT: 0.0, - ProductType.CREDIT_GRANT: 0.0, - ProductType.EMBEDDING_TOKENS: 100000.0, - }, - ) - ) - # No quota to use ELLM model - _assert_embed_fail(model.id) - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_external_api_key(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - assert isinstance(org, OrganizationRead) - assert isinstance(org.id, str) - assert len(org.id) > 0 - owl.admin.backend.add_event( - EventCreate( - id=f"{org.quota_reset_at}_credit_{uuid7_str()}", - organization_id=org.id, - values={ProductType.CREDIT: 0.001}, - ) - ) - with _create_project(owl, org.id) as project: - assert isinstance(project.id, str) - assert len(project.id) > 0 - # Get model list - jamai = JamAI(project_id=project.id) - models = jamai.model_info(capabilities=["chat"]).data - assert isinstance(models, list) - models = {m.owned_by: m for m in models} - model = models["openai"] - # Will use ELLM's OpenAI API key - _chat(jamai, model.id) - # Replace with fake key - org = owl.admin.backend.update_organization( - OrganizationUpdate(id=org.id, external_keys=dict(openai="fake-key")) - ) - assert org.external_keys["openai"] == "fake-key" - with pytest.raises(RuntimeError, match="AuthenticationError"): - _chat(jamai, model.id) - # Ensure no cooldown - org = owl.admin.backend.update_organization( - OrganizationUpdate(id=org.id, external_keys=dict()) - ) - _chat(jamai, model.id) - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_concurrent_usage(client_cls: Type[JamAI]): - def _work(worker_id: int, mp_dict: dict): - owl = client_cls() - # Fetch model list as external org - with _create_user(owl, f"user_{worker_id}") as user: - with _create_org(owl, user.id, name=f"org_{worker_id}") as org: - assert isinstance(org.id, str) - assert len(org.id) > 0 - # Add credit - owl.admin.backend.add_event( - EventCreate( - id=f"{org.quota_reset_at}_credit_{uuid7_str()}", - organization_id=org.id, - values={ProductType.CREDIT: 20.0}, - ) - ) - with _create_project(owl, org.id, name=f"proj_{worker_id}") as project: - assert isinstance(project.id, str) - assert len(project.id) > 0 - # Test model list - jamai = JamAI(project_id=project.id) - models = jamai.model_names(capabilities=["chat"]) - assert isinstance(models, list) - # Test chat - _chat(jamai, "") - # Test gen table - data = dict( - input="Hi", - Title="Dune: Part Two.", - Text='"Dune: Part Two" is a 2024 American epic science fiction film.', - User="Tell me a joke.", - ) - for table_type in TABLE_TYPES: - with _create_gen_table( - jamai, table_type, f"table_{table_type}_{worker_id}" - ) as table: - response = jamai.table.add_table_rows( - table_type, - RowAddRequest(table_id=table.id, data=[data], stream=False), - ) - assert isinstance(response, GenTableRowsChatCompletionChunks) - assert len(response.rows) > 0 - response = jamai.table.add_table_rows( - table_type, - RowAddRequest(table_id=table.id, data=[data], stream=True), - ) - responses = [r for r in response] - assert len(responses) > 0 - assert all( - isinstance(r, GenTableStreamChatCompletionChunk) for r in responses - ) - meta = jamai.table.get_table(table_type, table.id) - mp_dict[str(worker_id)] = meta - - num_workers = 5 - manager = Manager() - return_dict = manager.dict() - workers = [Process(target=_work, args=(i, return_dict)) for i in range(num_workers)] - for worker in workers: - worker.start() - for worker in workers: - worker.join() - assert len(return_dict) == num_workers - metas = list(return_dict.values()) - assert all(isinstance(meta, TableMetaResponse) for meta in metas) - assert all(meta.num_rows == 2 for meta in metas) - - -if __name__ == "__main__": - test_pat(JamAI) diff --git a/clients/python/tests/cloud/test_org_admin.py b/clients/python/tests/cloud/test_org_admin.py deleted file mode 100644 index 16a8680..0000000 --- a/clients/python/tests/cloud/test_org_admin.py +++ /dev/null @@ -1,848 +0,0 @@ -from contextlib import asynccontextmanager, contextmanager -from inspect import signature -from io import BytesIO -from os.path import join -from tempfile import TemporaryDirectory -from time import perf_counter -from typing import Generator, Type - -import pytest -from loguru import logger -from tenacity import retry, stop_after_attempt, wait_exponential - -from jamaibase import JamAI, JamAIAsync -from jamaibase.protocol import ( - ActionTableSchemaCreate, - AdminOrderBy, - ChatTableSchemaCreate, - ColumnSchemaCreate, - EventCreate, - GenTableRowsChatCompletionChunks, - GenTableStreamChatCompletionChunk, - KnowledgeTableSchemaCreate, - LLMGenConfig, - LLMModelConfig, - ModelDeploymentConfig, - ModelListConfig, - OkResponse, - OrganizationCreate, - OrganizationRead, - ProjectCreate, - ProjectRead, - ProjectUpdate, - RowAddRequest, - TableMetaResponse, - TableType, - UserCreate, - UserRead, -) -from jamaibase.utils import run -from owl.configs.manager import PlanName, ProductType -from owl.utils import uuid7_str - -CLIENT_CLS = [JamAI] -USER_ID_A = "duncan" -USER_ID_B = "mama" -USER_ID_C = "sus" -TABLE_TYPES = [TableType.action, TableType.knowledge, TableType.chat] - - -@contextmanager -def _create_user( - owl: JamAI, - user_id: str = USER_ID_A, - **kwargs, -) -> Generator[UserRead, None, None]: - owl.admin.backend.delete_user(user_id) - try: - user = owl.admin.backend.create_user( - UserCreate( - id=user_id, - name=kwargs.pop("name", "Duncan Idaho"), - description=kwargs.pop("description", "A Ginaz Swordmaster from House Atreides."), - email=kwargs.pop("email", "duncan.idaho@gmail.com"), - meta=kwargs.pop("meta", {}), - ) - ) - yield user - finally: - owl.admin.backend.delete_user(user_id) - - -@contextmanager -def _create_org( - owl: JamAI, - user_id: str, - active: bool = True, - **kwargs, -) -> Generator[OrganizationRead, None, None]: - org_id = None - try: - org = owl.admin.backend.create_organization( - OrganizationCreate( - creator_user_id=user_id, - name=kwargs.pop("name", "Company"), - external_keys=kwargs.pop("external_keys", {}), - tier=kwargs.pop("tier", PlanName.FREE), - active=active, - **kwargs, - ) - ) - org_id = org.id - yield org - finally: - if org_id is not None: - owl.admin.backend.delete_organization(org_id) - - -def _delete_project(owl: JamAI, project_id: str | None): - if project_id is not None: - owl.admin.organization.delete_project(project_id) - - -@contextmanager -def _create_project( - owl: JamAI, - organization_id: str, - name: str = "default", -) -> Generator[OrganizationRead, None, None]: - project_id = None - try: - project = owl.admin.organization.create_project( - ProjectCreate( - organization_id=organization_id, - name=name, - ) - ) - project_id = project.id - yield project - finally: - _delete_project(owl, project_id) - - -@asynccontextmanager -async def _set_org_model_config( - jamai: JamAI | JamAIAsync, - org_id: str, - config: ModelListConfig, -): - old_config = await run(jamai.admin.organization.get_org_model_config, org_id) - try: - response = await run(jamai.admin.organization.set_org_model_config, org_id, config) - assert isinstance(response, OkResponse) - yield response - finally: - await run(jamai.admin.organization.set_org_model_config, org_id, old_config) - - -@contextmanager -def _create_gen_table( - jamai: JamAI, - table_type: TableType, - table_id: str, - model_id: str = "", - cols: list[ColumnSchemaCreate] | None = None, - chat_cols: list[ColumnSchemaCreate] | None = None, - embedding_model: str = "", - delete_first: bool = True, - delete: bool = True, -): - try: - if delete_first: - jamai.table.delete_table(table_type, table_id) - if cols is None: - cols = [ - ColumnSchemaCreate(id="input", dtype="str"), - ColumnSchemaCreate( - id="output", - dtype="str", - gen_config=LLMGenConfig( - model=model_id, - prompt="${input}", - max_tokens=3, - ), - ), - ] - if chat_cols is None: - chat_cols = [ - ColumnSchemaCreate(id="User", dtype="str"), - ColumnSchemaCreate( - id="AI", - dtype="str", - gen_config=LLMGenConfig( - model=model_id, - system_prompt="You are an assistant.", - max_tokens=3, - ), - ), - ] - if table_type == TableType.action: - table = jamai.table.create_action_table( - ActionTableSchemaCreate(id=table_id, cols=cols) - ) - elif table_type == TableType.knowledge: - table = jamai.table.create_knowledge_table( - KnowledgeTableSchemaCreate(id=table_id, cols=cols, embedding_model=embedding_model) - ) - elif table_type == TableType.chat: - table = jamai.table.create_chat_table( - ChatTableSchemaCreate(id=table_id, cols=chat_cols + cols) - ) - else: - raise ValueError(f"Invalid table type: {table_type}") - assert isinstance(table, TableMetaResponse) - yield table - finally: - if delete: - jamai.table.delete_table(table_type, table_id) - - -def _add_row( - jamai: JamAI, - table_type: TableType, - table_id: str, - stream: bool = False, - data: dict | None = None, - knowledge_data: dict | None = None, - chat_data: dict | None = None, -): - if data is None: - data = dict(input="nano", output="shimmer") - - if knowledge_data is None: - knowledge_data = dict( - Title="Dune: Part Two.", - Text='"Dune: Part Two" is a 2024 American epic science fiction film.', - ) - if chat_data is None: - chat_data = dict(User="Tell me a joke.", AI="Who's there?") - if table_type == TableType.action: - pass - elif table_type == TableType.knowledge: - data.update(knowledge_data) - elif table_type == TableType.chat: - data.update(chat_data) - else: - raise ValueError(f"Invalid table type: {table_type}") - response = jamai.table.add_table_rows( - table_type, - RowAddRequest(table_id=table_id, data=[data], stream=stream), - ) - if stream: - response = list(response) - assert all(isinstance(r, GenTableStreamChatCompletionChunk) for r in response) - else: - assert isinstance(response, GenTableRowsChatCompletionChunks) - return response - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -async def test_get_set_org_model_config(client_cls: Type[JamAI | JamAIAsync]): - owl = client_cls() - # Get model config - config = await run(owl.admin.backend.get_model_config) - assert isinstance(config, ModelListConfig) - assert isinstance(config.models, list) - assert len(config.models) > 3 - assert isinstance(config.llm_models, list) - assert isinstance(config.embed_models, list) - assert isinstance(config.rerank_models, list) - assert len(config.llm_models) > 1 - assert len(config.embed_models) > 1 - assert len(config.rerank_models) > 1 - public_model_ids = [m.id for m in config.models] - assert "ellm/new_model" not in public_model_ids - # Set organization model config - with _create_user(owl) as duncan: - with ( - _create_org(owl, duncan.id) as org, - _create_org(owl, duncan.id, name="personal", tier=PlanName.PRO) as personal, - ): - assert isinstance(org.id, str) - assert len(org.id) > 0 - assert isinstance(personal.id, str) - assert len(personal.id) > 0 - with _create_project(owl, org.id) as p0, _create_project(owl, personal.id) as p1: - assert isinstance(p0.id, str) - assert len(p0.id) > 0 - assert isinstance(p1.id, str) - assert len(p1.id) > 0 - # Set - jamai = JamAI(project_id=p0.id) - new_model_config = ModelListConfig( - llm_models=[ - LLMModelConfig( - id="ellm/new_model", - name="ELLM hyperbolic Llama3.2-3B", - context_length=8000, - languages=["mul"], - capabilities=["chat"], - owned_by="ellm", - deployments=[ - ModelDeploymentConfig( - litellm_id="openai/meta-llama/Llama-3.2-3B-Instruct", - api_base="https://api.hyperbolic.xyz/v1", - provider="hyperbolic", - ), - ], - ) - ] - ) - async with _set_org_model_config(jamai, org.id, new_model_config): - # Fetch org-level config - models = await run(jamai.admin.organization.get_org_model_config, org.id) - assert isinstance(models, ModelListConfig) - assert isinstance(models.llm_models, list) - assert isinstance(models.embed_models, list) - assert isinstance(models.rerank_models, list) - assert len(models.llm_models) == 1 - assert len(models.embed_models) == 0 - assert len(models.rerank_models) == 0 - # Fetch model list - models = await run(jamai.model_names) - assert isinstance(models, list) - assert set(public_model_ids) - set(models) == set() - assert set(models) - set(public_model_ids) == {"ellm/new_model"} - # text add row with org_model - with _create_gen_table( - jamai, TableType.action, "test-org-model", "ellm/new_model", delete=True - ): - _add_row(jamai, TableType.action, "test-org-model") - # Try fetching from another org - jamai = JamAI(project_id=p1.id) - models = await run(jamai.model_names) - assert isinstance(models, list) - assert set(public_model_ids) - set(models) == set() - assert set(models) - set(public_model_ids) == set() - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_create_project(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - assert isinstance(org.id, str) - assert len(org.id) > 0 - with _create_project(owl, org.id, "my-project") as project: - assert isinstance(project.id, str) - assert len(project.id) > 0 - # Duplicate name - with pytest.raises(RuntimeError): - with _create_project(owl, org.id, "my-project"): - pass - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -@pytest.mark.parametrize( - "name", ["a", "0", "冰:淇 淋", "a.b", "_a_", " (a) ", "=a", " " + "a" * 100] -) -def test_create_organization_project_valid_name( - client_cls: Type[JamAI], - name: str, -): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id, name=name) as org: - assert isinstance(org.id, str) - assert len(org.id) > 0 - with _create_project(owl, org.id, name=name) as project: - assert isinstance(project.id, str) - assert len(project.id) > 0 - assert project.name == name.strip() - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -@pytest.mark.parametrize("name", ["=", " ", "()", "a" * 101]) -def test_create_organization_project_invalid_name( - client_cls: Type[JamAI], - name: str, -): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - assert isinstance(org.id, str) - assert len(org.id) > 0 - with pytest.raises(RuntimeError): - with _create_project(owl, org.id, name=name): - pass - with pytest.raises(RuntimeError): - with _create_org(owl, duncan.id, name=name): - pass - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_get_and_list_projects(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with ( - _create_org(owl, duncan.id) as org, - _create_org(owl, duncan.id, name="Personal", tier=PlanName.PRO) as personal, - ): - assert isinstance(org.id, str) - assert len(org.id) > 0 - assert org.name == "Company" - assert personal.name == "Personal" - with ( - _create_project(owl, org.id, "bear") as proj_bear, - _create_project(owl, personal.id) as personal_default, - ): - with _create_project(owl, org.id, "Pear") as proj_pear: - with _create_project(owl, org.id, "pearl") as proj_pearl: - assert isinstance(proj_bear.id, str) - assert len(proj_bear.id) > 0 - assert isinstance(proj_pear.id, str) - assert len(proj_pear.id) > 0 - - # Test fetch - project = owl.admin.organization.get_project(proj_bear.id) - assert isinstance(project, ProjectRead) - assert project.id == proj_bear.id - assert project.name == "bear" - assert isinstance(project.organization.members, list) - assert len(project.organization.members) == 1 - - project = owl.admin.organization.get_project(proj_pear.id) - assert isinstance(project, ProjectRead) - assert project.id == proj_pear.id - assert project.name == "Pear" - - project = owl.admin.organization.get_project(proj_pearl.id) - assert isinstance(project, ProjectRead) - assert project.id == proj_pearl.id - assert project.name == "pearl" - - project = owl.admin.organization.get_project(personal_default.id) - assert isinstance(project, ProjectRead) - assert project.id == personal_default.id - assert project.name == "default" - - # Test association - org = owl.admin.backend.get_organization(org.id) - assert isinstance(org, OrganizationRead) - assert all(isinstance(p, ProjectRead) for p in org.projects) - proj_names = [p.name for p in org.projects] - assert "bear" in proj_names - assert "Pear" in proj_names - assert "pearl" in proj_names - - # Test list - projects = owl.admin.organization.list_projects(org.id) - assert isinstance(projects.items, list) - assert all(isinstance(r, ProjectRead) for r in projects.items) - assert projects.total == 3 - assert projects.offset == 0 - assert projects.limit == 100 - assert len(projects.items) == 3 - - projects = owl.admin.organization.list_projects(personal.id) - assert isinstance(projects.items, list) - assert all(isinstance(r, ProjectRead) for r in projects.items) - assert projects.total == 1 - assert projects.offset == 0 - assert projects.limit == 100 - assert len(projects.items) == 1 - - projects = owl.admin.organization.list_projects(org.id, offset=1) - assert isinstance(projects.items, list) - assert all(isinstance(r, ProjectRead) for r in projects.items) - assert projects.total == 3 - assert projects.offset == 1 - assert projects.limit == 100 - assert len(projects.items) == 2 - - projects = owl.admin.organization.list_projects(org.id, limit=1) - assert isinstance(projects.items, list) - assert all(isinstance(r, ProjectRead) for r in projects.items) - assert projects.total == 3 - assert projects.offset == 0 - assert projects.limit == 1 - assert len(projects.items) == 1 - - # Test list with search query - projects = owl.admin.organization.list_projects(org.id, "ear") - assert isinstance(projects.items, list) - assert all(isinstance(r, ProjectRead) for r in projects.items) - assert projects.total == 3 - assert projects.offset == 0 - assert projects.limit == 100 - assert len(projects.items) == 3 - - projects = owl.admin.organization.list_projects(org.id, "pe") - assert isinstance(projects.items, list) - assert all(isinstance(r, ProjectRead) for r in projects.items) - assert projects.total == 2 - assert projects.offset == 0 - assert projects.limit == 100 - assert len(projects.items) == 2 - - projects = owl.admin.organization.list_projects(org.id, "pe", offset=1) - assert isinstance(projects.items, list) - assert all(isinstance(r, ProjectRead) for r in projects.items) - assert projects.total == 2 - assert projects.offset == 1 - assert projects.limit == 100 - assert len(projects.items) == 1 - - projects = owl.admin.organization.list_projects(org.id, "pe", limit=1) - assert isinstance(projects.items, list) - assert all(isinstance(r, ProjectRead) for r in projects.items) - assert projects.total == 2 - assert projects.offset == 0 - assert projects.limit == 1 - assert len(projects.items) == 1 - - # Test list with order_by - projects = owl.admin.organization.list_projects(org.id, "pe") - assert isinstance(projects.items, list) - assert all(isinstance(r, ProjectRead) for r in projects.items) - assert projects.items[0].name == "pearl" - assert projects.items[1].name == "Pear" - assert projects.total == 2 - assert projects.offset == 0 - assert projects.limit == 100 - assert len(projects.items) == 2 - - projects = owl.admin.organization.list_projects( - org.id, "pe", order_descending=False - ) - assert isinstance(projects.items, list) - assert all(isinstance(r, ProjectRead) for r in projects.items) - assert projects.items[0].name == "Pear" - assert projects.items[1].name == "pearl" - assert projects.total == 2 - assert projects.offset == 0 - assert projects.limit == 100 - assert len(projects.items) == 2 - - projects = owl.admin.organization.list_projects(org.id, order_by="name") - assert isinstance(projects.items, list) - assert all(isinstance(r, ProjectRead) for r in projects.items) - assert [p.name for p in projects.items] == ["pearl", "Pear", "bear"] - assert projects.total == 3 - assert projects.offset == 0 - assert projects.limit == 100 - assert len(projects.items) == 3 - - projects = owl.admin.organization.list_projects( - org.id, order_by="name", order_descending=False - ) - assert isinstance(projects.items, list) - assert all(isinstance(r, ProjectRead) for r in projects.items) - assert [p.name for p in projects.items] == ["bear", "Pear", "pearl"] - assert projects.total == 3 - assert projects.offset == 0 - assert projects.limit == 100 - assert len(projects.items) == 3 - - for order_by in AdminOrderBy: - projects = owl.admin.organization.list_projects( - org.id, order_by=order_by - ) - assert len(projects.items) == 3 - proj_ids = [p.id for p in projects.items] - projects_desc = owl.admin.organization.list_projects( - org.id, order_by=order_by, order_descending=False - ) - assert len(projects_desc.items) == 3 - proj_desc_ids = [p.id for p in projects_desc.items] - assert ( - proj_ids == proj_desc_ids[::-1] - ), f"Failed to order by {order_by}: {proj_ids} != {proj_desc_ids[::-1]}" - - # # Test starting_after - # projects = owl.admin.organization.list_projects( - # org.id, order_by="name", starting_after=proj_pearl.id - # ) - # assert isinstance(projects.items, list) - # assert all(isinstance(r, ProjectRead) for r in projects.items) - # assert [p.name for p in projects.items] == ["Pear", "bear"] - # assert projects.total == 3 - # assert projects.offset == 0 - # assert projects.limit == 100 - # assert len(projects.items) == 2 - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_delete_projects(client_cls: Type[JamAI]): - owl = client_cls() - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - assert isinstance(org.id, str) - assert len(org.id) > 0 - with _create_project(owl, org.id) as project: - assert isinstance(project.id, str) - assert len(project.id) > 0 - response = owl.admin.organization.delete_project(project.id) - assert isinstance(response, OkResponse) - with pytest.raises(RuntimeError, match="Project .+ is not found."): - owl.admin.organization.update_project( - ProjectUpdate(id=project.id, name="Updated Project") - ) - - with pytest.raises(RuntimeError, match="Project .+ is not found."): - owl.admin.organization.get_project(project.id) - - response = owl.admin.organization.delete_project(project.id) - assert isinstance(response, OkResponse) - with pytest.raises(RuntimeError, match="Project .+ is not found."): - owl.admin.organization.delete_project(project.id, missing_ok=False) - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_update_project(client_cls: Type[JamAI]): - owl = client_cls() - - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - with _create_project(owl, org.id) as project: - updated_project_request = ProjectUpdate(id=project.id, name="Updated Project") - updated_project_response = owl.admin.organization.update_project( - updated_project_request - ) - assert isinstance(updated_project_response, ProjectRead) - assert updated_project_response.id == project.id - assert updated_project_response.name == "Updated Project" - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -def test_project_updated_at(client_cls: Type[JamAI]): - owl = client_cls() - - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id) as org: - assert isinstance(org.id, str) - assert len(org.id) > 0 - # Add credit - owl.admin.backend.add_event( - EventCreate( - id=f"{org.quota_reset_at}_credit_{uuid7_str()}", - organization_id=org.id, - values={ProductType.CREDIT: 20.0}, - ) - ) - with _create_project(owl, org.id) as project: - assert isinstance(project.id, str) - assert len(project.id) > 0 - old_proj_updated_at = project.updated_at - jamai = JamAI(project_id=project.id) - # Test gen table - with _create_gen_table(jamai, TABLE_TYPES[0], "xx"): - pass - - @retry( - wait=wait_exponential(multiplier=1, min=1, max=10), - stop=stop_after_attempt(5), - reraise=True, - ) - def _assert_bumped_updated_at(): - proj = owl.admin.organization.get_project(project.id) - assert isinstance(proj, ProjectRead) - assert proj.updated_at > old_proj_updated_at - - t0 = perf_counter() - _assert_bumped_updated_at() - logger.info(f"Succeeded after {perf_counter() - t0:,.2f} seconds") - - -def test_project_update_model(): - sig = signature(ProjectUpdate) - for name, param in sig.parameters.items(): - if name == "id": - continue - assert ( - param.default is None - ), f'Parameter "{name}" has a default value of {param.default} instead of None.' - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -@pytest.mark.parametrize("empty_project", [True, False], ids=["Empty project", "With data"]) -def test_project_import_export_round_trip(client_cls: Type[JamAI], empty_project: bool): - owl = client_cls() - - with _create_user(owl) as duncan: - with ( - _create_org(owl, duncan.id, name="Personal", tier=PlanName.PRO) as o0, - _create_org(owl, duncan.id, name="Company", tier=PlanName.PRO) as o1, - ): - assert isinstance(o0.id, str) - assert len(o0.id) > 0 - assert isinstance(o1.id, str) - assert len(o1.id) > 0 - assert o0.id != o1.id - # Add credit - owl.admin.backend.add_event( - EventCreate( - id=f"{o0.quota_reset_at}_credit_{uuid7_str()}", - organization_id=o0.id, - values={ProductType.CREDIT: 20.0}, - ) - ) - with _create_project(owl, o0.id) as p0, _create_project(owl, o0.id, "p1") as p1: - assert isinstance(p0.id, str) - assert len(p0.id) > 0 - # Create some tables - jamai = JamAI(project_id=p0.id) - if not empty_project: - for table_type in TABLE_TYPES: - with _create_gen_table(jamai, table_type, table_type, delete=False): - _add_row(jamai, table_type, table_type) - - def _check_tables(_project_id: str): - jamai = JamAI(project_id=_project_id) - if empty_project: - for table_type in TABLE_TYPES: - assert jamai.table.list_tables(table_type).total == 0 - else: - for table_type in TABLE_TYPES: - assert jamai.table.list_tables(table_type).total == 1 - rows = jamai.table.list_table_rows(table_type, table_type) - assert len(rows.items) == 1 - - # --- Export --- # - data = jamai.admin.organization.export_project(p0.id) - - # --- Import as new project --- # - # Test file-like object - with BytesIO(data) as f: - new_p0 = jamai.admin.organization.import_project(f, o0.id) - assert isinstance(new_p0, ProjectRead) - _check_tables(new_p0.id) - # List the projects - proj_ids = set(p.id for p in owl.admin.organization.list_projects(o0.id).items) - assert len(proj_ids) == 3 # Also ensures uniqueness - assert p0.id in proj_ids - assert p1.id in proj_ids - assert new_p0.id in proj_ids - - # --- Import into existing project --- # - # Test file path - with TemporaryDirectory() as tmp_dir: - export_filepath = join(tmp_dir, "project.parquet") - with open(export_filepath, "wb") as f: - f.write(data) - new_p1 = jamai.admin.organization.import_project(export_filepath, o0.id, p1.id) - assert isinstance(new_p1, ProjectRead) - assert new_p1.id == p1.id - _check_tables(new_p1.id) - # List the projects - proj_ids = set(p.id for p in owl.admin.organization.list_projects(o0.id).items) - assert len(proj_ids) == 3 # Also ensures uniqueness - assert p0.id in proj_ids - assert p1.id in proj_ids - assert new_p0.id in proj_ids - - # --- Import again, should fail --- # - if not empty_project: - with BytesIO(data) as f: - with pytest.raises(RuntimeError): - jamai.admin.organization.import_project(f, o0.id, p1.id) - - # --- Import into another organization --- # - with BytesIO(data) as f: - project = JamAI().admin.organization.import_project(f, o1.id) - assert isinstance(project, ProjectRead) - _check_tables(project.id) - # List the projects - proj_ids = set(p.id for p in owl.admin.organization.list_projects(o1.id).items) - assert len(proj_ids) == 1 - assert project.id in proj_ids - - -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -@pytest.mark.parametrize("empty_project", [True, False], ids=["Empty project", "With data"]) -def test_project_import_export_template(client_cls: Type[JamAI], empty_project: bool): - owl = client_cls() - - with _create_user(owl) as duncan: - with _create_org(owl, duncan.id, name="Personal") as o0: - assert isinstance(o0.id, str) - assert len(o0.id) > 0 - # Add credit - owl.admin.backend.add_event( - EventCreate( - id=f"{o0.quota_reset_at}_credit_{uuid7_str()}", - organization_id=o0.id, - values={ProductType.CREDIT: 20.0}, - ) - ) - with ( - _create_project(owl, o0.id) as p0, - _create_project(owl, o0.id, "p1") as p1, - _create_project(owl, o0.id, "p2") as p2, - ): - assert isinstance(p0.id, str) - assert len(p0.id) > 0 - # Create some tables - jamai = JamAI(project_id=p0.id) - if not empty_project: - for table_type in TABLE_TYPES: - with _create_gen_table(jamai, table_type, table_type, delete=False): - _add_row(jamai, table_type, table_type) - - def _check_tables(_project_id: str): - jamai = JamAI(project_id=_project_id) - if empty_project: - for table_type in TABLE_TYPES: - assert jamai.table.list_tables(table_type).total == 0 - else: - for table_type in TABLE_TYPES: - assert jamai.table.list_tables(table_type).total == 1 - rows = jamai.table.list_table_rows(table_type, table_type) - assert len(rows.items) == 1 - - # --- Export template --- # - data = jamai.admin.organization.export_project_as_template( - p0.id, - name="Template 试验", - tags=["sector:finance", "sector:科技"], - description="テンプレート description", - ) - with BytesIO(data) as f: - # Import as new project - new_p0 = jamai.admin.organization.import_project(f, o0.id) - assert isinstance(new_p0, ProjectRead) - _check_tables(new_p0.id) - # Import into existing project - new_p1 = jamai.admin.organization.import_project(f, o0.id, p1.id) - assert isinstance(new_p1, ProjectRead) - assert new_p1.id == p1.id - _check_tables(new_p1.id) - # List the projects - proj_ids = set(p.id for p in owl.admin.organization.list_projects(o0.id).items) - assert len(proj_ids) == 4 # Also ensures uniqueness - assert p0.id in proj_ids - assert p1.id in proj_ids - assert p2.id in proj_ids - assert new_p0.id in proj_ids - - # --- Add template --- # - new_template_id = "test_template" - response = jamai.admin.backend.add_template(f, new_template_id, True) - assert isinstance(response, OkResponse) - # Add again, should fail - with pytest.raises(RuntimeError): - jamai.admin.backend.add_template(f, new_template_id) - # List templates - template_ids = set(t.id for t in jamai.template.list_templates().items) - assert new_template_id in template_ids - # Import as new project - new_p2 = jamai.admin.organization.import_project_from_template( - o0.id, new_template_id - ) - assert isinstance(new_p2, ProjectRead) - _check_tables(new_p2.id) - # Import into existing project - new_p3 = jamai.admin.organization.import_project_from_template( - o0.id, new_template_id, p2.id - ) - assert isinstance(new_p3, ProjectRead) - assert new_p3.id == p2.id - _check_tables(new_p3.id) - # List the projects - proj_ids = set(p.id for p in owl.admin.organization.list_projects(o0.id).items) - assert len(proj_ids) == 5 # Also ensures uniqueness - assert p0.id in proj_ids - assert p1.id in proj_ids - assert p2.id in proj_ids - assert new_p0.id in proj_ids - assert new_p2.id in proj_ids diff --git a/clients/python/tests/oss/gen_table/test_export_ops.py b/clients/python/tests/oss/gen_table/test_export_ops.py index 4a04446..1907869 100644 --- a/clients/python/tests/oss/gen_table/test_export_ops.py +++ b/clients/python/tests/oss/gen_table/test_export_ops.py @@ -12,11 +12,11 @@ from flaky import flaky from jamaibase import JamAI -from jamaibase import protocol as p +from jamaibase import types as t from jamaibase.utils.io import csv_to_df, df_to_csv CLIENT_CLS = [JamAI] -TABLE_TYPES = [p.TableType.action, p.TableType.knowledge, p.TableType.chat] +TABLE_TYPES = [t.TableType.action, t.TableType.knowledge, t.TableType.chat] TABLE_ID_A = "table_a" TABLE_ID_B = "table_b" @@ -65,10 +65,10 @@ def _rerun_on_fs_error_with_delay(err, *args): @contextmanager def _create_table( jamai: JamAI, - table_type: p.TableType, + table_type: t.TableType, table_id: str = TABLE_ID_A, - cols: list[p.ColumnSchemaCreate] | None = None, - chat_cols: list[p.ColumnSchemaCreate] | None = None, + cols: list[t.ColumnSchemaCreate] | None = None, + chat_cols: list[t.ColumnSchemaCreate] | None = None, embedding_model: str | None = None, delete_first: bool = True, ): @@ -77,15 +77,15 @@ def _create_table( jamai.table.delete_table(table_type, table_id) if cols is None: cols = [ - p.ColumnSchemaCreate(id="good", dtype="bool"), - p.ColumnSchemaCreate(id="words", dtype="int"), - p.ColumnSchemaCreate(id="stars", dtype="float"), - p.ColumnSchemaCreate(id="inputs", dtype="str"), - p.ColumnSchemaCreate(id="photo", dtype="image"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="good", dtype="bool"), + t.ColumnSchemaCreate(id="words", dtype="int"), + t.ColumnSchemaCreate(id="stars", dtype="float"), + t.ColumnSchemaCreate(id="inputs", dtype="str"), + t.ColumnSchemaCreate(id="photo", dtype="image"), + t.ColumnSchemaCreate( id="summary", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), system_prompt="You are a concise assistant.", # Interpolate string and non-string input columns @@ -95,10 +95,10 @@ def _create_table( max_tokens=10, ), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="captioning", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="", system_prompt="You are a concise assistant.", # Interpolate file input column @@ -111,11 +111,11 @@ def _create_table( ] if chat_cols is None: chat_cols = [ - p.ColumnSchemaCreate(id="User", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="User", dtype="str"), + t.ColumnSchemaCreate( id="AI", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), system_prompt="You are a wacky assistant.", temperature=0.001, @@ -125,25 +125,25 @@ def _create_table( ), ] - if table_type == p.TableType.action: + if table_type == t.TableType.action: table = jamai.table.create_action_table( - p.ActionTableSchemaCreate(id=table_id, cols=cols) + t.ActionTableSchemaCreate(id=table_id, cols=cols) ) - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: if embedding_model is None: embedding_model = "" table = jamai.table.create_knowledge_table( - p.KnowledgeTableSchemaCreate( + t.KnowledgeTableSchemaCreate( id=table_id, cols=cols, embedding_model=embedding_model ) ) - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: table = jamai.table.create_chat_table( - p.ChatTableSchemaCreate(id=table_id, cols=chat_cols + cols) + t.ChatTableSchemaCreate(id=table_id, cols=chat_cols + cols) ) else: raise ValueError(f"Invalid table type: {table_type}") - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) yield table finally: jamai.table.delete_table(table_type, table_id) @@ -151,7 +151,7 @@ def _create_table( def _add_row( jamai: JamAI, - table_type: p.TableType, + table_type: t.TableType, stream: bool, table_name: str = TABLE_ID_A, data: dict | None = None, @@ -175,21 +175,21 @@ def _add_row( ) if chat_data is None: chat_data = dict(User="Tell me a joke.") - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: data.update(knowledge_data) - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: data.update(chat_data) else: raise ValueError(f"Invalid table type: {table_type}") response = jamai.table.add_table_rows( table_type, - p.RowAddRequest(table_id=table_name, data=[data], stream=stream), + t.MultiRowAddRequest(table_id=table_name, data=[data], stream=stream), ) if stream: return response - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) assert len(response.rows) == 1 return response.rows[0] @@ -201,13 +201,13 @@ def _add_row( @pytest.mark.parametrize("delimiter", [","], ids=["comma_delimiter"]) def test_import_data_complete( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, delimiter: str, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Complete CSV with TemporaryDirectory() as tmp_dir: @@ -264,7 +264,7 @@ def test_import_data_complete( df_to_csv(df, file_path, delimiter) response = jamai.import_table_data( table_type, - p.TableDataImportRequest( + t.TableDataImportRequest( file_path=file_path, table_id=table.id, stream=stream, @@ -275,14 +275,14 @@ def test_import_data_complete( responses = [r for r in response] assert len(responses) == 0 else: - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) assert response.object == "gen_table.completion.rows" rows = jamai.table.list_table_rows(table_type, table.id, vec_decimals=2) assert isinstance(rows.items, list) assert len(rows.items) == 4 for row, d in zip(rows.items[::-1], data, strict=True): - if table_type == p.TableType.knowledge: + if table_type == t.TableType.knowledge: assert isinstance(row["Text Embed"]["value"], list) assert len(row["Text Embed"]["value"]) > 0 assert isinstance(row["Title Embed"]["value"], list) @@ -291,13 +291,13 @@ def test_import_data_complete( if k not in row and k in chat_data: continue if v == "": - assert ( - row[k]["value"] is None - ), f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + assert row[k]["value"] is None, ( + f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + ) else: - assert ( - row[k]["value"] == v - ), f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + assert row[k]["value"] == v, ( + f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + ) @flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) @@ -306,23 +306,23 @@ def test_import_data_complete( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_import_data_cast_to_string( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() - gen_cfg = p.LLMGenConfig() + gen_cfg = t.LLMGenConfig() cols = [ - p.ColumnSchemaCreate(id="bool", dtype="str"), - p.ColumnSchemaCreate(id="int", dtype="str"), - p.ColumnSchemaCreate(id="float", dtype="str"), - p.ColumnSchemaCreate(id="str", dtype="str"), + t.ColumnSchemaCreate(id="bool", dtype="str"), + t.ColumnSchemaCreate(id="int", dtype="str"), + t.ColumnSchemaCreate(id="float", dtype="str"), + t.ColumnSchemaCreate(id="str", dtype="str"), # p.ColumnSchemaCreate(id="bool_out", dtype="bool", gen_config=gen_cfg), # p.ColumnSchemaCreate(id="int_out", dtype="int", gen_config=gen_cfg), # p.ColumnSchemaCreate(id="float_out", dtype="float", gen_config=gen_cfg), - p.ColumnSchemaCreate(id="str_out", dtype="str", gen_config=gen_cfg), + t.ColumnSchemaCreate(id="str_out", dtype="str", gen_config=gen_cfg), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Complete CSV with TemporaryDirectory() as tmp_dir: @@ -356,7 +356,7 @@ def test_import_data_cast_to_string( df_to_csv(df, file_path) response = jamai.import_table_data( table_type, - p.TableDataImportRequest( + t.TableDataImportRequest( file_path=file_path, table_id=table.id, stream=stream, @@ -366,14 +366,14 @@ def test_import_data_cast_to_string( responses = [r for r in response] assert len(responses) == 0 else: - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) assert response.object == "gen_table.completion.rows" rows = jamai.table.list_table_rows(table_type, table.id, vec_decimals=2) assert isinstance(rows.items, list) assert len(rows.items) == 1 for row, d in zip(rows.items[::-1], data, strict=True): - if table_type == p.TableType.knowledge: + if table_type == t.TableType.knowledge: assert isinstance(row["Text Embed"]["value"], list) assert len(row["Text Embed"]["value"]) > 0 assert isinstance(row["Title Embed"]["value"], list) @@ -381,9 +381,9 @@ def test_import_data_cast_to_string( for k, v in d.items(): if k not in row and k in chat_data: continue - assert row[k]["value"] == str( - v - ), f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + assert row[k]["value"] == str(v), ( + f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + ) @flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) @@ -392,23 +392,23 @@ def test_import_data_cast_to_string( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_import_data_cast_from_string( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() - gen_cfg = p.LLMGenConfig() + gen_cfg = t.LLMGenConfig() cols = [ - p.ColumnSchemaCreate(id="bool", dtype="bool"), - p.ColumnSchemaCreate(id="int", dtype="int"), - p.ColumnSchemaCreate(id="float", dtype="float"), - p.ColumnSchemaCreate(id="str", dtype="str"), + t.ColumnSchemaCreate(id="bool", dtype="bool"), + t.ColumnSchemaCreate(id="int", dtype="int"), + t.ColumnSchemaCreate(id="float", dtype="float"), + t.ColumnSchemaCreate(id="str", dtype="str"), # p.ColumnSchemaCreate(id="bool_out", dtype="bool", gen_config=gen_cfg), # p.ColumnSchemaCreate(id="int_out", dtype="int", gen_config=gen_cfg), # p.ColumnSchemaCreate(id="float_out", dtype="float", gen_config=gen_cfg), - p.ColumnSchemaCreate(id="str_out", dtype="str", gen_config=gen_cfg), + t.ColumnSchemaCreate(id="str_out", dtype="str", gen_config=gen_cfg), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Complete CSV with TemporaryDirectory() as tmp_dir: @@ -442,7 +442,7 @@ def test_import_data_cast_from_string( df_to_csv(df, file_path) response = jamai.import_table_data( table_type, - p.TableDataImportRequest( + t.TableDataImportRequest( file_path=file_path, table_id=table.id, stream=stream, @@ -452,14 +452,14 @@ def test_import_data_cast_from_string( responses = [r for r in response] assert len(responses) == 0 else: - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) assert response.object == "gen_table.completion.rows" rows = jamai.table.list_table_rows(table_type, table.id, vec_decimals=2) assert isinstance(rows.items, list) assert len(rows.items) == 1 for row, d in zip(rows.items[::-1], data, strict=True): - if table_type == p.TableType.knowledge: + if table_type == t.TableType.knowledge: assert isinstance(row["Text Embed"]["value"], list) assert len(row["Text Embed"]["value"]) > 0 assert isinstance(row["Title Embed"]["value"], list) @@ -467,9 +467,9 @@ def test_import_data_cast_from_string( for k, v in d.items(): if k not in row and k in chat_data: continue - assert ( - str(row[k]["value"]) == v - ), f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + assert str(row[k]["value"]) == v, ( + f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + ) @flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) @@ -478,12 +478,12 @@ def test_import_data_cast_from_string( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_import_data_cast_dtype( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Complete CSV with TemporaryDirectory() as tmp_dir: @@ -524,7 +524,7 @@ def test_import_data_cast_dtype( df_to_csv(df, file_path) response = jamai.import_table_data( table_type, - p.TableDataImportRequest( + t.TableDataImportRequest( file_path=file_path, table_id=table.id, stream=stream, @@ -534,14 +534,14 @@ def test_import_data_cast_dtype( responses = [r for r in response] assert len(responses) == 0 else: - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) assert response.object == "gen_table.completion.rows" rows = jamai.table.list_table_rows(table_type, table.id, vec_decimals=2) assert isinstance(rows.items, list) assert len(rows.items) == len(data) for row, d in zip(rows.items[::-1], gt_data, strict=True): - if table_type == p.TableType.knowledge: + if table_type == t.TableType.knowledge: assert isinstance(row["Text Embed"]["value"], list) assert len(row["Text Embed"]["value"]) > 0 assert isinstance(row["Title Embed"]["value"], list) @@ -550,13 +550,13 @@ def test_import_data_cast_dtype( if k not in row and k in chat_data: continue if v == "": - assert ( - row[k]["value"] is None - ), f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + assert row[k]["value"] is None, ( + f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + ) else: - assert ( - row[k]["value"] == v - ), f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + assert row[k]["value"] == v, ( + f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + ) @flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) @@ -565,12 +565,12 @@ def test_import_data_cast_dtype( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_import_data_incomplete( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # CSV without input column with TemporaryDirectory() as tmp_dir: @@ -623,7 +623,7 @@ def test_import_data_incomplete( df_to_csv(df, file_path) response = jamai.import_table_data( table_type, - p.TableDataImportRequest( + t.TableDataImportRequest( file_path=file_path, table_id=table.id, stream=stream, @@ -633,7 +633,7 @@ def test_import_data_incomplete( responses = [r for r in response] assert len(responses) == 0 else: - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) assert response.object == "gen_table.completion.rows" rows = jamai.table.list_table_rows(table_type, table.id, vec_decimals=2) @@ -642,13 +642,13 @@ def test_import_data_incomplete( for row, d in zip(rows.items[::-1], data, strict=True): for k in cols: if k not in d: - assert ( - row[k]["value"] is None - ), f"Imported data should be None: col=`{k}` val={row[k]}" + assert row[k]["value"] is None, ( + f"Imported data should be None: col=`{k}` val={row[k]}" + ) else: - assert ( - row[k]["value"] == d[k] - ), f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{d[k]}`" + assert row[k]["value"] == d[k], ( + f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{d[k]}`" + ) @flaky(max_runs=3, min_passes=1) @@ -657,12 +657,12 @@ def test_import_data_incomplete( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_import_data_with_generation( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # CSV without output column with TemporaryDirectory() as tmp_dir: @@ -695,7 +695,7 @@ def test_import_data_with_generation( df_to_csv(df, file_path) response = jamai.import_table_data( table_type, - p.TableDataImportRequest( + t.TableDataImportRequest( file_path=file_path, table_id=table.id, stream=stream, @@ -704,7 +704,7 @@ def test_import_data_with_generation( if stream: responses = [r for r in response] assert len(responses) > 0 - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) assert all(r.object == "gen_table.completion.chunk" for r in responses) assert all(r.output_column_name in ("summary", "captioning") for r in responses) summaries = defaultdict(list) @@ -715,9 +715,9 @@ def test_import_data_with_generation( summaries = {k: "".join(v) for k, v in summaries.items()} assert len(summaries) == 2 assert all(len(v) > 0 for v in summaries.values()) - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) assert all( - isinstance(r.usage, p.CompletionUsage) + isinstance(r.usage, t.CompletionUsage) for r in responses if r.output_column_name in ("summary", "captioning") ) @@ -732,12 +732,12 @@ def test_import_data_with_generation( if r.output_column_name in ("summary", "captioning") ) else: - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) assert response.object == "gen_table.completion.rows" for row in response.rows: for output_column_name in ("summary", "captioning"): assert len(row.columns[output_column_name].text) > 0 - assert isinstance(row.columns[output_column_name].usage, p.CompletionUsage) + assert isinstance(row.columns[output_column_name].usage, t.CompletionUsage) assert isinstance(row.columns[output_column_name].prompt_tokens, int) assert isinstance(row.columns[output_column_name].completion_tokens, int) @@ -749,13 +749,13 @@ def test_import_data_with_generation( if k not in row and k in chat_data: continue if v == "": - assert ( - row[k]["value"] is None - ), f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + assert row[k]["value"] is None, ( + f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + ) else: - assert ( - row[k]["value"] == v - ), f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + assert row[k]["value"] == v, ( + f"Imported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + ) @flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) @@ -764,12 +764,12 @@ def test_import_data_with_generation( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_import_data_empty( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) with TemporaryDirectory() as tmp_dir: # Empty @@ -778,7 +778,7 @@ def test_import_data_empty( with pytest.raises(RuntimeError, match="No columns to parse"): response = jamai.import_table_data( table_type, - p.TableDataImportRequest( + t.TableDataImportRequest( file_path=file_path, table_id=table.id, stream=stream ), ) @@ -792,7 +792,7 @@ def test_import_data_empty( with pytest.raises(RuntimeError, match="empty"): response = jamai.import_table_data( table_type, - p.TableDataImportRequest( + t.TableDataImportRequest( file_path=file_path, table_id=table.id, stream=stream ), ) @@ -810,15 +810,15 @@ def test_import_data_with_vector( client_cls: Type[JamAI], stream: bool, ): - table_type = p.TableType.knowledge + table_type = t.TableType.knowledge jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Add a row first to figure out the vector length response = jamai.table.add_table_rows( table_type, - p.RowAddRequest( + t.MultiRowAddRequest( table_id=table.id, data=[ { @@ -896,7 +896,7 @@ def test_import_data_with_vector( df_to_csv(df, file_path) response = jamai.import_table_data( table_type, - p.TableDataImportRequest( + t.TableDataImportRequest( file_path=file_path, table_id=table.id, stream=stream, @@ -906,7 +906,7 @@ def test_import_data_with_vector( responses = [r for r in response] assert len(responses) == 0 else: - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) assert response.object == "gen_table.completion.rows" rows = jamai.table.list_table_rows(table_type, table.id, vec_decimals=2) @@ -925,12 +925,12 @@ def test_import_data_with_vector( @pytest.mark.parametrize("delimiter", [","], ids=["comma_delimiter"]) def test_export_data( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, delimiter: str, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) data = [ {"good": True, "words": 5, "stars": 0.0, "inputs": TEXT, "summary": TEXT}, {"good": False, "words": 5, "stars": 1.0, "inputs": TEXT, "summary": TEXT}, @@ -959,9 +959,9 @@ def test_export_data( for row, d in zip(exported_rows, data, strict=True): for k, v in d.items(): if k in columns: - assert ( - row[k] == v - ), f"Exported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + assert row[k] == v, ( + f"Exported data is wrong: col=`{k}` val={row[k]} ori=`{v}`" + ) else: assert k not in row @@ -972,12 +972,12 @@ def test_export_data( @pytest.mark.parametrize("delimiter", [","], ids=["comma_delimiter"]) def test_export_reordered_columns_data( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, delimiter: str, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) _add_row( jamai, table_type, @@ -994,7 +994,7 @@ def test_export_reordered_columns_data( "captioning", "good", ] - if table_type == p.TableType.knowledge: + if table_type == t.TableType.knowledge: new_cols_order = [ "Title", "Title Embed", @@ -1003,12 +1003,12 @@ def test_export_reordered_columns_data( "File ID", "Page", ] + new_cols_order - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: new_cols_order = ["User", "AI"] + new_cols_order jamai.table.reorder_columns( table_type=table_type, - request=p.ColumnReorderRequest( + request=t.ColumnReorderRequest( table_id=TABLE_ID_A, column_names=new_cols_order, ), @@ -1049,13 +1049,13 @@ def test_export_reordered_columns_data( @pytest.mark.parametrize("delimiter", [",", "\t"], ids=["comma_delimiter", "tab_delimiter"]) def test_import_export_data_round_trip( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, delimiter: str, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) with TemporaryDirectory() as tmp_dir: data = [ { @@ -1105,7 +1105,7 @@ def test_import_export_data_round_trip( df_to_csv(df, file_path, delimiter) response = jamai.import_table_data( table_type, - p.TableDataImportRequest( + t.TableDataImportRequest( file_path=file_path, table_id=table.id, stream=stream, @@ -1114,12 +1114,12 @@ def test_import_export_data_round_trip( ) if stream: responses = [r for r in response] - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: assert len(responses) > 0 else: assert len(responses) == 0 else: - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) assert response.object == "gen_table.completion.rows" csv_data = jamai.export_table_data(table_type, table.id, delimiter=delimiter) @@ -1133,11 +1133,11 @@ def test_import_export_data_round_trip( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_import_export_round_trip( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) _add_row(jamai, table_type, False) _add_row( jamai, @@ -1163,12 +1163,12 @@ def test_import_export_round_trip( try: imported_table = jamai.table.import_table( table_type, - p.TableImportRequest( + t.TableImportRequest( file_path=file_path, table_id_dst=table_id_dst, ), ) - assert isinstance(imported_table, p.TableMetaResponse) + assert isinstance(imported_table, t.TableMetaResponse) assert imported_table.id == table_id_dst imported_rows = jamai.table.list_table_rows(table_type, imported_table.id) assert len(imported_rows.items) == len(rows.items) @@ -1178,10 +1178,7 @@ def test_import_export_round_trip( raw_urls = jamai.file.get_raw_urls( [rows.items[2]["photo"]["value"], imported_rows.items[2]["photo"]["value"]] ) - raw_files = [ - httpx.get(url, headers={"X-PROJECT-ID": "default"}).content - for url in raw_urls.urls - ] + raw_files = [httpx.get(url).content for url in raw_urls.urls] assert ( raw_urls.urls[0] != raw_urls.urls[1] ) # URL is different but file should match @@ -1200,7 +1197,7 @@ def test_import_export_wrong_table_type( ): jamai = client_cls() with _create_table(jamai, "action") as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) _add_row(jamai, "action", False) _add_row( jamai, @@ -1220,7 +1217,7 @@ def test_import_export_wrong_table_type( with pytest.raises(RuntimeError): jamai.import_table( "knowledge", - p.TableImportRequest( + t.TableImportRequest( file_path=file_path, table_id_dst=table_id_dst, ), @@ -1229,7 +1226,7 @@ def test_import_export_wrong_table_type( with pytest.raises(RuntimeError): jamai.import_table( "chat", - p.TableImportRequest( + t.TableImportRequest( file_path=file_path, table_id_dst=table_id_dst, ), @@ -1237,4 +1234,4 @@ def test_import_export_wrong_table_type( if __name__ == "__main__": - test_import_export_round_trip(JamAI, p.TableType.action) + test_import_export_round_trip(JamAI, t.TableType.action) diff --git a/clients/python/tests/oss/gen_table/test_row_ops.py b/clients/python/tests/oss/gen_table/test_row_ops.py index 2a3dd62..4bdf375 100644 --- a/clients/python/tests/oss/gen_table/test_row_ops.py +++ b/clients/python/tests/oss/gen_table/test_row_ops.py @@ -13,13 +13,13 @@ from pydantic import ValidationError from jamaibase import JamAI -from jamaibase import protocol as p -from jamaibase.exceptions import ResourceNotFoundError -from jamaibase.protocol import IMAGE_FILE_EXTENSIONS +from jamaibase import types as t +from jamaibase.types import IMAGE_FILE_EXTENSIONS +from jamaibase.utils.exceptions import ResourceNotFoundError from jamaibase.utils.io import df_to_csv CLIENT_CLS = [JamAI] -TABLE_TYPES = [p.TableType.action, p.TableType.knowledge, p.TableType.chat] +TABLE_TYPES = [t.TableType.action, t.TableType.knowledge, t.TableType.chat] TABLE_ID_A = "table_a" TABLE_ID_B = "table_b" @@ -43,11 +43,8 @@ "application/x-ndjson", # alternative for jsonl "application/json-lines", # another alternative for jsonl "application/vnd.openxmlformats-officedocument.wordprocessingml.document", # docx - "application/msword", # doc "application/vnd.openxmlformats-officedocument.presentationml.presentation", # pptx - "application/vnd.ms-powerpoint", # ppt "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", # xlsx - "application/vnd.ms-excel", # xls "text/tab-separated-values", # tsv "text/csv", # csv ] @@ -107,10 +104,10 @@ def _rerun_on_fs_error_with_delay(err, *args): @contextmanager def _create_table( jamai: JamAI, - table_type: p.TableType, + table_type: t.TableType, table_id: str = TABLE_ID_A, - cols: list[p.ColumnSchemaCreate] | None = None, - chat_cols: list[p.ColumnSchemaCreate] | None = None, + cols: list[t.ColumnSchemaCreate] | None = None, + chat_cols: list[t.ColumnSchemaCreate] | None = None, embedding_model: str | None = None, delete_first: bool = True, ): @@ -119,16 +116,17 @@ def _create_table( jamai.table.delete_table(table_type, table_id) if cols is None: cols = [ - p.ColumnSchemaCreate(id="good", dtype="bool"), - p.ColumnSchemaCreate(id="words", dtype="int"), - p.ColumnSchemaCreate(id="stars", dtype="float"), - p.ColumnSchemaCreate(id="inputs", dtype="str"), - p.ColumnSchemaCreate(id="photo", dtype="image"), - p.ColumnSchemaCreate(id="audio", dtype="audio"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="good", dtype="bool"), + t.ColumnSchemaCreate(id="words", dtype="int"), + t.ColumnSchemaCreate(id="stars", dtype="float"), + t.ColumnSchemaCreate(id="inputs", dtype="str"), + t.ColumnSchemaCreate(id="photo", dtype="image"), + t.ColumnSchemaCreate(id="audio", dtype="audio"), + t.ColumnSchemaCreate(id="paper", dtype="document"), + t.ColumnSchemaCreate( id="summary", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), system_prompt="You are a concise assistant.", # Interpolate string and non-string input columns @@ -138,10 +136,10 @@ def _create_table( max_tokens=10, ), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="captioning", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="", system_prompt="You are a concise assistant.", # Interpolate file input column @@ -151,10 +149,10 @@ def _create_table( max_tokens=20, ), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="narration", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="", prompt="${audio} \n\nWhat happened?", temperature=0.001, @@ -162,14 +160,25 @@ def _create_table( max_tokens=10, ), ), + t.ColumnSchemaCreate( + id="concept", + dtype="str", + gen_config=t.LLMGenConfig( + model="", + prompt="${paper} \n\nTell the main concept of the paper in 5 words.", + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), ] if chat_cols is None: chat_cols = [ - p.ColumnSchemaCreate(id="User", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="User", dtype="str"), + t.ColumnSchemaCreate( id="AI", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), system_prompt="You are a wacky assistant.", temperature=0.001, @@ -179,25 +188,25 @@ def _create_table( ), ] - if table_type == p.TableType.action: + if table_type == t.TableType.action: table = jamai.table.create_action_table( - p.ActionTableSchemaCreate(id=table_id, cols=cols) + t.ActionTableSchemaCreate(id=table_id, cols=cols) ) - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: if embedding_model is None: embedding_model = "" table = jamai.table.create_knowledge_table( - p.KnowledgeTableSchemaCreate( + t.KnowledgeTableSchemaCreate( id=table_id, cols=cols, embedding_model=embedding_model ) ) - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: table = jamai.table.create_chat_table( - p.ChatTableSchemaCreate(id=table_id, cols=chat_cols + cols) + t.ChatTableSchemaCreate(id=table_id, cols=chat_cols + cols) ) else: raise ValueError(f"Invalid table type: {table_type}") - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) yield table finally: jamai.table.delete_table(table_type, table_id) @@ -205,7 +214,7 @@ def _create_table( def _add_row( jamai: JamAI, - table_type: p.TableType, + table_type: t.TableType, stream: bool, table_name: str = TABLE_ID_A, data: dict | None = None, @@ -219,6 +228,9 @@ def _add_row( audio_upload_response = jamai.file.upload_file( "clients/python/tests/files/mp3/turning-a4-size-magazine.mp3" ) + document_upload_response = jamai.file.upload_file( + "clients/python/tests/files/pdf/LLMs as Optimizers [DeepMind ; 2023].pdf" + ) data = dict( good=True, words=5, @@ -226,6 +238,7 @@ def _add_row( inputs=TEXT, photo=image_upload_response.uri, audio=audio_upload_response.uri, + paper=document_upload_response.uri, ) if knowledge_data is None: @@ -235,21 +248,21 @@ def _add_row( ) if chat_data is None: chat_data = dict(User="Tell me a joke.") - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: data.update(knowledge_data) - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: data.update(chat_data) else: raise ValueError(f"Invalid table type: {table_type}") response = jamai.table.add_table_rows( table_type, - p.RowAddRequest(table_id=table_name, data=[data], stream=stream), + t.MultiRowAddRequest(table_id=table_name, data=[data], stream=stream), ) if stream: return response - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) assert len(response.rows) == 1 return response.rows[0] @@ -261,11 +274,10 @@ def _assert_is_vector(x: Any): def _collect_text( - responses: p.GenTableRowsChatCompletionChunks - | Generator[p.GenTableStreamChatCompletionChunk, None, None], + responses: t.MultiRowCompletionResponse | Generator[t.CellCompletionResponse, None, None], col: str, ): - if isinstance(responses, p.GenTableRowsChatCompletionChunks): + if isinstance(responses, t.MultiRowCompletionResponse): return "".join(r.columns[col].text for r in responses.rows) return "".join(r.text for r in responses if r.output_column_name == col) @@ -283,7 +295,7 @@ def test_knowledge_table_embedding( ): jamai = client_cls() with _create_table(jamai, "knowledge", cols=[], embedding_model="") as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Don't include embeddings data = [ dict( @@ -307,13 +319,13 @@ def test_knowledge_table_embedding( ] response = jamai.table.add_table_rows( "knowledge", - p.RowAddRequest(table_id=table.id, data=data, stream=stream), + t.MultiRowAddRequest(table_id=table.id, data=data, stream=stream), ) if stream: responses = [r for r in response] assert len(responses) == 0 # We currently dont return anything if LLM is not called else: - assert isinstance(response.rows[0], p.GenTableChatCompletionChunks) + assert isinstance(response.rows[0], t.RowCompletionResponse) # Check embeddings rows = jamai.table.list_table_rows("knowledge", table.id) assert isinstance(rows.items, list) @@ -342,11 +354,11 @@ def test_knowledge_table_no_embed_input( ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="words", dtype="int"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="words", dtype="int"), + t.ColumnSchemaCreate( id="summary", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), temperature=0.001, top_p=0.001, @@ -355,21 +367,21 @@ def test_knowledge_table_no_embed_input( ), ] with _create_table(jamai, "knowledge", cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Purposely leave out Title and Text data = dict(words=5) response = jamai.table.add_table_rows( "knowledge", - p.RowAddRequest(table_id=table.id, data=[data], stream=stream), + t.MultiRowAddRequest(table_id=table.id, data=[data], stream=stream), ) if stream: # Must wait until stream ends responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) summary = "".join(r.text for r in responses if r.output_column_name == "summary") assert len(summary) > 0 else: - assert isinstance(response.rows[0], p.GenTableChatCompletionChunks) + assert isinstance(response.rows[0], t.RowCompletionResponse) @flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) @@ -380,9 +392,9 @@ def test_full_text_search( stream: bool, ): jamai = client_cls() - cols = [p.ColumnSchemaCreate(id="text", dtype="str")] + cols = [t.ColumnSchemaCreate(id="text", dtype="str")] with _create_table(jamai, "action", cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Add data texts = [ '"Dune: Part Two" 2024 is Denis\'s science-fiction film.', @@ -392,19 +404,21 @@ def test_full_text_search( ] response = jamai.table.add_table_rows( "action", - p.RowAddRequest(table_id=table.id, data=[{"text": t} for t in texts], stream=stream), + t.MultiRowAddRequest( + table_id=table.id, data=[{"text": t} for t in texts], stream=stream + ), ) if stream: # Must wait until stream ends responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) else: - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) # Search def _search(query: str): return jamai.table.hybrid_search( - "action", p.SearchRequest(table_id=table.id, query=query) + "action", t.SearchRequest(table_id=table.id, query=query) ) assert len(_search("AND")) == 0 # SQL-like statements should still work @@ -423,16 +437,16 @@ def _search(query: str): @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_rag( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() # Create Knowledge Table and add some rows with _create_table(jamai, "knowledge", cols=[]) as ktable: - assert isinstance(ktable, p.TableMetaResponse) + assert isinstance(ktable, t.TableMetaResponse) response = jamai.table.add_table_rows( - p.TableType.knowledge, - p.RowAddRequest( + t.TableType.knowledge, + t.MultiRowAddRequest( table_id=ktable.id, data=[ dict( @@ -451,27 +465,27 @@ def test_rag( stream=False, ), ) - assert isinstance(response, p.GenTableRowsChatCompletionChunks) - assert isinstance(response.rows[0], p.GenTableChatCompletionChunks) - rows = jamai.table.list_table_rows(p.TableType.knowledge, ktable.id) + assert isinstance(response, t.MultiRowCompletionResponse) + assert isinstance(response.rows[0], t.RowCompletionResponse) + rows = jamai.table.list_table_rows(t.TableType.knowledge, ktable.id) assert isinstance(rows.items, list) assert len(rows.items) == 3 # Create the other table cols = [ - p.ColumnSchemaCreate(id="question", dtype="str"), - p.ColumnSchemaCreate(id="words", dtype="int"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="question", dtype="str"), + t.ColumnSchemaCreate(id="words", dtype="int"), + t.ColumnSchemaCreate( id="rag", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), system_prompt="You are a concise assistant.", prompt="${question}? Summarise in ${words} words", temperature=0.001, top_p=0.001, max_tokens=10, - rag_params=p.RAGParams( + rag_params=t.RAGParams( table_id=ktable.id, reranking_model=_get_reranking_model(jamai), search_query="", # Generate using LM @@ -481,31 +495,31 @@ def test_rag( ), ] with _create_table(jamai, table_type, TABLE_ID_B, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Perform RAG data = dict(question="What is a burnet?", words=5) response = jamai.table.add_table_rows( table_type, - p.RowAddRequest(table_id=table.id, data=[data], stream=stream), + t.MultiRowAddRequest(table_id=table.id, data=[data], stream=stream), ) if stream: responses = [r for r in response if r.output_column_name == "rag"] assert len(responses) > 0 - assert isinstance(responses[0], p.GenTableStreamReferences) + assert isinstance(responses[0], t.CellReferencesResponse) responses = responses[1:] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) rag = "".join(r.text for r in responses) assert len(rag) > 0 else: assert len(response.rows) > 0 for row in response.rows: - assert isinstance(row, p.GenTableChatCompletionChunks) + assert isinstance(row, t.RowCompletionResponse) assert len(row.columns) > 0 - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: assert "AI" in row.columns assert "rag" in row.columns - assert isinstance(row.columns["rag"], p.ChatCompletionChunk) - assert isinstance(row.columns["rag"].references, p.References) + assert isinstance(row.columns["rag"], t.ChatCompletionChunk) + assert isinstance(row.columns["rag"].references, t.References) assert len(row.columns["rag"].text) > 0 @@ -515,16 +529,16 @@ def test_rag( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_rag_with_image_input( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() # Create Knowledge Table and add some rows with _create_table(jamai, "knowledge", cols=[]) as ktable: - assert isinstance(ktable, p.TableMetaResponse) + assert isinstance(ktable, t.TableMetaResponse) response = jamai.table.add_table_rows( - p.TableType.knowledge, - p.RowAddRequest( + t.TableType.knowledge, + t.MultiRowAddRequest( table_id=ktable.id, data=[ dict( @@ -539,28 +553,28 @@ def test_rag_with_image_input( stream=False, ), ) - assert isinstance(response, p.GenTableRowsChatCompletionChunks) - assert isinstance(response.rows[0], p.GenTableChatCompletionChunks) - rows = jamai.table.list_table_rows(p.TableType.knowledge, ktable.id) + assert isinstance(response, t.MultiRowCompletionResponse) + assert isinstance(response.rows[0], t.RowCompletionResponse) + rows = jamai.table.list_table_rows(t.TableType.knowledge, ktable.id) assert isinstance(rows.items, list) assert len(rows.items) == 2 # Create the other table cols = [ - p.ColumnSchemaCreate(id="photo", dtype="image"), - p.ColumnSchemaCreate(id="question", dtype="str"), - p.ColumnSchemaCreate(id="words", dtype="int"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="photo", dtype="image"), + t.ColumnSchemaCreate(id="question", dtype="str"), + t.ColumnSchemaCreate(id="words", dtype="int"), + t.ColumnSchemaCreate( id="rag", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), system_prompt="You are a concise assistant.", prompt="${photo} What's the animal? ${question} Summarise in ${words} words", temperature=0.001, top_p=0.001, max_tokens=10, - rag_params=p.RAGParams( + rag_params=t.RAGParams( table_id=ktable.id, reranking_model=_get_reranking_model(jamai), search_query="", # Generate using LM @@ -570,32 +584,32 @@ def test_rag_with_image_input( ), ] with _create_table(jamai, table_type, TABLE_ID_B, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) upload_response = jamai.file.upload_file("clients/python/tests/files/jpeg/rabbit.jpeg") # Perform RAG data = dict(photo=upload_response.uri, question="Get it's name.", words=5) response = jamai.table.add_table_rows( table_type, - p.RowAddRequest(table_id=table.id, data=[data], stream=stream), + t.MultiRowAddRequest(table_id=table.id, data=[data], stream=stream), ) if stream: responses = [r for r in response if r.output_column_name == "rag"] assert len(responses) > 0 - assert isinstance(responses[0], p.GenTableStreamReferences) + assert isinstance(responses[0], t.CellReferencesResponse) responses = responses[1:] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) rag = "".join(r.text for r in responses) assert len(rag) > 0 else: assert len(response.rows) > 0 for row in response.rows: - assert isinstance(row, p.GenTableChatCompletionChunks) + assert isinstance(row, t.RowCompletionResponse) assert len(row.columns) > 0 - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: assert "AI" in row.columns assert "rag" in row.columns - assert isinstance(row.columns["rag"], p.ChatCompletionChunk) - assert isinstance(row.columns["rag"].references, p.References) + assert isinstance(row.columns["rag"], t.ChatCompletionChunk) + assert isinstance(row.columns["rag"].references, t.References) assert len(row.columns["rag"].text) > 0 rows = jamai.table.list_table_rows(table_type, TABLE_ID_B) @@ -615,11 +629,11 @@ def test_conversation_starter( ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="User", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="User", dtype="str"), + t.ColumnSchemaCreate( id="AI", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), system_prompt="You help remember facts.", temperature=0.001, @@ -627,11 +641,11 @@ def test_conversation_starter( max_tokens=10, ), ), - p.ColumnSchemaCreate(id="words", dtype="int"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="words", dtype="int"), + t.ColumnSchemaCreate( id="summary", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), system_prompt="You are an assistant", temperature=0.001, @@ -641,22 +655,24 @@ def test_conversation_starter( ), ] with _create_table(jamai, "chat", cols=[], chat_cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Add the starter response = jamai.table.add_table_rows( "chat", - p.RowAddRequest(table_id=table.id, data=[dict(AI="Jim has 5 apples.")], stream=stream), + t.MultiRowAddRequest( + table_id=table.id, data=[dict(AI="Jim has 5 apples.")], stream=stream + ), ) if stream: # Must wait until stream ends responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) else: - assert isinstance(response.rows[0], p.GenTableChatCompletionChunks) + assert isinstance(response.rows[0], t.RowCompletionResponse) # Chat with it response = jamai.table.add_table_rows( "chat", - p.RowAddRequest( + t.MultiRowAddRequest( table_id=table.id, data=[dict(User="How many apples does Jim have?")], stream=stream, @@ -665,13 +681,13 @@ def test_conversation_starter( if stream: # Must wait until stream ends responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) answer = "".join(r.text for r in responses if r.output_column_name == "AI") assert "5" in answer or "five" in answer.lower() summary = "".join(r.text for r in responses if r.output_column_name == "summary") assert len(summary) > 0 else: - assert isinstance(response.rows[0], p.GenTableChatCompletionChunks) + assert isinstance(response.rows[0], t.RowCompletionResponse) @flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) @@ -680,38 +696,38 @@ def test_conversation_starter( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_add_row( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) response = _add_row(jamai, table_type, stream) if stream: responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) assert all(r.object == "gen_table.completion.chunk" for r in responses) - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: assert all( - r.output_column_name in ("summary", "captioning", "narration", "AI") + r.output_column_name in ("summary", "captioning", "narration", "concept", "AI") for r in responses ) else: assert all( - r.output_column_name in ("summary", "captioning", "narration") + r.output_column_name in ("summary", "captioning", "narration", "concept") for r in responses ) assert len("".join(r.text for r in responses)) > 0 - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) - assert all(isinstance(r.usage, p.CompletionUsage) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) + assert all(isinstance(r.usage, t.CompletionUsage) for r in responses) assert all(isinstance(r.prompt_tokens, int) for r in responses) assert all(isinstance(r.completion_tokens, int) for r in responses) else: - assert isinstance(response, p.GenTableChatCompletionChunks) + assert isinstance(response, t.RowCompletionResponse) assert response.object == "gen_table.completion.chunks" - for output_column_name in ("summary", "captioning", "narration"): + for output_column_name in ("summary", "captioning", "narration", "concept"): assert len(response.columns[output_column_name].text) > 0 - assert isinstance(response.columns[output_column_name].usage, p.CompletionUsage) + assert isinstance(response.columns[output_column_name].usage, t.CompletionUsage) assert isinstance(response.columns[output_column_name].prompt_tokens, int) assert isinstance(response.columns[output_column_name].completion_tokens, int) rows = jamai.table.list_table_rows(table_type, TABLE_ID_A) @@ -725,10 +741,96 @@ def test_add_row( assert row["audio"]["value"].endswith("/turning-a4-size-magazine.mp3"), row["audio"][ "value" ] + assert row["paper"]["value"].endswith("/LLMs as Optimizers [DeepMind ; 2023].pdf"), row[ + "paper" + ]["value"] for animal in ["deer", "rabbit"]: if animal in row["photo"]["value"].split("_")[0]: assert animal in row["captioning"]["value"] assert "paper" in row["narration"]["value"] or "turn" in row["narration"]["value"] + assert ( + "optimization" in row["concept"]["value"].lower() + or "optimize" in row["concept"]["value"].lower() + ) + + +@flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) +@pytest.mark.timeout(180) +@pytest.mark.parametrize("client_cls", CLIENT_CLS) +@pytest.mark.parametrize("table_type", TABLE_TYPES[:1]) +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.parametrize( + "docpath", + [ + "clients/python/tests/files/pdf/salary 总结.pdf", + "clients/python/tests/files/pdf_scan/1978_APL_FP_detrapping.PDF", + "clients/python/tests/files/pdf_mixed/digital_scan_combined.pdf", + "clients/python/tests/files/md/creative-story.md", + "clients/python/tests/files/txt/creative-story.txt", + "clients/python/tests/files/html/multilingual-code-examples.html", + "clients/python/tests/files/xml/weather-forecast-service.xml", + "clients/python/tests/files/jsonl/ChatMed_TCM-v0.2-5records.jsonl", + "clients/python/tests/files/docx/Recommendation Letter.docx", + "clients/python/tests/files/pptx/(2017.06.30) NMT in Linear Time (ByteNet).pptx", + "clients/python/tests/files/xlsx/Claims Form.xlsx", + "clients/python/tests/files/tsv/weather_observations.tsv", + "clients/python/tests/files/csv/weather_observations_long.csv", + ], + ids=lambda x: basename(x), +) +def test_add_row_document_dtype( + client_cls: Type[JamAI], table_type: t.TableType, stream: bool, docpath: str +): + jamai = client_cls() + cols = [ + t.ColumnSchemaCreate(id="doc", dtype="document"), + t.ColumnSchemaCreate( + id="content", + dtype="str", + gen_config=t.LLMGenConfig( + model="", + prompt="Document: \n${doc} \n\nReply 0 if document received, else -1. Omit any explanation, only answer 0 or -1.", + ), + ), + ] + with _create_table(jamai, table_type, cols=cols) as table: + assert isinstance(table, t.TableMetaResponse) + + upload_response = jamai.file.upload_file(docpath) + response = _add_row( + jamai, + table_type, + stream, + TABLE_ID_A, + data=dict(doc=upload_response.uri), + ) + if stream: + responses = [r for r in response] + assert all(isinstance(r, t.GenTableStreamChatCompletionChunk) for r in responses) + assert all(r.object == "gen_table.completion.chunk" for r in responses) + if table_type == t.TableType.chat: + assert all(r.output_column_name in ("content", "AI") for r in responses) + else: + assert all(r.output_column_name in ("content",) for r in responses) + assert len("".join(r.text for r in responses)) > 0 + assert all(isinstance(r, t.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r.usage, t.CompletionUsage) for r in responses) + assert all(isinstance(r.prompt_tokens, int) for r in responses) + assert all(isinstance(r.completion_tokens, int) for r in responses) + else: + assert isinstance(response, t.GenTableChatCompletionChunks) + assert response.object == "gen_table.completion.chunks" + output_column_name = "content" + assert len(response.columns[output_column_name].text) > 0 + assert isinstance(response.columns[output_column_name].usage, t.CompletionUsage) + assert isinstance(response.columns[output_column_name].prompt_tokens, int) + assert isinstance(response.columns[output_column_name].completion_tokens, int) + rows = jamai.table.list_table_rows(table_type, TABLE_ID_A) + assert isinstance(rows.items, list) + assert len(rows.items) == 1 + row = rows.items[0] + assert row["doc"]["value"] == upload_response.uri, row["doc"]["value"] + assert "0" in row["content"]["value"] @flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) @@ -737,16 +839,16 @@ def test_add_row( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_regen_with_reordered_columns( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="number", dtype="int"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="number", dtype="int"), + t.ColumnSchemaCreate( id="col1-english", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="", prompt=( "Number: ${number} \n\nTell the 'Number' in English, " @@ -754,10 +856,10 @@ def test_regen_with_reordered_columns( ), ), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="col2-malay", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="", prompt=( "Number: ${number} \n\nTell the 'Number' in Malay, " @@ -765,10 +867,10 @@ def test_regen_with_reordered_columns( ), ), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="col3-mandarin", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="", prompt=( "Number: ${number} \n\nTell the 'Number' in Mandarin (Chinese Character), " @@ -776,10 +878,10 @@ def test_regen_with_reordered_columns( ), ), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="col4-roman", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="", prompt=( "Number: ${number} \n\nTell the 'Number' in Roman Numerals, " @@ -790,14 +892,14 @@ def test_regen_with_reordered_columns( ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) row = _add_row( jamai, table_type, False, data=dict(number=1), ) - assert isinstance(row, p.GenTableChatCompletionChunks) + assert isinstance(row, t.RowCompletionResponse) rows = jamai.table.list_table_rows(table_type, table.id) assert isinstance(rows.items, list) assert len(rows.items) == 1 @@ -812,7 +914,7 @@ def test_regen_with_reordered_columns( # Update Input + Regen jamai.table.update_table_row( table_type, - p.RowUpdateRequest( + t.RowUpdateRequest( table_id=TABLE_ID_A, row_id=_id, data=dict(number=2), @@ -821,10 +923,10 @@ def test_regen_with_reordered_columns( response = jamai.table.regen_table_rows( table_type, - p.RowRegenRequest( + t.MultiRowRegenRequest( table_id=table.id, row_ids=[_id], - regen_strategy=p.RegenStrategy.RUN_ALL, + regen_strategy=t.RegenStrategy.RUN_ALL, stream=stream, ), ) @@ -850,13 +952,13 @@ def test_regen_with_reordered_columns( "col4-roman", "col2-malay", ] - if table_type == p.TableType.knowledge: + if table_type == t.TableType.knowledge: new_cols += ["Title", "Text", "Title Embed", "Text Embed", "File ID", "Page"] - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: new_cols += ["User", "AI"] jamai.table.reorder_columns( table_type=table_type, - request=p.ColumnReorderRequest( + request=t.ColumnReorderRequest( table_id=TABLE_ID_A, column_names=new_cols, ), @@ -864,7 +966,7 @@ def test_regen_with_reordered_columns( # RUN_SELECTED jamai.table.update_table_row( table_type, - p.RowUpdateRequest( + t.RowUpdateRequest( table_id=TABLE_ID_A, row_id=_id, data=dict(number=5), @@ -872,10 +974,10 @@ def test_regen_with_reordered_columns( ) response = jamai.table.regen_table_rows( table_type, - p.RowRegenRequest( + t.MultiRowRegenRequest( table_id=TABLE_ID_A, row_ids=[_id], - regen_strategy=p.RegenStrategy.RUN_SELECTED, + regen_strategy=t.RegenStrategy.RUN_SELECTED, output_column_id="col1-english", stream=stream, ), @@ -895,7 +997,7 @@ def test_regen_with_reordered_columns( # RUN_BEFORE jamai.table.update_table_row( table_type, - p.RowUpdateRequest( + t.RowUpdateRequest( table_id=TABLE_ID_A, row_id=_id, data=dict(number=6), @@ -903,10 +1005,10 @@ def test_regen_with_reordered_columns( ) response = jamai.table.regen_table_rows( table_type, - p.RowRegenRequest( + t.MultiRowRegenRequest( table_id=TABLE_ID_A, row_ids=[_id], - regen_strategy=p.RegenStrategy.RUN_BEFORE, + regen_strategy=t.RegenStrategy.RUN_BEFORE, output_column_id="col4-roman", stream=stream, ), @@ -926,7 +1028,7 @@ def test_regen_with_reordered_columns( # RUN_AFTER jamai.table.update_table_row( table_type, - p.RowUpdateRequest( + t.RowUpdateRequest( table_id=TABLE_ID_A, row_id=_id, data=dict(number=7), @@ -934,10 +1036,10 @@ def test_regen_with_reordered_columns( ) response = jamai.table.regen_table_rows( table_type, - p.RowRegenRequest( + t.MultiRowRegenRequest( table_id=TABLE_ID_A, row_ids=[_id], - regen_strategy=p.RegenStrategy.RUN_AFTER, + regen_strategy=t.RegenStrategy.RUN_AFTER, output_column_id="col4-roman", stream=stream, ), @@ -961,29 +1063,29 @@ def test_regen_with_reordered_columns( @pytest.mark.parametrize("stream", [True, False]) def test_add_row_sequential_image_model_completion( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="photo", dtype="image"), - p.ColumnSchemaCreate(id="photo2", dtype="image"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="photo", dtype="image"), + t.ColumnSchemaCreate(id="photo2", dtype="image"), + t.ColumnSchemaCreate( id="caption", dtype="str", - gen_config=p.LLMGenConfig(model="", prompt="${photo} What's in the image?"), + gen_config=t.LLMGenConfig(model="", prompt="${photo} What's in the image?"), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="question", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="", prompt="Caption: ${caption}\n\nImage: ${photo2}\n\nDoes the caption match? Reply True or False.", ), ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) upload_response = jamai.file.upload_file("clients/python/tests/files/jpeg/rabbit.jpeg") response = _add_row( @@ -995,25 +1097,25 @@ def test_add_row_sequential_image_model_completion( ) if stream: responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) assert all(r.object == "gen_table.completion.chunk" for r in responses) - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: assert all( r.output_column_name in ("caption", "question", "AI") for r in responses ) else: assert all(r.output_column_name in ("caption", "question") for r in responses) assert len("".join(r.text for r in responses)) > 0 - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) - assert all(isinstance(r.usage, p.CompletionUsage) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) + assert all(isinstance(r.usage, t.CompletionUsage) for r in responses) assert all(isinstance(r.prompt_tokens, int) for r in responses) assert all(isinstance(r.completion_tokens, int) for r in responses) else: - assert isinstance(response, p.GenTableChatCompletionChunks) + assert isinstance(response, t.RowCompletionResponse) assert response.object == "gen_table.completion.chunks" for output_column_name in ("caption", "question"): assert len(response.columns[output_column_name].text) > 0 - assert isinstance(response.columns[output_column_name].usage, p.CompletionUsage) + assert isinstance(response.columns[output_column_name].usage, t.CompletionUsage) assert isinstance(response.columns[output_column_name].prompt_tokens, int) assert isinstance(response.columns[output_column_name].completion_tokens, int) rows = jamai.table.list_table_rows(table_type, TABLE_ID_A) @@ -1035,29 +1137,29 @@ def test_add_row_sequential_image_model_completion( @pytest.mark.parametrize("stream", [True, False]) def test_add_row_map_dtype_file_to_image( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="photo", dtype="file"), - p.ColumnSchemaCreate(id="photo2", dtype="image"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="photo", dtype="file"), + t.ColumnSchemaCreate(id="photo2", dtype="image"), + t.ColumnSchemaCreate( id="caption", dtype="str", - gen_config=p.LLMGenConfig(model="", prompt="${photo} What's in the image?"), + gen_config=t.LLMGenConfig(model="", prompt="${photo} What's in the image?"), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="question", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="", prompt="Caption: ${caption}\n\nImage: ${photo2}\n\nDoes the caption match? Reply True or False.", ), ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) upload_response = jamai.file.upload_file("clients/python/tests/files/jpeg/rabbit.jpeg") response = _add_row( @@ -1069,25 +1171,25 @@ def test_add_row_map_dtype_file_to_image( ) if stream: responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) assert all(r.object == "gen_table.completion.chunk" for r in responses) - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: assert all( r.output_column_name in ("caption", "question", "AI") for r in responses ) else: assert all(r.output_column_name in ("caption", "question") for r in responses) assert len("".join(r.text for r in responses)) > 0 - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) - assert all(isinstance(r.usage, p.CompletionUsage) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) + assert all(isinstance(r.usage, t.CompletionUsage) for r in responses) assert all(isinstance(r.prompt_tokens, int) for r in responses) assert all(isinstance(r.completion_tokens, int) for r in responses) else: - assert isinstance(response, p.GenTableChatCompletionChunks) + assert isinstance(response, t.RowCompletionResponse) assert response.object == "gen_table.completion.chunks" for output_column_name in ("caption", "question"): assert len(response.columns[output_column_name].text) > 0 - assert isinstance(response.columns[output_column_name].usage, p.CompletionUsage) + assert isinstance(response.columns[output_column_name].usage, t.CompletionUsage) assert isinstance(response.columns[output_column_name].prompt_tokens, int) assert isinstance(response.columns[output_column_name].completion_tokens, int) rows = jamai.table.list_table_rows(table_type, TABLE_ID_A) @@ -1149,42 +1251,42 @@ def test_add_row_map_dtype_file_to_image( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_add_row_output_column_referred_image_input_with_chat_model( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="photo", dtype="image"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="photo", dtype="image"), + t.ColumnSchemaCreate( id="captioning", dtype="str", - gen_config=p.LLMGenConfig(model="", prompt="${photo} What's in the image?"), + gen_config=t.LLMGenConfig(model="", prompt="${photo} What's in the image?"), ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Add output column that referred to image file, but using chat model # (Notes: chat model can be set due to default prompt was added afterward) chat_only_model = _get_chat_only_model(jamai) cols = [ - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="captioning2", dtype="str", - gen_config=p.LLMGenConfig(model=chat_only_model), + gen_config=t.LLMGenConfig(model=chat_only_model), ), ] with pytest.raises(RuntimeError): - if table_type == p.TableType.action: - jamai.table.add_action_columns(p.AddActionColumnSchema(id=table.id, cols=cols)) - elif table_type == p.TableType.knowledge: + if table_type == t.TableType.action: + jamai.table.add_action_columns(t.AddActionColumnSchema(id=table.id, cols=cols)) + elif table_type == t.TableType.knowledge: jamai.table.add_knowledge_columns( - p.AddKnowledgeColumnSchema(id=table.id, cols=cols) + t.AddKnowledgeColumnSchema(id=table.id, cols=cols) ) - elif table_type == p.TableType.chat: - jamai.table.add_chat_columns(p.AddChatColumnSchema(id=table.id, cols=cols)) + elif table_type == t.TableType.chat: + jamai.table.add_chat_columns(t.AddChatColumnSchema(id=table.id, cols=cols)) else: raise ValueError(f"Invalid table type: {table_type}") - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) @pytest.mark.parametrize("client_cls", CLIENT_CLS) @@ -1192,31 +1294,31 @@ def test_add_row_output_column_referred_image_input_with_chat_model( @pytest.mark.parametrize("stream", [True, False]) def test_add_row_sequential_completion_with_error( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="input", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input", dtype="str"), + t.ColumnSchemaCreate( id="summary", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="", prompt="Summarise ${input}.", ), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="rephrase", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="", prompt="Rephrase ${summary}", ), ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) response = _add_row( jamai, @@ -1227,18 +1329,18 @@ def test_add_row_sequential_completion_with_error( ) if stream: responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) assert all(r.object == "gen_table.completion.chunk" for r in responses) - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: assert all( r.output_column_name in ("summary", "rephrase", "AI") for r in responses ) else: assert all(r.output_column_name in ("summary", "rephrase") for r in responses) assert len("".join(r.text for r in responses)) > 0 - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) else: - assert isinstance(response, p.GenTableChatCompletionChunks) + assert isinstance(response, t.RowCompletionResponse) assert response.object == "gen_table.completion.chunks" for output_column_name in ("summary", "rephrase"): assert len(response.columns[output_column_name].text) > 0 @@ -1271,11 +1373,11 @@ def test_add_row_sequential_completion_with_error( ids=lambda x: basename(x), ) def test_add_row_image_file_type_with_generation( - client_cls: Type[JamAI], table_type: p.TableType, stream: bool, img_filename: str + client_cls: Type[JamAI], table_type: t.TableType, stream: bool, img_filename: str ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) upload_response = jamai.file.upload_file(img_filename) response = _add_row( @@ -1288,21 +1390,21 @@ def test_add_row_image_file_type_with_generation( ) if stream: responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) assert all(r.object == "gen_table.completion.chunk" for r in responses) - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: assert all( - r.output_column_name in ("summary", "captioning", "narration", "AI") + r.output_column_name in ("summary", "captioning", "narration", "concept", "AI") for r in responses ) else: assert all( - r.output_column_name in ("summary", "captioning", "narration") + r.output_column_name in ("summary", "captioning", "narration", "concept") for r in responses ) assert len("".join(r.text for r in responses)) > 0 else: - assert isinstance(response, p.GenTableChatCompletionChunks) + assert isinstance(response, t.RowCompletionResponse) assert response.object == "gen_table.completion.chunks" assert len(response.columns["captioning"].text) > 0 rows = jamai.table.list_table_rows(table_type, TABLE_ID_A) @@ -1336,11 +1438,11 @@ def test_add_row_image_file_type_with_generation( ], ) def test_add_row_image_file_column_invalid_extension( - client_cls: Type[JamAI], table_type: p.TableType, stream: bool, img_filename: str + client_cls: Type[JamAI], table_type: t.TableType, stream: bool, img_filename: str ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) with pytest.raises( ValidationError, match=( @@ -1363,18 +1465,18 @@ def test_add_row_image_file_column_invalid_extension( @pytest.mark.parametrize("client_cls", CLIENT_CLS) @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_add_row_validate_one_image_per_completion( - client_cls: Type[JamAI], table_type: p.TableType, stream: bool = True + client_cls: Type[JamAI], table_type: t.TableType, stream: bool = True ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict( - captioning=p.LLMGenConfig( + captioning=t.LLMGenConfig( system_prompt="You are a concise assistant.", prompt="${photo} ${photo}\n\nWhat's in the image?", ), @@ -1392,16 +1494,17 @@ def test_add_row_validate_one_image_per_completion( ), ) responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) assert all(r.object == "gen_table.completion.chunk" for r in responses) - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: assert all( - r.output_column_name in ("summary", "captioning", "narration", "AI") + r.output_column_name in ("summary", "captioning", "narration", "concept", "AI") for r in responses ) else: assert all( - r.output_column_name in ("summary", "captioning", "narration") for r in responses + r.output_column_name in ("summary", "captioning", "narration", "concept") + for r in responses ) assert len("".join(r.text for r in responses)) > 0 @@ -1419,30 +1522,30 @@ def test_add_row_validate_one_image_per_completion( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_add_row_wrong_dtype( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) response = _add_row(jamai, table_type, stream) if stream: responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) assert all(r.object == "gen_table.completion.chunk" for r in responses) - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: assert all( - r.output_column_name in ("summary", "captioning", "narration", "AI") + r.output_column_name in ("summary", "captioning", "narration", "concept", "AI") for r in responses ) else: assert all( - r.output_column_name in ("summary", "captioning", "narration") + r.output_column_name in ("summary", "captioning", "narration", "concept") for r in responses ) assert len("".join(r.text for r in responses)) > 0 else: - assert isinstance(response, p.GenTableChatCompletionChunks) + assert isinstance(response, t.RowCompletionResponse) assert response.object == "gen_table.completion.chunks" assert len(response.columns["summary"].text) > 0 @@ -1456,9 +1559,9 @@ def test_add_row_wrong_dtype( ) if stream: responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) else: - assert isinstance(response, p.GenTableChatCompletionChunks) + assert isinstance(response, t.RowCompletionResponse) rows = jamai.table.list_table_rows(table_type, TABLE_ID_A) assert isinstance(rows.items, list) assert len(rows.items) == 2 @@ -1477,30 +1580,30 @@ def test_add_row_wrong_dtype( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_add_row_missing_columns( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) response = _add_row(jamai, table_type, stream) if stream: responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) assert all(r.object == "gen_table.completion.chunk" for r in responses) - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: assert all( - r.output_column_name in ("summary", "captioning", "narration", "AI") + r.output_column_name in ("summary", "captioning", "narration", "concept", "AI") for r in responses ) else: assert all( - r.output_column_name in ("summary", "captioning", "narration") + r.output_column_name in ("summary", "captioning", "narration", "concept") for r in responses ) assert len("".join(r.text for r in responses)) > 0 else: - assert isinstance(response, p.GenTableChatCompletionChunks) + assert isinstance(response, t.RowCompletionResponse) assert response.object == "gen_table.completion.chunks" assert len(response.columns["summary"].text) > 0 @@ -1514,9 +1617,9 @@ def test_add_row_missing_columns( ) if stream: responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) else: - assert isinstance(response, p.GenTableChatCompletionChunks) + assert isinstance(response, t.RowCompletionResponse) rows = jamai.table.list_table_rows(table_type, TABLE_ID_A) assert isinstance(rows.items, list) assert len(rows.items) == 2 @@ -1535,21 +1638,21 @@ def test_add_row_missing_columns( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_add_rows_all_input( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="0", dtype="int"), - p.ColumnSchemaCreate(id="1", dtype="float"), - p.ColumnSchemaCreate(id="2", dtype="bool"), - p.ColumnSchemaCreate(id="3", dtype="str"), + t.ColumnSchemaCreate(id="0", dtype="int"), + t.ColumnSchemaCreate(id="1", dtype="float"), + t.ColumnSchemaCreate(id="2", dtype="bool"), + t.ColumnSchemaCreate(id="3", dtype="str"), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) response = jamai.table.add_table_rows( table_type, - p.RowAddRequest( + t.MultiRowAddRequest( table_id=table.id, data=[ {"0": 1, "1": 2.0, "2": False, "3": "days"}, @@ -1562,7 +1665,7 @@ def test_add_rows_all_input( responses = [r for r in response if r.output_column_name != "AI"] assert len(responses) == 0 else: - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) assert len(response.rows) == 2 rows = jamai.table.list_table_rows(table_type, table.id) assert isinstance(rows.items, list) @@ -1574,18 +1677,18 @@ def test_add_rows_all_input( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_update_row( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) row = _add_row( jamai, table_type, False, data=dict(good=True, words=5, stars=9.9, inputs=TEXT, summary="dummy"), ) - assert isinstance(row, p.GenTableChatCompletionChunks) + assert isinstance(row, t.RowCompletionResponse) rows = jamai.table.list_table_rows(table_type, TABLE_ID_A) assert isinstance(rows.items, list) assert len(rows.items) == 1 @@ -1597,13 +1700,13 @@ def test_update_row( # Regular update response = jamai.table.update_table_row( table_type, - p.RowUpdateRequest( + t.RowUpdateRequest( table_id=TABLE_ID_A, row_id=row["ID"], data=dict(good=False, stars=1.0), ), ) - assert isinstance(response, p.OkResponse) + assert isinstance(response, t.OkResponse) rows = jamai.table.list_table_rows(table_type, TABLE_ID_A) assert isinstance(rows.items, list) assert len(rows.items) == 1 @@ -1616,13 +1719,13 @@ def test_update_row( # Test updating data with wrong dtype response = jamai.table.update_table_row( table_type, - p.RowUpdateRequest( + t.RowUpdateRequest( table_id=TABLE_ID_A, row_id=row["ID"], data=dict(good="dummy", words="dummy", stars="dummy"), ), ) - assert isinstance(response, p.OkResponse) + assert isinstance(response, t.OkResponse) rows = jamai.table.list_table_rows(table_type, TABLE_ID_A) assert isinstance(rows.items, list) assert len(rows.items) == 1 @@ -1638,13 +1741,13 @@ def test_update_row( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_regen_rows( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) image_upload_response = jamai.file.upload_file( "clients/python/tests/files/jpeg/rabbit.jpeg" @@ -1665,7 +1768,7 @@ def test_regen_rows( audio=audio_upload_response.uri, ), ) - assert isinstance(response, p.GenTableChatCompletionChunks) + assert isinstance(response, t.RowCompletionResponse) rows = jamai.table.list_table_rows(table_type, TABLE_ID_A) assert isinstance(rows.items, list) assert len(rows.items) == 1 @@ -1676,7 +1779,7 @@ def test_regen_rows( # Regen jamai.table.update_table_row( table_type, - p.RowUpdateRequest( + t.RowUpdateRequest( table_id=TABLE_ID_A, row_id=_id, data=dict( @@ -1685,25 +1788,25 @@ def test_regen_rows( ), ) response = jamai.table.regen_table_rows( - table_type, p.RowRegenRequest(table_id=TABLE_ID_A, row_ids=[_id], stream=stream) + table_type, t.MultiRowRegenRequest(table_id=TABLE_ID_A, row_ids=[_id], stream=stream) ) if stream: responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) assert all(r.object == "gen_table.completion.chunk" for r in responses) - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: assert all( - r.output_column_name in ("summary", "captioning", "narration", "AI") + r.output_column_name in ("summary", "captioning", "narration", "concept", "AI") for r in responses ) else: assert all( - r.output_column_name in ("summary", "captioning", "narration") + r.output_column_name in ("summary", "captioning", "narration", "concept") for r in responses ) assert len("".join(r.text for r in responses)) > 0 else: - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) assert response.rows[0].object == "gen_table.completion.chunks" assert len(response.rows[0].columns["summary"].text) > 0 rows = jamai.table.list_table_rows(table_type, TABLE_ID_A) @@ -1725,21 +1828,21 @@ def test_regen_rows( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_regen_rows_all_input( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="0", dtype="int"), - p.ColumnSchemaCreate(id="1", dtype="float"), - p.ColumnSchemaCreate(id="2", dtype="bool"), - p.ColumnSchemaCreate(id="3", dtype="str"), + t.ColumnSchemaCreate(id="0", dtype="int"), + t.ColumnSchemaCreate(id="1", dtype="float"), + t.ColumnSchemaCreate(id="2", dtype="bool"), + t.ColumnSchemaCreate(id="3", dtype="str"), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) response = jamai.table.add_table_rows( table_type, - p.RowAddRequest( + t.MultiRowAddRequest( table_id=table.id, data=[ {"0": 1, "1": 2.0, "2": False, "3": "days"}, @@ -1748,7 +1851,7 @@ def test_regen_rows_all_input( stream=False, ), ) - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) assert len(response.rows) == 2 rows = jamai.table.list_table_rows(table_type, table.id) assert isinstance(rows.items, list) @@ -1756,7 +1859,7 @@ def test_regen_rows_all_input( # Regen response = jamai.table.regen_table_rows( table_type, - p.RowRegenRequest( + t.MultiRowRegenRequest( table_id=table.id, row_ids=[r["ID"] for r in rows.items], stream=stream ), ) @@ -1764,7 +1867,7 @@ def test_regen_rows_all_input( responses = [r for r in response if r.output_column_name != "AI"] assert len(responses) == 0 else: - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) @flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) @@ -1772,12 +1875,12 @@ def test_regen_rows_all_input( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_delete_rows( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) data = dict(good=True, words=5, stars=9.9, inputs=TEXT, summary="dummy") _add_row(jamai, table_type, False, data=data) _add_row(jamai, table_type, False, data=data) @@ -1802,7 +1905,7 @@ def test_delete_rows( # Delete one row response = jamai.table.delete_table_row(table_type, TABLE_ID_A, delete_id) - assert isinstance(response, p.OkResponse) + assert isinstance(response, t.OkResponse) rows = jamai.table.list_table_rows(table_type, TABLE_ID_A) assert isinstance(rows.items, list) assert len(rows.items) == 5 @@ -1812,12 +1915,12 @@ def test_delete_rows( delete_ids = [r["ID"] for r in ori_rows.items[1:4]] response = jamai.table.delete_table_rows( table_type, - p.RowDeleteRequest( + t.MultiRowDeleteRequest( table_id=TABLE_ID_A, row_ids=delete_ids, ), ) - assert isinstance(response, p.OkResponse) + assert isinstance(response, t.OkResponse) rows = jamai.table.list_table_rows(table_type, TABLE_ID_A) assert isinstance(rows.items, list) assert len(rows.items) == 2 @@ -1830,11 +1933,11 @@ def test_delete_rows( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_get_and_list_rows( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) _add_row(jamai, table_type, False) _add_row( jamai, @@ -1870,15 +1973,17 @@ def test_get_and_list_rows( "inputs", "photo", "audio", + "paper", "summary", "captioning", "narration", + "concept", } - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: expected_cols |= {"User", "AI"} else: raise ValueError(f"Invalid table type: {table_type}") @@ -2091,15 +2196,15 @@ def test_get_and_list_rows( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_column_interpolate( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), system_prompt="You are a concise assistant.", prompt='Say "Jan has 5 apples.".', @@ -2108,11 +2213,11 @@ def test_column_interpolate( max_tokens=10, ), ), - p.ColumnSchemaCreate(id="input0", dtype="int"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="int"), + t.ColumnSchemaCreate( id="output1", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), system_prompt="You are a concise assistant.", prompt=( @@ -2126,7 +2231,7 @@ def test_column_interpolate( ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) def _add_row_wrapped(stream, data): return _add_row( @@ -2165,16 +2270,16 @@ def _add_row_wrapped(stream, data): @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_chat_history_and_sequential_add( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="input", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input", dtype="str"), + t.ColumnSchemaCreate( id="output", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( system_prompt="You are a calculator.", prompt="${input}", multi_turn=True, @@ -2185,11 +2290,11 @@ def test_chat_history_and_sequential_add( ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Initialise chat thread and set output format response = jamai.table.add_table_rows( table_type, - p.RowAddRequest( + t.MultiRowAddRequest( table_id=table.id, data=[ dict(input="x = 0", output="0"), @@ -2204,7 +2309,7 @@ def test_chat_history_and_sequential_add( # Test adding one row response = jamai.table.add_table_rows( table_type, - p.RowAddRequest( + t.MultiRowAddRequest( table_id=table.id, data=[dict(input="Add 1")], stream=stream, @@ -2215,7 +2320,7 @@ def test_chat_history_and_sequential_add( # Test adding multiple rows response = jamai.table.add_table_rows( table_type, - p.RowAddRequest( + t.MultiRowAddRequest( table_id=table.id, data=[ dict(input="Add 1"), @@ -2237,16 +2342,16 @@ def test_chat_history_and_sequential_add( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_chat_history_and_sequential_regen( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="input", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input", dtype="str"), + t.ColumnSchemaCreate( id="output", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( system_prompt="You are a calculator.", prompt="${input}", multi_turn=True, @@ -2257,11 +2362,11 @@ def test_chat_history_and_sequential_regen( ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Initialise chat thread and set output format response = jamai.table.add_table_rows( table_type, - p.RowAddRequest( + t.MultiRowAddRequest( table_id=table.id, data=[ dict(input="x = 0", output="0"), @@ -2278,7 +2383,7 @@ def test_chat_history_and_sequential_regen( # Test regen one row response = jamai.table.regen_table_rows( table_type, - p.RowRegenRequest( + t.MultiRowRegenRequest( table_id=table.id, row_ids=row_ids[3:4], stream=stream, @@ -2290,7 +2395,7 @@ def test_chat_history_and_sequential_regen( # Also test if regen proceeds in correct order from earliest row to latest response = jamai.table.regen_table_rows( table_type, - p.RowRegenRequest( + t.MultiRowRegenRequest( table_id=table.id, row_ids=row_ids[3:][::-1], stream=stream, @@ -2308,16 +2413,16 @@ def test_chat_history_and_sequential_regen( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_convert_into_multi_turn( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="input", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input", dtype="str"), + t.ColumnSchemaCreate( id="output", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( system_prompt="You are a calculator.", prompt="${input}", multi_turn=False, @@ -2328,11 +2433,11 @@ def test_convert_into_multi_turn( ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Initialise chat thread and set output format response = jamai.table.add_table_rows( table_type, - p.RowAddRequest( + t.MultiRowAddRequest( table_id=table.id, data=[ dict(input="x = 0", output="0"), @@ -2346,7 +2451,7 @@ def test_convert_into_multi_turn( # Test adding one row as single-turn response = jamai.table.add_table_rows( table_type, - p.RowAddRequest( + t.MultiRowAddRequest( table_id=table.id, data=[dict(input="x += 1")], stream=stream, @@ -2357,10 +2462,10 @@ def test_convert_into_multi_turn( # Convert into multi-turn table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict( - output=p.LLMGenConfig( + output=t.LLMGenConfig( system_prompt="You are a calculator.", prompt="${input}", multi_turn=True, @@ -2371,12 +2476,12 @@ def test_convert_into_multi_turn( ), ), ) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Regen rows = jamai.table.list_table_rows(table_type, table.id) response = jamai.table.regen_table_rows( table_type, - p.RowRegenRequest( + t.MultiRowRegenRequest( table_id=table.id, row_ids=[rows.items[0]["ID"]], stream=stream, @@ -2391,15 +2496,15 @@ def test_convert_into_multi_turn( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_get_conversation_thread( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="input", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input", dtype="str"), + t.ColumnSchemaCreate( id="output", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( system_prompt="You are a calculator.", prompt="${input}", multi_turn=True, @@ -2410,7 +2515,7 @@ def test_get_conversation_thread( ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Initialise chat thread and set output format data = [ dict(input="x = 0", output="0"), @@ -2419,22 +2524,22 @@ def test_get_conversation_thread( dict(input="Add 3", output="6"), ] response = jamai.table.add_table_rows( - table_type, p.RowAddRequest(table_id=table.id, data=data, stream=False) + table_type, t.MultiRowAddRequest(table_id=table.id, data=data, stream=False) ) row_ids = sorted([r.row_id for r in response.rows]) def _check_thread(_chat): - assert isinstance(_chat, p.ChatThread) + assert isinstance(_chat, t.ChatThreadResponse) for i, message in enumerate(_chat.thread): assert isinstance(message.content, str) assert len(message.content) > 0 if i == 0: - assert message.role == p.ChatRole.SYSTEM + assert message.role == t.ChatRole.SYSTEM elif i % 2 == 1: - assert message.role == p.ChatRole.USER + assert message.role == t.ChatRole.USER assert message.content == data[(i - 1) // 2]["input"] else: - assert message.role == p.ChatRole.ASSISTANT + assert message.role == t.ChatRole.ASSISTANT assert message.content == data[(i // 2) - 1]["output"] # --- Fetch complete thread --- # @@ -2471,32 +2576,32 @@ def test_hybrid_search( client_cls: Type[JamAI], ): jamai = client_cls() - table_type = p.TableType.knowledge + table_type = t.TableType.knowledge with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) data = dict(good=True, words=5, stars=9.9, inputs=TEXT, summary="dummy") rows = jamai.table.add_table_rows( table_type, - p.RowAddRequest( + t.MultiRowAddRequest( table_id=TABLE_ID_A, data=[dict(Title="Resume 2012", Text="Hi there, I am a farmer.", **data)], stream=False, ), ) - assert isinstance(rows, p.GenTableRowsChatCompletionChunks) + assert isinstance(rows, t.MultiRowCompletionResponse) rows = jamai.table.add_table_rows( table_type, - p.RowAddRequest( + t.MultiRowAddRequest( table_id=TABLE_ID_A, data=[dict(Title="Resume 2013", Text="Hi there, I am a carpenter.", **data)], stream=False, ), ) - assert isinstance(rows, p.GenTableRowsChatCompletionChunks) + assert isinstance(rows, t.MultiRowCompletionResponse) rows = jamai.table.add_table_rows( table_type, - p.RowAddRequest( + t.MultiRowAddRequest( table_id=TABLE_ID_A, data=[ dict( @@ -2508,12 +2613,12 @@ def test_hybrid_search( stream=False, ), ) - assert isinstance(rows, p.GenTableRowsChatCompletionChunks) + assert isinstance(rows, t.MultiRowCompletionResponse) sleep(1) # Optional, give it some time to index # Rely on embedding rows = jamai.table.hybrid_search( table_type, - p.SearchRequest( + t.SearchRequest( table_id=TABLE_ID_A, query="language", reranking_model=_get_reranking_model(jamai), @@ -2525,7 +2630,7 @@ def test_hybrid_search( # Rely on FTS rows = jamai.table.hybrid_search( table_type, - p.SearchRequest( + t.SearchRequest( table_id=TABLE_ID_A, query="candidate 2013", reranking_model=_get_reranking_model(jamai), @@ -2537,7 +2642,7 @@ def test_hybrid_search( # hybrid_search without reranker (RRF only) rows = jamai.table.hybrid_search( table_type, - p.SearchRequest( + t.SearchRequest( table_id=TABLE_ID_A, query="language", reranking_model=None, @@ -2555,7 +2660,6 @@ def test_hybrid_search( "file_path", [ "clients/python/tests/files/pdf/salary 总结.pdf", - "clients/python/tests/files/pdf/1970_PSS_ThAT_mechanism.pdf", "clients/python/tests/files/pdf_scan/1978_APL_FP_detrapping.PDF", "clients/python/tests/files/pdf_mixed/digital_scan_combined.pdf", "clients/python/tests/files/md/creative-story.md", @@ -2568,11 +2672,8 @@ def test_hybrid_search( "clients/python/tests/files/jsonl/llm-models.jsonl", "clients/python/tests/files/jsonl/ChatMed_TCM-v0.2-5records.jsonl", "clients/python/tests/files/docx/Recommendation Letter.docx", - "clients/python/tests/files/doc/Recommendation Letter.doc", - "clients/python/tests/files/pptx/(2017.06.30) Neural Machine Translation in Linear Time (ByteNet).pptx", - "clients/python/tests/files/ppt/(2017.06.30) Neural Machine Translation in Linear Time (ByteNet).ppt", + "clients/python/tests/files/pptx/(2017.06.30) NMT in Linear Time (ByteNet).pptx", "clients/python/tests/files/xlsx/Claims Form.xlsx", - "clients/python/tests/files/xls/Claims Form.xls", "clients/python/tests/files/tsv/weather_observations.tsv", "clients/python/tests/files/csv/company-profile.csv", "clients/python/tests/files/csv/weather_observations_long.csv", @@ -2584,12 +2685,12 @@ def test_upload_file( file_path: str, ): jamai = client_cls() - table_type = p.TableType.knowledge + table_type = t.TableType.knowledge with _create_table(jamai, table_type, cols=[]) as table: - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) response = jamai.table.embed_file(file_path, table.id) - assert isinstance(response, p.OkResponse) + assert isinstance(response, t.OkResponse) rows = jamai.table.list_table_rows(table_type, table.id) assert isinstance(rows.items, list) assert all(isinstance(r, dict) for r in rows.items) @@ -2622,15 +2723,15 @@ def test_upload_empty_file( file_path: str, ): jamai = client_cls() - table_type = p.TableType.knowledge + table_type = t.TableType.knowledge with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) pattern = re.compile("There is no text or content to embed") with pytest.raises(RuntimeError, match=pattern): response = jamai.table.embed_file(file_path, table.id) - assert isinstance(response, p.OkResponse) + assert isinstance(response, t.OkResponse) @flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) @@ -2649,10 +2750,10 @@ def test_upload_file_invalid_file_type( file_path: str, ): jamai = client_cls() - table_type = p.TableType.knowledge + table_type = t.TableType.knowledge with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) with pytest.raises(RuntimeError, match=r"File type .+ is unsupported"): jamai.table.embed_file(file_path, table.id) @@ -2694,7 +2795,7 @@ def test_upload_long_file( ): jamai = client_cls() with _create_table(jamai, "knowledge", cols=[], embedding_model="") as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) with TemporaryDirectory() as tmp_dir: # Create a long CSV data = [ @@ -2705,10 +2806,10 @@ def test_upload_long_file( file_path = join(tmp_dir, "long.csv") df_to_csv(pd.DataFrame.from_dict(data * 100), file_path) # Embed the CSV - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) response = jamai.table.embed_file(file_path, table.id) - assert isinstance(response, p.OkResponse) + assert isinstance(response, t.OkResponse) rows = jamai.table.list_table_rows("knowledge", table.id) assert isinstance(rows.items, list) assert all(isinstance(r, dict) for r in rows.items) @@ -2724,4 +2825,4 @@ def test_upload_long_file( if __name__ == "__main__": - test_get_conversation_thread(JamAI, p.TableType.action) + test_get_conversation_thread(JamAI, t.TableType.action) diff --git a/clients/python/tests/oss/gen_table/test_table_ops.py b/clients/python/tests/oss/gen_table/test_table_ops.py index a53d587..8e80531 100644 --- a/clients/python/tests/oss/gen_table/test_table_ops.py +++ b/clients/python/tests/oss/gen_table/test_table_ops.py @@ -8,11 +8,11 @@ from pydantic import ValidationError from jamaibase import JamAI -from jamaibase import protocol as p -from jamaibase.exceptions import ResourceNotFoundError +from jamaibase import types as t +from jamaibase.utils.exceptions import ResourceNotFoundError CLIENT_CLS = [JamAI] -TABLE_TYPES = [p.TableType.action, p.TableType.knowledge, p.TableType.chat] +TABLE_TYPES = [t.TableType.action, t.TableType.knowledge, t.TableType.chat] REGULAR_COLUMN_DTYPES: list[str] = ["int", "float", "bool", "str"] SAMPLE_DATA = { "int": -1, @@ -87,10 +87,10 @@ def _rerun_on_fs_error_with_delay(err, *args): @contextmanager def _create_table( jamai: JamAI, - table_type: p.TableType, + table_type: t.TableType, table_id: str = TABLE_ID_A, - cols: list[p.ColumnSchemaCreate] | None = None, - chat_cols: list[p.ColumnSchemaCreate] | None = None, + cols: list[t.ColumnSchemaCreate] | None = None, + chat_cols: list[t.ColumnSchemaCreate] | None = None, embedding_model: str | None = None, delete_first: bool = True, ): @@ -99,15 +99,15 @@ def _create_table( jamai.table.delete_table(table_type, table_id) if cols is None: cols = [ - p.ColumnSchemaCreate(id="good", dtype="bool"), - p.ColumnSchemaCreate(id="words", dtype="int"), - p.ColumnSchemaCreate(id="stars", dtype="float"), - p.ColumnSchemaCreate(id="inputs", dtype="str"), - p.ColumnSchemaCreate(id="photo", dtype="image"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="good", dtype="bool"), + t.ColumnSchemaCreate(id="words", dtype="int"), + t.ColumnSchemaCreate(id="stars", dtype="float"), + t.ColumnSchemaCreate(id="inputs", dtype="str"), + t.ColumnSchemaCreate(id="photo", dtype="image"), + t.ColumnSchemaCreate( id="summary", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), system_prompt="You are a concise assistant.", # Interpolate string and non-string input columns @@ -117,10 +117,10 @@ def _create_table( max_tokens=10, ), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="captioning", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="", system_prompt="You are a concise assistant.", # Interpolate file input column @@ -133,11 +133,11 @@ def _create_table( ] if chat_cols is None: chat_cols = [ - p.ColumnSchemaCreate(id="User", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="User", dtype="str"), + t.ColumnSchemaCreate( id="AI", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), system_prompt="You are a wacky assistant.", temperature=0.001, @@ -147,25 +147,25 @@ def _create_table( ), ] - if table_type == p.TableType.action: + if table_type == t.TableType.action: table = jamai.table.create_action_table( - p.ActionTableSchemaCreate(id=table_id, cols=cols) + t.ActionTableSchemaCreate(id=table_id, cols=cols) ) - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: if embedding_model is None: embedding_model = "" table = jamai.table.create_knowledge_table( - p.KnowledgeTableSchemaCreate( + t.KnowledgeTableSchemaCreate( id=table_id, cols=cols, embedding_model=embedding_model ) ) - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: table = jamai.table.create_chat_table( - p.ChatTableSchemaCreate(id=table_id, cols=chat_cols + cols) + t.ChatTableSchemaCreate(id=table_id, cols=chat_cols + cols) ) else: raise ValueError(f"Invalid table type: {table_type}") - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) yield table finally: jamai.table.delete_table(table_type, table_id) @@ -174,29 +174,29 @@ def _create_table( @contextmanager def _create_table_v2( jamai: JamAI, - table_type: p.TableType, + table_type: t.TableType, table_id: str = TABLE_ID_A, - cols: list[p.ColumnSchemaCreate] | None = None, - chat_cols: list[p.ColumnSchemaCreate] | None = None, + cols: list[t.ColumnSchemaCreate] | None = None, + chat_cols: list[t.ColumnSchemaCreate] | None = None, llm_model: str = "", embedding_model: str = "", system_prompt: str = "", prompt: str = "", delete_first: bool = True, -) -> Generator[p.TableMetaResponse, None, None]: +) -> Generator[t.TableMetaResponse, None, None]: try: if delete_first: jamai.table.delete_table(table_type, table_id) if cols is None: _input_cols = [ - p.ColumnSchemaCreate(id=f"in_{dtype}", dtype=dtype) + t.ColumnSchemaCreate(id=f"in_{dtype}", dtype=dtype) for dtype in REGULAR_COLUMN_DTYPES ] _output_cols = [ - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id=f"out_{dtype}", dtype=dtype, - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=llm_model, system_prompt=system_prompt, prompt=" ".join(f"${{{col.id}}}" for col in _input_cols) + prompt, @@ -208,11 +208,11 @@ def _create_table_v2( cols = _input_cols + _output_cols if chat_cols is None: chat_cols = [ - p.ColumnSchemaCreate(id="User", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="User", dtype="str"), + t.ColumnSchemaCreate( id="AI", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=llm_model, system_prompt=system_prompt, max_tokens=10, @@ -222,25 +222,25 @@ def _create_table_v2( expected_cols = {"ID", "Updated at"} expected_cols |= {c.id for c in cols} - if table_type == p.TableType.action: + if table_type == t.TableType.action: table = jamai.table.create_action_table( - p.ActionTableSchemaCreate(id=table_id, cols=cols) + t.ActionTableSchemaCreate(id=table_id, cols=cols) ) - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: table = jamai.table.create_knowledge_table( - p.KnowledgeTableSchemaCreate( + t.KnowledgeTableSchemaCreate( id=table_id, cols=cols, embedding_model=embedding_model ) ) expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: table = jamai.table.create_chat_table( - p.ChatTableSchemaCreate(id=table_id, cols=chat_cols + cols) + t.ChatTableSchemaCreate(id=table_id, cols=chat_cols + cols) ) expected_cols |= {c.id for c in chat_cols} else: raise ValueError(f"Invalid table type: {table_type}") - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) col_ids = set(c.id for c in table.cols) assert col_ids == expected_cols yield table @@ -250,7 +250,7 @@ def _create_table_v2( def _add_row( jamai: JamAI, - table_type: p.TableType, + table_type: t.TableType, stream: bool, table_name: str = TABLE_ID_A, data: dict | None = None, @@ -273,35 +273,35 @@ def _add_row( ) if chat_data is None: chat_data = dict(User="Tell me a joke.") - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: data.update(knowledge_data) - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: data.update(chat_data) else: raise ValueError(f"Invalid table type: {table_type}") response = jamai.table.add_table_rows( table_type, - p.RowAddRequest(table_id=table_name, data=[data], stream=stream), + t.MultiRowAddRequest(table_id=table_name, data=[data], stream=stream), ) if stream: return response - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) assert len(response.rows) == 1 return response.rows[0] def _add_row_v2( jamai: JamAI, - table_type: p.TableType, + table_type: t.TableType, stream: bool, table_name: str = TABLE_ID_A, data: dict | None = None, knowledge_data: dict | None = None, chat_data: dict | None = None, include_output_data: bool = False, -) -> p.GenTableRowsChatCompletionChunks: +) -> t.MultiRowCompletionResponse: if data is None: data = {f"in_{dtype}": SAMPLE_DATA[dtype] for dtype in REGULAR_COLUMN_DTYPES} if include_output_data: @@ -318,28 +318,28 @@ def _add_row_v2( chat_data = dict(User="Tell me a joke.") if include_output_data: chat_data.update({"AI": "Nah"}) - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: data.update(knowledge_data) - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: data.update(chat_data) else: raise ValueError(f"Invalid table type: {table_type}") response = jamai.table.add_table_rows( table_type, - p.RowAddRequest(table_id=table_name, data=[data], stream=stream), + t.MultiRowAddRequest(table_id=table_name, data=[data], stream=stream), ) if stream: chunks = [r for r in response] - assert all(isinstance(c, p.GenTableStreamChatCompletionChunk) for c in chunks) + assert all(isinstance(c, t.CellCompletionResponse) for c in chunks) assert all(c.object == "gen_table.completion.chunk" for c in chunks) assert len(set(c.row_id for c in chunks)) == 1 columns = {c.output_column_name: c for c in chunks} - return p.GenTableRowsChatCompletionChunks( - rows=[p.GenTableChatCompletionChunks(columns=columns, row_id=chunks[0].row_id)] + return t.MultiRowCompletionResponse( + rows=[t.RowCompletionResponse(columns=columns, row_id=chunks[0].row_id)] ) - assert isinstance(response, p.GenTableRowsChatCompletionChunks) + assert isinstance(response, t.MultiRowCompletionResponse) assert response.object == "gen_table.completion.rows" assert len(response.rows) == 1 return response @@ -348,7 +348,7 @@ def _add_row_v2( @contextmanager def _rename_table( jamai: JamAI, - table_type: p.TableType, + table_type: t.TableType, table_id_src: str, table_id_dst: str, delete_first: bool = True, @@ -357,7 +357,7 @@ def _rename_table( if delete_first: jamai.table.delete_table(table_type, table_id_dst) table = jamai.table.rename_table(table_type, table_id_src, table_id_dst) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) yield table finally: jamai.table.delete_table(table_type, table_id_dst) @@ -366,7 +366,7 @@ def _rename_table( @contextmanager def _duplicate_table( jamai: JamAI, - table_type: p.TableType, + table_type: t.TableType, table_id_src: str, table_id_dst: str, include_data: bool = True, @@ -383,7 +383,7 @@ def _duplicate_table( include_data=include_data, create_as_child=deploy, ) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) yield table finally: jamai.table.delete_table(table_type, table_id_dst) @@ -392,7 +392,7 @@ def _duplicate_table( @contextmanager def _create_child_table( jamai: JamAI, - table_type: p.TableType, + table_type: t.TableType, table_id_src: str, table_id_dst: str | None, delete_first: bool = True, @@ -404,7 +404,7 @@ def _create_child_table( table_type, table_id_src, table_id_dst, create_as_child=True ) table_id_dst = table.id - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) yield table finally: if isinstance(table_id_dst, str): @@ -412,11 +412,10 @@ def _create_child_table( def _collect_text( - responses: p.GenTableRowsChatCompletionChunks - | Generator[p.GenTableStreamChatCompletionChunk, None, None], + responses: t.MultiRowCompletionResponse | Generator[t.CellCompletionResponse, None, None], col: str, ): - if isinstance(responses, p.GenTableRowsChatCompletionChunks): + if isinstance(responses, t.MultiRowCompletionResponse): return "".join(r.columns[col].text for r in responses.rows) return "".join(r.text for r in responses if r.output_column_name == col) @@ -426,18 +425,18 @@ def _collect_text( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_create_delete_table( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() with _create_table_v2(jamai, table_type) as table_a: with _create_table_v2(jamai, table_type, TABLE_ID_B) as table_b: - assert isinstance(table_a, p.TableMetaResponse) + assert isinstance(table_a, t.TableMetaResponse) assert table_a.id == TABLE_ID_A assert table_b.id == TABLE_ID_B assert isinstance(table_a.cols, list) - assert all(isinstance(c, p.ColumnSchema) for c in table_a.cols) + assert all(isinstance(c, t.ColumnSchema) for c in table_a.cols) table = jamai.table.get_table(table_type, TABLE_ID_B) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # After deleting table B with pytest.raises(ResourceNotFoundError, match="is not found."): jamai.table.get_table(table_type, TABLE_ID_B) @@ -451,12 +450,12 @@ def test_create_delete_table( ) def test_create_table_valid_table_id( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, table_id: str, ): jamai = client_cls() with _create_table(jamai, table_type, table_id) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) assert table.id == table_id @@ -465,29 +464,29 @@ def test_create_table_valid_table_id( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_create_table_valid_column_id( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): table_id = TABLE_ID_A col_ids = ["a", "0", "a b", "a-b", "a_b", "a-_b", "a-_0b", "a -_0b", "0_0"] jamai = client_cls() # --- Test input column --- # - cols = [p.ColumnSchemaCreate(id=_id, dtype="str") for _id in col_ids] + cols = [t.ColumnSchemaCreate(id=_id, dtype="str") for _id in col_ids] with _create_table(jamai, table_type, table_id, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) assert len(set(col_ids) - {c.id for c in table.cols}) == 0 # --- Test output column --- # cols = [ - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id=_id, dtype="str", - gen_config=p.LLMGenConfig(), + gen_config=t.LLMGenConfig(), ) for _id in col_ids ] with _create_table(jamai, table_type, table_id, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) assert len(set(col_ids) - {c.id for c in table.cols}) == 0 @@ -499,7 +498,7 @@ def test_create_table_valid_column_id( ) def test_create_table_invalid_table_id( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, column_id: str, ): table_id = TABLE_ID_A @@ -507,7 +506,7 @@ def test_create_table_invalid_table_id( # --- Test input column --- # cols = [ - p.ColumnSchemaCreate(id=column_id, dtype="str"), + t.ColumnSchemaCreate(id=column_id, dtype="str"), ] with pytest.raises(RuntimeError): with _create_table(jamai, table_type, table_id, cols=cols): @@ -515,10 +514,10 @@ def test_create_table_invalid_table_id( # --- Test output column --- # cols = [ - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id=column_id, dtype="str", - gen_config=p.LLMGenConfig(), + gen_config=t.LLMGenConfig(), ), ] with pytest.raises(RuntimeError): @@ -534,7 +533,7 @@ def test_create_table_invalid_table_id( ) def test_create_table_invalid_column_id( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, column_id: str, ): table_id = TABLE_ID_A @@ -542,7 +541,7 @@ def test_create_table_invalid_column_id( # --- Test input column --- # cols = [ - p.ColumnSchemaCreate(id=column_id, dtype="str"), + t.ColumnSchemaCreate(id=column_id, dtype="str"), ] with pytest.raises(RuntimeError): with _create_table(jamai, table_type, table_id, cols=cols): @@ -550,10 +549,10 @@ def test_create_table_invalid_column_id( # --- Test output column --- # cols = [ - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id=column_id, dtype="str", - gen_config=p.LLMGenConfig(), + gen_config=t.LLMGenConfig(), ), ] with pytest.raises(RuntimeError): @@ -566,16 +565,16 @@ def test_create_table_invalid_column_id( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_create_table_invalid_model( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): table_id = TABLE_ID_A jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="input0", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="str"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig(model="INVALID"), + gen_config=t.LLMGenConfig(model="INVALID"), ), ] with pytest.raises(ResourceNotFoundError): @@ -588,16 +587,16 @@ def test_create_table_invalid_model( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_create_table_invalid_column_ref( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): table_id = TABLE_ID_A jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="input0", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="str"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig(prompt="Summarise ${input2}"), + gen_config=t.LLMGenConfig(prompt="Summarise ${input2}"), ), ] with pytest.raises(RuntimeError): @@ -610,7 +609,7 @@ def test_create_table_invalid_column_ref( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_create_table_invalid_rag( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() @@ -618,25 +617,25 @@ def test_create_table_invalid_rag( with _create_table(jamai, "knowledge", TABLE_ID_B, cols=[]) as ktable: # --- Valid knowledge table ID --- # cols = [ - p.ColumnSchemaCreate(id="input0", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="str"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig( - rag_params=p.RAGParams(table_id=ktable.id), + gen_config=t.LLMGenConfig( + rag_params=t.RAGParams(table_id=ktable.id), ), ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # --- Invalid knowledge table ID --- # cols = [ - p.ColumnSchemaCreate(id="input0", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="str"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig( - rag_params=p.RAGParams(table_id="INVALID"), + gen_config=t.LLMGenConfig( + rag_params=t.RAGParams(table_id="INVALID"), ), ), ] @@ -646,28 +645,28 @@ def test_create_table_invalid_rag( # --- Valid reranker --- # cols = [ - p.ColumnSchemaCreate(id="input0", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="str"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig( - rag_params=p.RAGParams( + gen_config=t.LLMGenConfig( + rag_params=t.RAGParams( table_id=ktable.id, reranking_model=_get_reranking_model(jamai) ), ), ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # --- Invalid reranker --- # cols = [ - p.ColumnSchemaCreate(id="input0", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="str"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig( - rag_params=p.RAGParams(table_id=ktable.id, reranking_model="INVALID"), + gen_config=t.LLMGenConfig( + rag_params=t.RAGParams(table_id=ktable.id, reranking_model="INVALID"), ), ), ] @@ -681,93 +680,93 @@ def test_create_table_invalid_rag( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_default_llm_model( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="input0", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="str"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig(), + gen_config=t.LLMGenConfig(), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="output1", dtype="str", gen_config=None, ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Check gen configs cols = {c.id: c for c in table.cols} - assert isinstance(cols["output0"].gen_config, p.GenConfig) + assert isinstance(cols["output0"].gen_config, t.GenConfig) assert isinstance(cols["output0"].gen_config.model, str) assert len(cols["output0"].gen_config.model) > 0 assert cols["output1"].gen_config is None - if table_type == p.TableType.chat: - assert isinstance(cols["AI"].gen_config, p.GenConfig) + if table_type == t.TableType.chat: + assert isinstance(cols["AI"].gen_config, t.GenConfig) assert isinstance(cols["AI"].gen_config.model, str) assert len(cols["AI"].gen_config.model) > 0 # --- Update gen config --- # table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=TABLE_ID_A, column_map=dict( output0=None, - output1=p.LLMGenConfig(), + output1=t.LLMGenConfig(), ), ), ) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Check gen configs cols = {c.id: c for c in table.cols} assert cols["output0"].gen_config is None - assert isinstance(cols["output1"].gen_config, p.GenConfig) + assert isinstance(cols["output1"].gen_config, t.GenConfig) assert isinstance(cols["output1"].gen_config.model, str) assert len(cols["output1"].gen_config.model) > 0 - if table_type == p.TableType.chat: - assert isinstance(cols["AI"].gen_config, p.GenConfig) + if table_type == t.TableType.chat: + assert isinstance(cols["AI"].gen_config, t.GenConfig) assert isinstance(cols["AI"].gen_config.model, str) assert len(cols["AI"].gen_config.model) > 0 # --- Add column --- # cols = [ - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="output2", dtype="str", gen_config=None, ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="output3", dtype="str", - gen_config=p.LLMGenConfig(), + gen_config=t.LLMGenConfig(), ), ] - if table_type == p.TableType.action: - table = jamai.table.add_action_columns(p.AddActionColumnSchema(id=table.id, cols=cols)) - elif table_type == p.TableType.knowledge: + if table_type == t.TableType.action: + table = jamai.table.add_action_columns(t.AddActionColumnSchema(id=table.id, cols=cols)) + elif table_type == t.TableType.knowledge: table = jamai.table.add_knowledge_columns( - p.AddKnowledgeColumnSchema(id=table.id, cols=cols) + t.AddKnowledgeColumnSchema(id=table.id, cols=cols) ) - elif table_type == p.TableType.chat: - table = jamai.table.add_chat_columns(p.AddChatColumnSchema(id=table.id, cols=cols)) + elif table_type == t.TableType.chat: + table = jamai.table.add_chat_columns(t.AddChatColumnSchema(id=table.id, cols=cols)) else: raise ValueError(f"Invalid table type: {table_type}") # Check gen configs cols = {c.id: c for c in table.cols} assert cols["output0"].gen_config is None - assert isinstance(cols["output1"].gen_config, p.GenConfig) + assert isinstance(cols["output1"].gen_config, t.GenConfig) assert isinstance(cols["output1"].gen_config.model, str) assert len(cols["output1"].gen_config.model) > 0 assert cols["output2"].gen_config is None - assert isinstance(cols["output3"].gen_config, p.GenConfig) + assert isinstance(cols["output3"].gen_config, t.GenConfig) assert isinstance(cols["output3"].gen_config.model, str) assert len(cols["output3"].gen_config.model) > 0 - if table_type == p.TableType.chat: - assert isinstance(cols["AI"].gen_config, p.GenConfig) + if table_type == t.TableType.chat: + assert isinstance(cols["AI"].gen_config, t.GenConfig) assert isinstance(cols["AI"].gen_config.model, str) assert len(cols["AI"].gen_config.model) > 0 @@ -777,107 +776,107 @@ def test_default_llm_model( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_default_image_model( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() available_image_models = _get_image_models(jamai) cols = [ - p.ColumnSchemaCreate(id="input0", dtype="image"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="image"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig(prompt="${input0}"), + gen_config=t.LLMGenConfig(prompt="${input0}"), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="output1", dtype="str", gen_config=None, ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Check gen configs cols = {c.id: c for c in table.cols} - assert isinstance(cols["output0"].gen_config, p.GenConfig) + assert isinstance(cols["output0"].gen_config, t.GenConfig) assert isinstance(cols["output0"].gen_config.model, str) assert cols["output0"].gen_config.model in available_image_models assert cols["output1"].gen_config is None - if table_type == p.TableType.chat: - assert isinstance(cols["AI"].gen_config, p.GenConfig) + if table_type == t.TableType.chat: + assert isinstance(cols["AI"].gen_config, t.GenConfig) assert isinstance(cols["AI"].gen_config.model, str) assert cols["AI"].gen_config.model in available_image_models # --- Update gen config --- # table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=TABLE_ID_A, column_map=dict( output0=None, - output1=p.LLMGenConfig(prompt="${input0}"), + output1=t.LLMGenConfig(prompt="${input0}"), ), ), ) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Check gen configs cols = {c.id: c for c in table.cols} assert cols["output0"].gen_config is None - assert isinstance(cols["output1"].gen_config, p.GenConfig) + assert isinstance(cols["output1"].gen_config, t.GenConfig) assert isinstance(cols["output1"].gen_config.model, str) assert cols["output1"].gen_config.model in available_image_models # --- Add column --- # cols = [ - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="output2", dtype="str", - gen_config=p.LLMGenConfig(prompt="${input0}"), + gen_config=t.LLMGenConfig(prompt="${input0}"), ), - p.ColumnSchemaCreate(id="file_input1", dtype="image"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="file_input1", dtype="image"), + t.ColumnSchemaCreate( id="output3", dtype="str", - gen_config=p.LLMGenConfig(prompt="${file_input1}"), + gen_config=t.LLMGenConfig(prompt="${file_input1}"), ), ] - if table_type == p.TableType.action: - table = jamai.table.add_action_columns(p.AddActionColumnSchema(id=table.id, cols=cols)) - elif table_type == p.TableType.knowledge: + if table_type == t.TableType.action: + table = jamai.table.add_action_columns(t.AddActionColumnSchema(id=table.id, cols=cols)) + elif table_type == t.TableType.knowledge: table = jamai.table.add_knowledge_columns( - p.AddKnowledgeColumnSchema(id=table.id, cols=cols) + t.AddKnowledgeColumnSchema(id=table.id, cols=cols) ) - elif table_type == p.TableType.chat: - table = jamai.table.add_chat_columns(p.AddChatColumnSchema(id=table.id, cols=cols)) + elif table_type == t.TableType.chat: + table = jamai.table.add_chat_columns(t.AddChatColumnSchema(id=table.id, cols=cols)) else: raise ValueError(f"Invalid table type: {table_type}") # Add a column with default prompt cols = [ - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="output4", dtype="str", - gen_config=p.LLMGenConfig(), + gen_config=t.LLMGenConfig(), ), ] - if table_type == p.TableType.action: - table = jamai.table.add_action_columns(p.AddActionColumnSchema(id=table.id, cols=cols)) - elif table_type == p.TableType.knowledge: + if table_type == t.TableType.action: + table = jamai.table.add_action_columns(t.AddActionColumnSchema(id=table.id, cols=cols)) + elif table_type == t.TableType.knowledge: table = jamai.table.add_knowledge_columns( - p.AddKnowledgeColumnSchema(id=table.id, cols=cols) + t.AddKnowledgeColumnSchema(id=table.id, cols=cols) ) - elif table_type == p.TableType.chat: - table = jamai.table.add_chat_columns(p.AddChatColumnSchema(id=table.id, cols=cols)) + elif table_type == t.TableType.chat: + table = jamai.table.add_chat_columns(t.AddChatColumnSchema(id=table.id, cols=cols)) else: raise ValueError(f"Invalid table type: {table_type}") # Check gen configs cols = {c.id: c for c in table.cols} assert cols["output0"].gen_config is None for output_column_name in ["output1", "output2", "output3", "output4"]: - assert isinstance(cols[output_column_name].gen_config, p.GenConfig) + assert isinstance(cols[output_column_name].gen_config, t.GenConfig) model = cols[output_column_name].gen_config.model assert isinstance(model, str) - assert ( - model in available_image_models - ), f'Column {output_column_name} has invalid default model "{model}". Valid: {available_image_models}' + assert model in available_image_models, ( + f'Column {output_column_name} has invalid default model "{model}". Valid: {available_image_models}' + ) @flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) @@ -885,16 +884,16 @@ def test_default_image_model( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_invalid_image_model( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() available_image_models = _get_image_models(jamai) cols = [ - p.ColumnSchemaCreate(id="input0", dtype="image"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="image"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig(model=_get_chat_only_model(jamai), prompt="${input0}"), + gen_config=t.LLMGenConfig(model=_get_chat_only_model(jamai), prompt="${input0}"), ), ] with pytest.raises(RuntimeError): @@ -902,22 +901,22 @@ def test_invalid_image_model( pass cols = [ - p.ColumnSchemaCreate(id="input0", dtype="image"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="image"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig(prompt="${input0}"), + gen_config=t.LLMGenConfig(prompt="${input0}"), ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Check gen configs cols = {c.id: c for c in table.cols} - assert isinstance(cols["output0"].gen_config, p.GenConfig) + assert isinstance(cols["output0"].gen_config, t.GenConfig) assert isinstance(cols["output0"].gen_config.model, str) assert cols["output0"].gen_config.model in available_image_models - if table_type == p.TableType.chat: - assert isinstance(cols["AI"].gen_config, p.GenConfig) + if table_type == t.TableType.chat: + assert isinstance(cols["AI"].gen_config, t.GenConfig) assert isinstance(cols["AI"].gen_config.model, str) assert cols["AI"].gen_config.model in available_image_models @@ -925,10 +924,10 @@ def test_invalid_image_model( with pytest.raises(RuntimeError): table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=TABLE_ID_A, column_map=dict( - output0=p.LLMGenConfig( + output0=t.LLMGenConfig( model=_get_chat_only_model(jamai), prompt="${input0}", ), @@ -937,43 +936,43 @@ def test_invalid_image_model( ) table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=TABLE_ID_A, column_map=dict( - output0=p.LLMGenConfig(prompt="${input0}"), + output0=t.LLMGenConfig(prompt="${input0}"), ), ), ) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Check gen configs cols = {c.id: c for c in table.cols} - assert isinstance(cols["output0"].gen_config, p.GenConfig) + assert isinstance(cols["output0"].gen_config, t.GenConfig) assert isinstance(cols["output0"].gen_config.model, str) assert cols["output0"].gen_config.model in available_image_models - if table_type == p.TableType.chat: - assert isinstance(cols["AI"].gen_config, p.GenConfig) + if table_type == t.TableType.chat: + assert isinstance(cols["AI"].gen_config, t.GenConfig) assert isinstance(cols["AI"].gen_config.model, str) assert cols["AI"].gen_config.model in available_image_models # --- Add column --- # cols = [ - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="output1", dtype="str", - gen_config=p.LLMGenConfig(model=_get_chat_only_model(jamai), prompt="${input0}"), + gen_config=t.LLMGenConfig(model=_get_chat_only_model(jamai), prompt="${input0}"), ) ] with pytest.raises(RuntimeError): - if table_type == p.TableType.action: + if table_type == t.TableType.action: table = jamai.table.add_action_columns( - p.AddActionColumnSchema(id=table.id, cols=cols) + t.AddActionColumnSchema(id=table.id, cols=cols) ) - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: table = jamai.table.add_knowledge_columns( - p.AddKnowledgeColumnSchema(id=table.id, cols=cols) + t.AddKnowledgeColumnSchema(id=table.id, cols=cols) ) - elif table_type == p.TableType.chat: - table = jamai.table.add_chat_columns(p.AddChatColumnSchema(id=table.id, cols=cols)) + elif table_type == t.TableType.chat: + table = jamai.table.add_chat_columns(t.AddChatColumnSchema(id=table.id, cols=cols)) else: raise ValueError(f"Invalid table type: {table_type}") @@ -985,7 +984,7 @@ def test_default_embedding_model( ): jamai = client_cls() with _create_table(jamai, "knowledge", cols=[], embedding_model="") as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) for col in table.cols: if col.vlen == 0: continue @@ -997,23 +996,23 @@ def test_default_embedding_model( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_default_reranker( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() # Create the knowledge table first with _create_table(jamai, "knowledge", TABLE_ID_B, cols=[]) as ktable: cols = [ - p.ColumnSchemaCreate(id="input0", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="str"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig( - rag_params=p.RAGParams(table_id=ktable.id, reranking_model=""), + gen_config=t.LLMGenConfig( + rag_params=t.RAGParams(table_id=ktable.id, reranking_model=""), ), ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) cols = {c.id: c for c in table.cols} reranking_model = cols["output0"].gen_config.rag_params.reranking_model assert isinstance(reranking_model, str) @@ -1026,131 +1025,131 @@ def test_default_reranker( @pytest.mark.parametrize( "messages", [ - [p.ChatEntry.system(""), p.ChatEntry.user("")], - [p.ChatEntry.user("")], + [t.ChatEntry.system(""), t.ChatEntry.user("")], + [t.ChatEntry.user("")], ], ids=["system + user", "user only"], ) def test_default_prompts( client_cls: Type[JamAI], - table_type: p.TableType, - messages: list[p.ChatEntry], + table_type: t.TableType, + messages: list[t.ChatEntry], ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="input0", dtype="str"), - p.ColumnSchemaCreate(id="input1", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="str"), + t.ColumnSchemaCreate(id="input1", dtype="str"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.ChatRequest(messages=messages), + gen_config=t.ChatRequest(messages=messages), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="output1", dtype="str", - gen_config=p.ChatRequest(messages=messages), + gen_config=t.ChatRequest(messages=messages), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="output2", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( system_prompt="You are an assistant.", prompt="Summarise ${input0}.", ), ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # ["output0", "output1"] should have default prompts input_cols = {"input0", "input1"} - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: input_cols |= {"Title", "Text", "File ID", "Page"} else: input_cols |= {"User"} cols = {c.id: c for c in table.cols} for col_id in ["output0", "output1"]: - assert isinstance(cols[col_id].gen_config, p.LLMGenConfig) + assert isinstance(cols[col_id].gen_config, t.LLMGenConfig) user_prompt = cols[col_id].gen_config.prompt - referenced_cols = set(re.findall(p.GEN_CONFIG_VAR_PATTERN, user_prompt)) - assert ( - input_cols == referenced_cols - ), f"Expected input cols = {input_cols}, referenced cols = {referenced_cols}" + referenced_cols = set(re.findall(t.GEN_CONFIG_VAR_PATTERN, user_prompt)) + assert input_cols == referenced_cols, ( + f"Expected input cols = {input_cols}, referenced cols = {referenced_cols}" + ) # ["output2"] should have provided prompts input_cols = {"input0"} cols = {c.id: c for c in table.cols} for col_id in ["output2"]: - assert isinstance(cols[col_id].gen_config, p.LLMGenConfig) + assert isinstance(cols[col_id].gen_config, t.LLMGenConfig) user_prompt = cols[col_id].gen_config.prompt - referenced_cols = set(re.findall(p.GEN_CONFIG_VAR_PATTERN, user_prompt)) - assert ( - input_cols == referenced_cols - ), f"Expected input cols = {input_cols}, referenced cols = {referenced_cols}" + referenced_cols = set(re.findall(t.GEN_CONFIG_VAR_PATTERN, user_prompt)) + assert input_cols == referenced_cols, ( + f"Expected input cols = {input_cols}, referenced cols = {referenced_cols}" + ) # --- Add column --- # cols = [ - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="input2", dtype="int", ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="output3", dtype="str", - gen_config=p.LLMGenConfig(), + gen_config=t.LLMGenConfig(), ), ] - if table_type == p.TableType.action: - table = jamai.table.add_action_columns(p.AddActionColumnSchema(id=table.id, cols=cols)) - elif table_type == p.TableType.knowledge: + if table_type == t.TableType.action: + table = jamai.table.add_action_columns(t.AddActionColumnSchema(id=table.id, cols=cols)) + elif table_type == t.TableType.knowledge: table = jamai.table.add_knowledge_columns( - p.AddKnowledgeColumnSchema(id=table.id, cols=cols) + t.AddKnowledgeColumnSchema(id=table.id, cols=cols) ) - elif table_type == p.TableType.chat: - table = jamai.table.add_chat_columns(p.AddChatColumnSchema(id=table.id, cols=cols)) + elif table_type == t.TableType.chat: + table = jamai.table.add_chat_columns(t.AddChatColumnSchema(id=table.id, cols=cols)) else: raise ValueError(f"Invalid table type: {table_type}") - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # ["output0", "output1"] should have default prompts input_cols = {"input0", "input1"} - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: input_cols |= {"Title", "Text", "File ID", "Page"} else: input_cols |= {"User"} cols = {c.id: c for c in table.cols} for col_id in ["output0", "output1"]: - assert isinstance(cols[col_id].gen_config, p.LLMGenConfig) + assert isinstance(cols[col_id].gen_config, t.LLMGenConfig) user_prompt = cols[col_id].gen_config.prompt - referenced_cols = set(re.findall(p.GEN_CONFIG_VAR_PATTERN, user_prompt)) - assert ( - input_cols == referenced_cols - ), f"Expected input cols = {input_cols}, referenced cols = {referenced_cols}" + referenced_cols = set(re.findall(t.GEN_CONFIG_VAR_PATTERN, user_prompt)) + assert input_cols == referenced_cols, ( + f"Expected input cols = {input_cols}, referenced cols = {referenced_cols}" + ) # ["output3"] should have default prompts input_cols = {"input0", "input1", "input2"} - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: input_cols |= {"Title", "Text", "File ID", "Page"} else: input_cols |= {"User"} for col_id in ["output3"]: - assert isinstance(cols[col_id].gen_config, p.LLMGenConfig) + assert isinstance(cols[col_id].gen_config, t.LLMGenConfig) user_prompt = cols[col_id].gen_config.prompt - referenced_cols = set(re.findall(p.GEN_CONFIG_VAR_PATTERN, user_prompt)) - assert ( - input_cols == referenced_cols - ), f"Expected input cols = {input_cols}, referenced cols = {referenced_cols}" + referenced_cols = set(re.findall(t.GEN_CONFIG_VAR_PATTERN, user_prompt)) + assert input_cols == referenced_cols, ( + f"Expected input cols = {input_cols}, referenced cols = {referenced_cols}" + ) # ["output2"] should have provided prompts input_cols = {"input0"} for col_id in ["output2"]: - assert isinstance(cols[col_id].gen_config, p.LLMGenConfig) + assert isinstance(cols[col_id].gen_config, t.LLMGenConfig) user_prompt = cols[col_id].gen_config.prompt - referenced_cols = set(re.findall(p.GEN_CONFIG_VAR_PATTERN, user_prompt)) - assert ( - input_cols == referenced_cols - ), f"Expected input cols = {input_cols}, referenced cols = {referenced_cols}" + referenced_cols = set(re.findall(t.GEN_CONFIG_VAR_PATTERN, user_prompt)) + assert input_cols == referenced_cols, ( + f"Expected input cols = {input_cols}, referenced cols = {referenced_cols}" + ) @flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) @@ -1158,12 +1157,12 @@ def test_default_prompts( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_add_drop_columns( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() with _create_table_v2(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) _add_row_v2( jamai, table_type, @@ -1173,14 +1172,14 @@ def test_add_drop_columns( # --- COLUMN ADD --- # _input_cols = [ - p.ColumnSchemaCreate(id=f"add_in_{dtype}", dtype=dtype) + t.ColumnSchemaCreate(id=f"add_in_{dtype}", dtype=dtype) for dtype in REGULAR_COLUMN_DTYPES ] _output_cols = [ - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id=f"add_out_{dtype}", dtype=dtype, - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="", system_prompt="", prompt=" ".join(f"${{{col.id}}}" for col in _input_cols), @@ -1195,20 +1194,20 @@ def test_add_drop_columns( expected_cols |= {f"out_{dtype}" for dtype in ["str"]} expected_cols |= {f"add_in_{dtype}" for dtype in REGULAR_COLUMN_DTYPES} expected_cols |= {f"add_out_{dtype}" for dtype in ["str"]} - if table_type == p.TableType.action: - table = jamai.table.add_action_columns(p.AddActionColumnSchema(id=table.id, cols=cols)) - elif table_type == p.TableType.knowledge: + if table_type == t.TableType.action: + table = jamai.table.add_action_columns(t.AddActionColumnSchema(id=table.id, cols=cols)) + elif table_type == t.TableType.knowledge: table = jamai.table.add_knowledge_columns( - p.AddKnowledgeColumnSchema(id=table.id, cols=cols) + t.AddKnowledgeColumnSchema(id=table.id, cols=cols) ) expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: expected_cols |= {"User", "AI"} - table = jamai.table.add_chat_columns(p.AddChatColumnSchema(id=table.id, cols=cols)) + table = jamai.table.add_chat_columns(t.AddChatColumnSchema(id=table.id, cols=cols)) else: raise ValueError(f"Invalid table type: {table_type}") - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) cols = set(c.id for c in table.cols) assert cols == expected_cols, cols # Existing row of new columns should contain None @@ -1242,7 +1241,7 @@ def test_add_drop_columns( # --- COLUMN DROP --- # table = jamai.table.drop_columns( table_type, - p.ColumnDropRequest( + t.ColumnDropRequest( table_id=table.id, column_names=[f"in_{dtype}" for dtype in REGULAR_COLUMN_DTYPES] + [f"out_{dtype}" for dtype in ["str"]], @@ -1251,16 +1250,16 @@ def test_add_drop_columns( expected_cols = {"ID", "Updated at"} expected_cols |= {f"add_in_{dtype}" for dtype in REGULAR_COLUMN_DTYPES} expected_cols |= {f"add_out_{dtype}" for dtype in ["str"]} - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: expected_cols |= {"User", "AI"} else: raise ValueError(f"Invalid table type: {table_type}") - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) cols = set(c.id for c in table.cols) assert cols == expected_cols, cols rows = jamai.table.list_table_rows(table_type, table.id) @@ -1282,12 +1281,12 @@ def test_add_drop_columns( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_add_drop_file_column( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() with _create_table_v2(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) _add_row_v2( jamai, table_type, @@ -1297,11 +1296,11 @@ def test_add_drop_file_column( # --- COLUMN ADD --- # cols = [ - p.ColumnSchemaCreate(id="add_in_file", dtype="image"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="add_in_file", dtype="image"), + t.ColumnSchemaCreate( id="add_out_str", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="", system_prompt="", prompt="Describe image ${add_in_file}", @@ -1312,20 +1311,20 @@ def test_add_drop_file_column( expected_cols = {"ID", "Updated at", "add_in_file", "add_out_str"} expected_cols |= {f"in_{dtype}" for dtype in REGULAR_COLUMN_DTYPES} expected_cols |= {f"out_{dtype}" for dtype in ["str"]} - if table_type == p.TableType.action: - table = jamai.table.add_action_columns(p.AddActionColumnSchema(id=table.id, cols=cols)) - elif table_type == p.TableType.knowledge: + if table_type == t.TableType.action: + table = jamai.table.add_action_columns(t.AddActionColumnSchema(id=table.id, cols=cols)) + elif table_type == t.TableType.knowledge: table = jamai.table.add_knowledge_columns( - p.AddKnowledgeColumnSchema(id=table.id, cols=cols) + t.AddKnowledgeColumnSchema(id=table.id, cols=cols) ) expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: expected_cols |= {"User", "AI"} - table = jamai.table.add_chat_columns(p.AddChatColumnSchema(id=table.id, cols=cols)) + table = jamai.table.add_chat_columns(t.AddChatColumnSchema(id=table.id, cols=cols)) else: raise ValueError(f"Invalid table type: {table_type}") - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) cols = set(c.id for c in table.cols) assert cols == expected_cols, cols # Existing row of new columns should contain None @@ -1358,10 +1357,10 @@ def test_add_drop_file_column( # Block file output column with pytest.raises(RuntimeError): cols = [ - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="add_out_file", dtype="image", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model="", system_prompt="", prompt="Describe image ${add_in_file}", @@ -1369,37 +1368,37 @@ def test_add_drop_file_column( ), ), ] - if table_type == p.TableType.action: - jamai.table.add_action_columns(p.AddActionColumnSchema(id=table.id, cols=cols)) - elif table_type == p.TableType.knowledge: + if table_type == t.TableType.action: + jamai.table.add_action_columns(t.AddActionColumnSchema(id=table.id, cols=cols)) + elif table_type == t.TableType.knowledge: jamai.table.add_knowledge_columns( - p.AddKnowledgeColumnSchema(id=table.id, cols=cols) + t.AddKnowledgeColumnSchema(id=table.id, cols=cols) ) - elif table_type == p.TableType.chat: - jamai.table.add_chat_columns(p.AddChatColumnSchema(id=table.id, cols=cols)) + elif table_type == t.TableType.chat: + jamai.table.add_chat_columns(t.AddChatColumnSchema(id=table.id, cols=cols)) else: raise ValueError(f"Invalid table type: {table_type}") # --- COLUMN DROP --- # table = jamai.table.drop_columns( table_type, - p.ColumnDropRequest( + t.ColumnDropRequest( table_id=table.id, column_names=[f"in_{dtype}" for dtype in REGULAR_COLUMN_DTYPES] + [f"out_{dtype}" for dtype in ["str"]], ), ) expected_cols = {"ID", "Updated at", "add_in_file", "add_out_str"} - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: expected_cols |= {"User", "AI"} else: raise ValueError(f"Invalid table type: {table_type}") - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) cols = set(c.id for c in table.cols) assert cols == expected_cols, cols rows = jamai.table.list_table_rows(table_type, table.id) @@ -1422,12 +1421,12 @@ def test_kt_drop_invalid_columns(client_cls: Type[JamAI]): table_type = "knowledge" jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) for col in KT_FIXED_COLUMN_IDS: with pytest.raises(RuntimeError): jamai.table.drop_columns( table_type, - p.ColumnDropRequest(table_id=table.id, column_names=[col]), + t.ColumnDropRequest(table_id=table.id, column_names=[col]), ) @@ -1437,12 +1436,12 @@ def test_ct_drop_invalid_columns(client_cls: Type[JamAI]): table_type = "chat" jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) for col in CT_FIXED_COLUMN_IDS: with pytest.raises(RuntimeError): jamai.table.drop_columns( table_type, - p.ColumnDropRequest(table_id=table.id, column_names=[col]), + t.ColumnDropRequest(table_id=table.id, column_names=[col]), ) @@ -1451,32 +1450,32 @@ def test_ct_drop_invalid_columns(client_cls: Type[JamAI]): @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_rename_columns( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="x", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="x", dtype="str"), + t.ColumnSchemaCreate( id="y", dtype="str", - gen_config=p.LLMGenConfig(prompt=r"Summarise ${x}, \${x}"), + gen_config=t.LLMGenConfig(prompt=r"Summarise ${x}, \${x}"), ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) # Test rename on empty table table = jamai.table.rename_columns( table_type, - p.ColumnRenameRequest(table_id=table.id, column_map=dict(y="z")), + t.ColumnRenameRequest(table_id=table.id, column_map=dict(y="z")), ) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) expected_cols = {"ID", "Updated at", "x", "z"} - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: expected_cols |= {"User", "AI"} else: raise ValueError(f"Invalid table type: {table_type}") @@ -1484,7 +1483,7 @@ def test_rename_columns( assert cols == expected_cols table = jamai.table.get_table(table_type, table.id) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) cols = set(c.id for c in table.cols) assert cols == expected_cols # Test adding data with new column names @@ -1493,22 +1492,22 @@ def test_rename_columns( # Test also auto gen config reference update table = jamai.table.rename_columns( table_type, - p.ColumnRenameRequest(table_id=table.id, column_map=dict(x="a")), + t.ColumnRenameRequest(table_id=table.id, column_map=dict(x="a")), ) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) expected_cols = {"ID", "Updated at", "a", "z"} - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: expected_cols |= {"User", "AI"} else: raise ValueError(f"Invalid table type: {table_type}") cols = set(c.id for c in table.cols) assert cols == expected_cols table = jamai.table.get_table(table_type, table.id) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) cols = set(c.id for c in table.cols) assert cols == expected_cols # Test auto gen config reference update @@ -1521,14 +1520,14 @@ def test_rename_columns( with pytest.raises(RuntimeError): jamai.table.rename_columns( table_type, - p.ColumnRenameRequest(table_id=table.id, column_map=dict(a="b", z="b")), + t.ColumnRenameRequest(table_id=table.id, column_map=dict(a="b", z="b")), ) # Overlapping new and old column names with pytest.raises(RuntimeError): jamai.table.rename_columns( table_type, - p.ColumnRenameRequest(table_id=table.id, column_map=dict(a="b", z="a")), + t.ColumnRenameRequest(table_id=table.id, column_map=dict(a="b", z="a")), ) @@ -1538,12 +1537,12 @@ def test_kt_rename_invalid_columns(client_cls: Type[JamAI]): table_type = "knowledge" jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) for col in KT_FIXED_COLUMN_IDS: with pytest.raises(RuntimeError): jamai.table.rename_columns( table_type, - p.ColumnRenameRequest(table_id=table.id, column_map={col: col}), + t.ColumnRenameRequest(table_id=table.id, column_map={col: col}), ) @@ -1553,12 +1552,12 @@ def test_ct_rename_invalid_columns(client_cls: Type[JamAI]): table_type = "chat" jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) for col in CT_FIXED_COLUMN_IDS: with pytest.raises(RuntimeError): jamai.table.rename_columns( table_type, - p.ColumnRenameRequest(table_id=table.id, column_map={col: col}), + t.ColumnRenameRequest(table_id=table.id, column_map={col: col}), ) @@ -1567,14 +1566,14 @@ def test_ct_rename_invalid_columns(client_cls: Type[JamAI]): @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_reorder_columns( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) table = jamai.table.get_table(table_type, TABLE_ID_A) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) column_names = [ "inputs", @@ -1596,16 +1595,16 @@ def test_reorder_columns( "summary", "captioning", ] - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: column_names += ["Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"] expected_order = ( expected_order[:2] + ["Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"] + expected_order[2:] ) - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: column_names += ["User", "AI"] expected_order = expected_order[:2] + ["User", "AI"] + expected_order[2:] else: @@ -1615,7 +1614,7 @@ def test_reorder_columns( # Test reorder empty table table = jamai.table.reorder_columns( table_type, - p.ColumnReorderRequest(table_id=TABLE_ID_A, column_names=column_names), + t.ColumnReorderRequest(table_id=TABLE_ID_A, column_names=column_names), ) expected_order = [ "ID", @@ -1628,18 +1627,18 @@ def test_reorder_columns( "summary", "captioning", ] - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: expected_order += ["Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"] - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: expected_order += ["User", "AI"] else: raise ValueError(f"Invalid table type: {table_type}") cols = [c.id for c in table.cols] assert cols == expected_order, cols table = jamai.table.get_table(table_type, TABLE_ID_A) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) cols = [c.id for c in table.cols] assert cols == expected_order, cols # Test add row @@ -1658,14 +1657,14 @@ def test_reorder_columns( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_reorder_columns_invalid( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert isinstance(table, t.TableMetaResponse) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) table = jamai.table.get_table(table_type, TABLE_ID_A) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) column_names = [ "inputs", @@ -1687,16 +1686,16 @@ def test_reorder_columns_invalid( "summary", "captioning", ] - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: column_names += ["Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"] expected_order = ( expected_order[:2] + ["Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"] + expected_order[2:] ) - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: column_names += ["User", "AI"] expected_order = expected_order[:2] + ["User", "AI"] + expected_order[2:] else: @@ -1714,18 +1713,18 @@ def test_reorder_columns_invalid( "photo", "captioning", ] - if table_type == p.TableType.action: + if table_type == t.TableType.action: pass - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: column_names += ["Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"] - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: column_names += ["User", "AI"] else: raise ValueError(f"Invalid table type: {table_type}") with pytest.raises(RuntimeError, match="referenced an invalid source column"): jamai.table.reorder_columns( table_type, - p.ColumnReorderRequest(table_id=TABLE_ID_A, column_names=column_names), + t.ColumnReorderRequest(table_id=TABLE_ID_A, column_names=column_names), ) @@ -1734,119 +1733,119 @@ def test_reorder_columns_invalid( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_update_gen_config( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="input0", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="str"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig(), + gen_config=t.LLMGenConfig(), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="output1", dtype="str", gen_config=None, ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Check gen configs cols = {c.id: c for c in table.cols} - assert isinstance(cols["output0"].gen_config, p.LLMGenConfig) + assert isinstance(cols["output0"].gen_config, t.LLMGenConfig) assert isinstance(cols["output0"].gen_config.system_prompt, str) assert isinstance(cols["output0"].gen_config.prompt, str) assert len(cols["output0"].gen_config.system_prompt) > 0 assert len(cols["output0"].gen_config.prompt) > 0 assert cols["output1"].gen_config is None - if table_type == p.TableType.chat: - assert isinstance(cols["AI"].gen_config, p.GenConfig) + if table_type == t.TableType.chat: + assert isinstance(cols["AI"].gen_config, t.GenConfig) # --- Switch gen config --- # table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict( output0=None, - output1=p.LLMGenConfig(), + output1=t.LLMGenConfig(), ), ), ) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Check gen configs cols = {c.id: c for c in table.cols} assert cols["output0"].gen_config is None - assert isinstance(cols["output1"].gen_config, p.LLMGenConfig) + assert isinstance(cols["output1"].gen_config, t.LLMGenConfig) assert isinstance(cols["output1"].gen_config.system_prompt, str) assert isinstance(cols["output1"].gen_config.prompt, str) assert len(cols["output1"].gen_config.system_prompt) > 0 assert len(cols["output1"].gen_config.prompt) > 0 - if table_type == p.TableType.chat: - assert isinstance(cols["AI"].gen_config, p.GenConfig) + if table_type == t.TableType.chat: + assert isinstance(cols["AI"].gen_config, t.GenConfig) # --- Update gen config --- # table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict( - output0=p.LLMGenConfig(), + output0=t.LLMGenConfig(), ), ), ) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Check gen configs cols = {c.id: c for c in table.cols} - assert isinstance(cols["output0"].gen_config, p.GenConfig) - assert isinstance(cols["output1"].gen_config, p.GenConfig) - if table_type == p.TableType.chat: - assert isinstance(cols["AI"].gen_config, p.GenConfig) + assert isinstance(cols["output0"].gen_config, t.GenConfig) + assert isinstance(cols["output1"].gen_config, t.GenConfig) + if table_type == t.TableType.chat: + assert isinstance(cols["AI"].gen_config, t.GenConfig) # --- Update gen config --- # table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict( output1=None, ), ), ) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Check gen configs cols = {c.id: c for c in table.cols} - assert isinstance(cols["output0"].gen_config, p.GenConfig) + assert isinstance(cols["output0"].gen_config, t.GenConfig) assert cols["output1"].gen_config is None - if table_type == p.TableType.chat: - assert isinstance(cols["AI"].gen_config, p.GenConfig) + if table_type == t.TableType.chat: + assert isinstance(cols["AI"].gen_config, t.GenConfig) # --- Chat AI column must always have gen config --- # - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict(AI=None), ), ) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) cols = {c.id: c for c in table.cols} assert cols["AI"].gen_config is not None # --- Chat AI column multi-turn must always be True --- # - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: chat_cfg = {c.id: c for c in table.cols}["AI"].gen_config chat_cfg.multi_turn = False table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict(AI=chat_cfg), ), ) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) cols = {c.id: c for c in table.cols} assert cols["AI"].gen_config.multi_turn is True @@ -1856,48 +1855,48 @@ def test_update_gen_config( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_update_gen_config_invalid_model( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="input0", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="str"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig(), + gen_config=t.LLMGenConfig(), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="output1", dtype="str", gen_config=None, ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Check gen configs cols = {c.id: c for c in table.cols} - assert isinstance(cols["output0"].gen_config, p.GenConfig) + assert isinstance(cols["output0"].gen_config, t.GenConfig) assert cols["output1"].gen_config is None - if table_type == p.TableType.chat: - assert isinstance(cols["AI"].gen_config, p.GenConfig) + if table_type == t.TableType.chat: + assert isinstance(cols["AI"].gen_config, t.GenConfig) # --- Update gen config --- # with pytest.raises(ResourceNotFoundError): table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict( - output0=p.LLMGenConfig(model="INVALID"), + output0=t.LLMGenConfig(model="INVALID"), ), ), ) table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict( - output0=p.LLMGenConfig(model=_get_chat_model(jamai)), + output0=t.LLMGenConfig(model=_get_chat_model(jamai)), ), ), ) @@ -1908,57 +1907,57 @@ def test_update_gen_config_invalid_model( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_update_gen_config_invalid_column_ref( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="input0", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="str"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig(), + gen_config=t.LLMGenConfig(), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="output1", dtype="str", gen_config=None, ), ] with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Check gen configs cols = {c.id: c for c in table.cols} - assert isinstance(cols["output0"].gen_config, p.LLMGenConfig) + assert isinstance(cols["output0"].gen_config, t.LLMGenConfig) assert isinstance(cols["output0"].gen_config.system_prompt, str) assert isinstance(cols["output0"].gen_config.prompt, str) assert len(cols["output0"].gen_config.system_prompt) > 0 assert len(cols["output0"].gen_config.prompt) > 0 assert cols["output1"].gen_config is None - if table_type == p.TableType.chat: - assert isinstance(cols["AI"].gen_config, p.GenConfig) + if table_type == t.TableType.chat: + assert isinstance(cols["AI"].gen_config, t.GenConfig) # --- Update gen config --- # with pytest.raises(RuntimeError): table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict( - output0=p.LLMGenConfig(prompt="Summarise ${input2}"), + output0=t.LLMGenConfig(prompt="Summarise ${input2}"), ), ), ) table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict( - output0=p.LLMGenConfig(prompt="Summarise ${input0}"), + output0=t.LLMGenConfig(prompt="Summarise ${input0}"), ), ), ) cols = {c.id: c for c in table.cols} - assert isinstance(cols["output0"].gen_config, p.LLMGenConfig) + assert isinstance(cols["output0"].gen_config, t.LLMGenConfig) assert isinstance(cols["output0"].gen_config.system_prompt, str) assert isinstance(cols["output0"].gen_config.prompt, str) assert len(cols["output0"].gen_config.system_prompt) > 0 @@ -1970,42 +1969,42 @@ def test_update_gen_config_invalid_column_ref( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_update_gen_config_invalid_rag( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="input0", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input0", dtype="str"), + t.ColumnSchemaCreate( id="output0", dtype="str", - gen_config=p.LLMGenConfig(), + gen_config=t.LLMGenConfig(), ), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate( id="output1", dtype="str", gen_config=None, ), ] with _create_table(jamai, "knowledge", cols=[]) as ktable: - assert isinstance(ktable, p.TableMetaResponse) + assert isinstance(ktable, t.TableMetaResponse) with _create_table(jamai, table_type, cols=cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) # Check gen configs cols = {c.id: c for c in table.cols} - assert isinstance(cols["output0"].gen_config, p.GenConfig) + assert isinstance(cols["output0"].gen_config, t.GenConfig) assert cols["output1"].gen_config is None - if table_type == p.TableType.chat: - assert isinstance(cols["AI"].gen_config, p.GenConfig) + if table_type == t.TableType.chat: + assert isinstance(cols["AI"].gen_config, t.GenConfig) # --- Invalid knowledge table ID --- # with pytest.raises(ResourceNotFoundError): table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict( - output0=p.LLMGenConfig( - rag_params=p.RAGParams(table_id="INVALID"), + output0=t.LLMGenConfig( + rag_params=t.RAGParams(table_id="INVALID"), ), ), ), @@ -2013,11 +2012,11 @@ def test_update_gen_config_invalid_rag( # --- Valid knowledge table ID --- # table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict( - output0=p.LLMGenConfig( - rag_params=p.RAGParams(table_id=ktable.id), + output0=t.LLMGenConfig( + rag_params=t.RAGParams(table_id=ktable.id), ), ), ), @@ -2027,11 +2026,11 @@ def test_update_gen_config_invalid_rag( with pytest.raises(ResourceNotFoundError): table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict( - output0=p.LLMGenConfig( - rag_params=p.RAGParams( + output0=t.LLMGenConfig( + rag_params=t.RAGParams( table_id=ktable.id, reranking_model="INVALID" ), ), @@ -2041,11 +2040,11 @@ def test_update_gen_config_invalid_rag( # --- Valid reranker --- # table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict( - output0=p.LLMGenConfig( - rag_params=p.RAGParams(table_id=ktable.id, reranking_model=None), + output0=t.LLMGenConfig( + rag_params=t.RAGParams(table_id=ktable.id, reranking_model=None), ), ), ), @@ -2054,11 +2053,11 @@ def test_update_gen_config_invalid_rag( assert cols["output0"].gen_config.rag_params.reranking_model is None table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest( + t.GenConfigUpdateRequest( table_id=table.id, column_map=dict( - output0=p.LLMGenConfig( - rag_params=p.RAGParams(table_id=ktable.id, reranking_model=""), + output0=t.LLMGenConfig( + rag_params=t.RAGParams(table_id=ktable.id, reranking_model=""), ), ), ), @@ -2074,15 +2073,15 @@ def test_update_gen_config_invalid_rag( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_null_gen_config( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) table = jamai.table.update_gen_config( table_type, - p.GenConfigUpdateRequest(table_id=table.id, column_map=dict(summary=None)), + t.GenConfigUpdateRequest(table_id=table.id, column_map=dict(summary=None)), ) response = _add_row( jamai, table_type, stream, data=dict(good=True, words=5, stars=9.9, inputs=TEXT) @@ -2090,9 +2089,9 @@ def test_null_gen_config( if stream: # Must wait until stream ends responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) else: - assert isinstance(response, p.GenTableChatCompletionChunks) + assert isinstance(response, t.RowCompletionResponse) rows = jamai.table.list_table_rows(table_type, table.id) assert isinstance(rows.items, list) assert len(rows.items) == 1 @@ -2105,16 +2104,16 @@ def test_null_gen_config( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_invalid_referenced_column( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() # --- Non-existent column --- # cols = [ - p.ColumnSchemaCreate(id="words", dtype="int"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="words", dtype="int"), + t.ColumnSchemaCreate( id="summary", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), system_prompt="You are a concise assistant.", prompt="Summarise ${inputs}", @@ -2130,11 +2129,11 @@ def test_invalid_referenced_column( # --- Vector column --- # cols = [ - p.ColumnSchemaCreate(id="words", dtype="int"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="words", dtype="int"), + t.ColumnSchemaCreate( id="summary", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), system_prompt="You are a concise assistant.", prompt="Summarise ${Text Embed}", @@ -2155,16 +2154,16 @@ def test_invalid_referenced_column( @pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) def test_gen_config_empty_prompts( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, stream: bool, ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="words", dtype="int"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="words", dtype="int"), + t.ColumnSchemaCreate( id="summary", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), temperature=0.001, top_p=0.001, @@ -2173,11 +2172,11 @@ def test_gen_config_empty_prompts( ), ] chat_cols = [ - p.ColumnSchemaCreate(id="User", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="User", dtype="str"), + t.ColumnSchemaCreate( id="AI", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), temperature=0.001, top_p=0.001, @@ -2186,26 +2185,26 @@ def test_gen_config_empty_prompts( ), ] with _create_table(jamai, table_type, cols=cols, chat_cols=chat_cols) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) data = dict(words=5) - if table_type == p.TableType.knowledge: + if table_type == t.TableType.knowledge: data["Title"] = "Dune: Part Two." data["Text"] = "Dune: Part Two is a 2024 American epic science fiction film." response = jamai.table.add_table_rows( table_type, - p.RowAddRequest(table_id=table.id, data=[data], stream=stream), + t.MultiRowAddRequest(table_id=table.id, data=[data], stream=stream), ) if stream: # Must wait until stream ends responses = [r for r in response] - assert all(isinstance(r, p.GenTableStreamChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.CellCompletionResponse) for r in responses) summary = "".join(r.text for r in responses if r.output_column_name == "summary") assert len(summary) > 0 - if table_type == p.TableType.chat: + if table_type == t.TableType.chat: ai = "".join(r.text for r in responses if r.output_column_name == "AI") assert len(ai) > 0 else: - assert isinstance(response.rows[0], p.GenTableChatCompletionChunks) + assert isinstance(response.rows[0], t.RowCompletionResponse) @pytest.mark.parametrize("client_cls", CLIENT_CLS) @@ -2215,11 +2214,11 @@ def test_gen_config_no_message( jamai = client_cls() with pytest.raises(ValidationError, match="at least 1 item"): _ = [ - p.ColumnSchemaCreate(id="words", dtype="int"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="words", dtype="int"), + t.ColumnSchemaCreate( id="summary", dtype="str", - gen_config=p.ChatRequest( + gen_config=t.ChatRequest( model=_get_chat_model(jamai), messages=[], temperature=0.001, @@ -2235,7 +2234,7 @@ def test_gen_config_no_message( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_get_and_list_tables( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() _delete_tables(jamai) @@ -2245,7 +2244,7 @@ def test_get_and_list_tables( _create_table(jamai, table_type, TABLE_ID_C), _create_table(jamai, table_type, TABLE_ID_X), ): - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) _add_row( jamai, table_type, @@ -2255,7 +2254,7 @@ def test_get_and_list_tables( # Regular case table = jamai.table.get_table(table_type, TABLE_ID_B) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) assert table.id == TABLE_ID_B tables = jamai.table.list_tables(table_type) @@ -2264,7 +2263,7 @@ def test_get_and_list_tables( assert tables.offset == 0 assert tables.limit == 100 assert len(tables.items) == 4 - assert all(isinstance(r, p.TableMetaResponse) for r in tables.items) + assert all(isinstance(r, t.TableMetaResponse) for r in tables.items) # Test various offset and limit tables = jamai.table.list_tables(table_type, offset=3, limit=2) @@ -2273,7 +2272,7 @@ def test_get_and_list_tables( assert tables.offset == 3 assert tables.limit == 2 assert len(tables.items) == 1 - assert all(isinstance(r, p.TableMetaResponse) for r in tables.items) + assert all(isinstance(r, t.TableMetaResponse) for r in tables.items) tables = jamai.table.list_tables(table_type, offset=4, limit=2) assert isinstance(tables.items, list) @@ -2295,7 +2294,7 @@ def test_get_and_list_tables( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_table_search_and_parent_id( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() _delete_tables(jamai) @@ -2305,7 +2304,7 @@ def test_table_search_and_parent_id( _create_table(jamai, table_type, "bear"), _create_table(jamai, table_type, "fear"), ): - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) with ( _create_child_table(jamai, table_type, "beast", "least"), _create_child_table(jamai, table_type, "beast", "lease"), @@ -2318,7 +2317,7 @@ def test_table_search_and_parent_id( assert tables.offset == 0 assert tables.limit == 3 assert len(tables.items) == 3 - assert all(isinstance(r, p.TableMetaResponse) for r in tables.items) + assert all(isinstance(r, t.TableMetaResponse) for r in tables.items) # Search tables = jamai.table.list_tables(table_type, search_query="be", limit=3) assert isinstance(tables.items, list) @@ -2326,7 +2325,7 @@ def test_table_search_and_parent_id( assert tables.offset == 0 assert tables.limit == 3 assert len(tables.items) == 2 - assert all(isinstance(r, p.TableMetaResponse) for r in tables.items) + assert all(isinstance(r, t.TableMetaResponse) for r in tables.items) # Search tables = jamai.table.list_tables(table_type, search_query="ast", limit=3) assert isinstance(tables.items, list) @@ -2334,7 +2333,7 @@ def test_table_search_and_parent_id( assert tables.offset == 0 assert tables.limit == 3 assert len(tables.items) == 3 - assert all(isinstance(r, p.TableMetaResponse) for r in tables.items) + assert all(isinstance(r, t.TableMetaResponse) for r in tables.items) # Search with parent ID tables = jamai.table.list_tables(table_type, search_query="ast", parent_id="beast") assert isinstance(tables.items, list) @@ -2342,7 +2341,7 @@ def test_table_search_and_parent_id( assert tables.offset == 0 assert tables.limit == 100 assert len(tables.items) == 2 - assert all(isinstance(r, p.TableMetaResponse) for r in tables.items) + assert all(isinstance(r, t.TableMetaResponse) for r in tables.items) # Search with parent ID tables = jamai.table.list_tables(table_type, search_query="as", parent_id="beast") assert isinstance(tables.items, list) @@ -2350,7 +2349,7 @@ def test_table_search_and_parent_id( assert tables.offset == 0 assert tables.limit == 100 assert len(tables.items) == 3 - assert all(isinstance(r, p.TableMetaResponse) for r in tables.items) + assert all(isinstance(r, t.TableMetaResponse) for r in tables.items) @flaky(max_runs=5, min_passes=1, rerun_filter=_rerun_on_fs_error_with_delay) @@ -2358,11 +2357,11 @@ def test_table_search_and_parent_id( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_duplicate_table( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) _add_row( jamai, table_type, @@ -2418,12 +2417,12 @@ def test_duplicate_table( ) def test_duplicate_table_invalid_name( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, table_id_dst: str, ): jamai = client_cls() with _create_table(jamai, table_type) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) _add_row( jamai, table_type, @@ -2441,11 +2440,11 @@ def test_duplicate_table_invalid_name( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_create_child_table( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() with _create_table(jamai, table_type) as table_a: - assert isinstance(table_a, p.TableMetaResponse) + assert isinstance(table_a, t.TableMetaResponse) _add_row( jamai, table_type, @@ -2454,7 +2453,7 @@ def test_create_child_table( ) # Duplicate with data with _create_child_table(jamai, table_type, TABLE_ID_A, TABLE_ID_B) as table_b: - assert isinstance(table_b, p.TableMetaResponse) + assert isinstance(table_b, t.TableMetaResponse) # Add another to table A _add_row( jamai, @@ -2484,11 +2483,11 @@ def test_create_child_table( @pytest.mark.parametrize("table_type", TABLE_TYPES) def test_rename_table( client_cls: Type[JamAI], - table_type: p.TableType, + table_type: t.TableType, ): jamai = client_cls() with _create_table(jamai, table_type, TABLE_ID_A) as table: - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) _add_row( jamai, table_type, @@ -2497,7 +2496,7 @@ def test_rename_table( ) # Create child table with _create_child_table(jamai, table_type, TABLE_ID_A, TABLE_ID_B) as child: - assert isinstance(child, p.TableMetaResponse) + assert isinstance(child, t.TableMetaResponse) # Rename with _rename_table(jamai, table_type, TABLE_ID_A, TABLE_ID_C) as table: rows = jamai.table.list_table_rows(table_type, TABLE_ID_C) @@ -2531,11 +2530,11 @@ def test_chat_table_gen_config( ): jamai = client_cls() cols = [ - p.ColumnSchemaCreate(id="User", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="User", dtype="str"), + t.ColumnSchemaCreate( id="AI", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=_get_chat_model(jamai), system_prompt="You are a concise assistant.", multi_turn=False, @@ -2552,4 +2551,4 @@ def test_chat_table_gen_config( if __name__ == "__main__": - test_add_drop_columns(JamAI, p.TableType.action) + test_add_drop_columns(JamAI, t.TableType.action) diff --git a/clients/python/tests/oss/test_admin.py b/clients/python/tests/oss/test_admin.py index c7eb4bb..5261e0b 100644 --- a/clients/python/tests/oss/test_admin.py +++ b/clients/python/tests/oss/test_admin.py @@ -4,7 +4,7 @@ import pytest from jamaibase import JamAI, JamAIAsync -from jamaibase.protocol import LLMModelConfig, ModelDeploymentConfig, ModelListConfig, OkResponse +from jamaibase.types import LLMModelConfig, ModelDeploymentConfig, ModelListConfig, OkResponse from jamaibase.utils import run CLIENT_CLS = [JamAI, JamAIAsync] @@ -60,7 +60,7 @@ async def test_get_set_org_model_config( context_length=8000, languages=["mul"], capabilities=["chat"], - owned_by="ellm", + owned_by=ORG_ID, ) ) async with _set_org_model_config(jamai, ORG_ID, new_config) as response: diff --git a/clients/python/tests/oss/test_chat.py b/clients/python/tests/oss/test_chat.py index 7a0d18a..33345d1 100644 --- a/clients/python/tests/oss/test_chat.py +++ b/clients/python/tests/oss/test_chat.py @@ -5,7 +5,7 @@ from loguru import logger from jamaibase import JamAI, JamAIAsync -from jamaibase import protocol as p +from jamaibase import types as t from jamaibase.utils import run CLIENT_CLS = [JamAI, JamAIAsync] @@ -19,9 +19,9 @@ async def test_model_info( # Get all model info response = await run(jamai.model_info) - assert isinstance(response, p.ModelInfoResponse) + assert isinstance(response, t.ModelInfoListResponse) assert len(response.data) > 0 - assert isinstance(response.data[0], p.ModelInfo) + assert isinstance(response.data[0], t.ModelInfo) model = response.data[0] assert isinstance(model.id, str) assert isinstance(model.context_length, int) @@ -31,20 +31,20 @@ async def test_model_info( # Get specific model info response = await run(jamai.model_info, name=model.id) - assert isinstance(response, p.ModelInfoResponse) + assert isinstance(response, t.ModelInfoListResponse) assert len(response.data) == 1 assert response.data[0].id == model.id # Filter based on capability response = await run(jamai.model_info, capabilities=["chat"]) - assert isinstance(response, p.ModelInfoResponse) + assert isinstance(response, t.ModelInfoListResponse) for m in response.data: assert "chat" in m.capabilities assert "embed" not in m.capabilities assert "rerank" not in m.capabilities response = await run(jamai.model_info, capabilities=["chat", "image"]) - assert isinstance(response, p.ModelInfoResponse) + assert isinstance(response, t.ModelInfoListResponse) for m in response.data: assert "chat" in m.capabilities assert "image" in m.capabilities @@ -52,14 +52,14 @@ async def test_model_info( assert "rerank" not in m.capabilities response = await run(jamai.model_info, capabilities=["embed"]) - assert isinstance(response, p.ModelInfoResponse) + assert isinstance(response, t.ModelInfoListResponse) for m in response.data: assert "chat" not in m.capabilities assert "embed" in m.capabilities assert "rerank" not in m.capabilities response = await run(jamai.model_info, capabilities=["rerank"]) - assert isinstance(response, p.ModelInfoResponse) + assert isinstance(response, t.ModelInfoListResponse) for m in response.data: assert "chat" not in m.capabilities assert "embed" not in m.capabilities @@ -124,12 +124,12 @@ async def test_model_names( def _get_chat_request(model: str, **kwargs): - request = p.ChatRequest( + request = t.ChatRequest( id="test", model=model, messages=[ - p.ChatEntry.system("You are a concise assistant."), - p.ChatEntry.user("What is a llama?"), + t.ChatEntry.system("You are a concise assistant."), + t.ChatEntry.user("What is a llama?"), ], temperature=0.001, top_p=0.001, @@ -171,10 +171,10 @@ async def test_chat_completion( # Non-streaming request = _get_chat_request(model, stream=False) response = await run(jamai.generate_chat_completions, request) - assert isinstance(response, p.ChatCompletionChunk) + assert isinstance(response, t.ChatCompletionChunk) assert isinstance(response.text, str) assert len(response.text) > 1 - assert isinstance(response.usage, p.CompletionUsage) + assert isinstance(response.usage, t.CompletionUsage) assert isinstance(response.prompt_tokens, int) assert isinstance(response.completion_tokens, int) assert response.prompt_tokens > 0 @@ -186,12 +186,12 @@ async def test_chat_completion( request.stream = True responses = await run(jamai.generate_chat_completions, request) assert len(responses) > 0 - assert all(isinstance(r, p.ChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.ChatCompletionChunk) for r in responses) assert all(isinstance(r.text, str) for r in responses) assert len("".join(r.text for r in responses)) > 1 assert all(r.references is None for r in responses) response = responses[-1] - assert all(isinstance(r.usage, p.CompletionUsage) for r in responses) + assert all(isinstance(r.usage, t.CompletionUsage) for r in responses) assert all(isinstance(r.prompt_tokens, int) for r in responses) assert all(isinstance(r.completion_tokens, int) for r in responses) assert response.prompt_tokens > 0 @@ -200,15 +200,15 @@ async def test_chat_completion( TOOLS = { - "get_weather": p.Tool( + "get_weather": t.Tool( type="function", - function=p.Function( + function=t.Function( name="get_weather", description="Get the current weather for a location", - parameters=p.FunctionParameters( + parameters=t.FunctionParameters( type="object", properties={ - "location": p.FunctionParameter( + "location": t.FunctionParameter( type="string", description="The city and state, e.g. San Francisco, CA" ) }, @@ -217,24 +217,24 @@ async def test_chat_completion( ), ), ), - "calculator": p.Tool( + "calculator": t.Tool( type="function", - function=p.Function( + function=t.Function( name="calculator", description="Perform a basic arithmetic operation", - parameters=p.FunctionParameters( + parameters=t.FunctionParameters( type="object", properties={ - "operation": p.FunctionParameter( + "operation": t.FunctionParameter( type="string", description="The arithmetic operation to perform", enum=["add", "subtract", "multiply", "divide"], ), - "first_number": p.FunctionParameter( + "first_number": t.FunctionParameter( type="number", description="The first number", ), - "second_number": p.FunctionParameter( + "second_number": t.FunctionParameter( type="number", description="The second number", ), @@ -270,20 +270,20 @@ async def test_chat_completion_with_tools( ): jamai = client_cls() - tool_choice = p.ToolChoice( + tool_choice = t.ToolChoice( type="function", - function=p.ToolChoiceFunction( + function=t.ToolChoiceFunction( name=tool_prompt["tool_choice"], ), ) # Create a chat request with a tool - request = p.ChatRequestWithTools( + request = t.ChatRequestWithTools( id="test", model=model, messages=[ - p.ChatEntry.system("You are a concise assistant."), - p.ChatEntry.user(tool_prompt["prompt"]), + t.ChatEntry.system("You are a concise assistant."), + t.ChatEntry.user(tool_prompt["prompt"]), ], tools=[v for _, v in TOOLS.items()] if set_multi_tools @@ -297,7 +297,7 @@ async def test_chat_completion_with_tools( # Non-streaming response = await run(jamai.generate_chat_completions, request) - assert isinstance(response, p.ChatCompletionChunk) + assert isinstance(response, t.ChatCompletionChunk) assert isinstance(response.text, str) assert len(response.text) == 0 tool_calls = response.message.tool_calls @@ -306,7 +306,7 @@ async def test_chat_completion_with_tools( assert tool_calls[0].function.name == tool_prompt["tool_choice"] for argument in tool_prompt["response"]: assert argument in tool_calls[0].function.arguments.replace(" ", "") - assert isinstance(response.usage, p.CompletionUsage) + assert isinstance(response.usage, t.CompletionUsage) assert isinstance(response.prompt_tokens, int) assert isinstance(response.completion_tokens, int) assert response.references is None @@ -315,12 +315,12 @@ async def test_chat_completion_with_tools( request.stream = True responses = await run(jamai.generate_chat_completions, request) assert len(responses) > 0 - assert all(isinstance(r, p.ChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.ChatCompletionChunk) for r in responses) assert all(isinstance(r.text, str) for r in responses) assert len("".join(r.text for r in responses)) == 0 assert all(r.references is None for r in responses) response = responses[-1] - assert all(isinstance(r.usage, p.CompletionUsage) for r in responses) + assert all(isinstance(r.usage, t.CompletionUsage) for r in responses) assert all(isinstance(r.prompt_tokens, int) for r in responses) assert all(isinstance(r.completion_tokens, int) for r in responses) assert response.prompt_tokens > 0 @@ -351,13 +351,13 @@ async def test_chat_opener( jamai = client_cls() # Non-streaming - request = p.ChatRequest( + request = t.ChatRequest( id="test", model=model, messages=[ - p.ChatEntry.system("You are a concise assistant."), - p.ChatEntry.assistant("Sam has 7 apples."), - p.ChatEntry.user("How many apples does Sam have?"), + t.ChatEntry.system("You are a concise assistant."), + t.ChatEntry.assistant("Sam has 7 apples."), + t.ChatEntry.user("How many apples does Sam have?"), ], temperature=0.001, top_p=0.001, @@ -365,11 +365,11 @@ async def test_chat_opener( stream=False, ) response = await run(jamai.generate_chat_completions, request) - assert isinstance(response, p.ChatCompletionChunk) + assert isinstance(response, t.ChatCompletionChunk) assert isinstance(response.text, str) assert "7" in response.text or "seven" in response.text.lower() assert len(response.text) > 1 - assert isinstance(response.usage, p.CompletionUsage) + assert isinstance(response.usage, t.CompletionUsage) assert isinstance(response.prompt_tokens, int) assert isinstance(response.completion_tokens, int) assert response.references is None @@ -378,11 +378,11 @@ async def test_chat_opener( request.stream = True responses = await run(jamai.generate_chat_completions, request) assert len(responses) > 0 - assert all(isinstance(r, p.ChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.ChatCompletionChunk) for r in responses) assert all(isinstance(r.text, str) for r in responses) assert "7" in response.text or "seven" in response.text.lower() assert all(r.references is None for r in responses) - assert isinstance(response.usage, p.CompletionUsage) + assert isinstance(response.usage, t.CompletionUsage) assert isinstance(response.prompt_tokens, int) assert isinstance(response.completion_tokens, int) @@ -397,20 +397,20 @@ async def test_chat_user_only( jamai = client_cls() # Non-streaming - request = p.ChatRequest( + request = t.ChatRequest( id="test", model=model, - messages=[p.ChatEntry.user("Hi there")], + messages=[t.ChatEntry.user("Hi there")], temperature=0.001, top_p=0.001, max_tokens=30, stream=False, ) response = await run(jamai.generate_chat_completions, request) - assert isinstance(response, p.ChatCompletionChunk) + assert isinstance(response, t.ChatCompletionChunk) assert isinstance(response.text, str) assert len(response.text) > 1 - assert isinstance(response.usage, p.CompletionUsage) + assert isinstance(response.usage, t.CompletionUsage) assert isinstance(response.prompt_tokens, int) assert isinstance(response.completion_tokens, int) assert response.references is None @@ -419,11 +419,11 @@ async def test_chat_user_only( request.stream = True responses = await run(jamai.generate_chat_completions, request) assert len(responses) > 0 - assert all(isinstance(r, p.ChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.ChatCompletionChunk) for r in responses) assert all(isinstance(r.text, str) for r in responses) assert len("".join(r.text for r in responses)) > 1 assert all(r.references is None for r in responses) - assert isinstance(response.usage, p.CompletionUsage) + assert isinstance(response.usage, t.CompletionUsage) assert isinstance(response.prompt_tokens, int) assert isinstance(response.completion_tokens, int) @@ -438,20 +438,20 @@ async def test_chat_system_only( jamai = client_cls() # Non-streaming - request = p.ChatRequest( + request = t.ChatRequest( id="test", model=model, - messages=[p.ChatEntry.system("You are a concise assistant.")], + messages=[t.ChatEntry.system("You are a concise assistant.")], temperature=0.001, top_p=0.001, max_tokens=30, stream=False, ) response = await run(jamai.generate_chat_completions, request) - assert isinstance(response, p.ChatCompletionChunk) + assert isinstance(response, t.ChatCompletionChunk) assert isinstance(response.text, str) assert len(response.text) > 1 - assert isinstance(response.usage, p.CompletionUsage) + assert isinstance(response.usage, t.CompletionUsage) assert isinstance(response.prompt_tokens, int) assert isinstance(response.completion_tokens, int) assert response.references is None @@ -460,11 +460,11 @@ async def test_chat_system_only( request.stream = True responses = await run(jamai.generate_chat_completions, request) assert len(responses) > 0 - assert all(isinstance(r, p.ChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.ChatCompletionChunk) for r in responses) assert all(isinstance(r.text, str) for r in responses) assert len("".join(r.text for r in responses)) > 1 assert all(r.references is None for r in responses) - assert isinstance(response.usage, p.CompletionUsage) + assert isinstance(response.usage, t.CompletionUsage) assert isinstance(response.prompt_tokens, int) assert isinstance(response.completion_tokens, int) @@ -479,12 +479,12 @@ async def test_long_chat_completion( jamai = client_cls() # Streaming - request = p.ChatRequest( + request = t.ChatRequest( id="test", model=model, messages=[ - p.ChatEntry.system("You are a concise assistant."), - p.ChatEntry.user(" ".join(["What is a llama?"] * 50000)), + t.ChatEntry.system("You are a concise assistant."), + t.ChatEntry.user(" ".join(["What is a llama?"] * 50000)), ], temperature=0.001, top_p=0.001, @@ -493,7 +493,7 @@ async def test_long_chat_completion( ) responses = await run(jamai.generate_chat_completions, request) assert len(responses) == 1 - assert all(isinstance(r, p.ChatCompletionChunk) for r in responses) + assert all(isinstance(r, t.ChatCompletionChunk) for r in responses) completion = responses[0] assert completion.finish_reason == "error" assert "ContextWindowExceededError" in completion.text diff --git a/clients/python/tests/oss/test_embeddings.py b/clients/python/tests/oss/test_embeddings.py index 13a3a01..335e6a3 100644 --- a/clients/python/tests/oss/test_embeddings.py +++ b/clients/python/tests/oss/test_embeddings.py @@ -5,7 +5,7 @@ import pytest from jamaibase import JamAI, JamAIAsync -from jamaibase import protocol as p +from jamaibase import types as t from jamaibase.utils import run CLIENT_CLS = [JamAI, JamAIAsync] @@ -39,13 +39,13 @@ async def test_generate_embeddings( } # Get float embeddings - response = await run(jamai.generate_embeddings, p.EmbeddingRequest(**kwargs)) - assert isinstance(response, p.EmbeddingResponse) + response = await run(jamai.generate_embeddings, t.EmbeddingRequest(**kwargs)) + assert isinstance(response, t.EmbeddingResponse) assert isinstance(response.data, list) - assert all(isinstance(d, p.EmbeddingResponseData) for d in response.data) + assert all(isinstance(d, t.EmbeddingResponseData) for d in response.data) assert all(isinstance(d.embedding, list) for d in response.data) assert isinstance(response.model, str) - assert isinstance(response.usage, p.CompletionUsage) + assert isinstance(response.usage, t.CompletionUsage) if isinstance(inputs, str): assert len(response.data) == 1 else: @@ -54,13 +54,13 @@ async def test_generate_embeddings( # Get base64 embeddings kwargs["encoding_format"] = "base64" - response = await run(jamai.generate_embeddings, p.EmbeddingRequest(**kwargs)) - assert isinstance(response, p.EmbeddingResponse) + response = await run(jamai.generate_embeddings, t.EmbeddingRequest(**kwargs)) + assert isinstance(response, t.EmbeddingResponse) assert isinstance(response.data, list) - assert all(isinstance(d, p.EmbeddingResponseData) for d in response.data) + assert all(isinstance(d, t.EmbeddingResponseData) for d in response.data) assert all(isinstance(d.embedding, str) for d in response.data) assert isinstance(response.model, str) - assert isinstance(response.usage, p.CompletionUsage) + assert isinstance(response.usage, t.CompletionUsage) if isinstance(inputs, str): assert len(response.data) == 1 else: diff --git a/clients/python/tests/oss/test_file.py b/clients/python/tests/oss/test_file.py index d31f90a..2e9be1c 100644 --- a/clients/python/tests/oss/test_file.py +++ b/clients/python/tests/oss/test_file.py @@ -11,12 +11,16 @@ from PIL import Image from jamaibase import JamAI, JamAIAsync -from jamaibase.protocol import ( +from jamaibase.types import ( FileUploadResponse, GetURLResponse, ) from jamaibase.utils import run -from jamaibase.utils.io import generate_audio_thumbnail, generate_image_thumbnail +from jamaibase.utils.io import ( + generate_extension_name_thumbnail, + generate_pdf_thumbnail, +) +from owl.utils.io import generate_audio_thumbnail, generate_image_thumbnail def read_file_content(file_path): @@ -37,6 +41,11 @@ def read_file_content(file_path): "clients/python/tests/files/mp3/turning-a4-size-magazine.mp3", ] +DOC_FILES = [ + "clients/python/tests/files/pdf/1970_PSS_ThAT_mechanism.pdf", + "clients/python/tests/files/xlsx/Claims Form.xlsx", +] + CLIENT_CLS = [JamAI, JamAIAsync] @@ -51,9 +60,9 @@ async def test_upload_image(client_cls: Type[JamAI | JamAIAsync], image_file: st # Upload the file upload_response = await run(jamai.file.upload_file, image_file) assert isinstance(upload_response, FileUploadResponse) - assert upload_response.uri.startswith( - ("file://", "s3://") - ), f"Returned URI '{upload_response.uri}' does not start with 'file://' or 's3://'" + assert upload_response.uri.startswith(("file://", "s3://")), ( + f"Returned URI '{upload_response.uri}' does not start with 'file://' or 's3://'" + ) filename = os.path.basename(image_file) expected_uri_pattern = re.compile( @@ -80,9 +89,9 @@ async def test_upload_audio(client_cls: Type[JamAI | JamAIAsync], audio_file: st # Upload the file upload_response = await run(jamai.file.upload_file, audio_file) assert isinstance(upload_response, FileUploadResponse) - assert upload_response.uri.startswith( - ("file://", "s3://") - ), f"Returned URI '{upload_response.uri}' does not start with 'file://' or 's3://'" + assert upload_response.uri.startswith(("file://", "s3://")), ( + f"Returned URI '{upload_response.uri}' does not start with 'file://' or 's3://'" + ) filename = os.path.basename(audio_file) expected_uri_pattern = re.compile( @@ -98,6 +107,35 @@ async def test_upload_audio(client_cls: Type[JamAI | JamAIAsync], audio_file: st print(f"Returned URI matches the expected format: {upload_response.uri}") +@pytest.mark.parametrize("client_cls", CLIENT_CLS) +@pytest.mark.parametrize("doc_file", DOC_FILES) +async def test_upload_doc(client_cls: Type[JamAI | JamAIAsync], doc_file: str): + # Initialize the client + jamai = client_cls() + + # Ensure the doc file exists + assert os.path.exists(doc_file), f"Test doc file does not exist: {doc_file}" + # Upload the file + upload_response = await run(jamai.file.upload_file, doc_file) + assert isinstance(upload_response, FileUploadResponse) + assert upload_response.uri.startswith(("file://", "s3://")), ( + f"Returned URI '{upload_response.uri}' does not start with 'file://' or 's3://'" + ) + + filename = os.path.basename(doc_file) + expected_uri_pattern = re.compile( + r"(file|s3)://[^/]+/raw/default/default/[a-f0-9-]{36}/" + re.escape(filename) + "$" + ) + + # Check if the returned URI matches the expected format + assert expected_uri_pattern.match(upload_response.uri), ( + f"Returned URI '{upload_response.uri}' does not match the expected format: " + f"(file|s3)://file/raw/default/default/{{UUID}}/{filename}" + ) + + print(f"Returned URI matches the expected format: {upload_response.uri}") + + @pytest.mark.parametrize("client_cls", CLIENT_CLS) async def test_upload_large_image_file(client_cls: Type[JamAI | JamAIAsync]): jamai = client_cls() @@ -121,28 +159,23 @@ async def test_get_raw_urls(client_cls: Type[JamAI | JamAIAsync]): jamai = client_cls() # Upload files first uploaded_uris = [] - for file in IMAGE_FILES + AUDIO_FILES: + for file in IMAGE_FILES + AUDIO_FILES + DOC_FILES: response = await run(jamai.file.upload_file, file) uploaded_uris.append(response.uri) # Now test get_raw_urls response = await run(jamai.file.get_raw_urls, uploaded_uris) assert isinstance(response, GetURLResponse) - assert len(response.urls) == len(IMAGE_FILES + AUDIO_FILES) - for original_file, url in zip(IMAGE_FILES + AUDIO_FILES, response.urls, strict=True): - if url.startswith(("http://", "https://")): - # Handle HTTP/HTTPS URLs - HEADERS = {"X-PROJECT-ID": "default"} - with httpx.Client() as client: - downloaded_content = client.get(url, headers=HEADERS).content - + assert len(response.urls) == len(IMAGE_FILES + AUDIO_FILES + DOC_FILES) + for original_file, url in zip( + IMAGE_FILES + AUDIO_FILES + DOC_FILES, response.urls, strict=True + ): # Read the original file content original_content = read_file_content(original_file) - # Compare the contents - assert ( - original_content == downloaded_content - ), f"Content mismatch for file: {original_file}" + assert original_content == httpx.get(url).content, ( + f"Content mismatch for file: {original_file}" + ) # Check if the returned URIs are absolute paths for url in response.urls: @@ -163,14 +196,14 @@ async def test_get_thumbnail_urls(client_cls: Type[JamAI | JamAIAsync]): # Upload files first uploaded_uris = [] - for file in IMAGE_FILES + AUDIO_FILES: + for file in IMAGE_FILES + AUDIO_FILES + DOC_FILES: response = await run(jamai.file.upload_file, file) uploaded_uris.append(response.uri) # Now test get_thumbnail_urls response = await run(jamai.file.get_thumbnail_urls, uploaded_uris) assert isinstance(response, GetURLResponse) - assert len(response.urls) == len(IMAGE_FILES + AUDIO_FILES) + assert len(response.urls) == len(IMAGE_FILES + AUDIO_FILES + DOC_FILES) # Generate thumbnails and compare for original_file, url in zip(IMAGE_FILES, response.urls[: len(IMAGE_FILES)], strict=True): @@ -182,17 +215,21 @@ async def test_get_thumbnail_urls(client_cls: Type[JamAI | JamAIAsync]): assert expected_thumbnail is not None, f"Failed to generate thumbnail for {original_file}" if url.startswith(("http://", "https://")): - downloaded_thumbnail = httpx.get(url, headers={"X-PROJECT-ID": "default"}).content + downloaded_thumbnail = httpx.get(url).content else: downloaded_thumbnail = read_file_content(url) # Compare thumbnails - assert ( - expected_thumbnail == downloaded_thumbnail - ), f"Thumbnail mismatch for file: {original_file}" + assert expected_thumbnail == downloaded_thumbnail, ( + f"Thumbnail mismatch for file: {original_file}" + ) # Generate audio thumbnails and compare - for original_file, url in zip(AUDIO_FILES, response.urls[len(IMAGE_FILES) :], strict=True): + for original_file, url in zip( + AUDIO_FILES, + response.urls[len(IMAGE_FILES) : len(IMAGE_FILES) + len(AUDIO_FILES)], + strict=True, + ): # Read original file content original_content = read_file_content(original_file) @@ -201,7 +238,7 @@ async def test_get_thumbnail_urls(client_cls: Type[JamAI | JamAIAsync]): assert expected_thumbnail is not None, f"Failed to generate thumbnail for {original_file}" if url.startswith(("http://", "https://")): - downloaded_thumbnail = httpx.get(url, headers={"X-PROJECT-ID": "default"}).content + downloaded_thumbnail = httpx.get(url).content else: downloaded_thumbnail = read_file_content(url) @@ -212,6 +249,31 @@ async def test_get_thumbnail_urls(client_cls: Type[JamAI | JamAIAsync]): == downloaded_thumbnail[-round(len(expected_thumbnail) * 0.9) :] ), f"Thumbnail mismatch for file: {original_file}" + # Generate doc thumbnails and compare + for original_file, url in zip( + DOC_FILES, response.urls[len(IMAGE_FILES) + len(AUDIO_FILES) :], strict=True + ): + # Read original file content + original_content = read_file_content(original_file) + + # Generate document thumbnail + file_extension = os.path.splitext(original_file)[1].lower() + if file_extension == ".pdf": + expected_thumbnail = generate_pdf_thumbnail(original_content) + else: + expected_thumbnail = generate_extension_name_thumbnail(file_extension) + assert expected_thumbnail is not None, f"Failed to generate thumbnail for {original_file}" + + if url.startswith(("http://", "https://")): + downloaded_thumbnail = httpx.get(url).content + else: + downloaded_thumbnail = read_file_content(url) + + # Compare thumbnails + assert expected_thumbnail == downloaded_thumbnail, ( + f"Thumbnail mismatch for file: {original_file}" + ) + # Check if the returned URIs are valid for url in response.urls: parsed_uri = urlparse(url) @@ -233,7 +295,7 @@ async def test_thumbnail_transparency(client_cls: Type[JamAI | JamAIAsync]): assert len(response.urls) == 1 thumb_url = response.urls[0] if thumb_url.startswith(("http://", "https://")): - downloaded_thumbnail = httpx.get(thumb_url, headers={"X-PROJECT-ID": "default"}).content + downloaded_thumbnail = httpx.get(thumb_url).content else: downloaded_thumbnail = read_file_content(thumb_url) diff --git a/clients/python/tests/oss/test_gen_executor.py b/clients/python/tests/oss/test_gen_executor.py index 1ca3d1d..df399b3 100644 --- a/clients/python/tests/oss/test_gen_executor.py +++ b/clients/python/tests/oss/test_gen_executor.py @@ -1,6 +1,5 @@ import asyncio import io -import time from contextlib import asynccontextmanager import httpx @@ -9,22 +8,22 @@ from PIL import Image from jamaibase import JamAI, JamAIAsync -from jamaibase.exceptions import ResourceNotFoundError -from jamaibase.protocol import ( +from jamaibase.types import ( + CellCompletionResponse, CodeGenConfig, ColumnSchemaCreate, GenConfigUpdateRequest, - GenTableRowsChatCompletionChunks, - GenTableStreamChatCompletionChunk, GetURLResponse, + MultiRowAddRequest, + MultiRowCompletionResponse, + MultiRowRegenRequest, RegenStrategy, - RowAddRequest, - RowRegenRequest, RowUpdateRequest, TableSchemaCreate, TableType, ) from jamaibase.utils import run +from jamaibase.utils.exceptions import ResourceNotFoundError CLIENT_CLS = [JamAI, JamAIAsync] REGEN_STRATEGY = [ @@ -183,12 +182,12 @@ async def test_exceed_context_length( chunks = await run( jamai.table.add_table_rows, TableType.action, - RowAddRequest(table_id=table_id, data=[row_input_data], stream=stream), + MultiRowAddRequest(table_id=table_id, data=[row_input_data], stream=stream), ) if stream: - assert isinstance(chunks[0], GenTableStreamChatCompletionChunk) + assert isinstance(chunks[0], CellCompletionResponse) else: - assert isinstance(chunks, GenTableRowsChatCompletionChunks) + assert isinstance(chunks, MultiRowCompletionResponse) # Get rows rows = await run(jamai.table.list_table_rows, TableType.action, table_id) @@ -199,140 +198,6 @@ async def test_exceed_context_length( assert column_gen.startswith("[ERROR]") -@flaky(max_runs=3, min_passes=1) -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -@pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) -async def test_multicols_concurrency_timing( - client_cls: JamAI | JamAIAsync, - stream: bool, -): - jamai = client_cls() - cols = IN_COLS[:2] + OUT_COLS[:3] - async with _create_table(jamai, TableType.action, cols) as table_id: - row_input_data = {"in_01": "0", "in_02": "100"} - column_map = COLUMN_MAP_CONCURRENCY.copy() - - async def execute(): - gen_config = GenConfigUpdateRequest( - table_id=table_id, - column_map=column_map, - ) - await _update_gen_config(jamai, TableType.action, gen_config) - - start_time = time.time() - chunks = await run( - jamai.table.add_table_rows, - TableType.action, - RowAddRequest( - table_id=table_id, data=[row_input_data], stream=stream, concurrent=True - ), - ) - if stream: - assert isinstance(chunks[0], GenTableStreamChatCompletionChunk) - else: - assert isinstance(chunks, GenTableRowsChatCompletionChunks) - execution_time = time.time() - start_time - return execution_time - - execution_time_3_cols = await execute() - column_map.pop("out_02") - column_map.pop("out_03") - execution_time_1_col = await execute() - - assert abs(execution_time_3_cols - execution_time_1_col) < (execution_time_1_col * 1.5) - - -@flaky(max_runs=3, min_passes=1) -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -@pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) -async def test_multirows_multicols_concurrency_timing( - client_cls: JamAI | JamAIAsync, - stream: bool, -): - jamai = client_cls() - cols = IN_COLS[:2] + OUT_COLS[:3] - async with _create_table(jamai, TableType.action, cols) as table_id: - rows_input_data = [ - {"in_01": "0", "in_02": "200"}, - {"in_01": "1", "in_02": "201"}, - {"in_01": "2", "in_02": "202"}, - ] - column_map = COLUMN_MAP_CONCURRENCY - - async def execute(): - gen_config = GenConfigUpdateRequest( - table_id=table_id, - column_map=column_map, - ) - await _update_gen_config(jamai, TableType.action, gen_config) - - start_time = time.time() - chunks = await run( - jamai.table.add_table_rows, - TableType.action, - RowAddRequest( - table_id=table_id, data=rows_input_data, stream=stream, concurrent=True - ), - ) - if stream: - assert isinstance(chunks[0], GenTableStreamChatCompletionChunk) - else: - assert isinstance(chunks, GenTableRowsChatCompletionChunks) - execution_time = time.time() - start_time - return execution_time - - execution_time_3_rows = await execute() - rows_input_data = rows_input_data[:1] - execution_time_1_row = await execute() - - assert abs(execution_time_3_rows - execution_time_1_row) < (execution_time_1_row * 1.5) - - -@flaky(max_runs=3, min_passes=1) -@pytest.mark.parametrize("client_cls", CLIENT_CLS) -@pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) -async def test_multicols_dependency( - client_cls: JamAI | JamAIAsync, - stream: bool, -): - jamai = client_cls() - cols = IN_COLS[:2] + OUT_COLS[:5] - async with _create_table(jamai, TableType.action, cols) as table_id: - row_input_data = {"in_01": "8", "in_02": "2"} - column_map = COLUMN_MAP_DEPENDENCY - ground_truths = { - "out_01": "10", - "out_02": "-6", - "out_03": "-60", - "out_04": "360", - "out_05": "120", - } - - gen_config = GenConfigUpdateRequest( - table_id=table_id, - column_map=column_map, - ) - await _update_gen_config(jamai, TableType.action, gen_config) - - chunks = await run( - jamai.table.add_table_rows, - TableType.action, - RowAddRequest(table_id=table_id, data=[row_input_data], stream=stream), - ) - if stream: - assert isinstance(chunks[0], GenTableStreamChatCompletionChunk) - else: - assert isinstance(chunks, GenTableRowsChatCompletionChunks) - - # Get rows - rows = await run(jamai.table.list_table_rows, TableType.action, table_id) - row_id = rows.items[0]["ID"] - row = await run(jamai.table.get_table_row, TableType.action, table_id, row_id) - - for output_column_name in column_map.keys(): - assert ground_truths[output_column_name] in row[output_column_name]["value"] - - @flaky(max_runs=3, min_passes=1) @pytest.mark.parametrize("client_cls", CLIENT_CLS) @pytest.mark.parametrize("regen_strategy", REGEN_STRATEGY) @@ -414,13 +279,13 @@ async def test_multicols_regen( chunks = await run( jamai.table.add_table_rows, TableType.action, - RowAddRequest(table_id=table_id, data=[row_input_data], stream=stream), + MultiRowAddRequest(table_id=table_id, data=[row_input_data], stream=stream), ) if not only_input_columns: if stream: - assert isinstance(chunks[0], GenTableStreamChatCompletionChunk) + assert isinstance(chunks[0], CellCompletionResponse) else: - assert isinstance(chunks, GenTableRowsChatCompletionChunks) + assert isinstance(chunks, MultiRowCompletionResponse) # Get rows rows = await run(jamai.table.list_table_rows, TableType.action, table_id) @@ -442,7 +307,7 @@ async def test_multicols_regen( chunks = await run( jamai.table.regen_table_rows, TableType.action, - RowRegenRequest( + MultiRowRegenRequest( table_id=table_id, row_ids=[row_id], regen_strategy=regen_strategy, @@ -453,9 +318,9 @@ async def test_multicols_regen( ) if not only_input_columns: if stream: - assert isinstance(chunks[0], GenTableStreamChatCompletionChunk) + assert isinstance(chunks[0], CellCompletionResponse) else: - assert isinstance(chunks, GenTableRowsChatCompletionChunks) + assert isinstance(chunks, MultiRowCompletionResponse) # Get rows rows = await run(jamai.table.list_table_rows, TableType.action, table_id) @@ -500,12 +365,12 @@ async def test_multicols_regen_invalid_column_id( chunks = await run( jamai.table.add_table_rows, TableType.action, - RowAddRequest(table_id=table_id, data=[row_input_data], stream=stream), + MultiRowAddRequest(table_id=table_id, data=[row_input_data], stream=stream), ) if stream: - assert isinstance(chunks[0], GenTableStreamChatCompletionChunk) + assert isinstance(chunks[0], CellCompletionResponse) else: - assert isinstance(chunks, GenTableRowsChatCompletionChunks) + assert isinstance(chunks, MultiRowCompletionResponse) # Get rows rows = await run(jamai.table.list_table_rows, TableType.action, table_id) @@ -526,14 +391,14 @@ async def test_multicols_regen_invalid_column_id( with pytest.raises( ResourceNotFoundError, match=( - f'`output_column_id` .*{invalid_output_column_id}.* is not found. ' + f"`output_column_id` .*{invalid_output_column_id}.* is not found. " f"Available output columns:.*{'.*'.join(ground_truths.keys())}.*" ), ): await run( jamai.table.regen_table_rows, TableType.action, - RowRegenRequest( + MultiRowRegenRequest( table_id=table_id, row_ids=[row_id], regen_strategy=regen_strategy, @@ -580,15 +445,15 @@ async def test_code_str(client_cls: JamAI | JamAIAsync, stream: bool): chunks = await run( jamai.table.add_table_rows, TableType.action, - RowAddRequest(table_id=table_id, data=[row_input_data], stream=stream), + MultiRowAddRequest(table_id=table_id, data=[row_input_data], stream=stream), ) if stream: print(chunks[0]) - assert isinstance(chunks[0], GenTableStreamChatCompletionChunk) + assert isinstance(chunks[0], CellCompletionResponse) else: print(chunks) - assert isinstance(chunks, GenTableRowsChatCompletionChunks) + assert isinstance(chunks, MultiRowCompletionResponse) # Get rows rows = await run(jamai.table.list_table_rows, TableType.action, table_id) @@ -602,7 +467,7 @@ async def test_code_str(client_cls: JamAI | JamAIAsync, stream: bool): chunks = await run( jamai.table.add_table_rows, TableType.action, - RowAddRequest(table_id=table_id, data=[row_input_data], stream=stream), + MultiRowAddRequest(table_id=table_id, data=[row_input_data], stream=stream), ) rows = await run(jamai.table.list_table_rows, TableType.action, table_id) row_id = rows.items[0]["ID"] @@ -668,15 +533,15 @@ async def test_code_image(client_cls: JamAI | JamAIAsync, stream: bool): chunks = await run( jamai.table.add_table_rows, TableType.action, - RowAddRequest(table_id=table_id, data=[row_input_data], stream=stream), + MultiRowAddRequest(table_id=table_id, data=[row_input_data], stream=stream), ) if stream: print(chunks[0]) - assert isinstance(chunks[0], GenTableStreamChatCompletionChunk) + assert isinstance(chunks[0], CellCompletionResponse) else: print(chunks) - assert isinstance(chunks, GenTableRowsChatCompletionChunks) + assert isinstance(chunks, MultiRowCompletionResponse) # Get rows rows = await run(jamai.table.list_table_rows, TableType.action, table_id) @@ -693,12 +558,7 @@ async def test_code_image(client_cls: JamAI | JamAIAsync, stream: bool): assert isinstance(response, GetURLResponse) for url in response.urls: if url.startswith(("http://", "https://")): - # Handle HTTP/HTTPS URLs - HEADERS = {"X-PROJECT-ID": "default"} - with httpx.Client() as client: - downloaded_content = client.get(url, headers=HEADERS).content - - image = Image.open(io.BytesIO(downloaded_content)) + image = Image.open(io.BytesIO(httpx.get(url).content)) assert image.format == case["expected_format"] # Test error handling @@ -707,7 +567,7 @@ async def test_code_image(client_cls: JamAI | JamAIAsync, stream: bool): chunks = await run( jamai.table.add_table_rows, TableType.action, - RowAddRequest(table_id=table_id, data=[row_input_data], stream=stream), + MultiRowAddRequest(table_id=table_id, data=[row_input_data], stream=stream), ) rows = await run(jamai.table.list_table_rows, TableType.action, table_id) diff --git a/clients/python/tests/oss/test_template.py b/clients/python/tests/oss/test_template.py index c9aeb95..cc0ab1e 100644 --- a/clients/python/tests/oss/test_template.py +++ b/clients/python/tests/oss/test_template.py @@ -5,21 +5,21 @@ import pytest from jamaibase import JamAI, JamAIAsync -from jamaibase import protocol as p +from jamaibase import types as t from jamaibase.utils import run CLIENT_CLS = [JamAI, JamAIAsync] -TABLE_TYPES = [p.TableType.action, p.TableType.knowledge, p.TableType.chat] +TABLE_TYPES = [t.TableType.action, t.TableType.knowledge, t.TableType.chat] @asynccontextmanager async def _create_gen_table( jamai: JamAI, - table_type: p.TableType, + table_type: t.TableType, table_id: str, model_id: str = "", - cols: list[p.ColumnSchemaCreate] | None = None, - chat_cols: list[p.ColumnSchemaCreate] | None = None, + cols: list[t.ColumnSchemaCreate] | None = None, + chat_cols: list[t.ColumnSchemaCreate] | None = None, embedding_model: str = "", delete_first: bool = True, delete: bool = True, @@ -29,11 +29,11 @@ async def _create_gen_table( await run(jamai.table.delete_table, table_type, table_id) if cols is None: cols = [ - p.ColumnSchemaCreate(id="input", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="input", dtype="str"), + t.ColumnSchemaCreate( id="output", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=model_id, prompt="${input}", max_tokens=3, @@ -42,36 +42,36 @@ async def _create_gen_table( ] if chat_cols is None: chat_cols = [ - p.ColumnSchemaCreate(id="User", dtype="str"), - p.ColumnSchemaCreate( + t.ColumnSchemaCreate(id="User", dtype="str"), + t.ColumnSchemaCreate( id="AI", dtype="str", - gen_config=p.LLMGenConfig( + gen_config=t.LLMGenConfig( model=model_id, system_prompt="You are an assistant.", max_tokens=3, ), ), ] - if table_type == p.TableType.action: + if table_type == t.TableType.action: table = await run( - jamai.table.create_action_table, p.ActionTableSchemaCreate(id=table_id, cols=cols) + jamai.table.create_action_table, t.ActionTableSchemaCreate(id=table_id, cols=cols) ) - elif table_type == p.TableType.knowledge: + elif table_type == t.TableType.knowledge: table = await run( jamai.table.create_knowledge_table, - p.KnowledgeTableSchemaCreate( + t.KnowledgeTableSchemaCreate( id=table_id, cols=cols, embedding_model=embedding_model ), ) - elif table_type == p.TableType.chat: + elif table_type == t.TableType.chat: table = await run( jamai.table.create_chat_table, - p.ChatTableSchemaCreate(id=table_id, cols=chat_cols + cols), + t.ChatTableSchemaCreate(id=table_id, cols=chat_cols + cols), ) else: raise ValueError(f"Invalid table type: {table_type}") - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) yield table finally: if delete: @@ -82,7 +82,7 @@ async def _create_gen_table( async def test_populate_templates(client_cls: Type[JamAI]): client = client_cls() response = await run(client.admin.backend.populate_templates) - assert isinstance(response, p.OkResponse) + assert isinstance(response, t.OkResponse) @pytest.mark.parametrize("client_cls", CLIENT_CLS) @@ -92,7 +92,7 @@ async def test_list_templates(client_cls: Type[JamAI]): assert len(response.items) == response.total templates = response.items assert len(templates) > 0 - assert all(isinstance(t, p.Template) for t in templates) + assert all(isinstance(t, t.Template) for t in templates) for template in templates: assert len(template.id) > 0 assert len(template.name) > 0 @@ -107,12 +107,12 @@ async def test_get_template(client_cls: Type[JamAI]): template_id = templates[0].id # Fetch template template = await run(client.template.get_template, template_id) - assert isinstance(template, p.Template) + assert isinstance(template, t.Template) assert len(template.id) > 0 assert len(template.name) > 0 assert len(template.created_at) > 0 assert len(template.tags) > 0 - assert all(isinstance(t, p.TemplateTag) for t in template.tags) + assert all(isinstance(t, t.TemplateTag) for t in template.tags) @pytest.mark.parametrize("client_cls", CLIENT_CLS) @@ -123,15 +123,15 @@ async def test_list_tables(client_cls: Type[JamAI]): assert len(templates) > 0 template_id = templates[0].id # List tables - tables: list[p.TableMetaResponse] = [] + tables: list[t.TableMetaResponse] = [] for table_type in TABLE_TYPES: tables += (await run(client.template.list_tables, template_id, table_type)).items assert len(tables) > 0 - assert all(isinstance(t, p.TableMetaResponse) for t in tables) + assert all(isinstance(t, t.TableMetaResponse) for t in tables) for table in tables: assert len(table.id) > 0 assert len(table.cols) > 0 - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) assert len(table.updated_at) > 0 # Create a template by exporting default project @@ -147,7 +147,7 @@ async def test_list_tables(client_cls: Type[JamAI]): new_template_id = "test_template" with BytesIO(data) as f: response = await run(client.admin.backend.add_template, f, new_template_id, True) - assert isinstance(response, p.OkResponse) + assert isinstance(response, t.OkResponse) # Search query tables = ( @@ -220,10 +220,10 @@ async def test_get_table(client_cls: Type[JamAI]): table_count += len(tables) table_id = tables[0].id table = await run(client.template.get_table, template_id, table_type, table_id) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) assert len(table.id) > 0 assert len(table.cols) > 0 - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) assert len(table.updated_at) > 0 assert table_count > 0 @@ -244,10 +244,10 @@ async def test_list_table_rows(client_cls: Type[JamAI]): table_count += len(tables) table_id = tables[0].id table = await run(client.template.get_table, template_id, table_type, table_id) - assert isinstance(table, p.TableMetaResponse) + assert isinstance(table, t.TableMetaResponse) assert len(table.id) > 0 assert len(table.cols) > 0 - assert all(isinstance(c, p.ColumnSchema) for c in table.cols) + assert all(isinstance(c, t.ColumnSchema) for c in table.cols) assert len(table.updated_at) > 0 # List rows rows = ( diff --git a/clients/typescript/README.md b/clients/typescript/README.md index 910c304..5107148 100644 --- a/clients/typescript/README.md +++ b/clients/typescript/README.md @@ -130,7 +130,7 @@ Create an API client with api key and project id: ```javascript import JamAI from "jamaibase"; -const jamai = new JamAI({ apiKey: "jamai_apikey", projectId: "proj_id" }); +const jamai = new JamAI({ token: "jamai_pat", projectId: "proj_id" }); ``` Create an API client with custom HTTP client: @@ -186,18 +186,12 @@ jamai.setHttpagentConfig({ }); ``` -Can be imported from different modules depending on the need: - -```javascript -import JamAI from "jamaibase/index.umd.js"; -``` - ### Types Types can be imported from resources: ```javascript -import { ChatRequest } from "jamaibase/dist/resources/llm/chat"; +import { ChatRequest } from "jamaibase/resources/llm/chat"; let response: ChatRequest; ``` @@ -297,7 +291,7 @@ To integrate JamAI into a React application, follow these steps: import { useEffect, useState } from "react"; import JamAI from "jamaibase"; -import { PageListTableMetaResponse } from "jamaibase/dist/resources/gen_tables/tables"; +import { PageListTableMetaResponse } from "jamaibase/resources/gen_tables/tables"; export default function Home() { const [tableData, setTableData] = useState(); @@ -420,7 +414,7 @@ export async function GET(request: NextRequest) { "use client"; -import { PageListTableMetaResponse } from "jamaibase/dist/resources/gen_tables/tables"; +import { PageListTableMetaResponse } from "jamaibase/resources/gen_tables/tables"; import { ChangeEvent, useEffect, useState } from "react"; export default function Home() { diff --git a/clients/typescript/__tests__/file.test.ts b/clients/typescript/__tests__/file.test.ts index 65f2482..670625d 100644 --- a/clients/typescript/__tests__/file.test.ts +++ b/clients/typescript/__tests__/file.test.ts @@ -155,4 +155,13 @@ describe("APIClient File", () => { const parsedDataGetThumbUrl = GetUrlResponseSchema.parse(responseGetThumbUrls); expect(parsedDataGetThumbUrl).toEqual(responseGetThumbUrls); }); + + it("audio file upload by file path", async () => { + const responseUploadFile = await client.file.uploadFile({ + file_path: path.resolve(__dirname, "./zoom-in-audio.mp3") + }); + + const parsedDataUploadFile = UploadFileResponseSchema.parse(responseUploadFile); + expect(parsedDataUploadFile).toEqual(responseUploadFile); + }); }); diff --git a/clients/typescript/__tests__/gentable.test.ts b/clients/typescript/__tests__/gentable.test.ts index 56936bc..bcbffd2 100644 --- a/clients/typescript/__tests__/gentable.test.ts +++ b/clients/typescript/__tests__/gentable.test.ts @@ -1,7 +1,7 @@ import JamAI from "@/index"; import tmp from "tmp"; -import { GenTableRowsChatCompletionChunksSchema, GetConversationThreadResponseSchema } from "@/resources/gen_tables/chat"; +import { GetConversationThreadResponseSchema, MultiRowCompletionResponseSchema } from "@/resources/gen_tables/chat"; import { ColumnSchema, ColumnSchemaCreate, @@ -508,7 +508,7 @@ describe("APIClient Gentable", () => { concurrent: true }); - const parsedData = GenTableRowsChatCompletionChunksSchema.parse(response); + const parsedData = MultiRowCompletionResponseSchema.parse(response); expect(parsedData).toEqual(response); } }); @@ -530,7 +530,7 @@ describe("APIClient Gentable", () => { concurrent: true }); - const parsedData = GenTableRowsChatCompletionChunksSchema.parse(response); + const parsedData = MultiRowCompletionResponseSchema.parse(response); expect(parsedData).toEqual(response); } }); @@ -1244,7 +1244,7 @@ describe("APIClient Gentable", () => { // @TODO // verify that the suggestions output is different after regen - // const parsedregenRowResponseData = GenTableRowsChatCompletionChunksSchema.parse(regenRowResponse); + // const parsedregenRowResponseData = MultiRowCompletionResponseSchema.parse(regenRowResponse); // expect(parsedregenRowResponseData).toEqual(regenRowResponse); // const listRowResponse2 = await client.table.listRows({ @@ -1315,7 +1315,7 @@ describe("APIClient Gentable", () => { if (done) { break; } - // console.log(GenTableStreamChatCompletionChunkSchema.parse(value)); + // console.log(ColumnCompletionResponseSchema.parse(value)); } } }); diff --git a/clients/typescript/__tests__/llm.test.ts b/clients/typescript/__tests__/llm.test.ts index b22f9d2..2623064 100644 --- a/clients/typescript/__tests__/llm.test.ts +++ b/clients/typescript/__tests__/llm.test.ts @@ -231,14 +231,9 @@ describe("APIClient LLM", () => { }); it("generate chat completion", async () => { - try { - console.log("model: ", requestDataChat.model); - const response = await client.llm.generateChatCompletions(requestDataChat); + const response = await client.llm.generateChatCompletions(requestDataChat); - expect(ChatCompletionChunkSchema.parse(response)).toEqual(response); - } catch (err: any) { - console.log("error: ", err.response.data); - } + expect(ChatCompletionChunkSchema.parse(response)).toEqual(response); }); it("generate chat completion - stream", async () => { diff --git a/clients/typescript/__tests__/template.test.ts b/clients/typescript/__tests__/template.test.ts index ad6b837..48c982c 100644 --- a/clients/typescript/__tests__/template.test.ts +++ b/clients/typescript/__tests__/template.test.ts @@ -13,7 +13,7 @@ dotenv.config({ path: "__tests__/.env" }); -describe("APIClient Templates", () => { +describe.skip("APIClient Templates", () => { let client: JamAI; jest.setTimeout(30000); jest.retryTimes(1, { diff --git a/clients/typescript/__tests__/zoom-in-audio.mp3 b/clients/typescript/__tests__/zoom-in-audio.mp3 new file mode 100644 index 0000000..75587f7 Binary files /dev/null and b/clients/typescript/__tests__/zoom-in-audio.mp3 differ diff --git a/clients/typescript/build b/clients/typescript/build old mode 100644 new mode 100755 index 690e10c..713911d --- a/clients/typescript/build +++ b/clients/typescript/build @@ -3,14 +3,16 @@ set -e node scripts/remove-tests-tsconfig.cjs # rm -rf dist && npx microbundle --tsconfig tsconfig.json --no-sourcemap && tsc-alias -p tsconfig.json -rimraf dist && tsc --project tsconfig.json && tsc-alias -p tsconfig.json && rollup -c --bundleConfigAsCjs +rimraf dist && tsc --project tsconfig.json && rollup -c --bundleConfigAsCjs && tsc-alias -p tsconfig.json -# cp -rp README.md dist -# for file in LICENSE CHANGELOG.md; do -# if [ -e "${file}" ]; then cp "${file}" dist; fi -# done +# Generate index.d.mts for ESM compatibility +cp dist/index.d.ts dist/index.d.mts +# Post-process files to fix import paths and compatibility +# node scripts/postprocess-files.cjs + +# Create package.json for dist folder node scripts/make-dist-package-json.cjs > dist/package.json # make sure that nothing crashes when we require the output CJS or @@ -18,7 +20,6 @@ node scripts/make-dist-package-json.cjs > dist/package.json (cd dist && node -e 'require("jamaibase")') (cd dist && node -e 'import("jamaibase")' --input-type=module) - # include "__tests__" folder in tsconfig to facilitate unit test. node scripts/include-tests-tsconfig.cjs diff --git a/clients/typescript/package-lock.json b/clients/typescript/package-lock.json index b2a4fce..e34e6ac 100644 --- a/clients/typescript/package-lock.json +++ b/clients/typescript/package-lock.json @@ -1,12 +1,12 @@ { "name": "jamaibase", - "version": "0.2.1", + "version": "0.5.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "jamaibase", - "version": "0.2.1", + "version": "0.5.0", "license": "Apache-2.0", "dependencies": { "agentkeepalive": "^4.5.0", diff --git a/clients/typescript/package.json b/clients/typescript/package.json index 04bb4b6..83d06a5 100644 --- a/clients/typescript/package.json +++ b/clients/typescript/package.json @@ -1,6 +1,6 @@ { "name": "jamaibase", - "version": "0.3.0", + "version": "0.5.0", "description": "JamAIBase Client SDK (JS/TS). JamAI Base: Let Your Database Orchestrate LLMs and RAG", "main": "dist/index.cjs", "module": "dist/index.mjs", @@ -13,7 +13,7 @@ "build": "/bin/bash build", "openapi-to-zod": "openapi-zod-client openapi.json -o zodschema/zodmodels.ts", "doc-ts-moduler": "typedoc --includeVersion --tsconfig tsconfig.build.json --includes ./dist/*.d.ts --includes ./dist/**/*.d.ts --includes ./dist/resources/**/*.d.ts --out docs-autogen-ts", - "doc-ts": "typedoc --readme ./README.md --includeVersion --tsconfig tsconfig.build.json --entryPoints ./dist/index.d.ts --out docs-autogen-ts && cp JamAI_Base_Cover.png docs-autogen-ts/" + "doc-ts": "typedoc --readme ./README.md --includeVersion --tsconfig tsconfig.build.json --entryPointStrategy expand --entryPoints ./dist/index.d.ts --entryPoints ./dist/resources --out docs-autogen-ts --categorizeByGroup --sort alphabetical && cp JamAI_Base_Cover.png docs-autogen-ts/" }, "repository": { "type": "git", @@ -27,29 +27,26 @@ ], "exports": { ".": { - "import": "./dist/index.mjs", - "require": "./dist/index.cjs.js", + "import": { + "types": "./dist/index.d.ts", + "default": "./dist/index.mjs" + }, + "require": { + "types": "./dist/index.d.ts", + "default": "./dist/index.cjs.js" + }, "browser": "./dist/index.umd.js", - "types": "./dist/index.d.ts", "default": "./dist/index.mjs" }, - "./index.mjs": { - "types": "./dist/index.d.ts", - "default": "./dist/index.mjs" - }, - "./index.cjs.js": { - "types": "./dist/index.d.ts", - "default": "./index.cjs.js" - }, - "./index.umd.js": { - "types": "./dist/index.d.ts", - "default": "./dist/index.umd.js" + "./resources/*": { + "import": "./dist/resources/*", + "require": "./dist/resources/*" } }, "files": [ "dist/**/*" ], - "author": "EmbeddedLLM, Tan Tun Jian", + "author": "EmbeddedLLM, Tan Tun Jian, Kamil Hassan", "license": "Apache-2.0", "bugs": { "url": "https://github.com/EmbeddedLLM/JAM.ai.dev/issues" diff --git a/clients/typescript/rollup.config.ts b/clients/typescript/rollup.config.ts new file mode 100644 index 0000000..d1381dc --- /dev/null +++ b/clients/typescript/rollup.config.ts @@ -0,0 +1,117 @@ +import commonjs from "@rollup/plugin-commonjs"; +import dynamicImportVars from "@rollup/plugin-dynamic-import-vars"; +import json from "@rollup/plugin-json"; +import resolve from "@rollup/plugin-node-resolve"; +import typescript from "@rollup/plugin-typescript"; +import copy from "rollup-plugin-copy"; +import builtins from "rollup-plugin-node-builtins"; +import globals from "rollup-plugin-node-globals"; +import polyfillNode from "rollup-plugin-polyfill-node"; + +export default [ + // Node.js Builds (CJS and ES Modules) + { + input: "src/index.ts", + output: [ + { + file: "dist/index.cjs.js", + format: "cjs", + sourcemap: true, + inlineDynamicImports: true + }, + { + file: "dist/index.mjs", + format: "es", + sourcemap: true, + inlineDynamicImports: true + } + ], + external: [ + "axios", + "zod", + "uuid", + "path", + "fs", + "os", + "agentkeepalive", + "axios-retry", + "csv-parser", + "mime-types", + "formdata-node", + "path-browserify" + ], + plugins: [ + + typescript({ + tsconfig: "./tsconfig.json", + sourceMap: true, + declaration: true, + emitDeclarationOnly: false + }), + json(), + + resolve({ + browser: false, + preferBuiltins: true, + extensions: ['.mjs', '.js', ".ts", '.jsx', '.json', '.sass', '.scss'] + }), + commonjs(), + dynamicImportVars(), + + copy({ + targets: [ + { src: "README.md", dest: "dist" }, + { src: "LICENSE", dest: "dist" }, + { src: "CHANGELOG.md", dest: "dist" } + ] + }) + ] + }, + + // Browser Build (UMD) + { + input: "src/index.ts", + output: { + file: "dist/index.umd.js", + format: "umd", + exports: "named", + name: "JamAI", + sourcemap: true, + inlineDynamicImports: true, + globals: { + axios: "axios", + zod: "zod", + uuid: "uuid" + } + }, + external: ["axios", "zod", "uuid"], + plugins: [ + + typescript({ + tsconfig: "./tsconfig.json", + sourceMap: true, + declaration: true, + emitDeclarationOnly: false + }), + json(), + polyfillNode(), + builtins(), + globals(), + + resolve({ + browser: true, + preferBuiltins: false, + extensions: ['.js', '.ts', '.tsx', '.json'] + }), + + commonjs(), + dynamicImportVars(), + + copy({ + targets: [ + // Only need to copy once, so you can remove this if already copied + ] + }) + ] + } +]; \ No newline at end of file diff --git a/clients/typescript/scripts/postprocess-files.cjs b/clients/typescript/scripts/postprocess-files.cjs index e7adf87..fa06da4 100644 --- a/clients/typescript/scripts/postprocess-files.cjs +++ b/clients/typescript/scripts/postprocess-files.cjs @@ -2,7 +2,7 @@ const fs = require("fs"); const path = require("path"); const { parse } = require("@typescript-eslint/parser"); -const pkgImportPath = process.env["PKG_IMPORT_PATH"] ?? "jamaisdk/"; +const pkgImportPath = process.env["PKG_IMPORT_PATH"] ?? "jamaibase/"; const distDir = process.env["DIST_PATH"] ? path.resolve(process.env["DIST_PATH"]) : path.resolve(__dirname, "..", "dist"); const distSrcDir = path.join(distDir, "src"); diff --git a/clients/typescript/src/index.ts b/clients/typescript/src/index.ts index 7f61ee2..46f4544 100644 --- a/clients/typescript/src/index.ts +++ b/clients/typescript/src/index.ts @@ -74,4 +74,11 @@ class JamAI extends Base { } } -export default JamAI; +// // Re-export types from internal modules for easier access +// export * from "@/resources/base"; +// export * from "@/resources/files"; +// export * from "@/resources/gen_tables/tables"; +// export * from "@/resources/llm/chat"; +// export * from "@/resources/templates"; + +export default JamAI; diff --git a/clients/typescript/src/resources/files/index.ts b/clients/typescript/src/resources/files/index.ts index 1cfe7d6..cad5d8f 100644 --- a/clients/typescript/src/resources/files/index.ts +++ b/clients/typescript/src/resources/files/index.ts @@ -12,14 +12,25 @@ import { UploadFileResponseSchema } from "./types"; +async function createFormData() { + if (!isRunningInBrowser()) { + // Node environment + // (import from `formdata-node`) + const { FormData } = await import("formdata-node"); + return new FormData(); + } else { + // Browser environment + return new FormData(); + } +} export class Files extends Base { public async uploadFile(params: IUploadFileRequest): Promise { - const apiURL = `/api/v1/files/upload`; + const apiURL = `/api/v2/files/upload`; const parsedParams = UploadFileRequestSchema.parse(params); // Create FormData to send as multipart/form-data - const formData = new FormData(); + const formData = await createFormData(); if (parsedParams.file) { formData.append("file", parsedParams.file, parsedParams.file.name); } else if (parsedParams.file_path) { @@ -27,7 +38,10 @@ export class Files extends Base { const mimeType = await getMimeType(parsedParams.file_path!); const fileName = await getFileName(parsedParams.file_path!); const data = await readFile(parsedParams.file_path!); - const file = new Blob([data], { type: mimeType }); + // const file = new Blob([data], { type: mimeType }); + const { File } = await import("formdata-node"); + const file = new File([data], fileName, { type: mimeType }); + // @ts-ignore formData.append("file", file, fileName); } else { throw new Error("Pass File instead of file path if you are using this function in client."); @@ -47,7 +61,7 @@ export class Files extends Base { public async getRawUrls(params: IGetUrlRequest): Promise { const parsedParams = GetUrlRequestSchema.parse(params); - const apiURL = `/api/v1/files/url/raw`; + const apiURL = `/api/v2/files/url/raw`; const response = await this.httpClient.post(apiURL, { uris: parsedParams.uris }); @@ -56,7 +70,7 @@ export class Files extends Base { public async getThumbUrls(params: IGetUrlRequest): Promise { const parsedParams = GetUrlRequestSchema.parse(params); - const apiURL = `/api/v1/files/url/thumb`; + const apiURL = `/api/v2/files/url/thumb`; const response = await this.httpClient.post(apiURL, { uris: parsedParams.uris }); diff --git a/clients/typescript/src/resources/files/types.ts b/clients/typescript/src/resources/files/types.ts index 14aba8a..a11f16a 100644 --- a/clients/typescript/src/resources/files/types.ts +++ b/clients/typescript/src/resources/files/types.ts @@ -3,9 +3,6 @@ import { z } from "zod"; export const UploadFileRequestSchema = z.object({ file: z .any() - .refine((value) => value instanceof File, { - message: "Value must be a File object" - }) .optional(), file_path: z.string().optional() }); diff --git a/clients/typescript/src/resources/gen_tables/chat.ts b/clients/typescript/src/resources/gen_tables/chat.ts index c95baf4..eee07c5 100644 --- a/clients/typescript/src/resources/gen_tables/chat.ts +++ b/clients/typescript/src/resources/gen_tables/chat.ts @@ -3,11 +3,11 @@ import { ChatCompletionChunkSchema, ChatEntrySchema, ReferencesSchema } from "@/ import { z } from "zod"; export const GetConversationThreadRequestSchema = z.object({ + table_type: TableTypesSchema, table_id: IdSchema, column_id: IdSchema, - row_id: z.string().default(""), - table_type: TableTypesSchema, - include: z.boolean().default(true) + row_id: z.string().optional(), + include: z.boolean().optional() }); export const GetConversationThreadResponseSchema = z.object({ @@ -21,18 +21,18 @@ export const GenTableChatCompletionChunksSchema = z.object({ row_id: z.string() }); -export const GenTableRowsChatCompletionChunksSchema = z.object({ +export const MultiRowCompletionResponseSchema = z.object({ object: z.enum(["gen_table.completion.rows"]), rows: z.array(GenTableChatCompletionChunksSchema) }); -export const GenTableStreamChatCompletionChunkSchema = ChatCompletionChunkSchema.extend({ +export const ColumnCompletionResponseSchema = ChatCompletionChunkSchema.extend({ object: z.enum(["gen_table.completion.chunk"]), output_column_name: z.string(), row_id: z.string() }); -export const GenTableStreamReferencesSchema = ReferencesSchema.extend({ +export const RowReferencesResponseSchema = ReferencesSchema.extend({ object: z.enum(["gen_table.references"]), output_column_name: z.string() }); @@ -42,7 +42,7 @@ export type CreateChatTableRequest = z.input; export type GetConversationThreadResponse = z.infer; -export type GenTableChatCompletionChunks = z.infer; -export type GenTableRowsChatCompletionChunks = z.infer; -export type GenTableStreamChatCompletionChunk = z.infer; -export type GenTableStreamReferences = z.infer; +export type RowCompletionResponse = z.infer; +export type MultiRowCompletionResponse = z.infer; +export type CellCompletionResponse = z.infer; +export type CellReferencesResponse = z.infer; diff --git a/clients/typescript/src/resources/gen_tables/index.ts b/clients/typescript/src/resources/gen_tables/index.ts index 8ff882c..5936e92 100644 --- a/clients/typescript/src/resources/gen_tables/index.ts +++ b/clients/typescript/src/resources/gen_tables/index.ts @@ -9,28 +9,31 @@ import { CreateActionTableRequestSchema } from "@/resources/gen_tables/action"; import { + CellCompletionResponse, + CellReferencesResponse, + ColumnCompletionResponseSchema, CreateChatTableRequest, CreateChatTableRequestSchema, - GenTableRowsChatCompletionChunks, - GenTableRowsChatCompletionChunksSchema, - GenTableStreamChatCompletionChunk, - GenTableStreamChatCompletionChunkSchema, - GenTableStreamReferences, - GenTableStreamReferencesSchema, GetConversationThreadRequest, GetConversationThreadRequestSchema, GetConversationThreadResponse, - GetConversationThreadResponseSchema + GetConversationThreadResponseSchema, + MultiRowCompletionResponse, + MultiRowCompletionResponseSchema, + RowReferencesResponseSchema } from "@/resources/gen_tables/chat"; import { CreateKnowledgeTableRequest, CreateKnowledgeTableRequestSchema, UploadFileRequest } from "@/resources/gen_tables/knowledge"; import { AddColumnRequest, AddColumnRequestSchema, AddRowRequest, - DeleteRowRequest, + AddRowRequestSchema, DeleteRowsRequest, + DeleteRowsRequestSchema, DeleteTableRequest, + DeleteTableRequestSchema, DropColumnsRequest, + DropColumnsRequestSchema, DuplicateTableRequest, DuplicateTableRequestSchema, ExportTableRequest, @@ -54,33 +57,50 @@ import { PageListTableRowsResponse, PageListTableRowsResponseSchema, RegenRowRequest, + RegenRowRequestSchema, RenameColumnsRequest, + RenameColumnsRequestSchema, RenameTableRequest, + RenameTableRequestSchema, ReorderColumnsRequest, + ReorderColumnsRequestSchema, TableMetaRequest, + TableMetaRequestSchema, TableMetaResponse, TableMetaResponseSchema, UpdateGenConfigRequest, UpdateGenConfigRequestSchema, - UpdateRowRequest + UpdateRowRequest, + UpdateRowRequestSchema } from "@/resources/gen_tables/tables"; import { ChunkError } from "@/resources/shared/error"; import axios, { AxiosResponse } from "axios"; -import { Blob, FormData } from "formdata-node"; +// import { Blob, FormData } from "formdata-node"; + +async function createFormData() { + if (!isRunningInBrowser()) { + // Node environment + // (import from `formdata-node`) + const { FormData } = await import("formdata-node"); + + return new FormData(); + } else { + // Browser environment + return new FormData(); + } +} export class GenTable extends Base { // Helper method to handle stream responses - private handleGenTableStreamResponse( - response: AxiosResponse - ): ReadableStream { + private handleGenTableStreamResponse(response: AxiosResponse): ReadableStream { this.logWarning(response); if (response.status != 200) { throw new Error(`Received Error Status: ${response.status}`); } - return new ReadableStream({ - async start(controller: ReadableStreamDefaultController) { + return new ReadableStream({ + async start(controller: ReadableStreamDefaultController) { response.data.on("data", (data: any) => { data = data.toString(); if (data.endsWith("\n\n")) { @@ -100,9 +120,9 @@ export class GenTable extends Base { try { const parsedValue = JSON.parse(chunk); if (parsedValue["object"] === "gen_table.completion.chunk") { - controller.enqueue(GenTableStreamChatCompletionChunkSchema.parse(parsedValue)); + controller.enqueue(ColumnCompletionResponseSchema.parse(parsedValue)); } else if (parsedValue["object"] === "gen_table.references") { - controller.enqueue(GenTableStreamReferencesSchema.parse(parsedValue)); + controller.enqueue(RowReferencesResponseSchema.parse(parsedValue)); } else { throw new ChunkError(`Unexpected SSE Chunk: ${parsedValue}`); } @@ -125,9 +145,9 @@ export class GenTable extends Base { try { const parsedValue = JSON.parse(chunk); if (parsedValue["object"] === "gen_table.completion.chunk") { - controller.enqueue(GenTableStreamChatCompletionChunkSchema.parse(parsedValue)); + controller.enqueue(ColumnCompletionResponseSchema.parse(parsedValue)); } else if (parsedValue["object"] === "gen_table.references") { - controller.enqueue(GenTableStreamReferencesSchema.parse(parsedValue)); + controller.enqueue(RowReferencesResponseSchema.parse(parsedValue)); } else { throw new ChunkError(`Unexpected SSE Chunk: ${parsedValue}`); } @@ -154,39 +174,29 @@ export class GenTable extends Base { public async listTables(params: ListTableRequest): Promise { const parsedParams = ListTableRequestSchema.parse(params); - let getURL = `/api/v1/gen_tables/${params.table_type}`; - - delete (parsedParams as any).table_type; + let getURL = `/api/v2/gen_tables/${params.table_type}/list`; const response = await this.httpClient.get(getURL, { - params: { - ...parsedParams, - search_query: encodeURIComponent(parsedParams.search_query) - } + params: parsedParams }); return this.handleResponse(response, PageListTableMetaResponseSchema); } public async getTable(params: TableMetaRequest): Promise { - let getURL = `/api/v1/gen_tables/${params.table_type}/${params.table_id}`; + const parsedParams = TableMetaRequestSchema.parse(params); + let getURL = `/api/v2/gen_tables/${params.table_type}`; - const response = await this.httpClient.get(getURL); + const response = await this.httpClient.get(getURL, { + params: parsedParams + }); return this.handleResponse(response, TableMetaResponseSchema); } public async listRows(params: ListTableRowsRequest): Promise { const parsedParams = ListTableRowsRequestSchema.parse(params); - const response = await this.httpClient.get(`/api/v1/gen_tables/${parsedParams.table_type}/${parsedParams.table_id}/rows`, { - params: { - offset: parsedParams.offset, - limit: parsedParams.limit, - search_query: encodeURIComponent(parsedParams.search_query), - columns: parsedParams.columns ? parsedParams.columns?.map(encodeURIComponent) : [], - float_decimals: parsedParams.float_decimals, - vec_decimals: parsedParams.vec_decimals, - order_descending: parsedParams.order_descending - }, + const response = await this.httpClient.get(`/api/v2/gen_tables/${parsedParams.table_type}/rows/list`, { + params: parsedParams, paramsSerializer: (params) => { return Object.entries(params) .flatMap(([key, value]) => (Array.isArray(value) ? value.map((val) => `${key}=${val}`) : `${key}=${value}`)) @@ -199,12 +209,8 @@ export class GenTable extends Base { public async getRow(params: GetRowRequest): Promise { const parsedParams = GetRowRequestSchema.parse(params); - const response = await this.httpClient.get(`/api/v1/gen_tables/${params.table_type}/${params.table_id}/rows/${params.row_id}`, { - params: { - columns: parsedParams.columns ? parsedParams.columns?.map(encodeURIComponent) : [], - float_decimals: parsedParams.float_decimals, - vec_decimals: parsedParams.vec_decimals - }, + const response = await this.httpClient.get(`/api/v2/gen_tables/${params.table_type}/rows`, { + params: parsedParams, paramsSerializer: (params) => { return Object.entries(params) .flatMap(([key, value]) => (Array.isArray(value) ? value.map((val) => `${key}=${val}`) : `${key}=${value}`)) @@ -218,13 +224,9 @@ export class GenTable extends Base { public async getConversationThread(params: GetConversationThreadRequest): Promise { const parsedParams = GetConversationThreadRequestSchema.parse(params); - let getURL = `/api/v1/gen_tables/${parsedParams.table_type}/${parsedParams.table_id}/thread`; + let getURL = `/api/v2/gen_tables/${parsedParams.table_type}/thread`; const response = await this.httpClient.get(getURL, { - params: { - column_id: parsedParams.column_id, - row_id: parsedParams.row_id, - include: parsedParams.include - } + params: parsedParams }); return this.handleResponse(response, GetConversationThreadResponseSchema); @@ -235,21 +237,15 @@ export class GenTable extends Base { */ public async createActionTable(params: CreateActionTableRequest): Promise { const parsedParams = CreateActionTableRequestSchema.parse(params); - const apiURL = "/api/v1/gen_tables/action"; - const response = await this.httpClient.post( - apiURL, - { - ...parsedParams, - stream: false - }, - {} - ); + const apiURL = "/api/v2/gen_tables/action"; + const response = await this.httpClient.post(apiURL, parsedParams); + return this.handleResponse(response, TableMetaResponseSchema); } public async createChatTable(params: CreateChatTableRequest): Promise { const parsedParams = CreateChatTableRequestSchema.parse(params); - const apiURL = "/api/v1/gen_tables/chat"; + const apiURL = "/api/v2/gen_tables/chat"; const response = await this.httpClient.post(apiURL, parsedParams); return this.handleResponse(response, TableMetaResponseSchema); @@ -257,7 +253,7 @@ export class GenTable extends Base { public async createKnowledgeTable(params: CreateKnowledgeTableRequest): Promise { const parsedParams = CreateKnowledgeTableRequestSchema.parse(params); - const apiURL = "/api/v1/gen_tables/knowledge"; + const apiURL = "/api/v2/gen_tables/knowledge"; const response = await this.httpClient.post(apiURL, parsedParams); return this.handleResponse(response, TableMetaResponseSchema); @@ -267,33 +263,35 @@ export class GenTable extends Base { * Gen Table Delete */ public async deleteTable(params: DeleteTableRequest): Promise { - let deleteURL = `/api/v1/gen_tables/${params.table_type}/${params.table_id}`; - const response = await this.httpClient.delete(deleteURL); + const parsedParams = DeleteTableRequestSchema.parse(params); + let deleteURL = `/api/v2/gen_tables/${params.table_type}`; + const response = await this.httpClient.delete(deleteURL, { + params: parsedParams + }); return this.handleResponse(response, OkResponseSchema); } - public async deleteRow(params: DeleteRowRequest): Promise { - let deleteURL = `/api/v1/gen_tables/${params.table_type}/${params.table_id}/rows/${params.row_id}`; + // public async deleteRow(params: DeleteRowRequest): Promise { + // let deleteURL = `/api/v2/gen_tables/${params.table_type}/${params.table_id}/rows/${params.row_id}`; - const response = await this.httpClient.delete(deleteURL, { - params: { - reindex: params?.reindex - } - }); + // const response = await this.httpClient.delete(deleteURL, { + // params: { + // reindex: params?.reindex + // } + // }); - return this.handleResponse(response, OkResponseSchema); - } + // return this.handleResponse(response, OkResponseSchema); + // } /** * @param {string} [params.where] - Optional. SQL where clause. If not provided, will match all rows and thus deleting all table content. */ public async deleteRows(params: DeleteRowsRequest): Promise { - const apiURL = `/api/v1/gen_tables/${params.table_type}/rows/delete`; - const response = await this.httpClient.post(apiURL, { - table_id: params.table_id, - where: params.where // Optional. SQL where clause. If not provided, will match all rows and thus deleting all table content. - }); + const parsedParams = DeleteRowsRequestSchema.parse(params); + const apiURL = `/api/v2/gen_tables/${params.table_type}/rows/delete`; + const response = await this.httpClient.post(apiURL, parsedParams); + return this.handleResponse(response, OkResponseSchema); } @@ -301,8 +299,9 @@ export class GenTable extends Base { * Gen Table Update */ public async renameTable(params: RenameTableRequest): Promise { - let postURL = `/api/v1/gen_tables/${params.table_type}/rename/${params.table_id_src}/${params.table_id_dst}`; - const response = await this.httpClient.post(postURL, {}, {}); + const parsedParams = RenameTableRequestSchema.parse(params); + let postURL = `/api/v2/gen_tables/${params.table_type}/rename`; + const response = await this.httpClient.post(postURL, undefined, { params: parsedParams }); return this.handleResponse(response, TableMetaResponseSchema); } @@ -316,69 +315,41 @@ export class GenTable extends Base { } const parsedParams = DuplicateTableRequestSchema.parse(params); - - let postURL = `/api/v1/gen_tables/${params.table_type}/duplicate/${params.table_id_src}`; - const response = await this.httpClient.post( - postURL, - {}, - { - params: { - table_id_dst: parsedParams.table_id_dst, - include_data: parsedParams.include_data, - create_as_child: parsedParams.create_as_child - } - } - ); + let postURL = `/api/v2/gen_tables/${params.table_type}/duplicate`; + const response = await this.httpClient.post(postURL, undefined, { + params: parsedParams + }); return this.handleResponse(response, TableMetaResponseSchema); } public async renameColumns(params: RenameColumnsRequest): Promise { - let postURL = `/api/v1/gen_tables/${params.table_type}/columns/rename`; - const response = await this.httpClient.post( - postURL, - { - table_id: params.table_id, - column_map: params.column_map - }, - {} - ); + const parsedParams = RenameColumnsRequestSchema.parse(params); + let postURL = `/api/v2/gen_tables/${params.table_type}/columns/rename`; + const response = await this.httpClient.post(postURL, parsedParams); return this.handleResponse(response, TableMetaResponseSchema); } public async reorderColumns(params: ReorderColumnsRequest): Promise { - let postURL = `/api/v1/gen_tables/${params.table_type}/columns/reorder`; - const response = await this.httpClient.post( - postURL, - { - table_id: params.table_id, - column_names: params.column_names - }, - {} - ); + const parsedParams = ReorderColumnsRequestSchema.parse(params); + let postURL = `/api/v2/gen_tables/${params.table_type}/columns/reorder`; + const response = await this.httpClient.post(postURL, parsedParams); return this.handleResponse(response, TableMetaResponseSchema); } public async dropColumns(params: DropColumnsRequest): Promise { - let postURL = `/api/v1/gen_tables/${params.table_type}/columns/drop`; - const response = await this.httpClient.post( - postURL, - { - table_id: params.table_id, - column_names: params.column_names - }, - {} - ); + const parsedParams = DropColumnsRequestSchema.parse(params); + let postURL = `/api/v2/gen_tables/${params.table_type}/columns/drop`; + const response = await this.httpClient.post(postURL, parsedParams); return this.handleResponse(response, TableMetaResponseSchema); } public async addActionColumns(params: AddActionColumnRequest): Promise { const parsedParams = AddActionColumnRequestSchema.parse(params); - let postURL = `/api/v1/gen_tables/action/columns/add`; - + let postURL = `/api/v2/gen_tables/action/columns/add`; const response = await this.httpClient.post(postURL, parsedParams); return this.handleResponse(response, TableMetaResponseSchema); @@ -386,7 +357,7 @@ export class GenTable extends Base { public async addKnowledgeColumns(params: AddColumnRequest): Promise { const parsedParams = AddColumnRequestSchema.parse(params); - let postURL = `/api/v1/gen_tables/knowledge/columns/add`; + let postURL = `/api/v2/gen_tables/knowledge/columns/add`; const response = await this.httpClient.post(postURL, parsedParams); return this.handleResponse(response, TableMetaResponseSchema); @@ -394,7 +365,7 @@ export class GenTable extends Base { public async addChatColumns(params: AddColumnRequest): Promise { const parsedParams = AddColumnRequestSchema.parse(params); - let postURL = `/api/v1/gen_tables/chat/columns/add`; + let postURL = `/api/v2/gen_tables/chat/columns/add`; const response = await this.httpClient.post(postURL, parsedParams); return this.handleResponse(response, TableMetaResponseSchema); @@ -402,30 +373,20 @@ export class GenTable extends Base { public async updateGenConfig(params: UpdateGenConfigRequest): Promise { const parsedParams = UpdateGenConfigRequestSchema.parse(params); - let postURL = `/api/v1/gen_tables/${params.table_type}/gen_config/update`; - const response = await this.httpClient.post( - postURL, - { - table_id: parsedParams.table_id, - column_map: parsedParams.column_map - }, - {} - ); + let postURL = `/api/v2/gen_tables/${params.table_type}/gen_config`; + const response = await this.httpClient.patch(postURL, parsedParams); return this.handleResponse(response, TableMetaResponseSchema); } - public async addRowStream(params: AddRowRequest): Promise> { - const apiURL = `/api/v1/gen_tables/${params.table_type}/rows/add`; - + public async addRowStream(params: AddRowRequest): Promise> { + const parsedParams = AddRowRequestSchema.parse(params); + const apiURL = `/api/v2/gen_tables/${params.table_type}/rows/add`; const response = await this.httpClient.post( apiURL, { - table_id: params.table_id, - data: params.data, - stream: true, - reindex: params.reindex, - concurrent: params.concurrent + ...parsedParams, + stream: true }, { responseType: "stream" @@ -435,35 +396,29 @@ export class GenTable extends Base { return this.handleGenTableStreamResponse(response); } - public async addRow(params: AddRowRequest): Promise { - const url = `/api/v1/gen_tables/${params.table_type}/rows/add`; - + public async addRow(params: AddRowRequest): Promise { + const parsedParams = AddRowRequestSchema.parse(params); + const url = `/api/v2/gen_tables/${params.table_type}/rows/add`; const response = await this.httpClient.post( url, { - table_id: params.table_id, - stream: false, - data: params.data, - reindex: params.reindex, - concurrent: params.concurrent + ...parsedParams, + stream: false }, {} ); - return this.handleResponse(response, GenTableRowsChatCompletionChunksSchema); + return this.handleResponse(response, MultiRowCompletionResponseSchema); } public async regenRowStream(params: RegenRowRequest) { - const apiURL = `/api/v1/gen_tables/${params.table_type}/rows/regen`; - + const parsedParams = RegenRowRequestSchema.parse(params); + const apiURL = `/api/v2/gen_tables/${params.table_type}/rows/regen`; const response = await this.httpClient.post( apiURL, { - table_id: params.table_id, - row_ids: params.row_ids, - stream: true, - reindex: params.reindex, - concurrent: params.concurrent + ...parsedParams, + stream: true }, { responseType: "stream" @@ -473,36 +428,46 @@ export class GenTable extends Base { return this.handleGenTableStreamResponse(response); } - public async regenRow(params: RegenRowRequest): Promise { - const apiURL = `/api/v1/gen_tables/${params.table_type}/rows/regen`; + public async regenRow(params: RegenRowRequest): Promise { + const parsedParams = RegenRowRequestSchema.parse(params); + const apiURL = `/api/v2/gen_tables/${params.table_type}/rows/regen`; const response = await this.httpClient.post( apiURL, { - table_id: params.table_id, - row_ids: params.row_ids, - stream: false, - reindex: params.reindex, - concurrent: params.concurrent + ...parsedParams, + stream: false }, {} ); - return this.handleResponse(response, GenTableRowsChatCompletionChunksSchema); + + return this.handleResponse(response, MultiRowCompletionResponseSchema); } - public async updateRow(params: UpdateRowRequest): Promise { - const apiURL = `/api/v1/gen_tables/${params.table_type}/rows/update`; - const response = await this.httpClient.post(apiURL, { + /** + * @deprecated Deprecated since 0.4.0, use updateRows instead + */ + public async updateRow(params: UpdateRowRequest & { row_id: string }): Promise { + const apiURL = `/api/v2/gen_tables/${params.table_type}/rows`; + const response = await this.httpClient.patch(apiURL, { table_id: params.table_id, - row_id: params.row_id, - data: params.data, - reindex: params.reindex + data: { + [params.row_id]: params.data + } }); return this.handleResponse(response, OkResponseSchema); } + public async updateRows(params: UpdateRowRequest): Promise { + const parsedParams = UpdateRowRequestSchema.parse(params); + const apiURL = `/api/v2/gen_tables/${params.table_type}/rows`; + const response = await this.httpClient.patch(apiURL, parsedParams); + + return this.handleResponse(response, OkResponseSchema); + } + public async hybridSearch(params: HybridSearchRequest): Promise { - const apiURL = `/api/v1/gen_tables/${params.table_type}/hybrid_search`; + const apiURL = `/api/v2/gen_tables/${params.table_type}/hybrid_search`; const { table_type, ...requestBody } = params; @@ -520,7 +485,7 @@ export class GenTable extends Base { const apiURL = `/api/v1/gen_tables/knowledge/upload_file`; // Create FormData to send as multipart/form-data - const formData = new FormData(); + const formData = await createFormData(); if (params.file) { formData.append("file", params.file, params.file.name); } else if (params.file_path) { @@ -528,7 +493,10 @@ export class GenTable extends Base { const mimeType = await getMimeType(params.file_path!); const fileName = await getFileName(params.file_path!); const data = await readFile(params.file_path!); - const file = new Blob([data], { type: mimeType }); + const { File } = await import("formdata-node"); + const file = new File([data], fileName, { type: mimeType }); + + // @ts-ignore formData.append("file", file, fileName); } else { throw new Error("Pass File instead of file path if you are using this function in client."); @@ -556,10 +524,10 @@ export class GenTable extends Base { } public async embedFile(params: UploadFileRequest): Promise { - const apiURL = `/api/v1/gen_tables/knowledge/embed_file`; + const apiURL = `/api/v2/gen_tables/knowledge/embed_file`; // Create FormData to send as multipart/form-data - const formData = new FormData(); + const formData = await createFormData(); if (params.file) { formData.append("file", params.file, params.file.name); } else if (params.file_path) { @@ -567,7 +535,10 @@ export class GenTable extends Base { const mimeType = await getMimeType(params.file_path!); const fileName = await getFileName(params.file_path!); const data = await readFile(params.file_path!); - const file = new Blob([data], { type: mimeType }); + const { File } = await import("formdata-node"); + const file = new File([data], fileName, { type: mimeType }); + + // @ts-ignore formData.append("file", file, fileName); } else { throw new Error("Pass File instead of file path if you are using this method in client."); @@ -594,12 +565,10 @@ export class GenTable extends Base { return this.handleResponse(response, OkResponseSchema); } - public async importTableData(params: ImportTableRequest): Promise { - const apiURL = `/api/v1/gen_tables/${params.table_type}/import_data`; + public async importTableData(params: ImportTableRequest): Promise { + const apiURL = `/api/v2/gen_tables/${params.table_type}/import_data`; - const delimiter = params.delimiter ? params.delimiter : ","; - - const formData = new FormData(); + const formData = await createFormData(); if (params.file) { formData.append("file", params.file, params.file.name); } else if (params.file_path) { @@ -607,7 +576,10 @@ export class GenTable extends Base { const mimeType = await getMimeType(params.file_path!); const fileName = await getFileName(params.file_path!); const data = await readFile(params.file_path!); - const file = new Blob([data], { type: mimeType }); + // const file = new Blob([data], { type: mimeType }); + const { File } = await import("formdata-node"); + const file = new File([data], fileName, { type: mimeType }); + // @ts-ignore formData.append("file", file, fileName); } else { throw new Error("Pass File instead of file path if you are using this function in client."); @@ -617,7 +589,7 @@ export class GenTable extends Base { } formData.append("table_id", params.table_id); - formData.append("delimiter", delimiter); + if (params.delimiter) formData.append("delimiter", params.delimiter); formData.append("stream", JSON.stringify(false)); const response = await this.httpClient.post(apiURL, formData, { @@ -626,17 +598,15 @@ export class GenTable extends Base { } }); - return this.handleResponse(response, GenTableRowsChatCompletionChunksSchema); + return this.handleResponse(response, MultiRowCompletionResponseSchema); } - public async importTableDataStream( - params: ImportTableRequest - ): Promise> { - const apiURL = `/api/v1/gen_tables/${params.table_type}/import_data`; + public async importTableDataStream(params: ImportTableRequest): Promise> { + const apiURL = `/api/v2/gen_tables/${params.table_type}/import_data`; // const fileName = params.file.name; const delimiter = params.delimiter ? params.delimiter : ","; - const formData = new FormData(); + const formData = await createFormData(); if (params.file) { formData.append("file", params.file, params.file.name); } else if (params.file_path) { @@ -644,7 +614,10 @@ export class GenTable extends Base { const mimeType = await getMimeType(params.file_path!); const fileName = await getFileName(params.file_path!); const data = await readFile(params.file_path!); - const file = new Blob([data], { type: mimeType }); + + const { File } = await import("formdata-node"); + const file = new File([data], fileName, { type: mimeType }); + // @ts-ignore formData.append("file", file, fileName); } else { throw new Error("Pass File instead of file path if you are using this function in client."); @@ -670,13 +643,10 @@ export class GenTable extends Base { public async exportTableData(params: ExportTableRequest): Promise { const parsedParams = ExportTableRequestSchema.parse(params); - const apiURL = `/api/v1/gen_tables/${parsedParams.table_type}/${encodeURIComponent(parsedParams.table_id)}/export_data`; + const apiURL = `/api/v2/gen_tables/${parsedParams.table_type}/export_data`; try { const response = await this.httpClient.get(apiURL, { - params: { - delimiter: encodeURIComponent(parsedParams.delimiter), - columns: parsedParams.columns ? parsedParams.columns?.map(encodeURIComponent) : [] - }, + params: parsedParams, paramsSerializer: (params) => { return Object.entries(params) .flatMap(([key, value]) => (Array.isArray(value) ? value.map((val) => `${key}=${val}`) : `${key}=${value}`)) diff --git a/clients/typescript/src/resources/gen_tables/knowledge.ts b/clients/typescript/src/resources/gen_tables/knowledge.ts index 8627263..8c9a28a 100644 --- a/clients/typescript/src/resources/gen_tables/knowledge.ts +++ b/clients/typescript/src/resources/gen_tables/knowledge.ts @@ -7,12 +7,7 @@ export const CreateKnowledgeTableRequestSchema = TableSchemaCreateSchema.extend( export type CreateKnowledgeTableRequest = z.input; export const UploadFileRequestSchema = z.object({ - file: z - .any() - .refine((value) => value instanceof File, { - message: "Value must be a File object" - }) - .optional(), + file: z.any().optional(), file_path: z.string().optional(), table_id: IdSchema, chunk_size: z.number().gt(0).optional(), diff --git a/clients/typescript/src/resources/gen_tables/tables.ts b/clients/typescript/src/resources/gen_tables/tables.ts index 5cb2de2..38e047d 100644 --- a/clients/typescript/src/resources/gen_tables/tables.ts +++ b/clients/typescript/src/resources/gen_tables/tables.ts @@ -4,6 +4,7 @@ import { z } from "zod"; export const GenTableOrderBy = Object.freeze({ ID: "id", // Sort by `id` column + TABLE_ID: "table_id", // Sort by `table_id` column UPDATED_AT: "updated_at" // Sort by `updated_at` column }); @@ -17,9 +18,9 @@ export const TableTypesSchema = z.enum(["action", "knowledge", "chat"]); export const IdSchema = z.string().regex(/^[A-Za-z0-9]([A-Za-z0-9 _-]{0,98}[A-Za-z0-9])?$/, "Invalid Id"); export const TableIdSchema = z.string().regex(/^[A-Za-z0-9]([A-Za-z0-9._-]{0,98}[A-Za-z0-9])?$/, "Invalid Table Id"); -const DtypeCreateEnumSchema = z.enum(["int", "float", "str", "bool", "image"]); +const DtypeCreateEnumSchema = z.enum(["int", "float", "bool", "str", "image", "audio", "document"]); -const DtypeEnumSchema = z.enum(["int", "int8", "float", "float64", "float32", "float16", "bool", "str", "date-time", "image", "bytes"]); +const DtypeEnumSchema = z.enum(["int", "int8", "float", "float32", "float16", "bool", "str", "image", "audio", "document", "date-time", "json"]); export const EmbedGenConfigSchema = z.object({ object: z.literal("gen_config.embed").default("gen_config.embed"), @@ -42,12 +43,17 @@ export const LLMGenConfigSchema = z.object({ logit_bias: z.record(z.string(), z.any()).default({}) }); +export const CodeGenConfigSchema = z.object({ + object: z.literal("gen_config.code").default("gen_config.code"), + source_column: z.string() +}); + export const ColumnSchemaSchema = z.object({ id: z.string(), dtype: DtypeEnumSchema.default("str"), vlen: z.number().int().gte(0).default(0), index: z.boolean().default(true), - gen_config: z.union([LLMGenConfigSchema, EmbedGenConfigSchema, z.null()]).optional() + gen_config: z.union([LLMGenConfigSchema, EmbedGenConfigSchema, CodeGenConfigSchema, z.null()]).optional() }); export const ColumnSchemaCreateSchema = ColumnSchemaSchema.extend({ @@ -61,12 +67,15 @@ export const TableSchemaCreateSchema = z.object({ }); export let ListTableRequestSchema = QueryRequestParams.extend({ - parent_id: z.union([z.string(), z.null()]).optional(), table_type: TableTypesSchema, - search_query: z.string().default(""), - order_by: z.string().optional().default(GenTableOrderBy.UPDATED_AT), - order_descending: z.boolean().optional().default(true), - count_rows: z.boolean().optional().default(false) + offset: z.number().min(0).optional(), + limit: z.number().min(1).max(100).optional(), + order_by: z.string().optional(), + order_ascending: z.boolean().optional().optional(), + parent_id: z.string().nullable().optional(), + search_query: z.string().optional(), + count_rows: z.boolean().optional(), + created_by: z.string().nullable().optional(), }); export const TableMetaResponseSchema = z.object({ @@ -74,12 +83,10 @@ export const TableMetaResponseSchema = z.object({ cols: z.array(ColumnSchemaSchema), parent_id: z.union([z.string(), z.null()]), title: z.string(), - lock_till: z.union([z.number(), z.null()]).optional(), + created_by: z.string(), updated_at: z.string(), - indexed_at_fts: z.union([z.string(), z.null()]), - indexed_at_vec: z.union([z.string(), z.null()]), - indexed_at_sca: z.union([z.string(), z.null()]), - num_rows: z.number().int() + num_rows: z.number().int(), + version: z.string() }); export const TableMetaRequestSchema = z.object({ @@ -104,8 +111,8 @@ export const GetRowRequestSchema = z.object({ table_id: TableIdSchema, row_id: z.string(), columns: z.array(IdSchema).nullable().optional(), - float_decimals: z.number().int().default(0), - vec_decimals: z.number().int().default(0) + float_decimals: z.number().int().optional(), + vec_decimals: z.number().int().optional() }); export const GetRowResponseSchema = z.record(z.string(), z.any()); @@ -131,9 +138,9 @@ export const RenameTableRequestSchema = z.object({ export const DuplicateTableRequestSchema = z.object({ table_type: TableTypesSchema, table_id_src: TableIdSchema, - table_id_dst: TableIdSchema.nullable().default(null), - include_data: z.boolean().optional().default(true), - create_as_child: z.boolean().optional().default(false) + table_id_dst: TableIdSchema.nullable().optional(), + include_data: z.boolean().optional(), + create_as_child: z.boolean().optional() }); export const CreateChildTableRequestSchema = z.object({ @@ -142,18 +149,20 @@ export const CreateChildTableRequestSchema = z.object({ table_id_dst: TableIdSchema }); -export const RenameColumnsRequestScheme = z.object({ +export const RenameColumnsRequestSchema = z.object({ table_type: TableTypesSchema, table_id: TableIdSchema, column_map: z.record(IdSchema, IdSchema) }); -export const ReorderColumnsRequestScheme = z.object({ +export const ReorderColumnsRequestSchema = z.object({ table_type: TableTypesSchema, table_id: TableIdSchema, column_names: z.array(IdSchema) }); +export const DropColumnsRequestSchema = ReorderColumnsRequestSchema; + export const AddColumnRequestSchema = z.object({ id: TableIdSchema, cols: z.array(ColumnSchemaCreateSchema) @@ -162,7 +171,7 @@ export const AddColumnRequestSchema = z.object({ export const UpdateGenConfigRequestSchema = z.object({ table_type: TableTypesSchema, table_id: TableIdSchema, - column_map: z.record(z.string(), z.union([LLMGenConfigSchema, EmbedGenConfigSchema, z.null()])) + column_map: z.record(z.string(), z.union([LLMGenConfigSchema, EmbedGenConfigSchema, CodeGenConfigSchema, z.null()])) }); export const DeleteRowRequestSchema = z.object({ @@ -174,7 +183,6 @@ export const DeleteRowRequestSchema = z.object({ export const AddRowRequestSchema = z.object({ table_type: TableTypesSchema, - reindex: z.boolean().nullable().default(true), table_id: TableIdSchema, data: z.array(z.record(IdSchema, z.any())), concurrent: z.boolean().default(true) @@ -185,35 +193,34 @@ export const RegenRowRequestSchema = z.object({ table_type: TableTypesSchema, table_id: TableIdSchema, row_ids: z.array(z.string()), - reindex: z.boolean().nullable().default(null), - concurrent: z.boolean().default(true) + regen_strategy: z.string().nullable().optional(), + output_column_id: z.string().nullable().optional(), + concurrent: z.boolean().optional() // stream: z.boolean() }); export const UpdateRowRequestSchema = z.object({ table_type: TableTypesSchema, table_id: TableIdSchema, - row_id: z.string(), - data: z.record(IdSchema, z.any()), - reindex: z.boolean().nullable().default(null) + data: z.record(z.string(), z.record(IdSchema, z.any())), }); export const DeleteRowsRequestSchema = z.object({ table_type: TableTypesSchema, table_id: TableIdSchema, + row_ids: z.array(z.string()).nullable().optional(), where: z.string().optional(), - reindex: z.boolean().default(true) }); export const HybridSearchRequestSchema = z.object({ table_type: TableTypesSchema, table_id: TableIdSchema, query: z.string(), - where: z.string().nullable().default(null).optional(), + // where: z.string().nullable().default(null).optional(), limit: z.number().gt(0).lte(1000).optional(), metric: z.string().optional(), - nprobes: z.number().gt(0).lte(1000).optional(), - refine_factor: z.number().gt(0).lte(1000).optional(), + // nprobes: z.number().gt(0).lte(1000).optional(), + // refine_factor: z.number().gt(0).lte(1000).optional(), reranking_model: z.string().nullable().default(null).optional(), float_decimals: z.number().int().default(0), vec_decimals: z.number().int().default(0) @@ -228,12 +235,7 @@ export const CreateTableRequestSchema = z.object({ export const ImportTableRequestSchema = z.object({ file_path: z.string().optional(), - file: z - .any() - .refine((value) => value instanceof File, { - message: "Value must be a File object" - }) - .optional(), + file: z.any().optional(), table_id: TableIdSchema, table_type: TableTypesSchema, delimiter: z.string().default(",").optional() @@ -264,8 +266,8 @@ export type DeleteTableRequest = z.input; export type RenameTableRequest = z.input; export type DuplicateTableRequest = z.input; export type CreateChildTableRequest = z.input; -export type RenameColumnsRequest = z.infer; -export type ReorderColumnsRequest = z.infer; +export type RenameColumnsRequest = z.infer; +export type ReorderColumnsRequest = z.infer; export type DropColumnsRequest = ReorderColumnsRequest; export type AddColumnRequest = z.input; export type UpdateGenConfigRequest = z.input; diff --git a/clients/typescript/src/resources/llm/model.ts b/clients/typescript/src/resources/llm/model.ts index 36c6f47..0f862c0 100644 --- a/clients/typescript/src/resources/llm/model.ts +++ b/clients/typescript/src/resources/llm/model.ts @@ -3,7 +3,7 @@ import { z } from "zod"; export const ModelInfoRequestSchema = z.object({ model: z.string().optional(), capabilities: z - .array(z.enum(["completion", "chat", "image", "audio", "tool", "embed", "rerank"])) + .array(z.enum(["completion", "chat", "image", "audio", "document", "tool", "embed", "rerank"])) .nullable() .optional() }); @@ -14,7 +14,7 @@ export const ModelInfoSchema = z.object({ name: z.string(), context_length: z.number().default(16384), languages: z.array(z.string()), - capabilities: z.array(z.enum(["completion", "chat", "image", "audio", "tool", "embed", "rerank"])).default(["chat"]), + capabilities: z.array(z.enum(["completion", "chat", "image", "audio", "document", "tool", "embed", "rerank"])).default(["chat"]), owned_by: z.string() }); @@ -26,7 +26,7 @@ export const ModelInfoResponseSchema = z.object({ export const ModelNamesRequestSchema = z.object({ prefer: z.string().optional(), capabilities: z - .array(z.enum(["completion", "chat", "image", "audio", "tool", "embed", "rerank"])) + .array(z.enum(["completion", "chat", "image", "audio", "document", "tool", "embed", "rerank"])) .nullable() .optional() }); diff --git a/clients/typescript/src/resources/templates/index.ts b/clients/typescript/src/resources/templates/index.ts index 8c56a36..69292d2 100644 --- a/clients/typescript/src/resources/templates/index.ts +++ b/clients/typescript/src/resources/templates/index.ts @@ -26,12 +26,10 @@ export class Templates extends Base { public async listTemplates(params: IListTemplatesRequest = {}): Promise { const parsedParams = ListTemplatesRequestSchema.parse(params); - let getURL = `/api/public/v1/templates`; + let getURL = `/api/v2/templates/list`; const response = await this.httpClient.get(getURL, { - params: { - search_query: encodeURIComponent(parsedParams.search_query) - } + params: parsedParams }); return this.handleResponse(response, ListTemplatesResponseSchema); @@ -39,38 +37,40 @@ export class Templates extends Base { public async getTemplate(params: IGetTemplateRequest): Promise { const parsedParams = GetTemplateRequestSchema.parse(params); - let getURL = `/api/public/v1/templates/${parsedParams.template_id}`; + let getURL = `/api/v2/templates`; - const response = await this.httpClient.get(getURL); + const response = await this.httpClient.get(getURL, { + params: parsedParams + }); return this.handleResponse(response, GetTemplateResponseSchema); } public async listTables(params: IListTablesRequest): Promise { const parsedParams = ListTablesRequestSchema.parse(params); - let getURL = `/api/public/v1/templates/${parsedParams.template_id}/gen_tables/${parsedParams.table_type}`; + let getURL = `/api/v2/templates/gen_tables/${parsedParams.table_type}/list`; - const response = await this.httpClient.get(getURL); + const response = await this.httpClient.get(getURL, { + params: parsedParams + }); return this.handleResponse(response, ListTablesResponseSchema); } public async getTable(params: IGetTableRequest): Promise { const parsedParams = GetTableRequestSchema.parse(params); - let getURL = `/api/public/v1/templates/${parsedParams.template_id}/gen_tables/${parsedParams.table_type}/${parsedParams.table_id}`; + let getURL = `/api/v2/templates/gen_tables/${parsedParams.table_type}`; - const response = await this.httpClient.get(getURL); + const response = await this.httpClient.get(getURL, { + params: parsedParams + }); return this.handleResponse(response, GetTableResponseSchema); } public async listTableRows(params: IListTableRowsRequest): Promise { const parsedParams = ListTableRowsRequestSchema.parse(params); - let getURL = `/api/public/v1/templates/${parsedParams.template_id}/gen_tables/${parsedParams.table_type}/${parsedParams.table_id}/rows`; - - delete (parsedParams as any).template_id; - delete (parsedParams as any).table_type; - delete (parsedParams as any).table_id; + let getURL = `/api/v2/templates/gen_tables/${parsedParams.table_type}/rows/list`; const response = await this.httpClient.get(getURL, { params: parsedParams diff --git a/clients/typescript/src/resources/templates/types.ts b/clients/typescript/src/resources/templates/types.ts index 3c64129..d185fcc 100644 --- a/clients/typescript/src/resources/templates/types.ts +++ b/clients/typescript/src/resources/templates/types.ts @@ -14,7 +14,11 @@ const TemplateSchema = z.object({ // List Templates export const ListTemplatesRequestSchema = z.object({ - search_query: z.string().default("") + offset: z.number().int().min(0).optional(), + limit: z.number().int().min(1).max(1000).optional(), + order_by: z.string().optional(), + order_ascending: z.boolean().optional(), + search_query: z.string().optional() }); export const ListTemplatesResponseSchema = createPaginationSchema(TemplateSchema); @@ -27,7 +31,14 @@ export const GetTemplateResponseSchema = TemplateSchema; // List Table export const ListTablesRequestSchema = z.object({ table_type: TableTypesSchema, - template_id: z.string() + template_id: z.string(), + offset: z.number().int().min(0).optional(), + limit: z.number().int().min(1).max(100).optional(), + order_by: z.string().optional(), + order_ascending: z.boolean().optional(), + parent_id: z.string().optional(), + search_query: z.string().optional(), + count_rows: z.boolean().optional() }); export const ListTablesResponseSchema = createPaginationSchema(TableMetaResponseSchema); @@ -47,8 +58,9 @@ export const ListTableRowsRequestSchema = z.object({ starting_after: z.string().nullable().optional(), offset: z.number().int().min(0).default(0), limit: z.number().int().min(1).max(100).default(100), - order_by: z.string().default("Updated at"), - order_descending: z.boolean().default(true), + order_by: z.string().default("updated_at"), + order_ascending: z.boolean().default(true), + parent_id: z.string().nullable().optional(), float_decimals: z.number().int().min(0).default(0), vec_decimals: z.number().int().min(0).default(0) }); diff --git a/clients/typescript/tsconfig.json b/clients/typescript/tsconfig.json index bd5c5e6..f583895 100644 --- a/clients/typescript/tsconfig.json +++ b/clients/typescript/tsconfig.json @@ -12,6 +12,7 @@ "compilerOptions": { "emitDeclarationOnly": true, "module": "ESNext", + "moduleResolution": "node", "removeComments": false, "declaration": true, "allowSyntheticDefaultImports": true, @@ -21,7 +22,6 @@ "baseUrl": "./", "incremental": false, "esModuleInterop": true, - "moduleResolution": "Node", "noUncheckedIndexedAccess": true, "paths": { "@/*": [ diff --git a/docker/Dockerfile.cnpg17 b/docker/Dockerfile.cnpg17 new file mode 100644 index 0000000..f5e9608 --- /dev/null +++ b/docker/Dockerfile.cnpg17 @@ -0,0 +1,76 @@ +# ======================== STAGE 1: BUILDER ======================== +ARG CNPG_IMAGE="ghcr.io/cloudnative-pg/postgresql:17.5-bookworm" +ARG PG_MAJOR=17 +FROM ${CNPG_IMAGE} AS builder + +# Switch to root to install build tools +USER root + +# Install build dependencies for compiling from source +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + build-essential pkg-config git \ + postgresql-server-dev-${PG_MAJOR} \ + curl libssl-dev libclang-dev clang \ + && rm -rf /var/lib/apt/lists/* + +# ---- Install Rust toolchain (as postgres user) ---- +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \ + export PATH="$HOME/.cargo/bin:$PATH" && \ + rustup update && \ + cargo install --locked cargo-pgrx --version 0.12.9 && \ + cargo pgrx init --pg${PG_MAJOR} $(which pg_config) + +# ---- Set build directory ---- +WORKDIR /tmp + +# ---- Build pgvector and pgvectorscale from source ---- +RUN git clone --branch v0.8.0 https://github.com/pgvector/pgvector.git && \ + cd pgvector && make OPTFLAGS="" && make install && cd .. && rm -rf pgvector + +# commit: 6af0ee1953ca3dab6dc45011a985cffb2aa865c1 (v0.8.0) ad of 2025-07-17 +RUN export PATH="$HOME/.cargo/bin:$PATH" && \ + git clone https://github.com/timescale/pgvectorscale.git && \ + cd pgvectorscale && git checkout 6af0ee1953ca3dab6dc45011a985cffb2aa865c1 && \ + cd pgvectorscale && cargo pgrx install --release && cd ../.. && rm -rf pgvectorscale + +# ======================== STAGE 2: FINAL IMAGE ======================== +# Start from a fresh, clean CNPG image. +FROM ${CNPG_IMAGE} + +USER root + +# ---- PART 1: Copy extensions BUILT FROM SOURCE in the builder stage ---- +COPY --from=builder --chown=postgres:postgres \ + /usr/lib/postgresql/${PG_MAJOR}/lib/*vectorscale* \ + /usr/lib/postgresql/${PG_MAJOR}/lib/ +COPY --from=builder --chown=postgres:postgres \ + /usr/share/postgresql/${PG_MAJOR}/extension/*vectorscale* \ + /usr/share/postgresql/${PG_MAJOR}/extension/ + +COPY --from=builder --chown=postgres:postgres \ + /usr/lib/postgresql/${PG_MAJOR}/lib/*vector* \ + /usr/lib/postgresql/${PG_MAJOR}/lib/ +COPY --from=builder --chown=postgres:postgres \ + /usr/share/postgresql/${PG_MAJOR}/extension/*vector* \ + /usr/share/postgresql/${PG_MAJOR}/extension/ + +# ---- PART 2: Install PRE-PACKAGED extensions with ZERO recommended packages ---- +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + wget lsb-release ca-certificates && \ + wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb && \ + # Using 'apt-get install' here is slightly better as it can handle dependencies of the .deb itself + apt-get install -y --no-install-recommends ./apache-arrow-apt-source-latest-*.deb && \ + wget https://packages.groonga.org/debian/groonga-apt-source-latest-$(lsb_release --codename --short).deb && \ + apt-get install -y --no-install-recommends ./groonga-apt-source-latest-*.deb && \ + apt-get update && \ + # apt search groonga && \ + # Install the final package, again with no recommended extras + apt-get install -y --no-install-recommends postgresql-${PG_MAJOR}-pgdg-pgroonga && \ + # Clean up build-only tools and cache + apt-get purge -y --auto-remove wget && \ + rm -rf /var/lib/apt/lists/* *.deb + +# Switch back to the default non-root user for security +USER postgres \ No newline at end of file diff --git a/docker/Dockerfile.docio b/docker/Dockerfile.docio deleted file mode 100644 index bb3ecba..0000000 --- a/docker/Dockerfile.docio +++ /dev/null @@ -1,12 +0,0 @@ -FROM docker.io/embeddedllminfo/jamaibase:ci - -RUN apt-get update && apt-get install -y --no-install-recommends \ - git \ - && apt-get clean && rm -rf /var/lib/apt/lists/* - -WORKDIR /app - -COPY --chown=$MAMBA_USER:$MAMBA_USER ./services/docio /app/services/docio -ARG MAMBA_DOCKERFILE_ACTIVATE=1 # (otherwise python will not be found) - -RUN cd /app/services/docio && python -m pip install --no-cache-dir --upgrade . diff --git a/docker/Dockerfile.frontend b/docker/Dockerfile.frontend index 4a3b06f..6853504 100644 --- a/docker/Dockerfile.frontend +++ b/docker/Dockerfile.frontend @@ -1,16 +1,17 @@ FROM node:20-alpine -ARG JAMAI_URL=http://owl:6969 -ARG PUBLIC_JAMAI_URL= +ARG PUBLIC_JAMAI_URL="" ARG PUBLIC_IS_SPA=false ARG CHECK_ORIGIN=false +ENV PUBLIC_ADMIN_ORGANIZATION_ID="0" WORKDIR /app COPY ./services/app . -RUN mv .env.example .env -RUN npm ci --force +# RUN mv .env.example .env +RUN rm -rf .env +RUN npm ci -RUN JAMAI_URL=${JAMAI_URL} PUBLIC_JAMAI_URL=${PUBLIC_JAMAI_URL} PUBLIC_IS_SPA=${PUBLIC_IS_SPA} CHECK_ORIGIN=${CHECK_ORIGIN} npx vite build +RUN PUBLIC_JAMAI_URL=${PUBLIC_JAMAI_URL} PUBLIC_IS_SPA=${PUBLIC_IS_SPA} CHECK_ORIGIN=${CHECK_ORIGIN} npx vite build RUN mv temp build RUN apk --no-cache add curl diff --git a/docker/Dockerfile.owl b/docker/Dockerfile.owl index a6ed08f..ed2dc2b 100644 --- a/docker/Dockerfile.owl +++ b/docker/Dockerfile.owl @@ -1,17 +1,36 @@ -FROM python:3.12 +FROM ghcr.io/embeddedllm/jamaibase/owl.base:latest -RUN pip install --no-cache-dir --upgrade setuptools -RUN apt-get update -qq && apt-get install ffmpeg libavcodec-extra -y +# Set initial working directory +WORKDIR /usr/src/app -WORKDIR /app +# Install owl requirements +COPY ./services/api/pyproject.toml ./api/ +COPY ./services/api/src/owl/version.py ./api/src/owl/version.py +RUN uv venv --python 3.12 && cd api && uv pip install --no-cache-dir --upgrade -e . -COPY ./clients/python /app/client -WORKDIR /app/client -RUN pip install --no-cache-dir --upgrade . +# Install Python client requirements +COPY ./clients/python/pyproject.toml ./client/ +COPY ./clients/python/src/jamaibase/version.py ./client/src/jamaibase/version.py +RUN cd client && uv pip install --no-cache-dir -e . -COPY ./services/api /app/api -WORKDIR /app/api +# Copy Python client source code +COPY ./clients/python/ ./client/ +RUN cd client && uv pip install --no-cache-dir -e . -RUN pip install --no-cache-dir --upgrade . +# Copy owl source code +COPY ./services/api/ ./api/ +RUN cd api && uv pip install --no-cache-dir -e . -CMD ["python", "-m", "owl.entrypoints.api"] +# Setup environment variables +ENV OWL_OPENTELEMETRY_HOST=otel-collector +ENV OWL_OPENTELEMETRY_PORT=4317 +ENV OTEL_PYTHON_FASTAPI_EXCLUDED_URLS=api/health +ENV OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST='X-.*' +ENV OTEL_EXPORTER_OTLP_METRICS_DEFAULT_HISTOGRAM_AGGREGATION=base2_exponential_bucket_histogram +ENV SKYPILOT_DISABLE_USAGE_COLLECTION=1 +# If wanted to skip instrumentation on some components, ex: redis +# ENV OTEL_PYTHON_DISABLED_INSTRUMENTATIONS=redis + +# Run the service +# OTEL_RESOURCE_ATTRIBUTES needs to be set at runtime +CMD uv run python -m owl.entrypoints.api diff --git a/docker/Dockerfile.owl.base b/docker/Dockerfile.owl.base new file mode 100644 index 0000000..bb84454 --- /dev/null +++ b/docker/Dockerfile.owl.base @@ -0,0 +1,12 @@ +FROM python:3.12 + +RUN apt-get update \ + && apt-get install -y \ + ffmpeg \ + git \ + libavcodec-extra \ + poppler-utils \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +COPY --from=ghcr.io/astral-sh/uv:0.5.8 /uv /uvx /bin/ diff --git a/docker/Dockerfile.pg17 b/docker/Dockerfile.pg17 new file mode 100644 index 0000000..e31f98e --- /dev/null +++ b/docker/Dockerfile.pg17 @@ -0,0 +1,63 @@ +# Use the official PostgreSQL 17 base image +FROM postgres:17.4 + +# Set environment variables +ENV DEBIAN_FRONTEND=noninteractive + +# Install prerequisites for pgvectorscale and PGroonga +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + git \ + build-essential \ + pkg-config \ + libssl-dev \ + libclang-dev \ + clang \ + lsb-release \ + wget \ + ca-certificates \ + jq \ + postgresql-server-dev-17 && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Install Rust and cargo-pgrx 0.12.5 +# compatible with the version of pgvectorscale +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \ + export PATH="$HOME/.cargo/bin:$PATH" && \ + rustup update && \ + cargo install --locked cargo-pgrx --version "0.12.5" && \ + cargo pgrx init --pg17 $(which pg_config) + +# Install pgvectorscale +# 6c01899405c19ab545c4e43881cc07f2cd5dd0d9 is the commit of pgvectorscale main branch as of 2025-03-04 (v0.6) +RUN export PATH="$HOME/.cargo/bin:$PATH" && \ + cd /tmp && \ + git clone https://github.com/timescale/pgvectorscale && \ + cd pgvectorscale && \ + git checkout 6c01899405c19ab545c4e43881cc07f2cd5dd0d9 && \ + cd pgvectorscale && \ + cargo pgrx install --release + +# Install PGroonga +RUN wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb && \ + apt-get install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb && \ + wget https://packages.groonga.org/debian/groonga-apt-source-latest-$(lsb_release --codename --short).deb && \ + apt-get install -y -V ./groonga-apt-source-latest-$(lsb_release --codename --short).deb && \ + apt-get update && \ + apt-get install -y -V postgresql-17-pgdg-pgroonga && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Install pgvector +# OPTFLAGS = "" to avoid optimization (default flag = -march=native) +RUN cd /tmp && \ + git clone --branch v0.8.0 https://github.com/pgvector/pgvector.git && \ + cd pgvector && \ + make clean && \ + make OPTFLAGS="" && \ + make install + +# Set the default command to run PostgreSQL +CMD ["postgres"] \ No newline at end of file diff --git a/docker/amd.yml b/docker/amd.yml index 81209af..2f8deaa 100644 --- a/docker/amd.yml +++ b/docker/amd.yml @@ -45,16 +45,3 @@ services: - video # Alternatively, you could use privileged mode (use with caution): # privileged: true - - docio: - cap_add: - - SYS_PTRACE - devices: - - /dev/kfd - - /dev/dri/renderD128 - security_opt: - - seccomp:unconfined - group_add: - - video - # Alternatively, you could use privileged mode (use with caution): - # privileged: true diff --git a/docker/build_owl_base_image.sh b/docker/build_owl_base_image.sh new file mode 100644 index 0000000..1bab7ac --- /dev/null +++ b/docker/build_owl_base_image.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +set -e + +# Get the current date in YYYYMMDD format +current_date=$(date +"%Y%m%d") + +docker build -t ghcr.io/embeddedllm/jamaibase/owl.base:latest -f docker/Dockerfile.owl.base . +docker image tag ghcr.io/embeddedllm/jamaibase/owl.base:latest ghcr.io/embeddedllm/jamaibase/owl.base:${current_date} +docker push ghcr.io/embeddedllm/jamaibase/owl.base:latest +docker push ghcr.io/embeddedllm/jamaibase/owl.base:${current_date} diff --git a/docker/ch_configs/clickhouse_config.xml b/docker/ch_configs/clickhouse_config.xml new file mode 100644 index 0000000..1f424bd --- /dev/null +++ b/docker/ch_configs/clickhouse_config.xml @@ -0,0 +1,35 @@ + + 0.0.0.0 + + + 9363 + + + /metrics + + expose_metrics + true + true + true + true + + + + /write + + remote_write + jamaibase_owl + jamaibase_owl_metrics
+
+
+ + /read + + remote_read + jamaibase_owl + jamaibase_owl_metrics
+
+
+
+
+
diff --git a/docker/ch_configs/clickhouse_user_config.xml b/docker/ch_configs/clickhouse_user_config.xml new file mode 100644 index 0000000..7620fbd --- /dev/null +++ b/docker/ch_configs/clickhouse_user_config.xml @@ -0,0 +1,8 @@ + + + + 1 + 1 + + + \ No newline at end of file diff --git a/docker/ch_configs/create_ch_prom_db.sh b/docker/ch_configs/create_ch_prom_db.sh new file mode 100644 index 0000000..cb80d20 --- /dev/null +++ b/docker/ch_configs/create_ch_prom_db.sh @@ -0,0 +1,182 @@ +#!/bin/bash + +# Table to record llm, embed, rerank usage and costs +clickhouse-client --query="CREATE TABLE IF NOT EXISTS jamaibase_owl.llm_usage +( + \`id\` UUID, + \`org_id\` String, + \`proj_id\` String, + \`user_id\` String, + \`timestamp\` DateTime64(6, 'UTC'), + \`model\` String, + \`input_token\` UInt32, + \`output_token\` UInt32, + \`cost\` Decimal128(12), + \`input_cost\` Decimal128(12), + \`output_cost\` Decimal128(12) +) +ENGINE=MergeTree +PARTITION BY toYYYYMM(timestamp) +ORDER BY (org_id, timestamp, model)" + +clickhouse-client --query="CREATE TABLE IF NOT EXISTS jamaibase_owl.embed_usage +( + \`id\` UUID, + \`org_id\` String, + \`proj_id\` String, + \`user_id\` String, + \`timestamp\` DateTime64(6, 'UTC'), + \`model\` String, + \`num_token\` UInt32, + \`cost\` Decimal128(12) +) +ENGINE=MergeTree +PARTITION BY toYYYYMM(timestamp) +ORDER BY (org_id, timestamp, model)" + +clickhouse-client --query="CREATE TABLE IF NOT EXISTS jamaibase_owl.rerank_usage +( + \`id\` UUID, + \`org_id\` String, + \`proj_id\` String, + \`user_id\` String, + \`timestamp\` DateTime64(6, 'UTC'), + \`model\` String, + \`num_search\` UInt32, + \`cost\` Decimal128(12) +) +ENGINE=MergeTree +PARTITION BY toYYYYMM(timestamp) +ORDER BY (org_id, timestamp, model)" + +# Table to record egress usage +clickhouse-client --query="CREATE TABLE IF NOT EXISTS jamaibase_owl.egress_usage +( + \`id\` UUID, + \`org_id\` String, + \`proj_id\` String, + \`user_id\` String, + \`timestamp\` DateTime64(6, 'UTC'), + \`amount_gib\` Decimal128(12), + \`cost\` Decimal128(12) +) +ENGINE=MergeTree +PARTITION BY toYYYYMM(timestamp) +ORDER BY (org_id, timestamp)" + +# Table to record file storage usage +clickhouse-client --query="CREATE TABLE IF NOT EXISTS jamaibase_owl.file_storage_usage +( + \`id\` UUID, + \`org_id\` String, + \`proj_id\` String, + \`user_id\` String, + \`timestamp\` DateTime64(6, 'UTC'), + \`amount_gib\` Decimal128(12), + \`cost\` Decimal128(12), + \`snapshot_gib\` Decimal128(12) +) +ENGINE=MergeTree +PARTITION BY toYYYYMM(timestamp) +ORDER BY (org_id, timestamp)" + +# Table to record db storage usage +clickhouse-client --query="CREATE TABLE IF NOT EXISTS jamaibase_owl.db_storage_usage +( + \`id\` UUID, + \`org_id\` String, + \`proj_id\` String, + \`user_id\` String, + \`timestamp\` DateTime64(6, 'UTC'), + \`amount_gib\` Decimal128(12), + \`cost\` Decimal128(12), + \`snapshot_gib\` Decimal128(12) +) +ENGINE=MergeTree +PARTITION BY toYYYYMM(timestamp) +ORDER BY (org_id, timestamp)" + +clickhouse-client --query="CREATE TABLE IF NOT EXISTS jamaibase_owl.owl_traces +( + \`Timestamp\` DateTime64(9) CODEC(Delta(8), ZSTD(1)), + \`TraceId\` String CODEC(ZSTD(1)), + \`SpanId\` String CODEC(ZSTD(1)), + \`ParentSpanId\` String CODEC(ZSTD(1)), + \`TraceState\` String CODEC(ZSTD(1)), + \`SpanName\` LowCardinality(String) CODEC(ZSTD(1)), + \`SpanKind\` LowCardinality(String) CODEC(ZSTD(1)), + \`ServiceName\` LowCardinality(String) CODEC(ZSTD(1)), + \`ResourceAttributes\` Map(LowCardinality(String), String) CODEC(ZSTD(1)), + \`ScopeName\` String CODEC(ZSTD(1)), + \`ScopeVersion\` String CODEC(ZSTD(1)), + \`SpanAttributes\` Map(LowCardinality(String), String) CODEC(ZSTD(1)), + \`Duration\` Int64 CODEC(ZSTD(1)), + \`StatusCode\` LowCardinality(String) CODEC(ZSTD(1)), + \`StatusMessage\` String CODEC(ZSTD(1)), + \`Events.Timestamp\` Array(DateTime64(9)) CODEC(ZSTD(1)), + \`Events.Name\` Array(LowCardinality(String)) CODEC(ZSTD(1)), + \`Events.Attributes\` Array(Map(LowCardinality(String), String)) CODEC(ZSTD(1)), + \`Links.TraceId\` Array(String) CODEC(ZSTD(1)), + \`Links.SpanId\` Array(String) CODEC(ZSTD(1)), + \`Links.TraceState\` Array(String) CODEC(ZSTD(1)), + \`Links.Attributes\` Array(Map(LowCardinality(String), String)) CODEC(ZSTD(1)), + INDEX idx_trace_id TraceId TYPE bloom_filter(0.001) GRANULARITY 1, + INDEX idx_res_attr_key mapKeys(ResourceAttributes) TYPE bloom_filter(0.01) GRANULARITY 1, + INDEX idx_res_attr_value mapValues(ResourceAttributes) TYPE bloom_filter(0.01) GRANULARITY 1, + INDEX idx_span_attr_key mapKeys(SpanAttributes) TYPE bloom_filter(0.01) GRANULARITY 1, + INDEX idx_span_attr_value mapValues(SpanAttributes) TYPE bloom_filter(0.01) GRANULARITY 1, + INDEX idx_duration Duration TYPE minmax GRANULARITY 1 +) +ENGINE = MergeTree +PARTITION BY toDate(Timestamp) +ORDER BY (ServiceName, SpanName, toUnixTimestamp(Timestamp), TraceId) +TTL toDateTime(Timestamp) + toIntervalDay(3) +SETTINGS index_granularity = 8192, ttl_only_drop_parts = 1" + + +clickhouse-client --query="CREATE TABLE IF NOT EXISTS jamaibase_owl.owl_traces_trace_id_ts +( + \`TraceId\` String CODEC(ZSTD(1)), + \`Start\` DateTime64(9) CODEC(Delta(8), ZSTD(1)), + \`End\` DateTime64(9) CODEC(Delta(8), ZSTD(1)), + INDEX idx_trace_id TraceId TYPE bloom_filter(0.01) GRANULARITY 1 +) +ENGINE = MergeTree +ORDER BY (TraceId, toUnixTimestamp(Start)) +TTL toDateTime(Start) + toIntervalDay(3)" + + +clickhouse-client --query="CREATE MATERIALIZED VIEW IF NOT EXISTS jamaibase_owl.owl_traces_trace_id_ts_mv TO jamaibase_owl.owl_traces_trace_id_ts +( + \`TraceId\` String, + \`Start\` DateTime64(9), + \`End\` DateTime64(9) +) AS +SELECT + TraceId, + min(Timestamp) AS Start, + max(Timestamp) AS End +FROM jamaibase_owl.owl_traces +WHERE TraceId != '' +GROUP BY TraceId" + +# Table using Json data type +# clickhouse-client --query="CREATE TABLE IF NOT EXISTS jamaibase_owl.owl_usage +# ( +# \`id\` UUID, +# \`org_id\` String, +# \`timestamp\` DateTime64(6, 'UTC'), +# \`data\` JSON() +# ) +# ENGINE=MergeTree ORDER BY (org_id, timestamp)" + +### --- Migrations --- ### + +clickhouse-client --query="ALTER TABLE jamaibase_owl.egress_usage RENAME COLUMN IF EXISTS amount_gb to amount_gib" +clickhouse-client --query="ALTER TABLE jamaibase_owl.llm_usage MODIFY COLUMN cost Decimal128(12)" +clickhouse-client --query="ALTER TABLE jamaibase_owl.llm_usage MODIFY COLUMN input_cost Decimal128(12)" +clickhouse-client --query="ALTER TABLE jamaibase_owl.llm_usage MODIFY COLUMN output_cost Decimal128(12)" +clickhouse-client --query="ALTER TABLE jamaibase_owl.embed_usage MODIFY COLUMN cost Decimal128(12)" +clickhouse-client --query="ALTER TABLE jamaibase_owl.rerank_usage MODIFY COLUMN cost Decimal128(12)" +clickhouse-client --query="ALTER TABLE jamaibase_owl.egress_usage MODIFY COLUMN cost Decimal128(12)" +clickhouse-client --query="ALTER TABLE jamaibase_owl.egress_usage MODIFY COLUMN amount_gib Decimal128(12)" \ No newline at end of file diff --git a/docker/compose.amd.yml b/docker/compose.amd.yml deleted file mode 100644 index a77af99..0000000 --- a/docker/compose.amd.yml +++ /dev/null @@ -1,4 +0,0 @@ -include: - - path: - - compose.cpu.yml - - amd.yml diff --git a/docker/compose.bake.hcl b/docker/compose.bake.hcl new file mode 100644 index 0000000..ab77323 --- /dev/null +++ b/docker/compose.bake.hcl @@ -0,0 +1,15 @@ +group "default" { + targets = ["owl", "jambu"] +} + +target "owl" { + dockerfile = "docker/Dockerfile.owl" + cache-from = ["type=azblob,name=owl-cache,account_url=AZURE_STORAGE_ACCOUNT_URL,secret_access_key=AZURE_STORAGE_ACCESS_KEY"] + cache-to = ["type=azblob,name=owl-cache,mode=max,account_url=AZURE_STORAGE_ACCOUNT_URL,secret_access_key=AZURE_STORAGE_ACCESS_KEY"] +} + +target "jambu" { + dockerfile = "docker/Dockerfile.frontend" + cache-from = ["type=azblob,name=jambu-cache,account_url=AZURE_STORAGE_ACCOUNT_URL,secret_access_key=AZURE_STORAGE_ACCESS_KEY"] + cache-to = ["type=azblob,name=jambu-cache,mode=max,account_url=AZURE_STORAGE_ACCOUNT_URL,secret_access_key=AZURE_STORAGE_ACCESS_KEY"] +} \ No newline at end of file diff --git a/docker/compose.base.yml b/docker/compose.base.yml new file mode 100644 index 0000000..1e303ef --- /dev/null +++ b/docker/compose.base.yml @@ -0,0 +1,327 @@ +services: + dragonfly: + image: ghcr.io/embeddedllm/dragonflydb/dragonfly:v1.27.0-ubuntu + ulimits: + memlock: -1 + healthcheck: + test: ["CMD-SHELL", "nc -z localhost 6379 || exit 1"] + interval: 10s + timeout: 2s + retries: 20 + start_period: 10s + # For better performance, consider `host` mode instead `port` to avoid docker NAT. + # `host` mode is NOT currently supported in Swarm Mode. + # https://docs.docker.com/compose/compose-file/compose-file-v3/#network_mode + # network_mode: "host" + # volumes: + # - ${PWD}/docker_data/dragonfly:/data + networks: + - jamai + + otel-collector: + image: otel/opentelemetry-collector-contrib:0.113.0 + command: ["--config=/etc/otelcol/config.yaml"] + volumes: + - ${PWD}/docker/otel_configs/otel-collector-config.yaml:/etc/otelcol/config.yaml + networks: + - jamai + + victoriametrics: + image: victoriametrics/victoria-metrics:v1.124.0 + command: + - "--selfScrapeInterval=15s" + - "--retentionPeriod=100y" + volumes: + - ${PWD}/docker_data/vm_data:/victoria-metrics-data + networks: + - jamai + healthcheck: + test: ["CMD", "wget", "--spider", "-q", "http://0.0.0.0:8428/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s + + vmauth: + image: victoriametrics/vmauth:v1.124.0 + command: + - "--auth.config=/etc/config.yml" + volumes: + - ${PWD}/docker/vmauth/config.yml:/etc/config.yml + networks: + - jamai + healthcheck: + test: ["CMD", "wget", "--spider", "-q", "http://0.0.0.0:8427/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s + + victorialogs: + image: victoriametrics/victoria-logs:v1.28.0 + command: + - "--retentionPeriod=100y" + volumes: + - ${PWD}/docker_data/vl_data:/victoria-logs-data + networks: + - jamai + healthcheck: + test: ["CMD", "wget", "--spider", "-q", "http://0.0.0.0:9428/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s + + vmagent: + image: victoriametrics/vmagent:v1.124.0 + depends_on: + victoriametrics: + condition: service_healthy + vmauth: + condition: service_healthy + victorialogs: + condition: service_healthy + volumes: + - ${PWD}/docker/vmagent/streamingAggregation.yaml:/etc/config/streamingAggregation.yaml + - ${PWD}/docker_data/vmagent:/tmp/vmagent + command: + # - "--promscrape.config=/etc/prometheus/prometheus.yml" + - "--remoteWrite.url=http://vmauth:8427/vm/api/v1/write" + - "--remoteWrite.basicAuth.username=${VMAUTH_USER:-owl}" + - "--remoteWrite.basicAuth.password=${VMAUTH_PASSWORD:-owl-vm}" + - "--remoteWrite.streamAggr.config=/etc/config/streamingAggregation.yaml" + - "--remoteWrite.streamAggr.enableWindows=true" + - "--remoteWrite.tmpDataPath=/tmp/vmagent" + # - "--promscrape.config.strictParse=false" + networks: + - jamai + + clickhouse: + image: clickhouse:24.10.2.80 + volumes: + - ${PWD}/docker_data/ch_data:/var/lib/clickhouse/ + - ${PWD}/docker_data/ch_logs:/var/log/clickhouse-server/ + - ${PWD}/docker/ch_configs/clickhouse_config.xml:/etc/clickhouse-server/config.d/custom_config.xml + - ${PWD}/docker/ch_configs/clickhouse_user_config.xml:/etc/clickhouse-server/users.d/custom_config.xml + - ${PWD}/docker/ch_configs/create_ch_prom_db.sh:/docker-entrypoint-initdb.d/create_ch_prom_db.sh + environment: + - CLICKHOUSE_ALWAYS_RUN_INITDB_SCRIPTS=1 + - CLICKHOUSE_USER=${CLICKHOUSE_USER:-owluser} + - CLICKHOUSE_DEFAULT_ACCESS_MANAGEMENT=1 + - CLICKHOUSE_PASSWORD=${CLICKHOUSE_PASSWORD:-owlpassword} + - CLICKHOUSE_DB=${CLICKHOUSE_DB:-jamaibase_owl} + ulimits: + nofile: + soft: 262144 + hard: 262144 + cap_add: + - SYS_NICE + - NET_ADMIN + - IPC_LOCK + networks: + - jamai + healthcheck: + test: ["CMD", "wget", "--spider", "--quiet", "http://localhost:8123/ping"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s + + postgresql: + image: ghcr.io/embeddedllm/jamaibase/postgres:20250305 + command: + - "-c" + - "max_connections=${PG_MAX_CONNECTIONS:-100}" + - "-c" + - "max_locks_per_transaction=512" + - "-c" + - "pgroonga.enable_wal_resource_manager=on" + - "-c" + - "shared_preload_libraries=pgroonga_wal_resource_manager,pg_stat_statements" + environment: + POSTGRES_USER: owlpguser + POSTGRES_PASSWORD: owlpgpassword + POSTGRES_DB: jamaibase_owl + PGUSER: owlpguser + PGPASSWORD: owlpgpassword + PGDATABASE: jamaibase_owl + volumes: + - ${PWD}/docker_data/postgres_db:/var/lib/postgresql/data + networks: + - jamai + healthcheck: + test: ["CMD", "pg_isready", "-U", "owlpguser", "-d", "jamaibase_owl"] + interval: 30s + timeout: 5s + retries: 3 + start_period: 30s + + pgbouncer: + image: edoburu/pgbouncer:v1.24.0-p0 + depends_on: + postgresql: + condition: service_healthy + ports: + - 5432:5432 + environment: + DB_USER: owlpguser + DB_PASSWORD: owlpgpassword + DB_HOST: postgresql + DB_PORT: 5432 + DB_NAME: jamaibase_owl + AUTH_TYPE: scram-sha-256 + POOL_MODE: transaction + ADMIN_USERS: owlpguser + MAX_CLIENT_CONN: ${PB_MAX_CLIENT_CONN:-100} + DEFAULT_POOL_SIZE: ${PB_MAX_CLIENT_CONN:-80} + SERVER_IDLE_TIMEOUT: 600 + QUERY_WAIT_TIMEOUT: 120 + SERVER_RESET_QUERY: DISCARD ALL + healthcheck: + test: ["CMD", "pg_isready", "-h", "localhost", "-U", "owlpguser", "-d", "jamaibase_owl"] + interval: 30s + timeout: 5s + retries: 3 + start_period: 5s + networks: + - jamai + + minio: + image: minio/minio:RELEASE.2025-05-24T17-08-30Z + entrypoint: /bin/sh -c " minio server /data --console-address ':9001' & until (mc config host add myminio http://localhost:9000 $${MINIO_ROOT_USER} $${MINIO_ROOT_PASSWORD}) do echo '...waiting...' && sleep 1; done; mc mb myminio/file; wait " + environment: + MINIO_ROOT_USER: minioadmin + MINIO_ROOT_PASSWORD: minioadmin + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 10s + timeout: 2s + retries: 20 + start_period: 10s + networks: + - jamai + volumes: + - ${PWD}/docker_data/minio:/data + + docling: + image: ghcr.io/embeddedllm/docling-serve:20250528 + healthcheck: + test: ["CMD-SHELL", "curl --fail http://localhost:5001/health || exit 1"] + interval: 10s + timeout: 2s + retries: 20 + start_period: 10s + restart: unless-stopped + networks: + - jamai + + owl: + build: + context: .. + dockerfile: docker/Dockerfile.owl + env_file: + - ../.env + entrypoint: + - /bin/bash + - -c + - uv run python -m owl.entrypoints.api + ports: + - "${API_PORT:-6969}:${OWL_PORT:-6969}" + networks: + - jamai + volumes: + - ${PWD}/docker_data/owl/db:/usr/src/app/db + - ${PWD}/docker_data/owl/logs:/usr/src/app/logs + - ${PWD}/docker_data/owl/file:/usr/src/app/file + depends_on: + dragonfly: + condition: service_healthy + otel-collector: # Ensure otel-collector is running before owl starts + condition: service_started + clickhouse: + condition: service_healthy + victorialogs: + condition: service_healthy + victoriametrics: + condition: service_healthy + vmagent: + condition: service_started + vmauth: + condition: service_healthy + pgbouncer: + condition: service_healthy + minio: + condition: service_healthy + docling: + condition: service_healthy + healthcheck: + test: ["CMD-SHELL", "curl --fail localhost:6969/api/health || exit 1"] + interval: 10s + timeout: 2s + retries: 20 + start_period: 10s + restart: unless-stopped + + starling: + extends: + service: owl + entrypoint: + - /bin/bash + - -c + - | + uv run --no-sync celery -A owl.entrypoints.starling worker --loglevel=info --max-memory-per-child 65536 --autoscale=4,2 --beat + command: !reset [] + depends_on: !override + owl: + condition: service_healthy + ports: !reset [] + healthcheck: !reset [] + + frontend: + build: + context: .. + dockerfile: docker/Dockerfile.frontend + args: + PUBLIC_JAMAI_URL: ${PUBLIC_JAMAI_URL} + PUBLIC_IS_SPA: ${PUBLIC_IS_SPA} + CHECK_ORIGIN: ${CHECK_ORIGIN} + command: ["node", "server"] + depends_on: + owl: + condition: service_healthy + healthcheck: + test: ["CMD-SHELL", "curl --fail localhost:4000 || exit 1"] + interval: 10s + timeout: 2s + retries: 20 + start_period: 10s + restart: unless-stopped + environment: + - NODE_ENV=production + - BODY_SIZE_LIMIT=Infinity + env_file: + - ../.env + ports: + - "${FRONTEND_PORT:-4000}:4000" + networks: + - jamai + + kopi: + image: hoipangg/v8-kopi:0.4 + healthcheck: + test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/health || exit 1"] + interval: 10s + timeout: 2s + retries: 20 + start_period: 10s + restart: unless-stopped + environment: + - PORT=3000 + - MAX_SIZE_BYTES=20971520 + networks: + - jamai + +networks: + jamai: + driver_opts: + com.docker.network.driver.mtu: 1442 diff --git a/docker/compose.ci.yml b/docker/compose.ci.yml new file mode 100644 index 0000000..be969b1 --- /dev/null +++ b/docker/compose.ci.yml @@ -0,0 +1,6 @@ +include: + - path: + - compose.base.yml + - override.ci.yml + - path: + - compose.test-llm.yml diff --git a/docker/compose.cpu.ollama.yml b/docker/compose.cpu.ollama.yml deleted file mode 100644 index 1b9749e..0000000 --- a/docker/compose.cpu.ollama.yml +++ /dev/null @@ -1,43 +0,0 @@ -include: - - path: - - compose.cpu.yml - - ollama.yml - -services: - ollama: - image: ollama/ollama - volumes: - - ${PWD}/ollama:/root/.ollama - ports: - - "11434:11434" - entrypoint: [ - "sh", - "-c", - "ollama serve & \ - sleep 1; \ - ATTEMPTS=0; \ - MAX_ATTEMPTS=5; \ - while [ $$ATTEMPTS -lt $$MAX_ATTEMPTS ]; do \ - ollama ps > /dev/null 2>&1; \ - if [ $$? -eq 0 ]; then \ - break; \ - fi; \ - sleep 3; \ - ATTEMPTS=$$((ATTEMPTS+1)); \ - done; \ - if [ $$ATTEMPTS -eq $$MAX_ATTEMPTS ]; then \ - echo 'ollama serve did not start in time'; \ - exit 1; \ - fi; \ - ollama pull qwen2.5:3b && ollama cp qwen2.5:3b Qwen/Qwen2.5-3B-Instruct; \ - tail -f /dev/null", - ] - restart: unless-stopped - healthcheck: - test: ["CMD", "sh", "-c", "ollama show Qwen/Qwen2.5-3B-Instruct || exit 1"] - interval: 20s - timeout: 2s - retries: 20 - start_period: 20s - networks: - - jamai diff --git a/docker/compose.cpu.yml b/docker/compose.cpu.yml deleted file mode 100644 index 6ff3f25..0000000 --- a/docker/compose.cpu.yml +++ /dev/null @@ -1,196 +0,0 @@ -services: - infinity: - image: michaelf34/infinity:0.0.70-cpu - container_name: jamai_infinity - command: ["v2", "--engine", "torch", "--port", "6909", "--model-warmup", "--model-id", "${EMBEDDING_MODEL}", "--model-id", "${RERANKER_MODEL}"] - healthcheck: - test: ["CMD-SHELL", "curl --fail http://localhost:6909/health"] - interval: 10s - timeout: 2s - retries: 20 - start_period: 10s - restart: unless-stopped - env_file: - - ../.env - volumes: - - ${PWD}/infinity_cache:/app/.cache - networks: - - jamai - - unstructuredio: - image: downloads.unstructured.io/unstructured-io/unstructured-api:latest - entrypoint: ["/usr/bin/env", "bash", "-c", "uvicorn prepline_general.api.app:app --log-config logger_config.yaml --port 6989 --host 0.0.0.0"] - healthcheck: - test: ["CMD-SHELL", "wget http://localhost:6989/healthcheck -O /dev/null || exit 1"] - interval: 10s - timeout: 2s - retries: 20 - start_period: 10s - restart: unless-stopped - networks: - - jamai - - docio: - build: - context: .. - dockerfile: docker/Dockerfile.docio - image: jamai/docio - pull_policy: build - command: ["python", "-m", "docio.entrypoints.api"] - healthcheck: - test: ["CMD-SHELL", "curl --fail http://localhost:6979/health || exit 1"] - interval: 10s - timeout: 2s - retries: 20 - start_period: 10s - restart: unless-stopped - env_file: - - ../.env - networks: - - jamai - - dragonfly: - image: "ghcr.io/embeddedllm/dragonfly" - ulimits: - memlock: -1 - healthcheck: - test: ["CMD-SHELL", "nc -z localhost 6379 || exit 1"] - interval: 10s - timeout: 2s - retries: 20 - start_period: 10s - # For better performance, consider `host` mode instead `port` to avoid docker NAT. - # `host` mode is NOT currently supported in Swarm Mode. - # https://docs.docker.com/compose/compose-file/compose-file-v3/#network_mode - # network_mode: "host" - # volumes: - # - ${PWD}/dragonflydata:/data - networks: - - jamai - - owl: - build: - context: .. - dockerfile: docker/Dockerfile.owl - image: jamai/owl - pull_policy: build - command: ["python", "-m", "owl.entrypoints.api"] - depends_on: - infinity: - condition: service_healthy - unstructuredio: - condition: service_healthy - docio: - condition: service_healthy - dragonfly: - condition: service_healthy - healthcheck: - test: ["CMD-SHELL", "curl --fail localhost:6969/api/health || exit 1"] - interval: 10s - timeout: 2s - retries: 20 - start_period: 10s - restart: unless-stopped - env_file: - - ../.env - volumes: - - ${PWD}/db:/app/api/db - - ${PWD}/logs:/app/api/logs - - ${PWD}/file:/app/api/file - ports: - - "${API_PORT:-6969}:6969" - networks: - - jamai - - starling: - extends: - service: owl - entrypoint: - - /bin/bash - - -c - - | - celery -A owl.entrypoints.starling worker --loglevel=info --max-memory-per-child 65536 --autoscale=2,4 & \ - celery -A owl.entrypoints.starling beat --loglevel=info & \ - FLOWER_UNAUTHENTICATED_API=1 celery -A owl.entrypoints.starling flower --loglevel=info - command: !reset [] - depends_on: - owl: - condition: service_healthy - healthcheck: - test: ["CMD-SHELL", "curl --fail http://localhost:5555/api/workers || exit 1"] - interval: 10s - timeout: 2s - retries: 20 - start_period: 10s - ports: !override - - "${STARLING_PORT:-5555}:5555" - - frontend: - build: - context: .. - dockerfile: docker/Dockerfile.frontend - args: - JAMAI_URL: ${JAMAI_URL} - PUBLIC_JAMAI_URL: ${PUBLIC_JAMAI_URL} - PUBLIC_IS_SPA: ${PUBLIC_IS_SPA} - CHECK_ORIGIN: ${CHECK_ORIGIN} - image: jamai/frontend - pull_policy: build - command: ["node", "server"] - depends_on: - owl: - condition: service_healthy - healthcheck: - test: ["CMD-SHELL", "curl --fail localhost:4000 || exit 1"] - interval: 10s - timeout: 2s - retries: 20 - start_period: 10s - restart: unless-stopped - environment: - - NODE_ENV=production - - BODY_SIZE_LIMIT=Infinity - env_file: - - ../.env - ports: - - "${FRONTEND_PORT:-4000}:4000" - networks: - - jamai - - # By default, minio service is not enabled, and only used for testing. use --profile minio along docker compose up if minio is needed. - minio: - profiles: ["minio"] - image: minio/minio - entrypoint: /bin/sh -c " minio server /data --console-address ':9001' & until (mc config host add myminio http://localhost:9000 $${MINIO_ROOT_USER} $${MINIO_ROOT_PASSWORD}) do echo '...waiting...' && sleep 1; done; mc mb myminio/file; wait " - environment: - MINIO_ROOT_USER: minioadmin - MINIO_ROOT_PASSWORD: minioadmin - healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] - interval: 10s - timeout: 2s - retries: 20 - start_period: 10s - ports: - - "9000:9000" - - "9001:9001" - networks: - - jamai - - # By default, kopi service is not enabled, and only used for testing. use --profile kopi along docker compose up if kopi is needed. - kopi: - profiles: ["kopi"] - image: hoipangg/kopi - healthcheck: - test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:5569/health')"] - interval: 10s - timeout: 2s - retries: 20 - start_period: 10s - ports: - - "5569:5569" - networks: - - jamai - -networks: - jamai: diff --git a/docker/compose.dev.yml b/docker/compose.dev.yml new file mode 100644 index 0000000..8677308 --- /dev/null +++ b/docker/compose.dev.yml @@ -0,0 +1,6 @@ +include: + - path: + - compose.base.yml + - override.dev.yml + - path: + - compose.test-llm.yml diff --git a/docker/compose.nvidia.yml b/docker/compose.nvidia.yml deleted file mode 100644 index 5424af5..0000000 --- a/docker/compose.nvidia.yml +++ /dev/null @@ -1,4 +0,0 @@ -include: - - path: - - compose.cpu.yml - - nvidia.yml diff --git a/docker/compose.test-llm.yml b/docker/compose.test-llm.yml new file mode 100644 index 0000000..04c865c --- /dev/null +++ b/docker/compose.test-llm.yml @@ -0,0 +1,26 @@ +services: + test-llm: + build: + context: .. + dockerfile: docker/Dockerfile.owl + env_file: + - ../.env + entrypoint: + - /bin/bash + - -c + - uv run coverage run --data-file=db/.coverage --rcfile=api/pyproject.toml -m owl.entrypoints.llm + ports: + - 6970:6970 + networks: + - jamai + volumes: + - ${PWD}/docker_data/owl/db:/usr/src/app/db + # depends_on: + # rqlite: + # condition: service_healthy + healthcheck: + test: ["CMD-SHELL", "curl --fail localhost:6970/health || exit 1"] + interval: 10s + timeout: 2s + retries: 20 + start_period: 10s diff --git a/docker/infinity.yml b/docker/infinity.yml new file mode 100644 index 0000000..2dd42d8 --- /dev/null +++ b/docker/infinity.yml @@ -0,0 +1,18 @@ +services: + infinity: + image: michaelf34/infinity:0.0.70-cpu + container_name: jamai_infinity + command: ["v2", "--engine", "torch", "--port", "6909", "--model-warmup", "--model-id", "${EMBEDDING_MODEL}", "--model-id", "${RERANKER_MODEL}"] + healthcheck: + test: ["CMD-SHELL", "curl --fail http://localhost:6909/health"] + interval: 10s + timeout: 2s + retries: 20 + start_period: 10s + restart: unless-stopped + env_file: + - ../.env + volumes: + - ${PWD}/docker_data/infinity_cache:/app/.cache + networks: + - jamai diff --git a/docker/kong.yml b/docker/kong.yml new file mode 100644 index 0000000..f7988d0 --- /dev/null +++ b/docker/kong.yml @@ -0,0 +1,42 @@ +services: + kong-external: + image: kong/kong-gateway:3.6.1.4 + volumes: + - ../services/gateway/kong_external.yml:/kong/declarative/kong.yml + environment: + - KONG_DATABASE=off + - KONG_DECLARATIVE_CONFIG=/kong/declarative/kong.yml + - KONG_PROXY_ACCESS_LOG=/dev/stdout + - KONG_ADMIN_ACCESS_LOG=/dev/stdout + - KONG_PROXY_ERROR_LOG=/dev/stderr + - KONG_ADMIN_ERROR_LOG=/dev/stderr + - KONG_ADMIN_LISTEN=0.0.0.0:8001 + - KONG_ADMIN_GUI_PATH=/ + ports: + - "8000:8000" # HTTP requests + - "8443:8443" # HTTPS requests + - "127.0.0.1:8001:8001" # HTTP Admin listen + - "127.0.0.1:8444:8444" # HTTPS Admin listen + networks: + - jamai + + kong-internal: + image: kong/kong-gateway:3.6.1.4 + volumes: + - ../services/gateway/kong_internal.yml:/kong/declarative/kong.yml + environment: + - KONG_DATABASE=off + - KONG_DECLARATIVE_CONFIG=/kong/declarative/kong.yml + - KONG_PROXY_ACCESS_LOG=/dev/stdout + - KONG_ADMIN_ACCESS_LOG=/dev/stdout + - KONG_PROXY_ERROR_LOG=/dev/stderr + - KONG_ADMIN_ERROR_LOG=/dev/stderr + - KONG_ADMIN_LISTEN=0.0.0.0:8001 + - KONG_ADMIN_GUI_PATH=/ + ports: + - "8010:8000" # HTTP requests + - "8453:8443" # HTTPS requests + - "127.0.0.1:8011:8001" # HTTP Admin listen + - "127.0.0.1:8454:8444" # HTTPS Admin listen + networks: + - jamai diff --git a/docker/ollama.yml b/docker/ollama.yml deleted file mode 100644 index 8c0cc58..0000000 --- a/docker/ollama.yml +++ /dev/null @@ -1,4 +0,0 @@ -services: - owl: - environment: - - OWL_MODELS_CONFIG=models_ollama.json diff --git a/docker/otel_configs/otel-collector-config.yaml b/docker/otel_configs/otel-collector-config.yaml new file mode 100644 index 0000000..8359b15 --- /dev/null +++ b/docker/otel_configs/otel-collector-config.yaml @@ -0,0 +1,69 @@ +receivers: + otlp: + protocols: + grpc: + endpoint: "0.0.0.0:4317" + http: + endpoint: "0.0.0.0:4318" + +exporters: + # prometheus: + # endpoint: "0.0.0.0:8889" + # namespace: "owl" + + otlphttp/victoriametrics: + endpoint: http://vmagent:8429/opentelemetry + compression: gzip + encoding: proto + + clickhouse: + endpoint: http://clickhouse:8123?dial_timeout=10s&compress=lz4&async_insert=1 + ttl: 24h + database: jamaibase_owl + traces_table_name: owl_traces + username: owluser + password: owlpassword + create_schema: false + timeout: 5s + sending_queue: + queue_size: 1000 + retry_on_failure: + enabled: true + initial_interval: 5s + max_interval: 30s + max_elapsed_time: 300s + + debug: + verbosity: detailed + + otlphttp: + logs_endpoint: http://victorialogs:9428/insert/opentelemetry/v1/logs + +processors: + batch: + timeout: 5s + + filter/ottl: + traces: + span: + - 'IsMatch(attributes["db.statement"], "^PRAGMA")' + +extensions: + health_check: + endpoint: "0.0.0.0:13133" + +service: + extensions: [health_check] + pipelines: + traces: + receivers: [otlp] + processors: [filter/ottl, batch] + exporters: [clickhouse] + metrics: + receivers: [otlp] + processors: [batch] + exporters: [otlphttp/victoriametrics] + logs: + receivers: [otlp] + processors: [batch] + exporters: [otlphttp] diff --git a/docker/override.ci.yml b/docker/override.ci.yml new file mode 100644 index 0000000..93969d9 --- /dev/null +++ b/docker/override.ci.yml @@ -0,0 +1,25 @@ +services: + dragonfly: + ports: + - 6379:6379 + + clickhouse: + ports: + - 8123:8123 + + docling: + ports: + - 5001:5001 + + owl: + volumes: + - ${PWD}/docker_data/owl/db:/usr/src/app/db + - ${HOME}/.kube/config:/root/.kube/config + environment: + - KUBECONFIG=/root/.kube/config + ports: + - "${API_PORT:-6969}:${OWL_PORT:-6969}" + entrypoint: + - /bin/bash + - -c + - uv run coverage run --data-file=db/.coverage --rcfile=api/pyproject.toml -m owl.entrypoints.api diff --git a/docker/override.dev.yml b/docker/override.dev.yml new file mode 100644 index 0000000..e679b28 --- /dev/null +++ b/docker/override.dev.yml @@ -0,0 +1,65 @@ +services: + dragonfly: + ports: + - 6379:6379 + + otel-collector: + ports: + - 8889:8889 # Prometheus metrics endpoint + - 4317:4317 # OTLP gRPC receiver + - 4318:4318 # OTLP HTTP receiver + - 13133:13133 # health_check extension + + vmauth: + ports: + - 8427:8427 + + vmagent: + ports: + - 8429:8429 + + victoriametrics: + ports: + - 8428:8428 + + victorialogs: + ports: + - 9428:9428 + + clickhouse: + ports: + - 8123:8123 + - 19000:9000 + - 9363:9363 + + postgresql: + ports: + - 5431:5432 + + pgbouncer: + ports: + - 5432:5432 + + minio: + ports: + - 9000:9000 + - 9001:9001 + + docling: + ports: + - 5001:5001 + + owl: + ports: + - "${API_PORT:-6969}:${OWL_PORT:-6969}" + volumes: + - ${PWD}/docker_data/owl/db:/usr/src/app/db + - ${PWD}/services/api/src:/usr/src/app/api/src + + kopi: + ports: + - 5569:3000 + + frontend: + ports: + - "${FRONTEND_PORT:-4000}:4000" diff --git a/docker/nvidia.yml b/docker/override.nvidia.yml similarity index 97% rename from docker/nvidia.yml rename to docker/override.nvidia.yml index 0c788b2..757d348 100644 --- a/docker/nvidia.yml +++ b/docker/override.nvidia.yml @@ -9,7 +9,7 @@ services: device_ids: ["0"] capabilities: [gpu] - docio: + docling: deploy: resources: reservations: diff --git a/docker/vmagent/streamingAggregation.yaml b/docker/vmagent/streamingAggregation.yaml new file mode 100644 index 0000000..9b2e32f --- /dev/null +++ b/docker/vmagent/streamingAggregation.yaml @@ -0,0 +1,73 @@ +- match: "flower.task.runtime.seconds_bucket" + interval: 1m + without: [service.instance.id, worker] + outputs: [total] + keep_metric_names: true + +- match: '{__name__=~"flower.task.runtime.seconds_(count|sum)"}' + interval: 1m + without: [service.instance.id, worker] + outputs: [total] + keep_metric_names: true + +- match: '{__name__=~"http.+_bucket"}' + interval: 1m + without: [service.instance.id, http.server_name, http.host] + outputs: [total] + keep_metric_names: true + +- match: '{__name__=~"http.+_(count|sum)"}' + interval: 1m + without: [service.instance.id, http.server_name, http.host] + outputs: [total] + keep_metric_names: true + +- match: "db.client.connections.usage" + interval: 1m + without: [service.instance.id] + outputs: [total] + keep_metric_names: true + +- match: "http.server.active_requests" + interval: 1m + without: [service.instance.id, http.server_name, http.host] + outputs: [total] + keep_metric_names: true + +- match: + - "request_count" + interval: 1m + without: [service.instance.id] + outputs: [total] + keep_metric_names: true +# - match: "storage_usage" +# interval: 5m +# without: [service.instance.id] +# outputs: [max] +# flush_on_shutdown: true +# keep_metric_names: true + +# - match: +# - "bandwidth_usage" +# interval: 5m +# without: [service.instance.id] +# outputs: [total] +# flush_on_shutdown: true +# keep_metric_names: true + +# - match: +# - "llm_token_usage" +# - "embedding_token_usage" +# - "reranker_search_usage" +# - "spent" +# interval: 1m +# without: [service.instance.id] +# outputs: [total] +# flush_on_shutdown: true +# keep_metric_names: true +# output_relabel_configs: +# - action: replace +# source_labels: [__name__] +# regex: "^([^:]+):.*$" +# target_label: __name__ +# replacement: "${1}_agg" diff --git a/docker/vmauth/config.yml b/docker/vmauth/config.yml new file mode 100644 index 0000000..f10d91e --- /dev/null +++ b/docker/vmauth/config.yml @@ -0,0 +1,12 @@ +users: + - username: owl + password: owl-vm + url_map: + - src_paths: + - "/vm/.*" + drop_src_path_prefix_parts: 1 + url_prefix: "http://victoriametrics:8428/" + - src_paths: + - "/vl/.*" + drop_src_path_prefix_parts: 1 + url_prefix: "http://victorialogs:9428/" diff --git a/docs/alert_guide.md b/docs/alert_guide.md new file mode 100644 index 0000000..be8d4bf --- /dev/null +++ b/docs/alert_guide.md @@ -0,0 +1,255 @@ +# JamAIBase Alerting Guide + +> A quick reference for what’s already in `vm-alert-config.yaml`. Update thresholds only if needed, and **remember to set the Discord `webhook_url` correctly**. As of 2025-10-22 + +--- + +## Recording Rules + +**Group:** `storage.rules` (interval: **30s**) + +- **ClickHouse free disk (%)** + + - **Record:** `chi_clickhouse_disk_free_percentage` + - **Expr:** `(DiskFreeBytes / DiskTotalBytes) * 100` + - **Labels:** `unit=percent`, `component=clickhouse` + +- **VictoriaMetrics cluster disk usage (%)** + - **Record:** `vmcluster_disk_usage_percentage` + - **Expr:** (derived from `vm_data_size_bytes` and free disk) + - **Labels:** `unit=percent`, `component=vmcluster` + +--- + +## Alert Rules + +> Update the threshold according to your need, especially for PostgreSQL + +- **ClickHouseDiskSpaceLow** + + - **Expr:** `chi_clickhouse_disk_free_percentage < 10` + - **For:** `1h` + - **Labels:** `severity=critical`, `component=clickhouse` + - **Meaning:** Free space < **10%** for 1h. + +- **PostgreSQLDatabaseSizeTooLarge** + + - **Expr:** `cnpg_pg_database_size_bytes / 1024 ^ 3 > 10` + - **For:** `1h` + - **Labels:** `severity=warning`, `component=postgresql` + - **Meaning:** DB size > **10 GiB** for 1h. + +- **VMClusterDiskSpaceHigh** + - **Expr:** `vmcluster_disk_usage_percentage > 80` + - **For:** `1h` + - **Labels:** `severity=critical`, `component=vmcluster` + - **Meaning:** VM storage usage > **80%** for 1h. + +**Routing & Inhibition (as configured):** + +- All alerts route to **`jamaibase-discord`**. +- Sub-routes match `component` but currently target the same receiver. +- Inhibition: a `critical` alert suppresses a `warning` of the **same `alertname`**. + +--- + +## Log Alerts (VictoriaLogs → vmalert) + +> Structured alerts from app logs (`owl`, `starling`) evaluated every **30s** by `vmalert-log` against **VictoriaLogs**. + +- **JamAIBase Exception** + + - **Match:** `severity:i(critical)` and exception fields present + - **Agg by:** `service.name`, `code.filepath`, `code.function`, `code.lineno`, `_msg`, `exception.message`, `exception.stacktrace` + - **When:** `count() > 0` + - **Labels:** `severity=exception`, `component=jamaibase-log` + +- **JamAIBase Error** + + - **Match:** `severity:i(error)` + - **Agg by:** `service.name`, `code.filepath`, `code.function`, `code.lineno`, `_msg` + - **When:** `count() > 0` + - **Labels:** `severity=critical`, `component=jamaibase-log` + +- **JamAIBase Warning** _(optional; enable/disable per need)_ + - **Match:** `severity:i(warning)` + - **Agg by:** `service.name`, `code.filepath`, `code.function`, `code.lineno`, `_msg` + - **When:** `count() > 0` + - **Labels:** `severity=warning`, `component=jamaibase-log` + +**vmalert (logs):** `vmalert-log` selects `vmalert/rule-type=logs`, datasource `VictoriaLogs :9428`, notifier `Alertmanager :9093`, interval **30s**. + +--- + +## Discord Integration + +- Get the webhook url from discord (server settings > integration > webhook) +- Update the webhook_url with the url generated + +```yaml +- webhook_url: "https://discord.com/api/webhooks/XXXX/YYYY" +``` + +## Manual Trigger (Alertmanager API) + +> Prefer rule-based alerts in Prometheus/vmalert for real monitoring. + +**Prereq:** Reach Alertmanager on `:9093`. If inside the cluster, port-forward: + +```bash +kubectl -n vm-operator port-forward svc/vmalertmanager-vmalertmanager 9093:9093 +``` + +**Fire a test alert (firing):** + +```bash +curl -XPOST http://127.0.0.1:9093/api/v2/alerts -H 'Content-Type: application/json' -d '[ + { + "labels": { + "alertname": "ManualSmokeTest", + "severity": "critical", + "component": "clickhouse" + }, + "annotations": { + "summary": "Manual test alert", + "description": "End-to-end Discord delivery check." + } + } + ]' +``` + +**Resolve the same alert (identical labels + endsAt):** + +```bash +curl -XPOST http://127.0.0.1:9093/api/v2/alerts -H 'Content-Type: application/json' -d '[ + { + "labels": { + "alertname": "ManualSmokeTest", + "severity": "critical", + "component": "clickhouse" + }, + "annotations": { + "summary": "Manual test resolved", + "description": "Marking alert resolved." + }, + "startsAt": "'"$(date -u -d "-10m" +"%Y-%m-%dT%H:%M:%SZ")"'", + "endsAt": "'"$(date -u +"%Y-%m-%dT%H:%M:%SZ")"'" + } + ]' +``` + +**Force a fresh notification immediately (change fingerprint):** + +```bash +curl -XPOST http://127.0.0.1:9093/api/v2/alerts -H 'Content-Type: application/json' -d '[ + { + "labels": { + "alertname": "ManualSmokeTest", + "severity": "critical", + "component": "clickhouse", + "run_id": "'"$(date +%s)"'" + }, + "annotations": { + "summary": "Another ping", + "description": "Forcing a new notification." + } + } + ]' +``` + +**Notes (matches current config):** + +- `group_by: ["alertname"]` → alerts with the same `alertname` are grouped. +- `group_wait: 10s` / `group_interval: 10s` → small, intentional delays before/between sends. +- `repeat_interval: 3h` → duplicate alerts with the **same label set** won’t resend within 3 hours. +- All routes currently go to **`jamaibase-discord`**; `component` is mainly for title templating/inhibition now, but keep it for future per-component routing. + +--- + +### Field reference (JSON payload) + +> Minimal payload = `labels` + `annotations`. Timestamps are optional but recommended for clarity. + +- **`labels`** _(object, required)_ — define the alert’s identity (**fingerprint**) and routing. + + - **`alertname`**: Logical name (e.g., `ManualSmokeTest`). Used by `group_by: ["alertname"]`. + - **`severity`**: Freeform (`warning`, `critical`, …). Your template adds `@here` for `critical`. + - **`component`**: Freeform (`clickhouse`, `postgresql`, `vmcluster`, …). Useful for titles and future routing. + - **(optional)** e.g., `instance`, `hostname`, or **`run_id`** to force a new fingerprint/notification. + +- **`annotations`** _(object, required)_ — human-readable content for messages. + + - **`summary`**: One-liner. + - **`description`**: A few sentences of detail. + +- **`startsAt`** _(RFC3339 UTC, optional)_ — when the alert **started firing**. If omitted, Alertmanager treats it as “nowâ€. Example: `2025-10-11T05:00:00Z`. + +- **`endsAt`** _(RFC3339 UTC, optional)_ — when the alert **stops**. + + - **Firing**: omit `endsAt`, or set it **in the future**; AM considers it active. + - **Resolved**: set `endsAt` **in the past** (and keep the _same labels_) to immediately resolve it. + +#### Firing vs. Resolved — quick examples + +- **Fire (no timestamps):** + +```json +[ + { + "labels": { "alertname": "ManualSmokeTest", "severity": "critical", "component": "clickhouse" }, + "annotations": { "summary": "Manual test", "description": "End-to-end Discord check." } + } +] +``` + +- **Resolve (same labels + past `endsAt`):** + +```json +[ + { + "labels": { "alertname": "ManualSmokeTest", "severity": "critical", "component": "clickhouse" }, + "annotations": { "summary": "Manual test resolved", "description": "Marking resolved." }, + "startsAt": "2025-10-11T04:50:00Z", + "endsAt": "2025-10-11T05:00:00Z" + } +] +``` + +- **Force a fresh notification (new fingerprint via `run_id`):** + +```json +[ + { + "labels": { + "alertname": "ManualSmokeTest", + "severity": "critical", + "component": "clickhouse", + "run_id": "1697000000" + }, + "annotations": { "summary": "Another ping", "description": "New fingerprint via run_id." } + } +] +``` + +#### Practical tips + +- **Fingerprint ≈ labels only.** Same labels → same alert; change any label (e.g., `run_id`) → new alert. +- **Grouping.** With `group_by: ["alertname"]`, different severities/components sharing the same `alertname` are batched after `group_wait`. +- **Time helpers (UTC/Zulu):** + +```bash +date -u +"%Y-%m-%dT%H:%M:%SZ" # now +date -u -d "-10 minutes" +"%Y-%m-%dT%H:%M:%SZ" +date -u -d "+30 minutes" +"%Y-%m-%dT%H:%M:%SZ" +``` + +- **Auto-resolution.** If a client doesn’t refresh a firing alert, AM eventually considers it resolved. For manual tests, either send a resolved payload or let it expire. +- Resolution timeout is defined in the manifest, ref. https://prometheus.io/docs/alerting/latest/configuration/ + +``` +global: + # ResolveTimeout is the default value used by alertmanager if the alert does not + # include EndsAt, after this time passes it can declare the alert as resolved if it has not been updated. + # This has no impact on alerts from Prometheus, as they always include EndsAt. + [ resolve_timeout: | default = 5m ] +``` diff --git a/docs/pgaudit_guide.md b/docs/pgaudit_guide.md new file mode 100644 index 0000000..275acca --- /dev/null +++ b/docs/pgaudit_guide.md @@ -0,0 +1,71 @@ +# JamAIBase — pgaudit Setup + +1. PostgreSQL cluster wide auditing is set by pgaudit params in cnpg-cluster-deploy.yaml +2. Object auditing is set in owl db/\_\_init\_\_.py + +--- + +## 1) CNPG Cluster Config (role + pgaudit parameters) + +- Can customize the logging option in the parameters, ref. https://github.com/pgaudit/pgaudit/blob/main/README.md +- jamaibase_auditor is the role required for object auditing + +```yaml +spec: + managed: + roles: + - name: jamaibase_auditor + ensure: present + comment: pgaudit role for jamaibase + login: false + + postgresql: + parameters: + # pgaudit for logging DDL and role changes; include useful context + # NOTE: Ensure pgaudit is actually loaded via shared_preload_libraries in your cluster. + # If not already set elsewhere, add: + # shared_preload_libraries: "pgaudit" + pgaudit.log: "ddl, role" + pgaudit.log_catalog: "off" + pgaudit.log_parameter: "on" + pgaudit.log_client: "on" + pgaudit.role: "jamaibase_auditor" # object-based auditing role +``` + +--- + +## 2) Object Auditing Grants (in db/\_\_init\_\_.py) + +- customize audit_statement based on the level of DML statement you would want to monitor + +```python +async def _grant_auditor_priviledge(engine: AsyncEngine) -> bool: + """ + Apply the necessary grants to allow the auditor role to audit the database. + """ + auditor_role = "jamaibase_auditor" + audit_statement = "UPDATE, DELETE" + + async with engine.connect() as conn: + role_exists = await conn.scalar( + text(f"SELECT 1 FROM pg_roles WHERE rolname = '{auditor_role}'") + ) + if role_exists is None: + return False + + # FUTURE tables in this schema + await conn.execute( + text( + f'ALTER DEFAULT PRIVILEGES IN SCHEMA "{SCHEMA}" ' + f"GRANT {audit_statement} ON TABLES TO {auditor_role};" + ) + ) + + # EXISTING tables now + await conn.exec_driver_sql( + f'GRANT {audit_statement} ON ALL TABLES IN SCHEMA "{SCHEMA}" TO {auditor_role};' + ) + await conn.commit() + + return True +``` diff --git a/scripts/compile_docio_exe.ps1 b/scripts/compile_docio_exe.ps1 deleted file mode 100644 index 6799fd9..0000000 --- a/scripts/compile_docio_exe.ps1 +++ /dev/null @@ -1,9 +0,0 @@ -.\scripts\remove_cloud_modules.ps1 -cd .\clients\python -pip install . -cd .\..\..\services\docio -pip install -e . -pip install pyinstaller==6.9.0 -pip install cryptography==42.0.8 -pip install python-magic-bin -pyinstaller docio.spec \ No newline at end of file diff --git a/scripts/migrate_model_json.py b/scripts/migrate_model_json.py index ec82f8c..2bf99ee 100644 --- a/scripts/migrate_model_json.py +++ b/scripts/migrate_model_json.py @@ -1,7 +1,7 @@ import json import sys -from owl.protocol import ModelListConfig +from owl.types import ModelListConfig def transform_json(original_json): @@ -16,7 +16,7 @@ def transform_json(original_json): # Create the ModelDeploymentConfig instance deployment_config = { - "litellm_id": config.get("litellm_id", ""), + "routing_id": config.get("litellm_id", ""), "api_base": config.get("api_base", ""), "provider": provider, } diff --git a/scripts/migrate_v1_to_v2.py b/scripts/migrate_v1_to_v2.py new file mode 100644 index 0000000..f358e9e --- /dev/null +++ b/scripts/migrate_v1_to_v2.py @@ -0,0 +1,394 @@ +import argparse +import json +import logging +import sqlite3 +from pathlib import Path +from typing import Any, Dict, List + +import lancedb +from filelock import FileLock + +from owl.db.gen_table import ( + ColumnDtype, + ColumnMetadata, + GenerativeTableCore, + TableMetadata, +) +from owl.types import ColName, TableName, TableType + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +class V1DatabaseReader: + """Class to read data from v1 database format.""" + + def __init__(self, base_path: str): + self.base_path = Path(base_path) + self.locks: Dict[str, FileLock] = {} + + def get_org_projects(self) -> List[Dict[str, str]]: + """Get list of all org/project directories.""" + org_projects = [] + for org_dir in self.base_path.iterdir(): + if org_dir.is_dir(): + for project_dir in org_dir.iterdir(): + if project_dir.is_dir(): + org_projects.append( + { + "org_id": org_dir.name, + "project_id": project_dir.name, + "path": str(project_dir), + } + ) + return org_projects + + def get_tables_for_project(self, project_path: str) -> List[Dict[str, str]]: + """Get list of tables for a project.""" + tables = [] + for table_type in ["action", "chat", "knowledge"]: + db_path = Path(project_path) / f"{table_type}.db" + if db_path.exists(): + # Get all .lance directories in the table_type folder + table_dir = db_path.parent / table_type + if table_dir.exists(): + for lance_dir in table_dir.iterdir(): + if lance_dir.is_dir() and lance_dir.suffix == ".lance": + tables.append( + { + "type": table_type, + "sqlite_path": str(db_path), + "lance_path": str(lance_dir), + "table_name": lance_dir.stem, + } + ) + return tables + + def read_table_metadata(self, sqlite_path: str) -> List[Dict[str, Any]]: + """Read table metadata from SQLite database.""" + with sqlite3.connect(sqlite_path) as conn: + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + + # Get all table metadata + cursor.execute("SELECT * FROM TableMeta") + meta_rows = cursor.fetchall() + if not meta_rows: + return [] + + # Process all metadata rows + metadata_list = [] + for row in meta_rows: + metadata = dict(row) + + # Parse columns if present + if "cols" in metadata and metadata["cols"]: + import json + + # Handle both string and dict formats + if isinstance(metadata["cols"], str): + try: + metadata["cols"] = json.loads(metadata["cols"]) + except json.JSONDecodeError: + logger.warning(f"Failed to parse columns metadata for {sqlite_path}") + metadata["cols"] = [] + # Convert columns to structured format + if isinstance(metadata["cols"], (list, dict)): + if isinstance(metadata["cols"], dict): + metadata["cols"] = [metadata["cols"]] + for col in metadata["cols"]: + col["dtype"] = col.get("dtype", "str") + col["vlen"] = col.get("vlen", 0) + if "gen_config" in col and col["gen_config"]: + if isinstance(col["gen_config"], str): + try: + col["gen_config"] = json.loads(col["gen_config"]) + except json.JSONDecodeError: + col["gen_config"] = {} + + metadata_list.append(metadata) + + return metadata_list + + def _process_state_column(self, value: Any) -> Any: + """Process state column values.""" + if isinstance(value, str): + if value == "": + return {} + try: + return json.loads(value) + except json.JSONDecodeError: + logger.warning(f"Failed to parse state column value: {value}") + return {} + return value + + def read_table_data(self, lance_path: str) -> List[Dict[str, Any]]: + """Read table data from LanceDB.""" + # Connect to parent directory of the .lance folder + db = lancedb.connect(str(Path(lance_path).parent)) + # Open table using the directory name + table_name = Path(lance_path).stem + data = db.open_table(table_name).to_pandas().to_dict("records") + + # Process state columns + for row in data: + for col_name in list(row.keys()): + if col_name.endswith("_"): + row[col_name] = self._process_state_column(row[col_name]) + return data + + def lock_table(self, table_path: str) -> FileLock: + """Acquire a file lock for the table.""" + lock_path = f"{table_path}.lock" + self.locks[lock_path] = FileLock(lock_path) + self.locks[lock_path].acquire() + return self.locks[lock_path] + + def release_table_lock(self, table_path: str) -> None: + """Release the file lock for the table.""" + lock_path = f"{table_path}.lock" + if lock_path in self.locks: + self.locks[lock_path].release() + del self.locks[lock_path] + + +class V2Migrator: + """Class to handle v2 migration using GenerativeTableCore.""" + + def __init__(self, migrate: bool = False): + self.v1_conn = None + self.migrate = migrate + + # Mapping between v1 and v2 ColumnDtype values + _DTYPE_MAPPING = { + "int": "INTEGER", + "int8": "INTEGER", + "float": "FLOAT", + "float32": "FLOAT", + "float16": "FLOAT", + "bool": "BOOL", + "str": "TEXT", + "date-time": "TIMESTAMPTZ", + "image": "TEXT", + "audio": "TEXT", + "document": "TEXT", + } + + def _map_dtype(self, dtype: str) -> str: + """Map v1 dtype to v2 ColumnDtype.""" + dtype = dtype.lower() + return self._DTYPE_MAPPING.get(dtype, "TEXT") + + async def connect(self): + """Connect to SQLite database""" + self.v1_conn = sqlite3.connect(":memory:") # Will attach v1 databases + + async def close(self): + """Close database connections""" + if self.v1_conn: + self.v1_conn.close() + + async def migrate_table( + self, + proj_id: str, + table_type: TableType, + table_name: TableName, + metadata_list: List[Dict[str, Any]], + data: List[Dict[str, Any]], + ): + """Migrate a single table""" + logger.info(f"Validating table {table_name} for migration") + + # Validate metadata + if not metadata_list: + logger.warning(f"No metadata found for table {table_name}") + return + + # Find metadata for this specific table + metadata = next((m for m in metadata_list if m.get("id") == table_name), None) + if not metadata: + logger.warning(f"No matching metadata found for table {table_name}") + return + + # Log migration details + logger.info(f"Table {table_name} would be migrated with:") + if data: + logger.info(f"- {len(data)} rows") + logger.info(f"- Columns: {list(data[0].keys())}") + else: + logger.info("- Empty table (0 rows)") + + # Skip actual migration unless --migrate is specified + if not self.migrate: + logger.info(f"Dry-run mode: Table {table_name} would be migrated") + return + + # Create PostgreSQL schema and metadata tables + schema_id = f"{proj_id}_{table_type}" + # clean up before migration + await GenerativeTableCore.drop_schema(proj_id, table_type) + await GenerativeTableCore.create_schema(proj_id, table_type) + await GenerativeTableCore.create_metadata_tables(schema_id) + + # System columns that are handled automatically + SYSTEM_COLUMNS = ["ID", "Updated at"] + # TODO: Are these columns really migrated? There seems to be a mismatch between the data model and this + + # Create PostgreSQL table + columns = [] + if metadata.get("cols"): + # Use column metadata from v1 if available + col_order_counter = 1 # Initialize the counter + for col in metadata["cols"]: + if col["id"] not in SYSTEM_COLUMNS and not col["id"].endswith("_"): + columns.append( + ColumnMetadata( + column_id=ColName(col["id"]), + table_id=table_name, + dtype=ColumnDtype.FLOAT + if col.get("vlen") + else ColumnDtype(self._map_dtype(col.get("dtype", "str"))), + vlen=col.get("vlen"), + gen_config=col.get("gen_config"), + column_order=col_order_counter, # Use the counter here + ) + ) + col_order_counter += 1 # Increment the counter only when the condition is True + else: + raise ValueError("No column metadata found for table") + # elif data: + # # Fallback to creating metadata from data if no v1 metadata + # columns = [ + # ColumnMetadataCreate( + # column_id=ColName(col), + # table_id=table_name, + # dtype=ColumnDtype.STR, # Default to STR if no type info + # vlen=None, + # gen_config=None, + # column_order=idx + 1 + # ) + # for idx, col in enumerate(data[0].keys()) + # if col not in SYSTEM_COLUMNS + # ] + logger.info(f"Creating table {table_name} with {[c.column_id for c in columns]} columns") + table = await GenerativeTableCore.create_data_table( + project_id=proj_id, + table_id=table_name, + table_type=table_type, + table_metadata=TableMetadata( + table_id=table_name, + title=metadata.get("title", ""), + parent_id=metadata.get("parent_id", ""), + ), + column_metadata_list=columns, + ) + + # Migrate data if present + if data: + await table.add_rows(data_list=data) + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Migrate data from v1 to v2 database format") + parser.add_argument("-i", "--input", required=True, help="Path to v1 database directory") + parser.add_argument( + "--org-project", help="Specific org_id/project_id to migrate (format: org_id/project_id)" + ) + parser.add_argument("--org-id", help="Specific org_id to migrate") + parser.add_argument("--project-id", help="Specific project_id to migrate") + parser.add_argument( + "--migrate", + action="store_true", + help="Actually perform the migration (default is dry-run)", + ) + return parser.parse_args() + + +async def main(): + args = parse_args() + + # Initialize reader and migrator + reader = V1DatabaseReader(args.input) + migrator = V2Migrator(args.migrate) + + try: + await migrator.connect() + + # Get org/project directories to process + org_projects = reader.get_org_projects() + + # Filter for specific org/project if specified + if args.org_project: + org_id, project_id = args.org_project.split("/") + org_projects = [ + p for p in org_projects if p["org_id"] == org_id and p["project_id"] == project_id + ] + if not org_projects: + logger.error(f"Could not find org/project: {args.org_project}") + return + else: + # Filter by org_id if specified + if args.org_id: + org_projects = [p for p in org_projects if p["org_id"] == args.org_id] + if not org_projects: + logger.error(f"Could not find org: {args.org_id}") + return + + # Filter by project_id if specified + if args.project_id: + org_projects = [p for p in org_projects if p["project_id"] == args.project_id] + if not org_projects: + logger.error(f"Could not find project: {args.project_id}") + return + + # Process each project + for project in org_projects: + logger.info(f"Processing project: {project['project_id']}") + + # Get tables for project + tables = reader.get_tables_for_project(project["path"]) + + # Process each table + for table in tables: + logger.info( + f"Processing table: {table['type']}, sqlite: {table['sqlite_path']}, lance: {table['lance_path']}" + ) + try: + # Acquire lock + reader.lock_table(table["lance_path"]) + + # Read metadata and data + metadata = reader.read_table_metadata(table["sqlite_path"]) + data = reader.read_table_data(table["lance_path"]) + + # Migrate table + await migrator.migrate_table( + project["project_id"], + TableType(table["type"]), + TableName(table["table_name"]), + metadata, + data, + ) + if args.migrate: + logger.info(f"Migrated table: {table['type']} with {len(data)} rows") + + except Exception as e: + logger.error(f"Error processing table {table['type']}: {str(e)}") + raise e + finally: + # Release lock and log + reader.release_table_lock(table["lance_path"]) + logger.debug(f"Released lock for table: {table['type']}") + + finally: + await migrator.close() + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/scripts/migration_s3_v1_to_v2.py b/scripts/migration_s3_v1_to_v2.py new file mode 100644 index 0000000..240d6b9 --- /dev/null +++ b/scripts/migration_s3_v1_to_v2.py @@ -0,0 +1,486 @@ +import concurrent.futures +import math +import os +import sys +import time + +import boto3 +from botocore.exceptions import ClientError, NoCredentialsError +from dotenv import load_dotenv +from loguru import logger + + +def logger_config(max_workers: int = 10): + logger.remove() + logger.add( + sys.stderr, + level="INFO", + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}", + ) + logger.add( + f"s3_migration_{max_workers}.log", + level="INFO", + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}", + rotation="5 MB", + enqueue=True, + backtrace=True, + diagnose=True, + ) + logger.info("Logger configured. Starting S3 migration script...") + + +def get_s3_client(endpoint, access_key, secret_key): + try: + client = boto3.client( + "s3", + endpoint_url=f"http://{endpoint}", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + ) + client.list_buckets() + logger.info(f"Successfully connected to MinIO at {endpoint}") + return client + except (NoCredentialsError, ClientError) as e: + logger.error(f"Failed to connect to MinIO at {endpoint}. Error: {e}") + return None + + +def get_all_organization_ids(s3_client, bucket_name: str) -> list[str]: + org_ids = set() + prefix_to_scan = "raw/" + logger.info(f"Discovering all organization IDs in s3://{bucket_name}/{prefix_to_scan}...") + try: + paginator = s3_client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix_to_scan, Delimiter="/") + for page in pages: + if "CommonPrefixes" in page: + for common_prefix in page["CommonPrefixes"]: + parts = common_prefix.get("Prefix", "").strip("/").split("/") + if len(parts) > 1: + org_ids.add(parts[1]) + except ClientError as e: + logger.error(f"Failed to scan for organization IDs in s3://{bucket_name}/. Error: {e}") + return [] + found_ids = list(org_ids) + if found_ids: + logger.info(f"Found {len(found_ids)} organization IDs: {found_ids}") + else: + logger.warning(f"No organization IDs found under the '{prefix_to_scan}' prefix.") + return found_ids + + +def format_bytes(size_bytes: int) -> str: + """Converts a size in bytes to a human-readable format (KB, MB, GB, etc.).""" + if size_bytes == 0: + return "0B" + size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") + i = int(math.floor(math.log(size_bytes, 1024))) + p = math.pow(1024, i) + s = round(size_bytes / p, 2) + return f"{s} {size_name[i]}" + + +def get_organization_storage_size( + s3_client, bucket_name: str, organization_id: str, log_summary: bool = True +) -> tuple[int, int]: + """ + Calculates the total number of files and storage size for a specific organization. + The `log_summary` parameter controls if the function prints its own summary. + """ + if log_summary: + logger.info( + f"Calculating storage size for organization '{organization_id}' in bucket '{bucket_name}'..." + ) + + total_bytes, total_files = 0, 0 + prefixes_to_scan = [f"raw/{organization_id}/", f"thumb/{organization_id}/"] + + try: + for prefix in prefixes_to_scan: + paginator = s3_client.get_paginator("list_objects_v2") + for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix): + for obj in page.get("Contents", []): + total_bytes += obj["Size"] + total_files += 1 + + if log_summary: + readable_size = format_bytes(total_bytes) + logger.info("=" * 40) + logger.info(f"Storage Summary for Organization: '{organization_id}'") + logger.info(f" Total Files: {total_files:,}") + logger.info(f" Total Size: {readable_size} ({total_bytes:,} bytes)") + logger.info("=" * 40) + + return total_bytes, total_files + except ClientError as e: + logger.error(f"Could not calculate storage for org '{organization_id}'. Error: {e}") + return 0, 0 + + +def analyze_all_organizations_storage(s3_client, bucket_name: str): + """ + Analyzes storage for all organizations, then logs a sorted report. + """ + logger.info(f"Starting storage analysis for all organizations in bucket '{bucket_name}'...") + + organization_ids = get_all_organization_ids(s3_client, bucket_name) + if not organization_ids: + logger.warning("No organizations found to analyze.") + return + + storage_data = [] + grand_total_bytes = 0 + grand_total_files = 0 + + analysis_start_time = time.time() + for org_id in organization_ids: + # Call with log_summary=False to prevent noisy individual logs + total_bytes, total_files = get_organization_storage_size( + s3_client, bucket_name, org_id, log_summary=False + ) + if total_files > 0: + storage_data.append( + { + "org_id": org_id, + "total_bytes": total_bytes, + "total_files": total_files, + } + ) + grand_total_bytes += total_bytes + grand_total_files += total_files + + # Sort the collected data by size, from lowest to highest + sorted_storage_data = sorted(storage_data, key=lambda x: x["total_bytes"]) + + analysis_end_time = time.time() + logger.info( + f"Completed storage analysis in {analysis_end_time - analysis_start_time:.2f} seconds." + ) + + # --- Log the formatted report --- + logger.info("=" * 70) + logger.info("Storage Size Report by Organization (Sorted Lowest to Highest)") + logger.info("-" * 70) + logger.info(f"{'Organization ID':<40} | {'Total Files':>12} | {'Total Size':>12}") + logger.info("-" * 70) + + for data in sorted_storage_data: + readable_size = format_bytes(data["total_bytes"]) + # Use f-string alignment and formatting for a clean table + logger.info(f"{data['org_id']:<40} | {data['total_files']:>12,} | {readable_size:>12}") + + logger.info("-" * 70) + readable_grand_total = format_bytes(grand_total_bytes) + logger.info(f"{'GRAND TOTAL':<40} | {grand_total_files:>12,} | {readable_grand_total:>12}") + logger.info("=" * 70) + + +def _copy_single_object( + source_s3_client, dest_s3_client, source_bucket, source_key, dest_bucket, dest_key +): + """ + Worker function executed by each thread. Handles one object. + Returns a status string: "COPIED", "SKIPPED", "FAILED". + """ + source_loc = f"s3://{source_bucket}/{source_key}" + dest_loc = f"s3://{dest_bucket}/{dest_key}" + try: + # 1. Check if the object already exists at the destination + dest_s3_client.head_object(Bucket=dest_bucket, Key=dest_key) + logger.info(f"[SKIP-EXISTING] Destination object already exists: {dest_loc}") + return "SKIPPED" + except ClientError as e: + if e.response["Error"]["Code"] != "404": + logger.error(f"[FAIL] Failed to check destination {dest_loc}: {e}") + return "FAILED" + + # 2. If it doesn't exist, copy it + try: + response = source_s3_client.get_object(Bucket=source_bucket, Key=source_key) + dest_s3_client.put_object( + Bucket=dest_bucket, + Key=dest_key, + Body=response["Body"].read(), + ContentType=response.get("ContentType", "application/octet-stream"), + ) + logger.info(f"[COPIED] {source_loc} -> {dest_loc}") + return "COPIED" + except ClientError as e: + logger.error(f"[FAIL] Failed during copy for {source_loc}: {e}") + return "FAILED" + + +def migrate_s3_structure_across_endpoints( + source_s3_client, + dest_s3_client, + old_organization_id: str, + source_bucket: str, + dest_bucket: str, + new_organization_id: str = None, + max_workers: int = 10, + dry_run: bool = True, +): + """ + Migrates a SINGLE organization's files in parallel, skipping existing files. + """ + if new_organization_id is None: + new_organization_id = old_organization_id + + if dry_run: + logger.info(f"DRY RUN for org '{old_organization_id}'. No changes will be made.") + # Perform a simple listing for the dry run plan + total_planned = 0 + for prefix in [f"raw/{old_organization_id}/", f"thumb/{old_organization_id}/"]: + paginator = source_s3_client.get_paginator("list_objects_v2") + for page in paginator.paginate(Bucket=source_bucket, Prefix=prefix): + for obj in page.get("Contents", []): + logger.info(f"[PLAN-COPY] s3://{source_bucket}/{obj['Key']}") + total_planned += 1 + logger.info( + f"Dry run summary for org '{old_organization_id}': Planned to copy {total_planned} objects." + ) + return total_planned, 0, 0 + + total_copied, total_skipped, total_failed = 0, 0, 0 + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {} + for prefix in [f"raw/{old_organization_id}/", f"thumb/{old_organization_id}/"]: + try: + paginator = source_s3_client.get_paginator("list_objects_v2") + for page in paginator.paginate(Bucket=source_bucket, Prefix=prefix): + for obj in page.get("Contents", []): + source_key = obj["Key"] + dest_key = source_key.replace( + f"/{old_organization_id}/", f"/{new_organization_id}/", 1 + ) + + future = executor.submit( + _copy_single_object, + source_s3_client, + dest_s3_client, + source_bucket, + source_key, + dest_bucket, + dest_key, + ) + futures[future] = source_key + except ClientError as e: + logger.error( + f"Could not list objects in s3://{source_bucket}/{prefix}. Error: {e}" + ) + total_failed += 1 # Count listing itself as a failure + + for future in concurrent.futures.as_completed(futures): + status = future.result() + if status == "COPIED": + total_copied += 1 + elif status == "SKIPPED": + total_skipped += 1 + else: + total_failed += 1 + + logger.info( + f"Summary for org '{old_organization_id}' (Workers: {max_workers}): " + f"Copied={total_copied}, Skipped={total_skipped}, Failed={total_failed}" + ) + return total_copied, total_skipped, total_failed + + +def migrate_all_organizations( + source_s3_client, + dest_s3_client, + source_bucket: str, + dest_bucket: str = None, + max_workers: int = 10, + dry_run: bool = True, +): + """ + Discovers all organization IDs and migrates their files in parallel, logging time per org. + """ + if dest_bucket is None: + dest_bucket = source_bucket + if not dry_run: + try: + dest_s3_client.create_bucket(Bucket=dest_bucket) + logger.info(f"Ensured destination bucket '{dest_bucket}' exists.") + except ClientError as e: + if e.response["Error"]["Code"] not in [ + "BucketAlreadyOwnedByYou", + "BucketAlreadyExists", + ]: + logger.error( + f"Could not create/verify destination bucket '{dest_bucket}'. Aborting. Error: {e}" + ) + return + + organization_ids = get_all_organization_ids(source_s3_client, source_bucket) + if not organization_ids: + logger.warning("No organizations found to migrate. Exiting.") + return + + grand_total_copied, grand_total_skipped, grand_total_failed = 0, 0, 0 + for i, org_id in enumerate(organization_ids): + logger.info("-" * 50) + logger.info( + f"Starting migration for Organization {i + 1}/{len(organization_ids)}: '{org_id}'" + ) + + org_start_time = time.time() + + copied, skipped, failed = migrate_s3_structure_across_endpoints( + source_s3_client=source_s3_client, + dest_s3_client=dest_s3_client, + source_bucket=source_bucket, + dest_bucket=dest_bucket, + old_organization_id=org_id, + new_organization_id="0" if org_id == "org_82d01c923f25d5939b9d4188" else org_id, + max_workers=max_workers, + dry_run=dry_run, + ) + + org_end_time = time.time() + org_time_taken_sec = org_end_time - org_start_time + org_time_taken_min = org_time_taken_sec / 60 + logger.warning( + f"\n'{org_id}' migration completed in {org_time_taken_sec:.3f} seconds ({org_time_taken_min:.3f} minutes)." + ) + + grand_total_copied += copied + grand_total_skipped += skipped + grand_total_failed += failed + + logger.info("=" * 50) + logger.info("BULK MIGRATION COMPLETE") + logger.info(f"Total organizations processed: {len(organization_ids)}") + logger.info(f"Grand total objects copied: {grand_total_copied}") + logger.info(f"Grand total objects skipped (already exist): {grand_total_skipped}") + logger.info(f"Grand total failures: {grand_total_failed}") + if dry_run: + logger.warning("This was a DRY RUN. No actual data was moved.") + logger.info("=" * 50) + + +def setup_dummy_v1_data(s3_client, bucket_name, org_id, project_id, uuid): + try: + s3_client.create_bucket(Bucket=bucket_name) + except ClientError as e: + if e.response["Error"]["Code"] not in ["BucketAlreadyOwnedByYou", "BucketAlreadyExists"]: + raise + raw_key = f"raw/{org_id}/{project_id}/{uuid}/report.pdf" + s3_client.put_object( + Bucket=bucket_name, Key=raw_key, Body=b"pdf content", ContentType="application/pdf" + ) + thumb_key = f"thumb/{org_id}/{project_id}/{uuid}/report.webp" + s3_client.put_object( + Bucket=bucket_name, Key=thumb_key, Body=b"thumbnail", ContentType="image/webp" + ) + logger.info(f" Created dummy data for org '{org_id}' in bucket '{bucket_name}'") + + +if __name__ == "__main__": + script_start_time = time.time() + load_dotenv() + + MAX_WORKERS = int(os.getenv("MIGRATION_MAX_WORKERS", 12)) + logger_config(MAX_WORKERS) + + SOURCE_MINIO_ENDPOINT = os.getenv("SOURCE_MINIO_ENDPOINT", "localhost:9000") + SOURCE_MINIO_ACCESS_KEY = os.getenv("OWL_S3_ACCESS_KEY_ID") + SOURCE_MINIO_SECRET_KEY = os.getenv("OWL_S3_SECRET_ACCESS_KEY") + SOURCE_BUCKET_NAME = os.getenv("SOURCE_BUCKET_NAME", "v1-company-bucket") + + DEST_MINIO_ENDPOINT = os.getenv("DEST_MINIO_ENDPOINT", "localhost:9000") + DEST_MINIO_ACCESS_KEY = os.getenv("OWL_S3_ACCESS_KEY_ID") + DEST_MINIO_SECRET_KEY = os.getenv("OWL_S3_SECRET_ACCESS_KEY") + DEST_BUCKET_NAME = os.getenv("DEST_BUCKET_NAME", "v2-migrated-data") + + logger.info(f"Source Endpoint: {SOURCE_MINIO_ENDPOINT}, Source Bucket: {SOURCE_BUCKET_NAME}") + logger.info( + f"Destination Endpoint: {DEST_MINIO_ENDPOINT}, Destination Bucket: {DEST_BUCKET_NAME}" + ) + logger.info(f"Using a maximum of {MAX_WORKERS} parallel workers.") + + s3_source = get_s3_client( + SOURCE_MINIO_ENDPOINT, SOURCE_MINIO_ACCESS_KEY, SOURCE_MINIO_SECRET_KEY + ) + s3_dest = get_s3_client(DEST_MINIO_ENDPOINT, DEST_MINIO_ACCESS_KEY, DEST_MINIO_SECRET_KEY) + + if not s3_source or not s3_dest: + logger.error("Could not establish connection to MinIO. Exiting.") + sys.exit(1) + + # # --- Setup Dummy Data for Testing --- + # logger.info("\n--- Setting up test data ---") + # setup_dummy_v1_data( + # s3_source, SOURCE_BUCKET_NAME, "org-acme-corp", "proj-q1-reports", "uuid-acme-1" + # ) + # setup_dummy_v1_data( + # s3_source, SOURCE_BUCKET_NAME, "org-acme-corp", "proj-q1-reports", "uuid-acme-2" + # ) + # setup_dummy_v1_data( + # s3_source, SOURCE_BUCKET_NAME, "org-globex-inc", "proj-doomsday", "uuid-globex-1" + # ) + # setup_dummy_v1_data( + # s3_source, SOURCE_BUCKET_NAME, "org-stark-industries", "proj-arc-reactor", "uuid-stark-1" + # ) + + logger.info("\n--- Starting ALL-ORG Migration (Dry Run) ---") + migrate_all_organizations( + source_s3_client=s3_source, + dest_s3_client=s3_dest, + source_bucket=SOURCE_BUCKET_NAME, + dest_bucket=DEST_BUCKET_NAME, + max_workers=MAX_WORKERS, + dry_run=True, + ) + + logger.info("\n--- Starting ALL-ORG Migration (Actual Run) ---") + # This run will copy some files and skip the one that was pre-seeded. + migrate_all_organizations( + source_s3_client=s3_source, + dest_s3_client=s3_dest, + source_bucket=SOURCE_BUCKET_NAME, + dest_bucket=DEST_BUCKET_NAME, + max_workers=MAX_WORKERS, + dry_run=False, + ) + + logger.info("\n--- Re-running Migration to demonstrate idempotency ---") + # This second run should skip all files, as they were all copied in the previous step. + migrate_all_organizations( + source_s3_client=s3_source, + dest_s3_client=s3_dest, + source_bucket=SOURCE_BUCKET_NAME, + dest_bucket=DEST_BUCKET_NAME, + max_workers=MAX_WORKERS, + dry_run=False, + ) + + script_end_time = time.time() + time_taken_min = (script_end_time - script_start_time) / 60 + time_taken_hrs = time_taken_min / 60 + logger.warning( + f"\nScript completed in {time_taken_min:.3f} minutes ({time_taken_hrs:.3f} hours)." + ) + + # source_org_size = get_organization_storage_size( + # s3_client=s3_source, + # bucket_name=SOURCE_BUCKET_NAME, + # organization_id="org_82d01c923f25d5939b9d4188", + # ) + # dest_org_size = get_organization_storage_size( + # s3_client=s3_dest, bucket_name=DEST_BUCKET_NAME, organization_id="0" + # ) + # assert ( + # source_org_size[0] == dest_org_size[0] + # ), f"Source size {source_org_size[0]} does not match destination size {dest_org_size[0]}" + # assert ( + # source_org_size[1] == dest_org_size[1] + # ), f"Source files {source_org_size[1]} do not match destination files {dest_org_size[1]}" + + # logger.info("\n--- Generating Storage Analysis Report for All Organizations ---") + # analyze_all_organizations_storage(s3_client=s3_source, bucket_name=SOURCE_BUCKET_NAME) + + logger.info("\n--- Generating Storage Analysis Report for All Organizations ---") + analyze_all_organizations_storage(s3_client=s3_dest, bucket_name=DEST_BUCKET_NAME) diff --git a/scripts/migration_v030.py b/scripts/migration_v030.py index 4d31a07..d279ada 100644 --- a/scripts/migration_v030.py +++ b/scripts/migration_v030.py @@ -10,7 +10,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict import owl -from jamaibase.protocol import GEN_CONFIG_VAR_PATTERN, ColumnSchema, LLMGenConfig +from jamaibase.types import GEN_CONFIG_VAR_PATTERN, ColumnSchema, LLMGenConfig class EnvConfig(BaseSettings): @@ -53,7 +53,7 @@ def restore(db_dir: str): ) ) src_path = join(proj_dir, bak_files[0]) - dst_path = join(proj_dir, f'{bak_files[0].split("_")[0]}.db') + dst_path = join(proj_dir, f"{bak_files[0].split('_')[0]}.db") os.remove(dst_path) copy2(src_path, dst_path) @@ -121,7 +121,7 @@ def update_gen_table(db_path: str): cols = orjson.loads(record[1]) updated_cols = [] - print(f"└─ (Table {i+1:,d}/{len(records):,d}) Checking table: {table_id}") + print(f"└─ (Table {i + 1:,d}/{len(records):,d}) Checking table: {table_id}") for col in cols: col = ColumnSchema.model_validate(col) if db_path.endswith("chat.db") and col.id.lower() == "ai": @@ -166,7 +166,7 @@ def update_gen_table(db_path: str): os.makedirs(backup_dir, exist_ok=False) for j, db_file in enumerate(sqlite_files): - print(f"(DB {j+1:,d}/{len(sqlite_files):,d}): Processing: {db_file}") + print(f"(DB {j + 1:,d}/{len(sqlite_files):,d}): Processing: {db_file}") backup_db(db_file, backup_dir) add_table_meta_columns(db_file) update_gen_table(db_file) diff --git a/scripts/migration_v040.py b/scripts/migration_v040.py index 2d09e11..776418a 100644 --- a/scripts/migration_v040.py +++ b/scripts/migration_v040.py @@ -10,7 +10,7 @@ from loguru import logger from pydantic_settings import BaseSettings, SettingsConfigDict -from jamaibase.protocol import ColumnSchema +from jamaibase.types import ColumnSchema class EnvConfig(BaseSettings): diff --git a/scripts/remove_cloud_modules.ps1 b/scripts/remove_cloud_modules.ps1 deleted file mode 100644 index 2a86c15..0000000 --- a/scripts/remove_cloud_modules.ps1 +++ /dev/null @@ -1,22 +0,0 @@ -Get-ChildItem -Recurse -File -Filter "cloud_*.py" | Remove-Item -Force -Get-ChildItem -Recurse -File -Filter "cloud_*.json" | Remove-Item -Force -Get-ChildItem -Recurse -File -Filter "*_cloud.json" | Remove-Item -Force -Get-ChildItem -Recurse -File -Filter "compose.*.cloud.yml" | Remove-Item -Force -Get-ChildItem -Recurse -Directory -Filter "(cloud)" | Remove-Item -Recurse -Force -if (Test-Path -Path "docker\enterprise") { - Remove-Item -Path "docker\enterprise" -Recurse -Force -} - -# Remove a file or folder quietly -# Like linux "rm -rf" -function quiet_rm($item) -{ - if (Test-Path $item) { - echo "Removing $item" - Remove-Item -Force $item - } -} -quiet_rm "services/app/ecosystem.config.cjs" -quiet_rm "services/appecosystem.json" -quiet_rm ".github/workflows/trigger-push-gh-image.yml" -quiet_rm ".github/workflows/ci.cloud.yml" \ No newline at end of file diff --git a/scripts/remove_cloud_modules.sh b/scripts/remove_cloud_modules.sh index 0ef410f..748c54b 100644 --- a/scripts/remove_cloud_modules.sh +++ b/scripts/remove_cloud_modules.sh @@ -1,11 +1,14 @@ #!/usr/bin/env bash +rm -rf k8s/ rm -rf docker/enterprise/ -find . -type f -name "cloud_*.py" -delete -find . -type f -name "cloud_*.json" -delete -find . -type f -name "*_cloud.json" -delete -find . -type f -name "compose.*.cloud.yml" -delete -find . -type d -name "(cloud)" -exec rm -rf {} + +find . -type f -iname "*cloud*.md" -delete +find . -type f -iname "*cloud*.py" -delete +find . -type f -iname "cloud_*.json" -delete +find . -type f -iname "*_cloud.json" -delete +find . -type f -iname "compose.*.cloud.yml" -delete +find . -type d -iname "*cloud" -exec rm -rf {} + +find . -type d -iname "(cloud)" -exec rm -rf {} + rm -f services/app/ecosystem.config.cjs rm -f services/app/ecosystem.json rm -f .github/workflows/trigger-push-gh-image.yml diff --git a/scripts/update_model_id.py b/scripts/update_model_id.py new file mode 100644 index 0000000..0e6cf66 --- /dev/null +++ b/scripts/update_model_id.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python3 +""" +A script to update model IDs with a clear, subcommand-based interface. + +Usage: + python update.py one-to-one + python update.py many-to-one ... --to + python update.py file +""" + +import asyncio +import json +import pathlib +import sys +from typing import Dict, List, Optional, Tuple + +import typer +from asyncpg import Connection, Record +from loguru import logger + +from owl.db import async_session +from owl.db.gen_table import ( + GENTABLE_ENGINE, + ActionTable, + ChatTable, + KnowledgeTable, + TableType, +) +from owl.db.models import ModelConfig, Organization, Project +from owl.types import ( + CodeGenConfig, + DiscriminatedGenConfig, + EmbedGenConfig, + LLMGenConfig, + OrganizationRead, + ProjectRead, +) + +# Typer app with subcommands +app = typer.Typer( + help="A tool to update model IDs using different modes.", + context_settings={"help_option_names": ["-h", "--help"]}, + no_args_is_help=True, +) + +# Common options that can be used with any subcommand +_DRY_RUN_OPTION = typer.Option( + False, "--dry-run", "-n", help="Show what would be updated without making changes." +) +_ORGANIZATIONS_OPTION = typer.Option( + None, + "--organizations", + "-o", + help="Comma-separated list of specific organization IDs to target.", +) + + +# ------------------------------------------------------------------ +# 1. HELPER FUNCTIONS +# ------------------------------------------------------------------ + + +async def validate_models(mapping: Dict[str, str]) -> bool: + """ + Check that all new models exist + Check that no embedding models. + """ + async with async_session() as session: + # Use a set to check each new ID only once, improving efficiency + for new_id in set(mapping.values()): + new_model = await ModelConfig.get(session, new_id) + if not new_model: + logger.error( + f"Validation failed: New model ID '{new_id}' does not exist in the system." + ) + return False + if "embed" in new_model.capabilities: + logger.error(f"Validation failed: New model ID '{new_id}' is an embedding model.") + return False + for old_id in set(mapping.keys()): + old_model = await ModelConfig.get(session, old_id) + if not old_model: + logger.warning( + f"Old Model ID '{old_id}' does not exist in the system, assumed to be deleted. Skipping validation." + ) + continue + if "embed" in old_model.capabilities: + logger.error(f"Validation failed: Old model ID '{old_id}' is an embedding model.") + return False + logger.success("All new models validated successfully.") + return True + + +async def validate_model_mapping(mapping: Dict[str, str]) -> bool: + """Check that all new models has at least the same capability as old models, and not embedding models.""" + async with async_session() as session: + # Use a set to check each new ID only once, improving efficiency + for old_id, new_id in mapping.items(): + new_model = await ModelConfig.get(session, new_id) + old_model = await ModelConfig.get(session, old_id) + if not new_model: + logger.error( + f"Validation failed: New model ID '{new_id}' does not exist in the system." + ) + return False + if not old_model: + logger.warning( + f"Old Model ID '{old_id}' does not exist in the system, assumed to be deleted. Skipping validation." + ) + continue + if "embed" in new_model.capabilities: + logger.error(f"Validation failed: New model ID '{new_id}' is an embedding model.") + return False + if not [c for c in old_model.capabilities if c in new_model.capabilities]: + logger.error( + f"Validation failed: New model ID '{new_id}' does not have the same capabilities as old model ID '{old_id}'." + ) + return False + logger.success("All new model IDs validated successfully.") + return True + + +# ------------------------------------------------------------------ +# 2. CORE PROCESSING LOGIC +# ------------------------------------------------------------------ + + +class ModelIDUpdater: + """The engine that performs the single-pass update. This is generic and works with any mapping.""" + + def __init__( + self, + model_mapping: Dict[str, str], + dry_run: bool = False, + organization_ids: Optional[List[str]] = None, + ): + self.model_mapping = model_mapping + self.dry_run = dry_run + self.organization_ids = organization_ids + self.updated_count = 0 + + @staticmethod + def _gen_config_model_validate(gen_config_json: dict) -> Optional[DiscriminatedGenConfig]: + obj_type = gen_config_json.get("object") + try: + if obj_type in ("gen_config.llm", "gen_config.chat"): + return LLMGenConfig.model_validate(gen_config_json) + elif obj_type == "gen_config.embed": + return EmbedGenConfig.model_validate(gen_config_json) + elif obj_type == "gen_config.code": + return CodeGenConfig.model_validate(gen_config_json) + except Exception as e: + logger.warning(f"Skipping column due to validation error in its gen_config: {e}") + return None + + async def get_all_organizations( + self, organization_ids: Optional[List[str]] = None + ) -> List[OrganizationRead]: + """Retrieve all organizations from the database, optionally filtered.""" + async with async_session() as session: + if not organization_ids: + orgs = await Organization.list_(session=session, return_type=OrganizationRead) + return orgs.items + + all_organizations = [] + for org_id in organization_ids: + org = await Organization.get(session, org_id) + org = OrganizationRead.model_validate(org) + if org: + all_organizations.append(org) + else: + logger.warning(f"Organization with ID '{org_id}' not found. Skipping.") + return all_organizations + + async def get_all_projects_from_org(self, organization_id: str) -> List[ProjectRead]: + """Retrieve all projects from a specific organization.""" + async with async_session() as session: + projects = await Project.list_( + session=session, + return_type=ProjectRead, + filters=dict(organization_id=organization_id), + ) + return projects.items + + async def get_tables_for_project( + self, conn: Connection, project_id: str, table_type: TableType + ) -> List[Record]: + """Get all tables for a project and table type.""" + schema_id = f"{project_id}_{table_type.value}" + try: + return await conn.fetch(f'SELECT table_id FROM "{schema_id}"."TableMetadata"') + except Exception as e: + logger.warning(f"Could not access schema '{schema_id}': {e}") + return [] + + async def get_columns_with_gen_config( + self, conn: Connection, project_id: str, table_type: TableType, table_id: str + ) -> List[Record]: + """Get all columns with gen_config for a specific table.""" + schema_id = f"{project_id}_{table_type.value}" + return await conn.fetch( + f'SELECT column_id, gen_config FROM "{schema_id}"."ColumnMetadata" ' + "WHERE table_id = $1 AND gen_config IS NOT NULL", + table_id, + ) + + def update_model_id_in_config( + self, gen_config: DiscriminatedGenConfig + ) -> Tuple[bool, DiscriminatedGenConfig]: + """Update model IDs in a gen_config if they match the mapping.""" + updated = False + config = gen_config.model_copy(deep=True) + + if isinstance(config, LLMGenConfig): + if config.model in self.model_mapping: + config.model = self.model_mapping[config.model] + updated = True + if config.rag_params and config.rag_params.reranking_model in self.model_mapping: + config.rag_params.reranking_model = self.model_mapping[ + config.rag_params.reranking_model + ] + updated = True + elif isinstance(config, EmbedGenConfig): + if config.embedding_model in self.model_mapping: + config.embedding_model = self.model_mapping[config.embedding_model] + updated = True + + return updated, config + + @staticmethod + async def get_table_instance( + project_id: str, table_type: TableType, table_id: str + ) -> ActionTable | KnowledgeTable | ChatTable: + table_classes = { + TableType.ACTION: ActionTable, + TableType.KNOWLEDGE: KnowledgeTable, + TableType.CHAT: ChatTable, + } + table_class = table_classes[table_type] + return await table_class.open_table(project_id=project_id, table_id=table_id) + + async def update_column_gen_config( + self, + project_id: str, + table_type: TableType, + table_id: str, + update_mapping: dict[str, DiscriminatedGenConfig], + ): + """Update the gen_config for a specific column in the database.""" + if self.dry_run: + return + + table = await self.get_table_instance(project_id, table_type, table_id) + await table.update_gen_config(update_mapping=update_mapping, allow_nonexistent_refs=True) + + async def process_organization(self, organization: OrganizationRead): + """Iterate through all projects and tables in an organization and update them.""" + logger.info(f"Processing organization: {organization.id}") + projects = await self.get_all_projects_from_org(organization.id) + + for project in projects: + logger.info(f" Processing project: {project.id}") + for table_type in [TableType.ACTION, TableType.KNOWLEDGE, TableType.CHAT]: + col_updated_count = 0 + async with GENTABLE_ENGINE.transaction() as conn: + tables = await self.get_tables_for_project(conn, project.id, table_type) + if not tables: + continue + + for table in tables: + table_id = table["table_id"] + columns = await self.get_columns_with_gen_config( + conn, project.id, table_type, table_id + ) + to_be_update = {} + for column in columns: + gen_config = self._gen_config_model_validate(column["gen_config"]) + if not gen_config: + continue + was_updated, updated_config = self.update_model_id_in_config( + gen_config + ) + if was_updated: + found_old_model = None + if isinstance(gen_config, LLMGenConfig): + if gen_config.model in self.model_mapping: + found_old_model = gen_config.model + elif ( + gen_config.rag_params + and gen_config.rag_params.reranking_model + in self.model_mapping + ): + found_old_model = gen_config.rag_params.reranking_model + elif isinstance(gen_config, EmbedGenConfig): + if gen_config.embedding_model in self.model_mapping: + found_old_model = gen_config.embedding_model + + log_msg = ( + f" - Found '{found_old_model}' in column '{column['column_id']}' (table: {table_id})" + if found_old_model + else f" - Updating column '{column['column_id']}' in table '{table_id}'" + ) + logger.info(log_msg) + to_be_update[column["column_id"]] = updated_config + if len(to_be_update) > 0: + await self.update_column_gen_config( + project.id, + table_type, + table_id, + to_be_update, + ) + col_updated_count += len(to_be_update) + + if col_updated_count > 0: + self.updated_count += col_updated_count + logger.info( + f" Updated {col_updated_count} columns in {table_type.value} tables for this project." + ) + + async def run(self): + """Run the entire model ID update process.""" + mode = "DRY RUN" if self.dry_run else "UPDATE MODE" + logger.info(f"Starting model ID update in {mode}.") + logger.info( + f"Applying {len(self.model_mapping)} model mapping(s): {json.dumps(self.model_mapping, indent=2)}" + ) + + organizations = await self.get_all_organizations(self.organization_ids) + logger.info(f"Found {len(organizations)} organization(s) to process.") + + for organization in organizations: + await self.process_organization(organization) + + summary = "Dry run complete" if self.dry_run else "Update complete" + logger.success( + f"{summary}. Total columns affected across all organizations: {self.updated_count}" + ) + + +# ------------------------------------------------------------------ +# 3. CLI SUBCOMMANDS +# ------------------------------------------------------------------ + + +@app.command("one-to-one") +def cmd_one( + old_model_id: str = typer.Argument(..., help="The single old model ID to be replaced."), + new_model_id: str = typer.Argument(..., help="The single new model ID to use."), + dry_run: bool = _DRY_RUN_OPTION, + organizations: Optional[str] = _ORGANIZATIONS_OPTION, +): + """Replaces one model ID with another.""" + mapping = {old_model_id: new_model_id} + run_update(mapping, dry_run, organizations) + + +@app.command("many-to-one") +def cmd_many_to_one( + old_model_ids: List[str] = typer.Argument(..., help="A list of old model IDs to be replaced."), # noqa: B008 + new_model_id: str = typer.Option( + ..., "--to", "-t", help="The target model ID that will replace all old models. (Required)" + ), + dry_run: bool = _DRY_RUN_OPTION, + organizations: Optional[str] = _ORGANIZATIONS_OPTION, +): + """Maps many old model IDs to a single new model ID.""" + mapping = {old_id: new_model_id for old_id in old_model_ids} + run_update(mapping, dry_run, organizations) + + +@app.command("file") +def cmd_file( + mapping_file: pathlib.Path = typer.Argument( # noqa: B008 + ..., + help="Path to a JSON file with {old_id: new_id} mappings.", + exists=True, + readable=True, + dir_okay=False, + ), + dry_run: bool = _DRY_RUN_OPTION, + organizations: Optional[str] = _ORGANIZATIONS_OPTION, +): + """Replaces model IDs based on a JSON mapping file.""" + try: + mapping = json.loads(mapping_file.read_text()) + if not isinstance(mapping, dict): + raise TypeError("Mapping file must contain a JSON object (a dictionary).") + except Exception as e: + logger.error(f"Cannot read or parse mapping file '{mapping_file}': {type(e)}") + raise typer.Exit(code=1) from None + run_update(mapping, dry_run, organizations) + + +# ------------------------------------------------------------------ +# 4. SHARED RUNNER +# ------------------------------------------------------------------ + + +def run_update(mapping: Dict[str, str], dry_run: bool, org_string: Optional[str]): + """A central runner that validates and executes the update process.""" + + if not mapping: + logger.warning("The model mapping is empty. Nothing to do.") + return + + organization_ids = org_string.split(",") if org_string else None + + async def _inner(): + # 1. Validate models before doing anything else + if not await validate_models(mapping): + raise typer.Exit(code=1) + + # 2. Create and run the updater + updater = ModelIDUpdater(mapping, dry_run, organization_ids) + await updater.run() + + try: + asyncio.run(_inner()) + except Exception as e: + logger.error(f"An unexpected error occurred during the update process: {type(e)} {str(e)}") + sys.exit(1) + + +# ------------------------------------------------------------------ +# 5. SCRIPT ENTRYPOINT +# ------------------------------------------------------------------ + +if __name__ == "__main__": + app() diff --git a/services/api/Chat.md b/services/api/Chat.md new file mode 100644 index 0000000..b123097 --- /dev/null +++ b/services/api/Chat.md @@ -0,0 +1,71 @@ +# JamAI Chat + +# Chat Message Format + +A user message can include: + +- Text +- Images (zero or more) +- Audio (zero or more) +- Document (zero or more) + +User can also request for RAG to be used so that the model can get context from a Knowledge Table. In this case, the assistant's reply will have references attached. + +## RAG References + +RAG references contain the following data: + +- Search query (text sentence) used to retrieve the chunks/data +- A list of chunks/data, each containing: + - Title: Text sentence + - Text: Text sentences/paragraphs + - Page: Integer or null + - Chunk ID: UUID text-string, can be empty + - Context: + - Any arbitrary number of : + - Metadata: + - Any arbitrary number of : + - Project ID + - Knowledge Table ID (can be used together with Project ID to display a hyperlink to the user) + +For example: + +```json +{ + "search_query": "pet rabbit name", + "chunks": [ + { + "title": "Pet names", + "text": "My rabbit's name is Latte.", + "page": 1, + "chunk_id": "066a8a49-6dcc-764f-8000-a7bfc34f863c", + "context": { + "Colour": "White", + "Weight": "1 kg" + }, + "metadata": { + "bm25-score": 1.5, + "rrf-score": 0.8, + "project_id": "proj_f37ff1cf46aaa453143ca50b", + "table_id": "pet-names" + } + }, + { + "title": "Pet names", + "text": "My deer's name is Daisy.", + "page": null, + "chunk_id": "066a8a49-6dcc-764f-8000-a7bfc34f864c", + "context": { + "Colour": "Brown", + "Weight": "8 kg" + }, + "metadata": { + "bm25-score": 1.95, + "rrf-score": 0.6, + "project_id": "proj_f37ff1cf46aaa453143ca50b", + "table_id": "pet-names" + } + } + ] +} +``` diff --git a/services/api/README.md b/services/api/README.md index 4d73f61..e283eee 100644 --- a/services/api/README.md +++ b/services/api/README.md @@ -1,13 +1,76 @@ # JamAI Base API service -## Compiling Executable +## Note for VSCode Users -### Windows +In order for Ruff settings to apply correctly, you must open the repo folder directly via `Open Folder...` instead of as a `Workspace`. Workspace does not work correctly for some unknown reason. -1. Create fresh python environment: `conda create -n jamaiapi python=3.10`. -2. Activate the python environment: `conda activate jamaiapi`. -3. Remove any of the cloud modules in PowerShell: `.\scripts\remove_cloud_modules.ps1`. -4. Install JamAI Base Python SDK: `pip install .\clients\python` -5. Install api service: `cd services\api ; pip install -e .` -6. Install Pyinstaller: `pip install pyinstaller` -7. Create Pyinstaller executable: `pyinstaller api.spec` +## Getting Started + +1. Create an environment `.env` file. You can modify it from the provided `.env.example` file. +2. Start the services using Docker Compose. Depending on your needs, you can choose to either start everything or excluding the API server `owl` (for easier dev work, for example). + + - Launch all services + ```bash + docker compose -p jm --env-file .env -f docker/compose.dev.yml up --quiet-pull + ``` + - Launch all except `owl`, `frontend` + ```bash + docker compose -p jm --env-file .env -f docker/compose.dev.yml up --quiet-pull -d --scale owl=0 --scale frontend=0 + ``` + +3. If you choose to launch `owl` manually, then run these steps to setup your environment + + 1. Create a Python 3.12 environment and install `owl` (here we use [micromamba](https://mamba.readthedocs.io/en/latest/user_guide/micromamba.html) but you can use other tools such as [conda](https://conda.io/projects/conda/en/latest/user-guide/getting-started.html), virtualenv, etc): + + ```bash + micromamba create -n jamai312 python=3.12 -y + micromamba activate jamai312 + cd services/api + python -m pip install -e . + ``` + + 2. Uncomment the "Service connection" section of `.env.example` and copy them into your `.env` file. + 3. Start `owl`. + + ```bash + OWL_WORKERS=2 OTEL_PYTHON_FASTAPI_EXCLUDED_URLS="api/health" OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST='X-.*' python -m owl.entrypoints.api + + # Delete existing DB data, start owl + OWL_DB_RESET=1 OWL_WORKERS=2 OTEL_PYTHON_FASTAPI_EXCLUDED_URLS="api/health" OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST='X-.*' python -m owl.entrypoints.api + ``` + +4. To run Stripe tests: + 1. Add Stripe API keys into `.env`: + - `OWL_STRIPE_API_KEY` + - `OWL_STRIPE_PUBLISHABLE_KEY_TEST` + - `OWL_STRIPE_WEBHOOK_SECRET_TEST` + 2. Run Stripe event forwarding `stripe listen --forward-to localhost:6969/api/v2/organizations/webhooks/stripe --api-key ` + + + +> [!TIP] +> - You can launch the Docker services in background mode by appending `-d --wait` +> - You can rebuild the `owl` image by appending `--build --force-recreate owl` + + + +## Backend Dev Tips + +- How to run tests (can refer to `.github/workflows/ci.yml` for more info) + + 1. Launch services via `compose.dev.yml` by following the steps above + 2. `pytest services/api/tests` + + + +> [!TIP] +> - Run all tests except those that require on-prem setup: `pytest services/api/tests -m "not onprem"` +> - Run a specific test or a subset: `pytest services/api/tests -k ` + + + +- How to have your code reflected in the Docker environment + + 1. Launch services via `compose.dev.yml` by following the steps above + 2. Modify backend code + 3. Restart `owl` by issuing: `docker compose -p jm --env-file .env -f docker/compose.ci.yml restart owl` diff --git a/services/api/api.spec b/services/api/api.spec deleted file mode 100644 index c94c036..0000000 --- a/services/api/api.spec +++ /dev/null @@ -1,73 +0,0 @@ -# -*- mode: python ; coding: utf-8 -*- - -import glob -from pathlib import Path -from PyInstaller.utils.hooks import collect_all - -binaries_list = [] - -print(Path("src/owl/entrypoints/api.py").resolve().as_posix()) - -datas_list = [ - (Path("src/owl/entrypoints/api.py").resolve().as_posix(), 'owl/entrypoints'), - (Path("src/owl/configs/models_aipc.json").resolve().as_posix(), 'owl/configs'), -] - -# Add parquet and JSON files from templates directory -template_files = glob.glob("src/owl/templates/**/*.parquet", recursive=True) -template_files += glob.glob("src/owl/templates/**/*.json", recursive=True) -for file in template_files: - datas_list.append((file, str(Path(file).parent.relative_to("src")))) - -hiddenimports_list = ['multipart', "tiktoken_ext.openai_public", "tiktoken_ext"] - -def add_package(package_name): - datas, binaries, hiddenimports = collect_all(package_name) - datas_list.extend(datas) - binaries_list.extend(binaries) - hiddenimports_list.extend(hiddenimports) - -add_package('litellm') -# add_package('fastapi') - -a = Analysis( - ['src\\owl\\entrypoints\\api.py'], - pathex=[], - binaries=binaries_list, - datas=datas_list, - hiddenimports=hiddenimports_list, - hookspath=[], - hooksconfig={}, - runtime_hooks=[], - excludes=[], - noarchive=False, - optimize=0, -) -pyz = PYZ(a.pure) - -exe = EXE( - pyz, - a.scripts, - [], - exclude_binaries=True, - name='api', - debug=False, - bootloader_ignore_signals=False, - strip=False, - upx=True, - console=True, - disable_windowed_traceback=False, - argv_emulation=False, - target_arch=None, - codesign_identity=None, - entitlements_file=None, -) -coll = COLLECT( - exe, - a.binaries, - a.datas, - strip=False, - upx=True, - upx_exclude=[], - name='api', -) \ No newline at end of file diff --git a/services/api/pyproject.toml b/services/api/pyproject.toml index 1f45a29..3360db2 100644 --- a/services/api/pyproject.toml +++ b/services/api/pyproject.toml @@ -9,7 +9,7 @@ timeout = 90 log_cli = true asyncio_mode = "auto" # log_cli_level = "DEBUG" -addopts = "--cov=owl --doctest-modules" +addopts = "--doctest-modules -vv -ra --strict-markers --no-flaky-report --durations=200 --durations-min=0.05" testpaths = ["tests"] filterwarnings = [ "ignore::DeprecationWarning:tensorflow.*", @@ -17,6 +17,24 @@ filterwarnings = [ "ignore::DeprecationWarning:matplotlib.*", "ignore::DeprecationWarning:flatbuffers.*", ] +markers = [ + "oss: Cloud-only tests", + "cloud: Cloud-only tests", + "stripe: Stripe tests", +] + +# ----------------------------------------------------------------------------- +# Coverage configuration +# https://coverage.readthedocs.io/en + +[tool.coverage.run] +source = ["owl"] +relative_files = true +concurrency = ["multiprocessing", "thread", "greenlet"] +parallel = true + +[tool.coverage.paths] +source = ["src", "api/src"] # ----------------------------------------------------------------------------- # Ruff configuration @@ -57,7 +75,7 @@ unfixable = ["B"] "**/{tests,docs,tools}/*" = ["E402"] [tool.ruff.lint.isort] -known-first-party = ["jamaibase", "owl", "docio"] +known-first-party = ["jamaibase", "owl"] [tool.ruff.lint.flake8-bugbear] # Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. @@ -74,91 +92,140 @@ extend-immutable-calls = [ # https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html [build-system] -# setuptools-scm considers all files tracked by git to be data files -requires = ["setuptools>=62.0", "setuptools-scm"] +requires = ["setuptools>=62.0"] build-backend = "setuptools.build_meta" [project] name = "owl" description = "Owl: API server for JamAI Base." readme = "README.md" -requires-python = "~=3.10" +requires-python = "~=3.12.0" # keywords = ["one", "two"] -license = { text = "Apache 2.0" } +license = "Apache-2.0" classifiers = [ # https://pypi.org/classifiers/ "Development Status :: 3 - Alpha", "Programming Language :: Python :: 3 :: Only", "Intended Audience :: Information Technology", "Operating System :: Unix", ] +# Sort your dependencies https://sortmylist.com/ +# In general, for v1 and above, we pin to minor version using ~= dependencies = [ "aioboto3~=7.0.0", - "aiobotocore~=2.15.0", + "aiobotocore~=2.21.0", # Took long time to resolve "aiofiles~=24.1.0", - "authlib~=1.3.2", - "boto3~=1.35.7", - "celery~=5.4.0", - "duckdb~=1.1.3", - "fastapi[standard]~=0.115.2", - "filelock~=3.15.1", - "flower~=2.0.1", + "aiosqlite~=0.21.0", + "async-lru~=2.0.0", + "asyncpg~=0.30.0", + "authlib~=1.6.0", + "bm25s~=0.2.0", + "boto3==1.37.1", # Took long time to resolve + "celery~=5.5.0", + "clickhouse-connect~=0.8.0", + "cloudevents~=1.12.0", + "coredis~=4.24.0", + "coverage~=7.10.0", + "cryptography", + "duckdb~=1.3.0", + "email-validator~=2.2.0", + "fastapi[standard]~=0.115.0", + "flower~=2.0.0", "gunicorn~=22.0.0", "httpx~=0.27.0", "itsdangerous~=2.2.0", - "jamaibase>=0.2.1", - # lancedb 0.9.0 has issues with row deletion + "jamaibase>=0.4.1", "lancedb==0.12.0", - "langchain-community~=0.2.12", - "langchain~=0.2.14", - "litellm~=1.50.0", - "loguru~=0.7.2", - "natsort[fast]>=8.4.0", - "numpy>=1.26.4", - "openai>=1.51.0", - "openmeter~=1.0.0b89", - "orjson~=3.10.7", - "pandas~=2.2", - "Pillow~=10.4.0", + "langchain~=0.2.0", + "limits~=3.14.0", + "litellm~=1.75.0", + "loguru~=0.7.0", + "natsort[fast]~=8.4.0", + "nltk~=3.9.0", + "numpy>=1.26.0", + "openai~=1.99.0", + "opentelemetry-api~=1.36.0", + "opentelemetry-distro~=0.57b0", + "opentelemetry-exporter-otlp~=1.36.0", + "opentelemetry-instrumentation-aiohttp-client~=0.57b0", + "opentelemetry-instrumentation-aiohttp-server~=0.57b0", + "opentelemetry-instrumentation-asgi~=0.57b0", + "opentelemetry-instrumentation-asyncio~=0.57b0", + "opentelemetry-instrumentation-boto3sqs~=0.57b0", + "opentelemetry-instrumentation-botocore~=0.57b0", + "opentelemetry-instrumentation-celery~=0.57b0", + "opentelemetry-instrumentation-dbapi~=0.57b0", + "opentelemetry-instrumentation-fastapi~=0.57b0", + "opentelemetry-instrumentation-grpc~=0.57b0", + "opentelemetry-instrumentation-httpx~=0.57b0", + "opentelemetry-instrumentation-jinja2~=0.57b0", + "opentelemetry-instrumentation-logging~=0.57b0", + "opentelemetry-instrumentation-redis~=0.57b0", + "opentelemetry-instrumentation-requests~=0.57b0", + "opentelemetry-instrumentation-sqlalchemy~=0.57b0", + "opentelemetry-instrumentation-sqlite3~=0.57b0", + "opentelemetry-instrumentation-starlette~=0.57b0", + "opentelemetry-instrumentation-system-metrics~=0.57b0", + "opentelemetry-instrumentation-threading~=0.57b0", + "opentelemetry-instrumentation-tornado~=0.57b0", + "opentelemetry-instrumentation-tortoiseorm~=0.57b0", + "opentelemetry-instrumentation-urllib3~=0.57b0", + "opentelemetry-instrumentation-urllib~=0.57b0", + "opentelemetry-instrumentation-wsgi~=0.57b0", + "opentelemetry-sdk~=1.36.0", + "orjson~=3.11.0", + "pandas~=2.3.0", + "pdf2image~=1.17.0", + "pgvector~=0.3.0", + "Pillow~=11.3.0", + "pottery~=3.0.0", + "prometheus-api-client~=0.5.0", + "psutil~=7.0.0", + "psycopg[binary]~=3.2.0", + "pwdlib[argon2]>=0.2.0", "pyarrow~=17.0.0", - "pycryptodomex~=3.20.0", - "pydantic-settings~=2.4.0", - "pydantic[email,timezone]~=2.8.2", - "pydub~=0.25.1", - "pyjwt~=2.9.0", - # pylance 0.13.0 has issues with row deletion + "pycountry~=24.6.0", + "pycryptodomex~=3.23.0", + "pydantic-extra-types~=2.10.0", + "pydantic-settings~=2.10.0", + "pydantic[email,timezone]~=2.11.0", + "pydub~=0.25.0", + "pyjwt~=2.10.0", "pylance==0.16.0", - "python-multipart~=0.0.9", - "redis[hiredis]~=5.0.8", - "SQLAlchemy>=2.0", - "sqlmodel~=0.0.21", - "srsly~=2.4.8", - # starlette 0.38.3 and 0.38.4 seem to have issues with background tasks - "starlette~=0.41.3", + "python-multipart~=0.0.20", + "redis[hiredis]~=5.3.0", + "SQLAlchemy~=2.0.0", + "sqlmodel~=0.0.20", + "sqlparse~=0.5.0", + "starlette~=0.41.0", "stripe~=9.12.0", + "tabulate~=0.9.0", "tantivy~=0.22.0", "tenacity~=8.5.0", "tiktoken~=0.7.0", - "toml~=0.10.2", - "tqdm~=4.66.5", - "typer[all]~=0.12.4", - "typing_extensions>=4.12.2", - "unstructured-client @ git+https://github.com/EmbeddedLLM/unstructured-python-client.git@fix-nested-asyncio-conflict-with-uvloop#egg=unstructured-client", + "toml~=0.10.0", + "tqdm~=4.67.0", + "typer~=0.17.0", + "typing_extensions~=4.14.0", "uuid-utils~=0.9.0", "uuid7~=0.1.0", # uvicorn 0.29.x shutdown seems unclean and 0.30.x child process sometimes dies - "uvicorn[standard]~=0.28.1", -] # Sort your dependencies https://sortmylist.com/ + "uvicorn[standard]~=0.28.0", + "xmltodict~=0.14.0", +] dynamic = ["version"] [project.optional-dependencies] -lint = ["ruff~=0.6.1"] +lint = ["ruff~=0.12.9"] test = [ - "flaky~=3.8.1", - "mypy~=1.11.1", + "flaky~=3.8", + "freezegun~=1.5", + "junitparser", + "locust~=2.36", + "mcp[cli]~=1.12", + "mypy~=1.11", "pytest-asyncio>=0.23.8", - "pytest-cov~=5.0.0", - "pytest-timeout>=2.3.1", - "pytest~=8.3.2", + "pytest-timeout~=2.3", + "pytest~=8.3", ] docs = [ "furo~=2024.8.6", # Sphinx theme (nice looking, with dark mode) @@ -186,4 +253,4 @@ version = { attr = "owl.version.__version__" } where = ["src"] [tool.setuptools.package-data] -owl = ["**/*.json", "**/*.parquet"] +owl = ["**/*.json", "**/*.parquet", "**/*.ttf"] diff --git a/services/api/scripts/recreate_template.py b/services/api/scripts/recreate_template.py new file mode 100644 index 0000000..df990e0 --- /dev/null +++ b/services/api/scripts/recreate_template.py @@ -0,0 +1,83 @@ +from contextlib import contextmanager +from typing import Generator + +from sqlalchemy import NullPool +from sqlmodel import Session, create_engine, delete, select, text + +from owl.db.models import ( + Organization, + OrgMember, + Project, + ProjectMember, +) + + +@contextmanager +def sync_session() -> Generator[Session, None, None]: + engine = create_engine( + "postgresql+psycopg://:@/jamaibase_owl", + poolclass=NullPool, + ) + with Session(engine) as session: + yield session + + +def main(): + template_id = "template" + template_owner = "github|16820751" + with sync_session() as sess: + # Re-assign owner + org = sess.get(Organization, template_id) + org.owner = template_owner + org.created_by = template_owner + sess.add(org) + sess.commit() + # Re-build template membership + sess.exec(delete(OrgMember).where(OrgMember.organization_id == template_id)) + sys_members = sess.exec(select(OrgMember).where(OrgMember.organization_id == "0")).all() + for m in sys_members: + sess.add(OrgMember(user_id=m.user_id, organization_id=template_id, role=m.role)) + sess.commit() + + # Get list of orphaned Gen Table schemas + orphaned = sess.exec( + text(""" + SELECT + s.schema_name + FROM + information_schema.schemata s + WHERE + ( + s.schema_name LIKE 'proj_%_action' OR + s.schema_name LIKE 'proj_%_knowledge' OR + s.schema_name LIKE 'proj_%_chat' + ) + AND NOT EXISTS ( + -- Check if a project exists with an ID matching the extracted identifier + SELECT 1 + FROM jamai."Project" p + WHERE p.id = substring(s.schema_name from '(proj_[^_]+)_') + ) + ORDER BY + s.schema_name; + """) + ).all() + project_ids = list({"_".join(o[0].split("_")[:2]) for o in orphaned}) + # Re-create projects + for project_id in project_ids: + sess.add( + Project( + id=project_id, + name=project_id, + organization_id=template_id, + created_by=template_owner, + owner=template_owner, + ) + ) + sess.commit() + for m in sys_members: + sess.add(ProjectMember(user_id=m.user_id, project_id=project_id, role=m.role)) + sess.commit() + + +main() diff --git a/services/api/src/owl/__init__.py b/services/api/src/owl/__init__.py index 8c77e2c..ff2110e 100644 --- a/services/api/src/owl/__init__.py +++ b/services/api/src/owl/__init__.py @@ -1,6 +1,3 @@ -from loguru import logger - from owl.version import __version__ -logger.disable("owl") __all__ = ["__version__"] diff --git a/services/api/src/owl/assets/Roboto-Regular.ttf b/services/api/src/owl/assets/Roboto-Regular.ttf new file mode 100644 index 0000000..7e3bb2f Binary files /dev/null and b/services/api/src/owl/assets/Roboto-Regular.ttf differ diff --git a/services/api/src/owl/assets/icons/csv.webp b/services/api/src/owl/assets/icons/csv.webp new file mode 100644 index 0000000..1ea38e7 Binary files /dev/null and b/services/api/src/owl/assets/icons/csv.webp differ diff --git a/services/api/src/owl/assets/icons/docx.webp b/services/api/src/owl/assets/icons/docx.webp new file mode 100644 index 0000000..6aec104 Binary files /dev/null and b/services/api/src/owl/assets/icons/docx.webp differ diff --git a/services/api/src/owl/assets/icons/html.webp b/services/api/src/owl/assets/icons/html.webp new file mode 100644 index 0000000..c7a9e8d Binary files /dev/null and b/services/api/src/owl/assets/icons/html.webp differ diff --git a/services/api/src/owl/assets/icons/jpg.webp b/services/api/src/owl/assets/icons/jpg.webp new file mode 100644 index 0000000..b0aa542 Binary files /dev/null and b/services/api/src/owl/assets/icons/jpg.webp differ diff --git a/services/api/src/owl/assets/icons/json.webp b/services/api/src/owl/assets/icons/json.webp new file mode 100644 index 0000000..db20691 Binary files /dev/null and b/services/api/src/owl/assets/icons/json.webp differ diff --git a/services/api/src/owl/assets/icons/jsonl.webp b/services/api/src/owl/assets/icons/jsonl.webp new file mode 100644 index 0000000..249257f Binary files /dev/null and b/services/api/src/owl/assets/icons/jsonl.webp differ diff --git a/services/api/src/owl/assets/icons/md.webp b/services/api/src/owl/assets/icons/md.webp new file mode 100644 index 0000000..3065acb Binary files /dev/null and b/services/api/src/owl/assets/icons/md.webp differ diff --git a/services/api/src/owl/assets/icons/pdf.webp b/services/api/src/owl/assets/icons/pdf.webp new file mode 100644 index 0000000..1e3f837 Binary files /dev/null and b/services/api/src/owl/assets/icons/pdf.webp differ diff --git a/services/api/src/owl/assets/icons/pptx.webp b/services/api/src/owl/assets/icons/pptx.webp new file mode 100644 index 0000000..efefc92 Binary files /dev/null and b/services/api/src/owl/assets/icons/pptx.webp differ diff --git a/services/api/src/owl/assets/icons/tsv.webp b/services/api/src/owl/assets/icons/tsv.webp new file mode 100644 index 0000000..9cb4b33 Binary files /dev/null and b/services/api/src/owl/assets/icons/tsv.webp differ diff --git a/services/api/src/owl/assets/icons/txt.webp b/services/api/src/owl/assets/icons/txt.webp new file mode 100644 index 0000000..5837468 Binary files /dev/null and b/services/api/src/owl/assets/icons/txt.webp differ diff --git a/services/api/src/owl/assets/icons/xlsx.webp b/services/api/src/owl/assets/icons/xlsx.webp new file mode 100644 index 0000000..d7babe2 Binary files /dev/null and b/services/api/src/owl/assets/icons/xlsx.webp differ diff --git a/services/api/src/owl/assets/icons/xml.webp b/services/api/src/owl/assets/icons/xml.webp new file mode 100644 index 0000000..95b567c Binary files /dev/null and b/services/api/src/owl/assets/icons/xml.webp differ diff --git a/services/api/src/owl/billing.py b/services/api/src/owl/billing.py deleted file mode 100644 index 788dd7f..0000000 --- a/services/api/src/owl/billing.py +++ /dev/null @@ -1,663 +0,0 @@ -from collections import defaultdict -from datetime import datetime, timedelta, timezone -from time import perf_counter - -import stripe -from cloudevents.conversion import to_dict -from cloudevents.http import CloudEvent -from fastapi import Request -from loguru import logger -from openmeter.aio import Client as OpenMeterAsyncClient - -from jamaibase import JamAIAsync -from jamaibase.exceptions import InsufficientCreditsError -from jamaibase.protocol import EventCreate, OrganizationRead -from owl.configs.manager import CONFIG, ENV_CONFIG, ProductType -from owl.db.gen_table import GenerativeTable -from owl.protocol import ( - EmbeddingModelConfig, - LLMGenConfig, - LLMModelConfig, - RerankingModelConfig, - UserAgent, -) -from owl.utils import uuid7_str - -if ENV_CONFIG.stripe_api_key_plain.strip() == "": - STRIPE_CLIENT = None -else: - STRIPE_CLIENT = stripe.StripeClient( - api_key=ENV_CONFIG.stripe_api_key_plain, - http_client=stripe.RequestsClient(), - max_network_retries=5, - ) -if ENV_CONFIG.openmeter_api_key_plain.strip() == "": - OPENMETER_CLIENT = None -else: - # Async client can be initialized by importing the `Client` from `openmeter.aio` - OPENMETER_CLIENT = OpenMeterAsyncClient( - endpoint="https://openmeter.cloud", - headers={ - "Accept": "application/json", - "Authorization": f"Bearer {ENV_CONFIG.openmeter_api_key_plain}", - }, - retry_status=3, - retry_total=5, - ) -CLIENT = JamAIAsync(token=ENV_CONFIG.service_key_plain) - - -class BillingManager: - def __init__( - self, - *, - organization: OrganizationRead | None = None, - project_id: str = "", - user_id: str = "", - openmeter_client: OpenMeterAsyncClient = OPENMETER_CLIENT, - client: JamAIAsync | None = CLIENT, - request: Request | None = None, - ) -> None: - self.org = organization - self.project_id = project_id - self.user_id = user_id - self.openmeter_client = openmeter_client - self.client = client - self.request = request - if request is None: - self.user_agent = UserAgent(is_browser=False, agent="") - else: - self.user_agent: UserAgent = request.state.user_agent - self.is_oss = ENV_CONFIG.is_oss - self._events = [] - self._deltas = defaultdict(float) - self._values = defaultdict(float) - self._cost = 0.0 - - @property - def total_balance(self) -> float: - if self.is_oss or self.org is None: - return 0.0 - return self.org.credit + self.org.credit_grant - - def _compute_cost( - self, - product_type: ProductType, - remaining_quota: float, - usage: float, - ) -> float: - if self.org is None: - return 0.0 - prices = CONFIG.get_pricing() - try: - product = prices.plans[self.org.tier].products[product_type] - except Exception as e: - logger.warning(f"Failed to fetch product: {e}") - return 0.0 - cost = 0.0 - remaining_usage = (usage - remaining_quota) if remaining_quota > 0 else usage - for tier in product.tiers: - if remaining_usage <= 0: - break - if tier.up_to is not None and remaining_usage > tier.up_to: - tier_usage = tier.up_to - else: - tier_usage = remaining_usage - cost += tier_usage * float(tier.unit_amount_decimal) - remaining_usage -= tier_usage - if cost > 0: - self._cost += cost - self._events += [ - CloudEvent( - attributes={ - "type": "spent", - "source": "owl", - "subject": self.org.openmeter_id, - }, - data={ - "spent_usd": cost, - "category": product_type, - "org_id": self.org.id, - "project_id": self.project_id, - "user_id": self.user_id, - "agent": self.user_agent.agent, - "agent_version": self.user_agent.agent_version, - "architecture": self.user_agent.architecture, - "system": self.user_agent.system, - "system_version": self.user_agent.system_version, - "language": self.user_agent.language, - "language_version": self.user_agent.language_version, - }, - ) - ] - return cost - - async def process_all(self) -> None: - try: - if self.is_oss or self.org is None: - return - # No billing events for admin API - if self.request is not None and "api/admin" in self.request.url.path: - return - - if self.request is not None and self.request.scope.get("route", None): - # https://stackoverflow.com/a/72239186 - path = self.request.scope.get("root_path", "") + self.request.scope["route"].path - self._events += [ - CloudEvent( - attributes={ - "type": "request_count", - "source": "owl", - "subject": self.org.openmeter_id, - }, - data={ - "method": self.request.method, - "path": path, - "org_id": self.org.id, - "project_id": self.project_id, - "user_id": self.user_id, - "agent": self.user_agent.agent, - "agent_version": self.user_agent.agent_version, - "architecture": self.user_agent.architecture, - "system": self.user_agent.system, - "system_version": self.user_agent.system_version, - "language": self.user_agent.language, - "language_version": self.user_agent.language_version, - }, - ), - ] - - # Process credits - # Deduct from credit_grant first - if self.org.credit_grant >= self._cost: - credit_deduct = 0.0 - credit_grant_deduct = self._cost - else: - credit_deduct = self._cost - self.org.credit_grant - credit_grant_deduct = self.org.credit_grant - if credit_deduct > 0: - self._deltas[ProductType.CREDIT] -= credit_deduct - if credit_grant_deduct > 0: - self._deltas[ProductType.CREDIT_GRANT] -= credit_grant_deduct - # Update records - if len(self._deltas) > 0 or len(self._values) > 0: - await self.client.admin.backend.add_event( - EventCreate( - id=uuid7_str(), - organization_id=self.org.id, - deltas=self._deltas, - values=self._values, - ) - ) - # Send OpenMeter events - if ( - self.openmeter_client is not None - and self.org.openmeter_id is not None - and len(self._events) > 0 - ): - t0 = perf_counter() - await self.openmeter_client.ingest_events([to_dict(e) for e in self._events]) - logger.info( - ( - f"{self.request.state.id} - OpenMeter events ingestion: " - if self.request is not None - else "OpenMeter events ingestion: " - ) - + ( - f"t={(perf_counter() - t0) * 1e3:,.2f} ms " - f"num_events={len(self._events):,d}" - ) - ) - except Exception as e: - logger.exception(f"Failed to process billing events due to error: {e}") - - def _quota_ok( - self, - quota: float, - usage: float, - provider: str | None = None, - ): - # OSS has no billing - if self.is_oss: - return True - # If there is credit left - if self.total_balance > 0: - return True - # If user provides their own API key - if self.org.external_keys.get(provider, "").strip(): - return True - # If it's a ELLM model and there is quota left - has_quota = (quota - usage) > 0 - if provider is None: - return has_quota - elif provider.startswith("ellm") and has_quota: - return True - return False - - # --- LLM Usage --- # - - def check_llm_quota(self, model_id: str) -> None: - if self.is_oss or self.org is None: - return - provider = model_id.split("/")[0] - if self._quota_ok( - self.org.llm_tokens_quota_mtok, self.org.llm_tokens_usage_mtok, provider - ): - return - # Return different error message depending if request came from browser - if self.request is not None and self.user_agent.is_browser: - model_id = self.request.state.all_models.get_llm_model_info(model_id).name - raise InsufficientCreditsError( - f"Insufficient LLM token quota or credits for model: {model_id}" - ) - - def check_gen_table_llm_quota( - self, - table: GenerativeTable, - table_id: str, - ) -> None: - if self.is_oss or self.org is None: - return - with table.create_session() as session: - meta = table.open_meta(session, table_id) - for c in meta.cols_schema: - if not isinstance(c.gen_config, LLMGenConfig): - continue - self.check_llm_quota(c.gen_config.model) - - def create_llm_events( - self, - model: str, - input_tokens: int, - output_tokens: int, - ) -> None: - if self.is_oss or self.org is None: - return - if input_tokens < 1: - logger.warning(f"Input token count should be > 0, received: {input_tokens}") - input_tokens = 1 - if output_tokens < 1: - logger.warning(f"Output token count should be > 0, received: {output_tokens}") - output_tokens = 1 - self._events += [ - CloudEvent( - attributes={ - "type": ProductType.LLM_TOKENS, - "source": "owl", - "subject": self.org.openmeter_id, - }, - data={ - "model": model, - "tokens": v, - "type": t, - "org_id": self.org.id, - "project_id": self.project_id, - "user_id": self.user_id, - "agent": self.user_agent.agent, - "agent_version": self.user_agent.agent_version, - "architecture": self.user_agent.architecture, - "system": self.user_agent.system, - "system_version": self.user_agent.system_version, - "language": self.user_agent.language, - "language_version": self.user_agent.language_version, - }, - ) - for t, v in [("input", input_tokens), ("output", output_tokens)] - ] - provider = model.split("/")[0] - model_config: LLMModelConfig = self.request.state.all_models.get_llm_model_info(model) - input_cost_per_mtoken = model_config.input_cost_per_mtoken - output_cost_per_mtoken = model_config.output_cost_per_mtoken - llm_credit_mtok = max(0.0, self.org.llm_tokens_quota_mtok - self.org.llm_tokens_usage_mtok) - input_mtoken = input_tokens / 1e6 - output_mtoken = output_tokens / 1e6 - - if provider.startswith("ellm"): - self._deltas[ProductType.LLM_TOKENS] += input_mtoken + output_mtoken - if provider.startswith("ellm") and llm_credit_mtok > 0: - # Deduct input tokens first - if llm_credit_mtok >= input_mtoken: - input_mtoken = 0.0 - output_mtoken = max(0.0, output_mtoken - llm_credit_mtok) - else: - input_mtoken = max(0.0, input_mtoken - llm_credit_mtok) - cost = input_cost_per_mtoken * input_mtoken + output_cost_per_mtoken * output_mtoken - elif self.org.external_keys.get(provider, "").strip(): - cost = 0.0 - else: - cost = input_cost_per_mtoken * input_mtoken + output_cost_per_mtoken * output_mtoken - - if cost > 0: - self._cost += cost - self._events += [ - CloudEvent( - attributes={ - "type": "spent", - "source": "owl", - "subject": self.org.openmeter_id, - }, - data={ - "spent_usd": cost, - "category": ProductType.LLM_TOKENS, - "org_id": self.org.id, - "project_id": self.project_id, - "user_id": self.user_id, - "agent": self.user_agent.agent, - "agent_version": self.user_agent.agent_version, - "architecture": self.user_agent.architecture, - "system": self.user_agent.system, - "system_version": self.user_agent.system_version, - "language": self.user_agent.language, - "language_version": self.user_agent.language_version, - }, - ) - ] - - # --- Embedding Usage --- # - - def check_embedding_quota(self, model_id: str) -> None: - if self.is_oss or self.org is None: - return - provider = model_id.split("/")[0] - if self._quota_ok( - self.org.embedding_tokens_quota_mtok, self.org.embedding_tokens_usage_mtok, provider - ): - return - # Return different error message depending if request came from browser - if self.request is not None and self.user_agent.is_browser: - model_id = self.request.state.all_models.get_embed_model_info(model_id).name - raise InsufficientCreditsError( - f"Insufficient Embedding token quota or credits for model: {model_id}" - ) - - def create_embedding_events( - self, - model: str, - token_usage: int, - ) -> None: - if self.is_oss or self.org is None: - return - if token_usage < 1: - logger.warning(f"Token usage should be >= 1, received: {token_usage}") - token_usage = 1 - # Create the CloudEvent for embedding token usage - self._events += [ - CloudEvent( - attributes={ - "type": ProductType.EMBEDDING_TOKENS, - "source": "owl", - "subject": self.org.openmeter_id, - }, - data={ - "model": model, - "tokens": token_usage, - "org_id": self.org.id, - "project_id": self.project_id, - "user_id": self.user_id, - "agent": self.user_agent.agent, - "agent_version": self.user_agent.agent_version, - "architecture": self.user_agent.architecture, - "system": self.user_agent.system, - "system_version": self.user_agent.system_version, - "language": self.user_agent.language, - "language_version": self.user_agent.language_version, - }, - ) - ] - - # Determine the provider from the model string - provider = model.split("/")[0] - # Get tokens in per mtoken unit - model_config: EmbeddingModelConfig = self.request.state.all_models.get_embed_model_info( - model - ) - cost_per_mtoken = model_config.cost_per_mtoken - embedding_credit_mtok = max( - 0.0, self.org.embedding_tokens_quota_mtok - self.org.embedding_tokens_usage_mtok - ) - token_usage_mtok = token_usage / 1e6 - - if provider.startswith("ellm"): - self._deltas[ProductType.EMBEDDING_TOKENS] += token_usage_mtok - - if provider.startswith("ellm") and embedding_credit_mtok > 0: - cost = max(0.0, token_usage_mtok - embedding_credit_mtok) * cost_per_mtoken - elif self.org.external_keys.get(provider, "").strip(): - cost = 0.0 - else: - cost = token_usage_mtok * cost_per_mtoken - - # If there is a cost, update the total cost and create a CloudEvent for the spending - if cost > 0: - self._cost += cost - self._events += [ - CloudEvent( - attributes={ - "type": "spent", - "source": "owl", - "subject": self.org.openmeter_id, - }, - data={ - "spent_usd": cost, - "category": ProductType.EMBEDDING_TOKENS, - "org_id": self.org.id, - "project_id": self.project_id, - "user_id": self.user_id, - "agent": self.user_agent.agent, - "agent_version": self.user_agent.agent_version, - "architecture": self.user_agent.architecture, - "system": self.user_agent.system, - "system_version": self.user_agent.system_version, - "language": self.user_agent.language, - "language_version": self.user_agent.language_version, - }, - ) - ] - - # --- Reranker Usage --- # - - def check_reranker_quota(self, model_id: str) -> None: - if self.is_oss or self.org is None: - return - provider = model_id.split("/")[0] - if self._quota_ok( - self.org.reranker_quota_ksearch, self.org.reranker_usage_ksearch, provider - ): - return - # Return different error message depending if request came from browser - if self.request is not None and self.user_agent.is_browser: - model_id = self.request.state.all_models.get_rerank_model_info(model_id).name - raise InsufficientCreditsError( - f"Insufficient Reranker search quota or credits for model: {model_id}" - ) - - def create_reranker_events( - self, - model: str, - num_searches: int, - ) -> None: - if self.is_oss or self.org is None: - return - if num_searches < 1: - logger.warning(f"Number of searches should be >= 1, received: {num_searches}") - num_searches = 1 - - # Create the CloudEvent for rerank search usage - self._events += [ - CloudEvent( - attributes={ - "type": ProductType.RERANKER_SEARCHES, - "source": "owl", - "subject": self.org.openmeter_id, - }, - data={ - "model": model, - "searches": num_searches, - "org_id": self.org.id, - "project_id": self.project_id, - "user_id": self.user_id, - "agent": self.user_agent.agent, - "agent_version": self.user_agent.agent_version, - "architecture": self.user_agent.architecture, - "system": self.user_agent.system, - "system_version": self.user_agent.system_version, - "language": self.user_agent.language, - "language_version": self.user_agent.language_version, - }, - ) - ] - - # Determine the provider from the model string - provider = model.split("/")[0] - - # Get search cost per ksearch unit - model_config: RerankingModelConfig = self.request.state.all_models.get_rerank_model_info( - model - ) - cost_per_ksearch = model_config.cost_per_ksearch - - remaining_rerank_ksearches = ( - self.org.reranker_quota_ksearch - self.org.reranker_usage_ksearch - ) - num_ksearches = num_searches / 1e3 - - if provider.startswith("ellm"): - self._deltas[ProductType.RERANKER_SEARCHES] += num_ksearches - - if provider.startswith("ellm") and remaining_rerank_ksearches > 0: - cost = max(0.0, num_ksearches - remaining_rerank_ksearches) * cost_per_ksearch - elif self.org.external_keys.get(provider, "").strip(): - cost = 0.0 - else: - cost = cost_per_ksearch * num_ksearches - - # If there is a cost, update the total cost and create a CloudEvent for the spending - if cost > 0: - self._cost += cost - self._events += [ - CloudEvent( - attributes={ - "type": "spent", - "source": "owl", - "subject": self.org.openmeter_id, - }, - data={ - "spent_usd": cost, - "category": ProductType.RERANKER_SEARCHES, - "org_id": self.org.id, - "project_id": self.project_id, - "user_id": self.user_id, - "agent": self.user_agent.agent, - "agent_version": self.user_agent.agent_version, - "architecture": self.user_agent.architecture, - "system": self.user_agent.system, - "system_version": self.user_agent.system_version, - "language": self.user_agent.language, - "language_version": self.user_agent.language_version, - }, - ) - ] - - # --- Egress Usage --- # - - def check_egress_quota(self) -> None: - if self.is_oss or self.org is None: - return - if self._quota_ok(self.org.egress_quota_gib, self.org.egress_usage_gib): - return - raise InsufficientCreditsError("Insufficient egress quota or credits.") - - def create_egress_events(self, amount_gb: float) -> None: - if self.is_oss or self.org is None: - return - if amount_gb <= 0: - logger.warning(f"Egress amount should be > 0, received: {amount_gb}") - return - self._events += [ - CloudEvent( - attributes={ - "type": "bandwidth", - "source": "owl", - "subject": self.org.openmeter_id, - }, - data={ - "amount_gb": amount_gb, - "type": ProductType.EGRESS, - "org_id": self.org.id, - "project_id": self.project_id, - "user_id": self.user_id, - "agent": self.user_agent.agent, - "agent_version": self.user_agent.agent_version, - "architecture": self.user_agent.architecture, - "system": self.user_agent.system, - "system_version": self.user_agent.system_version, - "language": self.user_agent.language, - "language_version": self.user_agent.language_version, - }, - ) - ] - self._compute_cost( - ProductType.EGRESS, self.org.egress_quota_gib - self.org.egress_usage_gib, amount_gb - ) - self._deltas[ProductType.EGRESS] += amount_gb - - # --- Storage Usage --- # - - def check_db_storage_quota(self) -> None: - if self.is_oss or self.org is None: - return - if self._quota_ok(self.org.db_quota_gib, self.org.db_usage_gib): - return - raise InsufficientCreditsError("Insufficient DB storage quota.") - - def check_file_storage_quota(self) -> None: - if self.is_oss or self.org is None: - return - if self._quota_ok(self.org.file_quota_gib, self.org.file_usage_gib): - return - raise InsufficientCreditsError("Insufficient file storage quota.") - - def create_storage_events(self, db_usage_gib: float, file_usage_gib: float) -> None: - if self.is_oss or self.org is None: - return - if db_usage_gib <= 0: - logger.warning(f"DB storage usage should be > 0, received: {db_usage_gib}") - return - if file_usage_gib <= 0: - logger.warning(f"File storage usage should be > 0, received: {file_usage_gib}") - return - # Wait for at least `min_wait` before recomputing - now = datetime.now(timezone.utc) - min_wait = timedelta(minutes=max(5.0, ENV_CONFIG.owl_compute_storage_period_min)) - # Wait because quota refresh might be called a few times - quota_reset_at = datetime.fromisoformat(self.org.quota_reset_at) - if (now - quota_reset_at) <= min_wait: - return - self._events += [ - CloudEvent( - attributes={ - "type": "storage", - "source": "owl", - "subject": self.org.openmeter_id, - }, - data={ - "amount_gb": db_usage_gib, - "type": "db", - "org_id": self.org.id, - }, - ), - CloudEvent( - attributes={ - "type": "storage", - "source": "owl", - "subject": self.org.openmeter_id, - }, - data={ - "amount_gb": file_usage_gib, - "type": "file", - "org_id": self.org.id, - }, - ), - ] - self._values[ProductType.DB_STORAGE] = db_usage_gib - self._values[ProductType.FILE_STORAGE] = file_usage_gib diff --git a/services/api/src/owl/client.py b/services/api/src/owl/client.py new file mode 100644 index 0000000..de9c27f --- /dev/null +++ b/services/api/src/owl/client.py @@ -0,0 +1,125 @@ +import base64 +from typing import Any, Type + +import httpx +from fastapi import FastAPI +from pydantic import BaseModel + +from jamaibase.client import _ClientAsync +from owl.configs import ENV_CONFIG +from owl.version import __version__ + + +class VictoriaMetricsAsync(_ClientAsync): + def __init__( + self, + api_base: str = f"http://{ENV_CONFIG.victoria_metrics_host}:{ENV_CONFIG.victoria_metrics_port}", + user: str = ENV_CONFIG.victoria_metrics_user, + password: str = ENV_CONFIG.victoria_metrics_password_plain, + timeout: float | None = 10.0, + ) -> None: + """ + Creates an async Emu client. + + Args: + api_base (str, optional): The base URL for the API. + Defaults to "http://{ENV_CONFIG.victoria_metrics_host}:{ENV_CONFIG.victoria_metrics_port}". + user (str, optional): Victoria Metrics Basic authentication Username. + Defaults to ENV_CONFIG.victoria_metrics_user. + password (str, optional): Victoria Metrics Basic authentication Password. + Defaults to ENV_CONFIG.victoria_metrics_password_plain. + timeout (float | None, optional): The timeout to use when sending requests. + Defaults to 10 seconds. + """ + http_client = httpx.AsyncClient( + timeout=timeout, + transport=httpx.AsyncHTTPTransport(retries=3), + ) + + def basic_auth(username, password): + token = base64.b64encode(f"{username}:{password}".encode("utf-8")).decode("ascii") + return f"Basic {token}" + + headers = {"Authorization": basic_auth(user, password)} + kwargs = dict( + user_id="", + project_id="", + token="", + api_base=api_base, + headers=headers, + http_client=http_client, + timeout=timeout, + ) + super().__init__(**kwargs) + + async def _fetch_victoria_metrics( + self, endpoint: str, params: dict | None = None + ) -> httpx.Response: + """ + Send a GET request to the specified VictoriaMetrics API endpoint. + + Args: + endpoint (str): The API endpoint to send the request to. + params (dict | None, optional): Query parameters to include in the request. + + Returns: + httpx.Response | None: The HTTP response object if the request is successful, or None if the request fails. + """ + return await self._get(endpoint, params=params) + + +class JamaiASGIAsync(_ClientAsync): + def __init__( + self, + app: FastAPI, + timeout: float | None = None, + ) -> None: + """ + Creates an async Owl ASGI client. + + Args: + timeout (float | None, optional): The timeout to use when sending requests. + Defaults to None. + """ + super().__init__( + user_id="", + project_id="", + token="", + api_base="", + headers=None, + http_client=httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://apiserver", + timeout=timeout, + ), + timeout=timeout, + ) + + async def request( + self, + method: str, + endpoint: str, + *, + headers: dict[str, Any] | None = None, + params: dict[str, Any] | BaseModel | None = None, + body: BaseModel | None = None, + response_model: Type[BaseModel] | None = None, + timeout: float | None = None, + **kwargs, + ) -> httpx.Response | BaseModel: + if headers is None: + headers = {} + headers["User-Agent"] = headers.get("User-Agent", f"MCP-Server/{__version__}") + return await self._request( + method=method, + address="", + endpoint=endpoint, + headers=headers, + params=params, + body=body, + response_model=response_model, + timeout=timeout, + ignore_code=None, + process_body_kwargs=None, + **kwargs, + ) diff --git a/services/api/src/owl/configs/__init__.py b/services/api/src/owl/configs/__init__.py index e69de29..6ac6281 100644 --- a/services/api/src/owl/configs/__init__.py +++ b/services/api/src/owl/configs/__init__.py @@ -0,0 +1,35 @@ +import os +from os.path import join + +from celery import Celery + +from owl.utils.cache import Cache + +try: + from owl.configs.cloud import EnvConfig +except ImportError: + from owl.configs.oss import EnvConfig + +ENV_CONFIG = EnvConfig() +CACHE = Cache( + redis_url=f"redis://{ENV_CONFIG.redis_host}:{ENV_CONFIG.redis_port}/1", + clickhouse_buffer_key=ENV_CONFIG.clickhouse_buffer_key, +) + + +celery_app = Celery("tasks", broker=f"redis://{ENV_CONFIG.redis_host}:{ENV_CONFIG.redis_port}/0") + +# Configure Celery +CELERY_SCHEDULER_DB = "_scheduler" +os.makedirs(CELERY_SCHEDULER_DB, exist_ok=True) +celery_app.conf.update( + result_backend=f"redis://{ENV_CONFIG.redis_host}:{ENV_CONFIG.redis_port}/0", + task_serializer="json", + accept_content=["json"], + result_serializer="json", + timezone="UTC", + enable_utc=True, + result_expires=36000, + # TODO: Update to use DB via sqlalchemy-celery-beat + beat_schedule_filename=join(CELERY_SCHEDULER_DB, "celerybeat-schedule"), +) diff --git a/services/api/src/owl/configs/manager.py b/services/api/src/owl/configs/manager.py deleted file mode 100644 index a434edc..0000000 --- a/services/api/src/owl/configs/manager.py +++ /dev/null @@ -1,553 +0,0 @@ -import os -from decimal import Decimal -from enum import Enum -from functools import cached_property, lru_cache -from os.path import abspath -from pathlib import Path -from typing import Annotated, Any - -import redis -from loguru import logger -from pydantic import BaseModel, Field, SecretStr, computed_field, model_validator -from pydantic_settings import BaseSettings, SettingsConfigDict -from redis.backoff import ExponentialBackoff -from redis.exceptions import ConnectionError, TimeoutError -from redis.retry import Retry - -from owl.protocol import ( - EXAMPLE_CHAT_MODEL_IDS, - EXAMPLE_EMBEDDING_MODEL_IDS, - EXAMPLE_RERANKING_MODEL_IDS, - ModelListConfig, -) - -CURR_DIR = Path(__file__).resolve().parent - - -class EnvConfig(BaseSettings): - model_config = SettingsConfigDict( - # env_prefix="owl_", # TODO: Enable this - env_file=".env", - env_file_encoding="utf-8", - extra="ignore", - cli_parse_args=False, - ) - # API configs - owl_is_prod: bool = False - owl_cache_purge: bool = False - owl_db_dir: str = "db" - owl_file_dir: str = "file://file" - owl_log_dir: str = "logs" - owl_file_proxy_url: str = "localhost:6969" - owl_host: str = "0.0.0.0" - owl_port: int = 6969 - owl_workers: int = 1 - owl_max_concurrency: int = 300 - default_org_id: str = "default" - default_project_id: str = "default" - owl_redis_host: str = "dragonfly" - owl_redis_port: int = 6379 - owl_internal_org_id: str = "org_82d01c923f25d5939b9d4188" - # Configs - owl_embed_file_upload_max_bytes: int = 200 * 1024 * 1024 # 200MB in bytes - owl_image_file_upload_max_bytes: int = 20 * 1024 * 1024 # 20MB in bytes - owl_audio_file_upload_max_bytes: int = 120 * 1024 * 1024 # 120MB in bytes - owl_compute_storage_period_min: float = 1 - owl_models_config: str = "models.json" - owl_pricing_config: str = "cloud_pricing.json" - # Starling configs - s3_endpoint: str = "" - s3_access_key_id: str = "" - s3_secret_access_key: SecretStr = "" - s3_backup_bucket_name: str = "" - # Generative Table configs - owl_table_lock_timeout_sec: int = 15 - owl_reindex_period_sec: int = 60 - owl_immediate_reindex_max_rows: int = 2000 - owl_optimize_period_sec: int = 60 - owl_remove_version_older_than_mins: float = 5.0 - owl_concurrent_rows_batch_size: int = 3 - owl_concurrent_cols_batch_size: int = 5 - owl_max_write_batch_size: int = 1000 - # Code Executor configs - code_executor_endpoint: str = "http://kopi:5569" - # Loader configs - docio_url: str = "http://docio:6979/api/docio" - unstructuredio_url: str = "http://unstructuredio:6989" - # PDF Loader configs - owl_fast_pdf_parsing: bool = True - # LLM configs - owl_llm_timeout_sec: Annotated[int, Field(gt=0, le=60 * 60)] = 60 - owl_embed_timeout_sec: Annotated[int, Field(gt=0, le=60 * 60)] = 60 - cohere_api_base: str = "https://api.cohere.ai/v1" - jina_api_base: str = "https://api.jina.ai/v1" - voyage_api_base: str = "https://api.voyageai.com/v1" - clip_api_base: str = "http://localhost:51010" - # Auth Keys - owl_session_secret: SecretStr = "oh yeah" - owl_github_client_id: str = "" - owl_github_client_secret: SecretStr = "" - owl_encryption_key: SecretStr = "" - service_key: SecretStr = "" - service_key_alt: SecretStr = "" - # Keys - unstructuredio_api_key: SecretStr = "ellm" - stripe_api_key: SecretStr = "" - openmeter_api_key: SecretStr = "" - custom_api_key: SecretStr = "" - openai_api_key: SecretStr = "" - anthropic_api_key: SecretStr = "" - gemini_api_key: SecretStr = "" - cohere_api_key: SecretStr = "" - groq_api_key: SecretStr = "" - together_api_key: SecretStr = "" - jina_api_key: SecretStr = "" - voyage_api_key: SecretStr = "" - hyperbolic_api_key: SecretStr = "" - cerebras_api_key: SecretStr = "" - sambanova_api_key: SecretStr = "" - deepseek_api_key: SecretStr = "" - - @model_validator(mode="after") - def make_paths_absolute(self): - self.owl_db_dir = abspath(self.owl_db_dir) - self.owl_log_dir = abspath(self.owl_log_dir) - self.owl_models_config: str = str(CURR_DIR / self.owl_models_config) - self.owl_pricing_config: str = str(CURR_DIR / self.owl_pricing_config) - return self - - @model_validator(mode="after") - def check_alternate_service_key(self): - if self.service_key_alt.get_secret_value().strip() == "": - self.service_key_alt = self.service_key - return self - - @cached_property - def is_oss(self): - if self.service_key.get_secret_value() == "": - return True - return not (CURR_DIR.parent / "routers" / "cloud_admin.py").is_file() - - @property - def s3_secret_access_key_plain(self): - return self.s3_secret_access_key.get_secret_value() - - @property - def owl_encryption_key_plain(self): - return self.owl_encryption_key.get_secret_value() - - @property - def owl_session_secret_plain(self): - return self.owl_session_secret.get_secret_value() - - @property - def owl_github_client_secret_plain(self): - return self.owl_github_client_secret.get_secret_value() - - @property - def service_key_plain(self): - return self.service_key.get_secret_value() - - @property - def service_key_alt_plain(self): - return self.service_key_alt.get_secret_value() - - @property - def unstructuredio_api_key_plain(self): - return self.unstructuredio_api_key.get_secret_value() - - @property - def stripe_api_key_plain(self): - return self.stripe_api_key.get_secret_value() - - @property - def openmeter_api_key_plain(self): - return self.openmeter_api_key.get_secret_value() - - @property - def custom_api_key_plain(self): - return self.custom_api_key.get_secret_value() - - @property - def openai_api_key_plain(self): - return self.openai_api_key.get_secret_value() - - @property - def anthropic_api_key_plain(self): - return self.anthropic_api_key.get_secret_value() - - @property - def gemini_api_key_plain(self): - return self.gemini_api_key.get_secret_value() - - @property - def cohere_api_key_plain(self): - return self.cohere_api_key.get_secret_value() - - @property - def groq_api_key_plain(self): - return self.groq_api_key.get_secret_value() - - @property - def together_api_key_plain(self): - return self.together_api_key.get_secret_value() - - @property - def jina_api_key_plain(self): - return self.jina_api_key.get_secret_value() - - @property - def voyage_api_key_plain(self): - return self.voyage_api_key.get_secret_value() - - @property - def hyperbolic_api_key_plain(self): - return self.hyperbolic_api_key.get_secret_value() - - @property - def cerebras_api_key_plain(self): - return self.cerebras_api_key.get_secret_value() - - @property - def sambanova_api_key_plain(self): - return self.sambanova_api_key.get_secret_value() - - @property - def deepseek_api_key_plain(self): - return self.deepseek_api_key.get_secret_value() - - -MODEL_CONFIG_KEY = " models" -PRICES_KEY = " prices" -INTERNAL_ORG_ID_KEY = " internal_org_id" -ENV_CONFIG = EnvConfig() -# Create db dir -try: - os.makedirs(ENV_CONFIG.owl_db_dir, exist_ok=False) -except OSError: - pass - - -class PlanName(str, Enum): - DEFAULT = "default" - FREE = "free" - PRO = "pro" - TEAM = "team" - DEMO = "_demo" - PARTNER = "_partner" - DEBUG = "_debug" - - def __str__(self) -> str: - return self.value - - -_product2column = dict( - credit=("credit",), - credit_grant=("credit_grant",), - llm_tokens=("llm_tokens_quota_mtok", "llm_tokens_usage_mtok"), - embedding_tokens=( - "embedding_tokens_quota_mtok", - "embedding_tokens_usage_mtok", - ), - reranker_searches=("reranker_quota_ksearch", "reranker_usage_ksearch"), - db_storage=("db_quota_gib", "db_usage_gib"), - file_storage=("file_quota_gib", "file_usage_gib"), - egress=("egress_quota_gib", "egress_usage_gib"), -) - - -class ProductType(str, Enum): - CREDIT = "credit" - CREDIT_GRANT = "credit_grant" - LLM_TOKENS = "llm_tokens" - EMBEDDING_TOKENS = "embedding_tokens" - RERANKER_SEARCHES = "reranker_searches" - DB_STORAGE = "db_storage" - FILE_STORAGE = "file_storage" - EGRESS = "egress" - - def __str__(self) -> str: - return self.value - - @property - def quota_column(self) -> str: - return _product2column[self.value][0] - - @property - def usage_column(self) -> str: - return _product2column[self.value][-1] - - @classmethod - def exclude_credits(cls) -> list["ProductType"]: - return [p for p in cls if not p.value.startswith("credit")] - - -class Tier(BaseModel): - """ - https://docs.stripe.com/api/prices/object#price_object-tiers - """ - - unit_amount_decimal: Decimal = Field( - description="Per unit price for units relevant to the tier.", - ) - up_to: float | None = Field( - description=( - "Up to and including to this quantity will be contained in the tier. " - "None means infinite quantity." - ), - ) - - -class Product(BaseModel): - name: str = Field( - min_length=1, - description="Plan name.", - ) - included: Tier = Tier(unit_amount_decimal=0, up_to=0) - tiers: list[Tier] - unit: str = Field( - description="Unit of measurement.", - ) - - -class Plan(BaseModel): - name: str - stripe_price_id_live: str - stripe_price_id_test: str - flat_amount_decimal: Decimal = Field( - description="Base price for the entire tier.", - ) - credit_grant: float = Field( - description="Credit amount included in USD.", - ) - max_users: int = Field( - description="Maximum number of users per organization.", - ) - products: dict[ProductType, Product] = Field( - description="Mapping of price name to tier list where each element represents a pricing tier.", - ) - - @computed_field - @property - def stripe_price_id(self) -> str: - return ( - self.stripe_price_id_live - if ENV_CONFIG.stripe_api_key_plain.startswith("sk_live") - else self.stripe_price_id_test - ) - - -class Price(BaseModel): - object: str = Field( - default="prices.plans", - description="Type of API response object.", - examples=["prices.plans"], - ) - plans: dict[PlanName, Plan] = Field( - description="Mapping of price plan name to price plan.", - ) - - -class _ModelPrice(BaseModel): - id: str = Field( - description=( - 'Unique identifier in the form of "{provider}/{model_id}". ' - "Users will specify this to select a model." - ), - examples=[ - EXAMPLE_CHAT_MODEL_IDS[0], - EXAMPLE_EMBEDDING_MODEL_IDS[0], - EXAMPLE_RERANKING_MODEL_IDS[0], - ], - ) - name: str = Field( - description="Name of the model.", - examples=["OpenAI GPT-4o Mini"], - ) - - -class LLMModelPrice(_ModelPrice): - input_cost_per_mtoken: float = Field( - description="Cost in USD per million (mega) input / prompt token.", - ) - output_cost_per_mtoken: float = Field( - description="Cost in USD per million (mega) output / completion token.", - ) - - -class EmbeddingModelPrice(_ModelPrice): - cost_per_mtoken: float = Field( - description="Cost in USD per million embedding tokens.", - ) - - -class RerankingModelPrice(_ModelPrice): - cost_per_ksearch: float = Field(description="Cost in USD for a thousand searches.") - - -class ModelPrice(BaseModel): - object: str = Field( - default="prices.models", - description="Type of API response object.", - examples=["prices.models"], - ) - llm_models: list[LLMModelPrice] = [] - embed_models: list[EmbeddingModelPrice] = [] - rerank_models: list[RerankingModelPrice] = [] - - -class Config: - def __init__(self): - self.use_redis = ENV_CONFIG.owl_workers > 1 - if self.use_redis: - logger.debug("Using Redis as cache.") - self._redis = redis.Redis( - host=ENV_CONFIG.owl_redis_host, - port=ENV_CONFIG.owl_redis_port, - db=0, - # https://redis.io/kb/doc/22wxq63j93/how-to-manage-client-reconnections-in-case-of-errors-with-redis-py - retry=Retry(ExponentialBackoff(cap=10, base=1), 25), - retry_on_error=[ConnectionError, TimeoutError, ConnectionResetError], - health_check_interval=1, - ) - else: - logger.debug("Using in-memory dict as cache.") - self._data = {} - - def get(self, key: str) -> Any: - return self[key] - - def set(self, key: str, value: str) -> None: - self[key] = value - - def purge(self): - if self.use_redis: - for key in self._redis.scan_iter("*"): - self._redis.delete(key) - else: - self._data = {} - - def __setitem__(self, key: str, value: str) -> None: - if not isinstance(value, str): - raise TypeError(f"`value` must be a str, received: {type(value)}") - if not (isinstance(key, str) and key.startswith("")): - raise ValueError(f'`key` must be a str that starts with "", received: {key}') - if self.use_redis: - self._redis.set(key, value) - else: - self._data[key] = value - - def __getitem__(self, key: str) -> str | None: - if self.use_redis: - item = self._redis.get(key) - return None if item is None else item.decode("utf-8") - else: - try: - return self._data[key] - except KeyError: - return None - - def __delitem__(self, key) -> None: - if self.use_redis: - self._redis.delete(key) - else: - if key in self._data: - del self._data[key] - - def __contains__(self, key) -> bool: - if self.use_redis: - self._redis.exists(key) - else: - return key in self._data - - def __repr__(self) -> str: - if self.use_redis: - _data = {key.decode("utf-8"): self[key] for key in self._redis.scan_iter("*")} - else: - _data = self._data - return repr(_data) - - def get_internal_organization_id(self) -> str: - org_id = self[INTERNAL_ORG_ID_KEY] - if org_id is None: - org_id = ENV_CONFIG.owl_internal_org_id - self[INTERNAL_ORG_ID_KEY] = org_id - return org_id - - def set_internal_organization_id(self, organization_id: str) -> None: - self[INTERNAL_ORG_ID_KEY] = organization_id - logger.info(f"Internal organization ID set to: {organization_id}") - - @property - def internal_organization_id(self) -> str: - return self.get_internal_organization_id() - - @staticmethod - @lru_cache(maxsize=1) - def _load_model_config_from_json(json: str) -> ModelListConfig: - models = ModelListConfig.model_validate_json(json) - return models - - def _load_model_config_from_file(self) -> ModelListConfig: - # Validate JSON file - with open(ENV_CONFIG.owl_models_config, "r") as f: - models = self._load_model_config_from_json(f.read()) - return models - - def get_model_json(self) -> str: - model_json = self[MODEL_CONFIG_KEY] - if model_json is None: - model_json = self._load_model_config_from_file().model_dump_json() - self[MODEL_CONFIG_KEY] = model_json - return model_json - - def get_model_config(self) -> ModelListConfig: - model_json = self[MODEL_CONFIG_KEY] - if model_json is None: - model_json = self.get_model_json() - return self._load_model_config_from_json(model_json) - - def set_model_config(self, body: ModelListConfig) -> None: - self[MODEL_CONFIG_KEY] = body.model_dump_json() - logger.info(f"Model config set to: {body}") - try: - with open(ENV_CONFIG.owl_models_config, "w") as f: - f.write(body.model_dump_json(exclude_defaults=True)) - except Exception as e: - logger.warning(f"Failed to update `{ENV_CONFIG.owl_models_config}`: {e}") - - def get_model_pricing(self) -> ModelPrice: - return ModelPrice.model_validate(self.get_model_config().model_dump(exclude={"object"})) - - @staticmethod - @lru_cache(maxsize=1) - def _load_pricing_from_json(json: str) -> Price: - pricing = Price.model_validate_json(json) - return pricing - - def _load_pricing_from_file(self) -> Price: - # Validate JSON file - with open(ENV_CONFIG.owl_pricing_config, "r") as f: - pricing = self._load_pricing_from_json(f.read()) - return pricing - - def get_pricing(self) -> Price: - pricing_json = self[PRICES_KEY] - if pricing_json is None: - pricing = self._load_pricing_from_file() - self[PRICES_KEY] = pricing.model_dump_json() - logger.warning(f"Pricing set to: {pricing}") - return pricing - return self._load_pricing_from_json(pricing_json) - - def set_pricing(self, body: Price) -> None: - self[PRICES_KEY] = body.model_dump_json() - logger.info(f"Pricing set to: {body}") - try: - with open(ENV_CONFIG.owl_pricing_config, "w") as f: - f.write(body.model_dump_json(exclude_defaults=True)) - except Exception as e: - logger.warning(f"Failed to update `{ENV_CONFIG.owl_pricing_config}`: {e}") - - -CONFIG = Config() diff --git a/services/api/src/owl/configs/models.json b/services/api/src/owl/configs/models.json deleted file mode 100644 index 7ec0332..0000000 --- a/services/api/src/owl/configs/models.json +++ /dev/null @@ -1,155 +0,0 @@ -{ - "llm_models": [ - { - "id": "openai/gpt-4o-mini", - "name": "OpenAI GPT-4o Mini", - "context_length": 128000, - "languages": ["mul"], - "capabilities": ["chat", "image"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "anthropic/claude-3-haiku-20240307", - "name": "Anthropic Claude 3 Haiku", - "context_length": 200000, - "languages": ["mul"], - "capabilities": ["chat"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "anthropic" - } - ] - }, - { - "id": "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", - "name": "Together AI Meta Llama 3.1 (8B)", - "context_length": 130000, - "languages": ["mul"], - "capabilities": ["chat"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "together_ai" - } - ] - } - ], - "embed_models": [ - { - "id": "ellm/BAAI/bge-small-en-v1.5", - "name": "ELLM BAAI BGE Small EN v1.5", - "context_length": 512, - "embedding_size": 384, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "openai/BAAI/bge-small-en-v1.5", - "api_base": "http://infinity:6909", - "provider": "ellm" - } - ] - }, - { - "id": "openai/text-embedding-3-large-3072", - "name": "OpenAI Text Embedding 3 Large (3072-dim)", - "context_length": 8192, - "embedding_size": 3072, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "text-embedding-3-large", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "openai/text-embedding-3-large-256", - "name": "OpenAI Text Embedding 3 Large (256-dim)", - "context_length": 8192, - "embedding_size": 256, - "dimensions": 256, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "text-embedding-3-large", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "openai/text-embedding-3-small-512", - "name": "OpenAI Text Embedding 3 Small (512-dim)", - "context_length": 8192, - "embedding_size": 512, - "dimensions": 512, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "text-embedding-3-small", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "cohere/embed-multilingual-v3.0", - "name": "Cohere Embed Multilingual v3.0", - "context_length": 512, - "embedding_size": 1024, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "embed-multilingual-v3.0", - "api_base": "", - "provider": "cohere" - } - ] - } - ], - "rerank_models": [ - { - "id": "ellm/mixedbread-ai/mxbai-rerank-xsmall-v1", - "name": "ELLM MxBAI Rerank XSmall v1", - "context_length": 512, - "languages": ["en"], - "capabilities": ["rerank"], - "deployments": [ - { - "litellm_id": "", - "api_base": "http://infinity:6909", - "provider": "ellm" - } - ] - }, - { - "id": "cohere/rerank-multilingual-v3.0", - "name": "Cohere Rerank Multilingual v3.0", - "context_length": 512, - "languages": ["mul"], - "capabilities": ["rerank"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "cohere" - } - ] - } - ] -} diff --git a/services/api/src/owl/configs/models_aipc.json b/services/api/src/owl/configs/models_aipc.json deleted file mode 100644 index 3ba623c..0000000 --- a/services/api/src/owl/configs/models_aipc.json +++ /dev/null @@ -1,241 +0,0 @@ -{ - "llm_models": [ - { - "id": "ellm/phi3-mini-int4", - "name": "ELLM Phi-3 Instruct", - "context_length": 4096, - "languages": ["en", "cn"], - "capabilities": ["chat"], - "deployments": [ - { - "litellm_id": "openai/phi3-mini-int4", - "api_base": "http://localhost:5555/v1", - "provider": "ellm" - } - ] - }, - { - "id": "openai/gpt-4o-mini", - "name": "OpenAI GPT-4o Mini", - "context_length": 128000, - "languages": ["mul"], - "capabilities": ["chat", "image"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "openai/gpt-4o", - "name": "OpenAI GPT-4o", - "context_length": 128000, - "languages": ["mul"], - "capabilities": ["chat", "image"], - "deployments": [ - { - "litellm_id": "openai/gpt-4o-2024-08-06", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "openai/gpt-4-turbo", - "name": "OpenAI GPT-4 Turbo", - "context_length": 128000, - "languages": ["mul"], - "capabilities": ["chat", "image"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "anthropic/claude-3.5-sonnet", - "name": "Anthropic Claude 3.5 Sonnet", - "context_length": 200000, - "languages": ["mul"], - "capabilities": ["chat"], - "deployments": [ - { - "litellm_id": "anthropic/claude-3-5-sonnet-20240620", - "api_base": "", - "provider": "anthropic" - } - ] - }, - { - "id": "anthropic/claude-3-haiku-20240307", - "name": "Anthropic Claude 3 Haiku", - "context_length": 200000, - "languages": ["mul"], - "capabilities": ["chat"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "anthropic" - } - ] - }, - { - "id": "anthropic/claude-3-sonnet-20240229", - "name": "Anthropic Claude 3 Sonnet", - "context_length": 200000, - "languages": ["mul"], - "capabilities": ["chat"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "anthropic" - } - ] - }, - { - "id": "anthropic/claude-3-opus-20240229", - "name": "Anthropic Claude 3 Opus", - "context_length": 200000, - "languages": ["mul"], - "capabilities": ["chat"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "anthropic" - } - ] - }, - { - "id": "together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", - "name": "Together AI Meta Llama 3.1 (405B)", - "context_length": 128000, - "languages": ["mul"], - "capabilities": ["chat"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "together_ai" - } - ] - } - ], - "embed_models": [ - { - "id": "ellm/BAAI/bge-small-en-v1.5", - "name": "ELLM BAAI BGE Small EN v1.5", - "context_length": 512, - "embedding_size": 384, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "openai/BAAI/bge-small-en-v1.5", - "api_base": "http://infinity:6909", - "provider": "ellm" - } - ] - }, - { - "id": "openai/text-embedding-3-large-3072", - "name": "OpenAI Text Embedding 3 Large (3072-dim)", - "context_length": 8192, - "embedding_size": 3072, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "text-embedding-3-large", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "openai/text-embedding-3-large-256", - "name": "OpenAI Text Embedding 3 Large (256-dim)", - "context_length": 8192, - "embedding_size": 256, - "dimensions": 256, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "text-embedding-3-large", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "openai/text-embedding-3-small-512", - "name": "OpenAI Text Embedding 3 Small (512-dim)", - "context_length": 8192, - "embedding_size": 512, - "dimensions": 512, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "text-embedding-3-small", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "cohere/embed-multilingual-v3.0", - "name": "Cohere Embed Multilingual v3.0", - "context_length": 512, - "embedding_size": 1024, - "languages": ["mul"], - "capabilities": ["embed"], - "owned_by": "cohere", - "deployments": [ - { - "litellm_id": "embed-multilingual-v3.0", - "api_base": "", - "provider": "cohere" - } - ] - } - ], - "rerank_models": [ - { - "id": "ellm/mixedbread-ai/mxbai-rerank-xsmall-v1", - "name": "ELLM MxBAI Rerank XSmall v1", - "context_length": 512, - "languages": ["en"], - "capabilities": ["rerank"], - "deployments": [ - { - "litellm_id": "", - "api_base": "http://infinity:6909", - "provider": "ellm" - } - ] - }, - { - "id": "cohere/rerank-multilingual-v3.0", - "name": "Cohere Rerank Multilingual v3.0", - "context_length": 512, - "languages": ["mul"], - "capabilities": ["rerank"], - "owned_by": "cohere", - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "cohere" - } - ] - } - ] -} diff --git a/services/api/src/owl/configs/models_ci.json b/services/api/src/owl/configs/models_ci.json deleted file mode 100644 index fcc62a0..0000000 --- a/services/api/src/owl/configs/models_ci.json +++ /dev/null @@ -1,139 +0,0 @@ -{ - "llm_models": [ - { - "id": "openai/gpt-4o-mini", - "name": "OpenAI GPT-4o Mini", - "context_length": 128000, - "languages": ["mul"], - "capabilities": ["chat", "image", "tool"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "anthropic/claude-3-haiku-20240307", - "name": "Anthropic Claude 3 Haiku", - "context_length": 200000, - "languages": ["mul"], - "capabilities": ["chat", "tool"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "anthropic" - } - ] - }, - { - "id": "meta/Llama3.2-3b-instruct", - "name": "Meta Llama 3.2 (3B)", - "context_length": 128000, - "languages": ["mul"], - "capabilities": ["chat"], - "deployments": [ - { - "litellm_id": "openai/meta/Llama3.2-3b-instruct", - "api_base": "https://llmci.embeddedllm.com/chat/v1", - "provider": "custom" - } - ] - }, - { - "id": "ellm/Qwen/Qwen-2-Audio-7B", - "object": "model", - "name": "Qwen 2 Audio 7B (Audio, internal)", - "context_length": 128000, - "languages": ["mul"], - "capabilities": ["chat", "audio"], - "deployments": [ - { - "litellm_id": "openai/Qwen/Qwen-2-Audio-7B", - "api_base": "https://llmci.embeddedllm.com/audio/v1", - "provider": "custom" - } - ] - } - ], - "embed_models": [ - { - "id": "ellm/sentence-transformers/all-MiniLM-L6-v2", - "name": "ELLM MiniLM L6 v2", - "context_length": 512, - "embedding_size": 384, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "openai/sentence-transformers/all-MiniLM-L6-v2", - "api_base": "http://infinity:6909", - "provider": "ellm" - } - ] - }, - { - "id": "openai/text-embedding-3-small-512", - "name": "OpenAI Text Embedding 3 Small (512-dim)", - "context_length": 8192, - "embedding_size": 512, - "dimensions": 512, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "text-embedding-3-small", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "cohere/embed-multilingual-v3.0", - "name": "Cohere Embed Multilingual v3.0", - "context_length": 512, - "embedding_size": 1024, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "embed-multilingual-v3.0", - "api_base": "", - "provider": "cohere" - } - ] - } - ], - "rerank_models": [ - { - "id": "ellm/cross-encoder/ms-marco-TinyBERT-L-2", - "name": "ELLM TinyBERT L2", - "context_length": 512, - "languages": ["en"], - "capabilities": ["rerank"], - "deployments": [ - { - "litellm_id": "", - "api_base": "http://infinity:6909", - "provider": "ellm" - } - ] - }, - { - "id": "cohere/rerank-multilingual-v3.0", - "name": "Cohere Rerank Multilingual v3.0", - "context_length": 512, - "languages": ["mul"], - "capabilities": ["rerank"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "cohere" - } - ] - } - ] -} diff --git a/services/api/src/owl/configs/models_ollama.json b/services/api/src/owl/configs/models_ollama.json deleted file mode 100644 index 00f1ed0..0000000 --- a/services/api/src/owl/configs/models_ollama.json +++ /dev/null @@ -1,171 +0,0 @@ -{ - "llm_models": [ - { - "id": "openai/gpt-4o-mini", - "name": "OpenAI GPT-4o Mini", - "context_length": 128000, - "languages": ["mul"], - "capabilities": ["chat"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "anthropic/claude-3-haiku-20240307", - "name": "Anthropic Claude 3 Haiku", - "context_length": 200000, - "languages": ["mul"], - "capabilities": ["chat"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "anthropic" - } - ] - }, - { - "id": "together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", - "name": "Together AI Meta Llama 3.1 (405B)", - "context_length": 128000, - "languages": ["mul"], - "capabilities": ["chat"], - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "together_ai" - } - ] - }, - { - "id": "ellm/Qwen/Qwen2.5-3B-Instruct", - "name": "ELLM Qwen2.5 (3B)", - "context_length": 32000, - "languages": ["en"], - "capabilities": ["chat"], - "deployments": [ - { - "litellm_id": "openai/Qwen/Qwen2.5-3B-Instruct", - "api_base": "http://ollama:11434/v1", - "provider": "ellm" - } - ] - } - ], - "embed_models": [ - { - "id": "ellm/BAAI/bge-small-en-v1.5", - "name": "ELLM BAAI BGE Small EN v1.5", - "context_length": 512, - "embedding_size": 384, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "openai/BAAI/bge-small-en-v1.5", - "api_base": "http://infinity:6909", - "provider": "ellm" - } - ] - }, - { - "id": "openai/text-embedding-3-large-3072", - "name": "OpenAI Text Embedding 3 Large (3072-dim)", - "context_length": 8192, - "embedding_size": 3072, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "text-embedding-3-large", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "openai/text-embedding-3-large-256", - "name": "OpenAI Text Embedding 3 Large (256-dim)", - "context_length": 8192, - "embedding_size": 256, - "dimensions": 256, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "text-embedding-3-large", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "openai/text-embedding-3-small-512", - "name": "OpenAI Text Embedding 3 Small (512-dim)", - "context_length": 8192, - "embedding_size": 512, - "dimensions": 512, - "languages": ["mul"], - "capabilities": ["embed"], - "deployments": [ - { - "litellm_id": "text-embedding-3-small", - "api_base": "", - "provider": "openai" - } - ] - }, - { - "id": "cohere/embed-multilingual-v3.0", - "name": "Cohere Embed Multilingual v3.0", - "context_length": 512, - "embedding_size": 1024, - "languages": ["mul"], - "capabilities": ["embed"], - "owned_by": "cohere", - "deployments": [ - { - "litellm_id": "embed-multilingual-v3.0", - "api_base": "", - "provider": "cohere" - } - ] - } - ], - "rerank_models": [ - { - "id": "ellm/mixedbread-ai/mxbai-rerank-xsmall-v1", - "name": "ELLM MxBAI Rerank XSmall v1", - "context_length": 512, - "languages": ["en"], - "capabilities": ["rerank"], - "deployments": [ - { - "litellm_id": "", - "api_base": "http://infinity:6909", - "provider": "ellm" - } - ] - }, - { - "id": "cohere/rerank-multilingual-v3.0", - "name": "Cohere Rerank Multilingual v3.0", - "context_length": 512, - "languages": ["mul"], - "capabilities": ["rerank"], - "owned_by": "cohere", - "deployments": [ - { - "litellm_id": "", - "api_base": "", - "provider": "cohere" - } - ] - } - ] -} diff --git a/services/api/src/owl/configs/oss.py b/services/api/src/owl/configs/oss.py new file mode 100644 index 0000000..87d9338 --- /dev/null +++ b/services/api/src/owl/configs/oss.py @@ -0,0 +1,306 @@ +from functools import cached_property +from os.path import abspath +from pathlib import Path +from typing import Annotated, Literal, Self + +from loguru import logger +from pydantic import Field, SecretStr, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + +CURR_DIR = Path(__file__).resolve().parent + + +class EnvConfig(BaseSettings): + model_config = SettingsConfigDict( + env_prefix="owl_", + env_file=".env", + env_file_encoding="utf-8", + extra="ignore", + cli_parse_args=False, + ) + # API configs + db_path: str = "postgresql+psycopg://owlpguser:owlpgpassword@pgbouncer:5432/jamaibase_owl" # Default to Postgres + log_dir: str = "logs" + host: str = "0.0.0.0" + port: int = 6969 + workers: int = 1 # The suggested number of workers is (2*CPU)+1 + max_concurrency: int = 300 + db_init: bool | None = None # None means unset + db_reset: bool = False + db_init_max_users: int = 5 + cache_reset: bool = False + enable_byok: bool = True + disable_billing: bool = False + log_timings: bool = False + # Services + redis_host: str = "dragonfly" + redis_port: int = 6379 + file_proxy_url: str = "localhost:6969" + file_dir: str = "s3://file" + s3_endpoint: str = "http://minio:9000" + s3_access_key_id: str = "minioadmin" + s3_secret_access_key: SecretStr = "minioadmin" + code_executor_endpoint: str = "http://kopi:3000" + docling_url: str = "http://docling:5001" + docling_timeout_sec: Annotated[int, Field(gt=0, le=60 * 60)] = 20 * 60 + test_llm_api_base: str = "http://test-llm:6970/v1" + # Configs + embed_file_upload_max_bytes: int = 200 * 1024 * 1024 # 200MiB in bytes + image_file_upload_max_bytes: int = 20 * 1024 * 1024 # 20MiB in bytes + audio_file_upload_max_bytes: int = 120 * 1024 * 1024 # 120MiB in bytes + compute_storage_period_sec: Annotated[float, Field(ge=0, le=60 * 60)] = 60 * 5 + document_loader_cache_ttl_sec: int = 60 * 15 # 15 minutes + # Starling configs + s3_backup_bucket_name: str = "" + # Starling database configs + flush_clickhouse_buffer_sec: int = 60 + # Generative Table configs + concurrent_rows_batch_size: int = 3 + concurrent_cols_batch_size: int = 5 + max_write_batch_size: int = 100 + max_file_cache_size: int = 20 + # PDF Loader configs + fast_pdf_parsing: bool = True + # LLM configs + llm_timeout_sec: Annotated[int, Field(gt=0, le=60 * 60)] = 60 + embed_timeout_sec: Annotated[int, Field(gt=0, le=60 * 60)] = 60 + code_timeout_sec: Annotated[int, Field(gt=0, le=60 * 60)] = 120 + cohere_api_base: str = "https://api.cohere.ai/v1" + jina_ai_api_base: str = "https://api.jina.ai/v1" + voyage_api_base: str = "https://api.voyageai.com/v1" + # Keys + encryption_key: SecretStr = "" + service_key: SecretStr = "" + service_key_alt: SecretStr = "" + # OpenTelemetry configs + opentelemetry_host: str = "otel-collector" + opentelemetry_port: int = 4317 + # VictoriaMetrics configs + victoria_metrics_host: str = "vmauth" + victoria_metrics_port: int = 8427 + victoria_metrics_user: str = "owl" + victoria_metrics_password: SecretStr = "owl-vm" + # Clickhouse configs + clickhouse_host: str = "clickhouse" + clickhouse_port: int = 8123 + clickhouse_user: str = "owluser" + clickhouse_password: SecretStr = "owlpassword" + clickhouse_db: str = "jamaibase_owl" + clickhouse_max_buffer_queue_size: int = 10000 + # Clickhouse Redis queue buffer + clickhouse_buffer_key: str = "clickhouse_insert_buffer" + # Stripe & Billing + stripe_api_key: SecretStr = "" + stripe_publishable_key_live: SecretStr = "" + stripe_publishable_key_test: SecretStr = "" + stripe_webhook_secret_live: SecretStr = "" + stripe_webhook_secret_test: SecretStr = "" + payment_lapse_max_days: int = 7 + # Auth0 + auth0_api_key: SecretStr = "" + # Keys + unstructuredio_api_key: SecretStr = "ellm" + anthropic_api_key: SecretStr = "" + azure_api_key: SecretStr = "" + azure_ai_api_key: SecretStr = "" + bedrock_api_key: SecretStr = "" + cerebras_api_key: SecretStr = "" + cohere_api_key: SecretStr = "" + deepseek_api_key: SecretStr = "" + ellm_api_key: SecretStr = "" + gemini_api_key: SecretStr = "" + groq_api_key: SecretStr = "" + hyperbolic_api_key: SecretStr = "" + jina_ai_api_key: SecretStr = "" + openai_api_key: SecretStr = "" + openrouter_api_key: SecretStr = "" + sagemaker_api_key: SecretStr = "" + sambanova_api_key: SecretStr = "" + together_ai_api_key: SecretStr = "" + vertex_ai_api_key: SecretStr = "" + voyage_api_key: SecretStr = "" + + @model_validator(mode="after") + def check_db_init(self) -> Self: + if self.db_init is None: + self.db_init = True if self.is_oss else False + return self + + @model_validator(mode="after") + def make_paths_absolute(self) -> Self: + self.log_dir = abspath(self.log_dir) + return self + + @model_validator(mode="after") + def check_alternate_service_key(self) -> Self: + if self.service_key_alt.get_secret_value().strip() == "": + self.service_key_alt = self.service_key + return self + + @model_validator(mode="after") + def validate_db_path(self) -> Self: + """ + Validates that `db_path` starts with either `rqlite+pyrqlite://` or `sqlite://` or `sqlite+libsql://` or `postgresql`. + """ + if not ( + self.db_path.startswith("rqlite+pyrqlite://") + or self.db_path.startswith("sqlite://") + or self.db_path.startswith("sqlite+libsql://") + or self.db_path.startswith("postgresql") + ): + raise ValueError(f'`db_path` "{self.db_path}" has an invalid dialect.') + return self + + @property + def db_dialect(self) -> Literal["rqlite", "libsql", "postgresql", "sqlite"]: + """ + Show the sqlite dialect that's in use based on the `db_path`. + """ + if self.db_path.startswith("rqlite+pyrqlite://"): + return "rqlite" + elif self.db_path.startswith("sqlite+libsql://"): + return "libsql" + elif self.db_path.startswith("postgresql"): + return "postgresql" + elif self.db_path.startswith("sqlite://"): + return "sqlite" + + @cached_property + def is_oss(self) -> bool: + logger.opt(colors=True).info("Launching in OSS mode.") + return True + + @cached_property + def is_cloud(self) -> bool: + return not self.is_oss + + @property + def s3_secret_access_key_plain(self) -> str: + return self.s3_secret_access_key.get_secret_value() + + @property + def victoria_metrics_password_plain(self) -> str: + return self.victoria_metrics_password.get_secret_value().strip() + + @property + def is_stripe_live(self) -> bool: + return self.stripe_api_key_plain.startswith("sk_live") + + @property + def stripe_api_key_plain(self) -> str: + return self.stripe_api_key.get_secret_value() + + @property + def stripe_webhook_secret_plain(self) -> str: + return ( + self.stripe_webhook_secret_live.get_secret_value() + if self.is_stripe_live + else self.stripe_webhook_secret_test.get_secret_value() + ) + + @property + def stripe_publishable_key_plain(self) -> str: + return ( + self.stripe_publishable_key_live.get_secret_value() + if self.is_stripe_live + else self.stripe_publishable_key_test.get_secret_value() + ) + + @property + def auth0_api_key_plain(self) -> str: + return self.auth0_api_key.get_secret_value() + + @property + def encryption_key_plain(self) -> str: + return self.encryption_key.get_secret_value() + + @property + def service_key_plain(self) -> str: + return self.service_key.get_secret_value() + + @property + def service_key_alt_plain(self) -> str: + return self.service_key_alt.get_secret_value() + + @property + def unstructuredio_api_key_plain(self) -> str: + return self.unstructuredio_api_key.get_secret_value() + + @property + def anthropic_api_key_plain(self) -> str: + return self.anthropic_api_key.get_secret_value() + + @property + def azure_api_key_plain(self) -> str: + return self.azure_api_key.get_secret_value() + + @property + def azure_ai_api_key_plain(self) -> str: + return self.azure_ai_api_key.get_secret_value() + + @property + def bedrock_api_key_plain(self) -> str: + return self.azure_ai_api_key.get_secret_value() + + @property + def cerebras_api_key_plain(self) -> str: + return self.cerebras_api_key.get_secret_value() + + @property + def cohere_api_key_plain(self) -> str: + return self.cohere_api_key.get_secret_value() + + @property + def deepseek_api_key_plain(self) -> str: + return self.deepseek_api_key.get_secret_value() + + @property + def ellm_api_key_plain(self) -> str: + return self.ellm_api_key.get_secret_value() + + @property + def gemini_api_key_plain(self) -> str: + return self.gemini_api_key.get_secret_value() + + @property + def groq_api_key_plain(self) -> str: + return self.groq_api_key.get_secret_value() + + @property + def hyperbolic_api_key_plain(self) -> str: + return self.hyperbolic_api_key.get_secret_value() + + @property + def jina_ai_api_key_plain(self) -> str: + return self.jina_ai_api_key.get_secret_value() + + @property + def openai_api_key_plain(self) -> str: + return self.openai_api_key.get_secret_value() + + @property + def openrouter_api_key_plain(self) -> str: + return self.openrouter_api_key.get_secret_value() + + @property + def sagemaker_api_key_plain(self) -> str: + return self.sagemaker_api_key.get_secret_value() + + @property + def sambanova_api_key_plain(self) -> str: + return self.sambanova_api_key.get_secret_value() + + @property + def together_ai_api_key_plain(self) -> str: + return self.together_ai_api_key.get_secret_value() + + @property + def vertex_ai_api_key_plain(self) -> str: + return self.vertex_ai_api_key.get_secret_value() + + @property + def voyage_api_key_plain(self) -> str: + return self.voyage_api_key.get_secret_value() + + def get_api_key(self, provider: str, default: str = "") -> str: + return getattr(self, f"{provider}_api_key_plain", default) diff --git a/services/api/src/owl/configs/preset_models.json b/services/api/src/owl/configs/preset_models.json index 34ef5ce..fb60cc0 100644 --- a/services/api/src/owl/configs/preset_models.json +++ b/services/api/src/owl/configs/preset_models.json @@ -7,13 +7,14 @@ "name": "OpenAI GPT-4.1", "type": "llm", "context_length": 1047576, + "max_output_tokens": 32768, "capabilities": ["chat", "image"], "languages": ["en", "mul"], "llm_input_cost_per_mtoken": 2.0, "llm_output_cost_per_mtoken": 8.0, "deployments": [ { - "name": "OpenAI GPT-4.1", + "name": "OpenAI GPT-4.1 Deployment", "provider": "openai", "routing_id": "openai/gpt-4.1", "api_base": "" @@ -28,13 +29,14 @@ "name": "OpenAI GPT-4.1 Mini", "type": "llm", "context_length": 1047576, + "max_output_tokens": 32768, "capabilities": ["chat", "image"], "languages": ["en", "mul"], "llm_input_cost_per_mtoken": 0.4, "llm_output_cost_per_mtoken": 1.6, "deployments": [ { - "name": "OpenAI GPT-4.1 Mini", + "name": "OpenAI GPT-4.1 Mini Deployment", "provider": "openai", "routing_id": "openai/gpt-4.1-mini", "api_base": "" @@ -49,13 +51,14 @@ "name": "OpenAI GPT-4.1 Nano", "type": "llm", "context_length": 1047576, + "max_output_tokens": 32768, "capabilities": ["chat", "image"], "languages": ["en", "mul"], "llm_input_cost_per_mtoken": 0.1, "llm_output_cost_per_mtoken": 0.4, "deployments": [ { - "name": "OpenAI GPT-4.1 Nano", + "name": "OpenAI GPT-4.1 Nano Deployment", "provider": "openai", "routing_id": "openai/gpt-4.1-nano", "api_base": "" @@ -69,37 +72,61 @@ "id": "openai/gpt-4o", "name": "OpenAI GPT-4o", "type": "llm", - "context_length": 1047576, + "context_length": 128000, + "max_output_tokens": 16384, "capabilities": ["chat", "image"], "languages": ["en", "mul"], "llm_input_cost_per_mtoken": 2.5, "llm_output_cost_per_mtoken": 10.0, "deployments": [ { - "name": "OpenAI GPT-4o", + "name": "OpenAI GPT-4o Deployment", "provider": "openai", "routing_id": "openai/gpt-4o", "api_base": "" } ] }, + { + "meta": { + "icon": "openai" + }, + "id": "openai/gpt-4o-mini", + "name": "OpenAI GPT-4o Mini", + "type": "llm", + "context_length": 128000, + "max_output_tokens": 16384, + "capabilities": ["chat", "image"], + "languages": ["en", "mul"], + "llm_input_cost_per_mtoken": 0.15, + "llm_output_cost_per_mtoken": 0.6, + "deployments": [ + { + "name": "OpenAI GPT-4o Mini Deployment", + "provider": "openai", + "routing_id": "openai/gpt-4o-mini", + "api_base": "" + } + ] + }, { "meta": { "icon": "anthropic" }, - "id": "anthropic/claude-3.5-haiku", - "name": "Anthropic Claude 3.5 Haiku", + "id": "anthropic/claude-opus-4", + "name": "Anthropic Claude Opus 4", "type": "llm", "context_length": 200000, + "max_output_tokens": 32000, "capabilities": ["chat", "image"], "languages": ["en", "mul"], - "llm_input_cost_per_mtoken": 0.8, - "llm_output_cost_per_mtoken": 4.0, + "llm_input_cost_per_mtoken": 3.0, + "llm_output_cost_per_mtoken": 15.0, "deployments": [ { - "name": "Anthropic Claude 3.5 Haiku", + "name": "Anthropic Claude Opus 4 Deployment", "provider": "anthropic", - "routing_id": "anthropic/claude-3-5-haiku-20241022", + "routing_id": "anthropic/claude-opus-4-0", "api_base": "" } ] @@ -108,19 +135,20 @@ "meta": { "icon": "anthropic" }, - "id": "anthropic/claude-3.5-sonnet", - "name": "Anthropic Claude 3.5 Sonnet", + "id": "anthropic/claude-sonnet-4", + "name": "Anthropic Claude Sonnet 4", "type": "llm", "context_length": 200000, + "max_output_tokens": 64000, "capabilities": ["chat", "image"], "languages": ["en", "mul"], "llm_input_cost_per_mtoken": 3.0, "llm_output_cost_per_mtoken": 15.0, "deployments": [ { - "name": "Anthropic Claude 3.5 Sonnet", + "name": "Anthropic Claude Sonnet 4 Deployment", "provider": "anthropic", - "routing_id": "anthropic/claude-3-5-sonnet-20241022", + "routing_id": "anthropic/claude-sonnet-4-0", "api_base": "" } ] @@ -130,22 +158,67 @@ "icon": "anthropic" }, "id": "anthropic/claude-3.7-sonnet", - "name": "Anthropic Claude 3.7 Sonnet", + "name": "Anthropic Claude Sonnet 3.7", "type": "llm", "context_length": 200000, + "max_output_tokens": 64000, "capabilities": ["chat", "image"], "languages": ["en", "mul"], "llm_input_cost_per_mtoken": 3.0, "llm_output_cost_per_mtoken": 15.0, "deployments": [ { - "name": "Anthropic Claude 3.7 Sonnet", + "name": "Anthropic Claude Sonnet 3.7 Deployment", "provider": "anthropic", "routing_id": "anthropic/claude-3-7-sonnet-latest", "api_base": "" } ] }, + { + "meta": { + "icon": "anthropic" + }, + "id": "anthropic/claude-3.5-sonnet", + "name": "Anthropic Claude Sonnet 3.5", + "type": "llm", + "context_length": 200000, + "max_output_tokens": 8192, + "capabilities": ["chat", "image"], + "languages": ["en", "mul"], + "llm_input_cost_per_mtoken": 3.0, + "llm_output_cost_per_mtoken": 15.0, + "deployments": [ + { + "name": "Anthropic Claude Sonnet 3.5 Deployment", + "provider": "anthropic", + "routing_id": "anthropic/claude-3-5-sonnet-latest", + "api_base": "" + } + ] + }, + { + "meta": { + "icon": "anthropic" + }, + "id": "anthropic/claude-3.5-haiku", + "name": "Anthropic Claude Haiku 3.5", + "type": "llm", + "context_length": 200000, + "max_output_tokens": 8192, + "capabilities": ["chat", "image"], + "languages": ["en", "mul"], + "llm_input_cost_per_mtoken": 0.8, + "llm_output_cost_per_mtoken": 4.0, + "deployments": [ + { + "name": "Anthropic Claude Haiku 3.5 Deployment", + "provider": "anthropic", + "routing_id": "anthropic/claude-3-5-haiku-latest", + "api_base": "" + } + ] + }, { "meta": { "icon": "google" @@ -160,7 +233,7 @@ "llm_output_cost_per_mtoken": 15.0, "deployments": [ { - "name": "Google Gemini 2.5 Pro Preview", + "name": "Google Gemini 2.5 Pro Preview Deployment", "provider": "gemini", "routing_id": "gemini/gemini-2.5-pro-preview-03-25", "api_base": "" @@ -181,7 +254,7 @@ "llm_output_cost_per_mtoken": 0.6, "deployments": [ { - "name": "Google Gemini 2.5 Flash Preview", + "name": "Google Gemini 2.5 Flash Preview Deployment", "provider": "gemini", "routing_id": "gemini/gemini-2.5-flash-preview-04-17", "api_base": "" @@ -202,13 +275,13 @@ "llm_output_cost_per_mtoken": 0.5, "deployments": [ { - "name": "Meta Llama 4 Scout (109B-A17B, MoE)", + "name": "Meta Llama 4 Scout (109B-A17B, MoE) Deployment", "huggingface_id": "meta-llama/Llama-4-Scout-17B-16E-Instruct", "cpu_count": "4", "memory_gb": "24", "required_vram": "140", "num_replicas": 1, - "service_type": "vllm" + "provider": "vllm" } ] }, @@ -226,13 +299,13 @@ "llm_output_cost_per_mtoken": 0.95, "deployments": [ { - "name": "Meta Llama 4 Maverick (400B-A17B, MoE)", + "name": "Meta Llama 4 Maverick (400B-A17B, MoE) Deployment", "huggingface_id": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", "cpu_count": "4", "memory_gb": "24", "required_vram": "320", "num_replicas": 1, - "service_type": "vllm" + "provider": "vllm" } ] }, @@ -250,13 +323,13 @@ "llm_output_cost_per_mtoken": 1.2, "deployments": [ { - "name": "DeepSeek V3 (685B-A22B, MoE)", + "name": "DeepSeek V3 (685B-A22B, MoE) Deployment", "huggingface_id": "deepseek-ai/DeepSeek-V3-0324", "cpu_count": "4", "memory_gb": "8", "required_vram": "1100", "num_replicas": 1, - "service_type": "vllm" + "provider": "vllm" } ] }, @@ -274,13 +347,13 @@ "llm_output_cost_per_mtoken": 2.5, "deployments": [ { - "name": "DeepSeek R1 (685B-A22B, MoE)", + "name": "DeepSeek R1 (685B-A22B, MoE) Deployment", "huggingface_id": "deepseek-ai/DeepSeek-R1", "cpu_count": "4", "memory_gb": "8", "required_vram": "1100", "num_replicas": 1, - "service_type": "vllm" + "provider": "vllm" } ] }, @@ -289,7 +362,7 @@ "icon": "qwen" }, "id": "Qwen/Qwen3-235B-A22B-FP8", - "name": "Qwen3 (235B-A22B, MoE)", + "name": "Qwen 3 (235B-A22B, MoE)", "type": "llm", "context_length": 40960, "capabilities": ["chat"], @@ -298,13 +371,13 @@ "llm_output_cost_per_mtoken": 0.45, "deployments": [ { - "name": "Qwen3 (235B-A22B, MoE)", + "name": "Qwen 3 (235B-A22B, MoE) Deployment", "huggingface_id": "Qwen/Qwen3-235B-A22B-FP8", "cpu_count": "4", "memory_gb": "8", "required_vram": "280", "num_replicas": 1, - "service_type": "vllm" + "provider": "vllm" } ] }, @@ -313,7 +386,7 @@ "icon": "qwen" }, "id": "Qwen/Qwen3-32B-FP8", - "name": "Qwen3 (32B)", + "name": "Qwen 3 (32B)", "type": "llm", "context_length": 40960, "capabilities": ["chat"], @@ -322,13 +395,13 @@ "llm_output_cost_per_mtoken": 0.45, "deployments": [ { - "name": "Qwen3 (32B)", + "name": "Qwen 3 (32B) Deployment", "huggingface_id": "Qwen/Qwen3-32B-FP8", "cpu_count": "4", "memory_gb": "8", "required_vram": "42", "num_replicas": 1, - "service_type": "vllm" + "provider": "vllm" } ] }, @@ -337,7 +410,7 @@ "icon": "qwen" }, "id": "Qwen/Qwen3-8B-FP8", - "name": "Qwen3 (8B)", + "name": "Qwen 3 (8B)", "type": "llm", "context_length": 40960, "capabilities": ["chat"], @@ -346,13 +419,13 @@ "llm_output_cost_per_mtoken": 0.22, "deployments": [ { - "name": "Qwen3 (8B)", + "name": "Qwen 3 (8B) Deployment", "huggingface_id": "Qwen/Qwen3-8B-FP8", "cpu_count": "4", "memory_gb": "8", "required_vram": "15", "num_replicas": 1, - "service_type": "vllm" + "provider": "vllm" } ] }, @@ -361,7 +434,7 @@ "icon": "qwen" }, "id": "Qwen/Qwen2.5-VL-32B-Instruct", - "name": "Qwen2.5 VL (32B)", + "name": "Qwen 2.5 VL (32B)", "type": "llm", "context_length": 32768, "capabilities": ["chat"], @@ -370,13 +443,13 @@ "llm_output_cost_per_mtoken": 0.4, "deployments": [ { - "name": "Qwen2.5 VL (32B)", + "name": "Qwen 2.5 VL (32B) Deployment", "huggingface_id": "Qwen/Qwen2.5-VL-32B-Instruct", "cpu_count": "4", "memory_gb": "16", "required_vram": "52", "num_replicas": 1, - "service_type": "vllm" + "provider": "vllm" } ] }, @@ -394,7 +467,7 @@ "embedding_cost_per_mtoken": 0.13, "deployments": [ { - "name": "OpenAI Text Embedding 3 Large (3072-dim)", + "name": "OpenAI Text Embedding 3 Large (3072-dim) Deployment", "provider": "openai", "routing_id": "openai/text-embedding-3-large", "api_base": "" @@ -411,12 +484,12 @@ "context_length": 8192, "capabilities": ["embed"], "languages": ["en", "mul"], - "embedding_size": 256, + "embedding_size": 3072, "embedding_dimensions": 256, "embedding_cost_per_mtoken": 0.13, "deployments": [ { - "name": "OpenAI Text Embedding 3 Large (256-dim)", + "name": "OpenAI Text Embedding 3 Large (256-dim) Deployment", "provider": "openai", "routing_id": "openai/text-embedding-3-large", "api_base": "" @@ -434,16 +507,60 @@ "capabilities": ["embed"], "languages": ["en", "mul"], "embedding_size": 1536, - "embedding_cost_per_mtoken": 0.022, + "embedding_cost_per_mtoken": 0.02, "deployments": [ { - "name": "OpenAI Text Embedding 3 Small (1536-dim)", + "name": "OpenAI Text Embedding 3 Small (1536-dim) Deployment", "provider": "openai", "routing_id": "openai/text-embedding-3-small", "api_base": "" } ] }, + { + "meta": { + "icon": "openai" + }, + "id": "openai/text-embedding-3-small-256", + "name": "OpenAI Text Embedding 3 Small (256-dim)", + "type": "embed", + "context_length": 8192, + "capabilities": ["embed"], + "languages": ["en", "mul"], + "embedding_size": 1536, + "embedding_dimensions": 256, + "embedding_cost_per_mtoken": 0.02, + "deployments": [ + { + "name": "OpenAI Text Embedding 3 Small (256-dim) Deployment", + "provider": "openai", + "routing_id": "openai/text-embedding-3-small", + "api_base": "" + } + ] + }, + { + "meta": { + "icon": "cohere" + }, + "id": "cohere/embed-v4.0-256", + "name": "Cohere Embed v4.0 (256-dim)", + "type": "embed", + "context_length": 128000, + "capabilities": ["embed"], + "languages": ["en", "mul"], + "embedding_size": 1536, + "embedding_dimensions": 256, + "embedding_cost_per_mtoken": 0.12, + "deployments": [ + { + "name": "Cohere Embed v4.0 (256-dim) Deployment", + "provider": "cohere", + "routing_id": "embed-v4.0", + "api_base": "" + } + ] + }, { "meta": { "icon": "cohere" @@ -458,7 +575,7 @@ "embedding_cost_per_mtoken": 0.11, "deployments": [ { - "name": "Cohere Embed Multilingual v3.0", + "name": "Cohere Embed Multilingual v3.0 Deployment", "provider": "cohere", "routing_id": "embed-multilingual-v3.0", "api_base": "" @@ -470,7 +587,7 @@ "icon": "generic" }, "id": "BAAI/bge-m3", - "name": "BAAI bge-m3", + "name": "BAAI BGE-M3", "type": "embed", "context_length": 8192, "capabilities": ["embed"], @@ -478,13 +595,13 @@ "embedding_cost_per_mtoken": 0.022, "deployments": [ { - "name": "BAAI bge-m3", + "name": "BAAI BGE-M3 Deployment", "huggingface_id": "BAAI/bge-m3", "cpu_count": "2", "memory_gb": "4", "required_vram": "14", "num_replicas": 1, - "service_type": "infinity" + "provider": "infinity" } ] }, @@ -501,7 +618,7 @@ "reranking_cost_per_ksearch": 2, "deployments": [ { - "name": "Cohere Rerank Multilingual v3.0", + "name": "Cohere Rerank Multilingual v3.0 Deployment", "provider": "cohere", "routing_id": "rerank-multilingual-v3.0", "api_base": "" @@ -521,13 +638,13 @@ "reranking_cost_per_ksearch": 2, "deployments": [ { - "name": "BGE Reranker V2 M3", + "name": "BGE Reranker V2 M3 Deployment", "huggingface_id": "BAAI/bge-reranker-v2-m3", "cpu_count": "2", "memory_gb": "4", "required_vram": "14", "num_replicas": 1, - "service_type": "infinity" + "provider": "infinity" } ] } diff --git a/services/api/src/owl/db/__init__.py b/services/api/src/owl/db/__init__.py index 9af6c28..1b47542 100644 --- a/services/api/src/owl/db/__init__.py +++ b/services/api/src/owl/db/__init__.py @@ -1,88 +1,572 @@ +from contextlib import asynccontextmanager, contextmanager from functools import lru_cache -from os import makedirs -from os.path import dirname -from typing import Type -from urllib.parse import urlsplit +from typing import AsyncGenerator, Callable, Generator +from async_lru import alru_cache from loguru import logger -from sqlalchemy import Engine, NullPool, Pool, QueuePool, event +from sqlalchemy import Connection, Engine, NullPool, TextClause, text from sqlalchemy.exc import OperationalError -from sqlmodel import MetaData, SQLModel, create_engine, text +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlmodel import Session, create_engine +from sqlmodel.ext.asyncio.session import AsyncSession -from owl.configs.manager import ENV_CONFIG +from owl.configs import CACHE, ENV_CONFIG +from owl.db.models import TEMPLATE_ORG_ID, JamaiSQLModel # noqa: F401 +from owl.utils import uuid7_str +SCHEMA = JamaiSQLModel.metadata.schema -def _pragma_on_connect(dbapi_con, con_record): - dbapi_con.execute("pragma foreign_keys = ON;\n") - dbapi_con.execute("pragma journal_mode = WAL;\n") - dbapi_con.execute("pragma synchronous = normal;\n") - dbapi_con.execute("pragma journal_size_limit = 6144000;\n") - # dbapi_con.execute("pragma temp_store = memory;\n") - # dbapi_con.execute("pragma mmap_size = 30000000000;\n") - -def _do_connect(dbapi_connection, connection_record): - # Disable pysqlite's emitting of the BEGIN statement entirely. - # Also stops it from emitting COMMIT before any DDL. - dbapi_connection.isolation_level = None - - -def _do_begin(conn): - # Emit our own BEGIN - conn.exec_driver_sql("BEGIN") - - -def create_sqlite_engine( +def _create_db_engine( db_url: str, *, connect_args: dict | None = None, - poolclass: Type[Pool] | None = None, + engine_create_fn: Callable[..., Engine | AsyncEngine] | None = None, echo: bool = False, - **kwargs, + dialect: str = "sqlite", ) -> Engine: - db_dir = dirname(urlsplit(db_url).path.replace("/", "", 1)) - makedirs(db_dir, exist_ok=True) - engine = create_engine( - db_url, - connect_args=connect_args or {"check_same_thread": False}, - poolclass=poolclass or NullPool, - echo=echo, - **kwargs, + if connect_args is None: + if dialect == "postgresql": + connect_args = {} + else: + connect_args = {"check_same_thread": False} + if engine_create_fn is None: + engine_create_fn = create_engine + if dialect == "postgresql": + logger.debug("Using PostgreSQL DB.") + if "asyncpg" in db_url: + connect_args["prepared_statement_name_func"] = lambda: f"__asyncpg_{uuid7_str()}__" + engine = engine_create_fn( + db_url, + connect_args=connect_args, + poolclass=NullPool, + echo=echo, + ) + else: + raise ValueError(f'Dialect "{dialect}" is not supported.') + try: + from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor + except ImportError: + logger.warning("Skip sqlalchemy instrumentation.") + else: + SQLAlchemyInstrumentor().instrument( + engine=engine if isinstance(engine, Engine) else engine.sync_engine, + enable_commenter=True, + commenter_options={}, + ) + return engine + + +@lru_cache(maxsize=1) +def create_db_engine() -> Engine: + engine = _create_db_engine( + ENV_CONFIG.db_path, + dialect=ENV_CONFIG.db_dialect, + ) + return engine + + +@alru_cache(maxsize=1) +async def create_db_engine_async() -> AsyncEngine: + engine = _create_db_engine( + ENV_CONFIG.db_path, + engine_create_fn=create_async_engine, + dialect=ENV_CONFIG.db_dialect, ) - event.listen(engine, "connect", _pragma_on_connect) - # Enabling these seems to lead to DB locking issues - # https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#pysqlite-serializable - # https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#aiosqlite-serializable - # event.listen(engine, "connect", _do_connect) - # event.listen(engine, "begin", _do_begin) return engine -def create_sql_tables(db_class: Type[SQLModel], engine: Engine): +def yield_session() -> Generator[Session, None, None]: + with Session(create_db_engine()) as session: + yield session + + +async def yield_async_session() -> AsyncGenerator[AsyncSession, None]: + async with AsyncSession(await create_db_engine_async(), expire_on_commit=False) as session: + yield session + + +# Sync Session context manager +sync_session = contextmanager(yield_session) +# Async Session context manager +async_session = asynccontextmanager(yield_async_session) + + +@lru_cache(maxsize=10000) +def cached_text(query: str) -> TextClause: + return text(query) + + +async def reset_db(*, reset_max_users: int = 3): + from sqlmodel import func, select + + from owl.db.models import User + + # Only allow DB reset in dev with localhost + if "@localhost:" not in ENV_CONFIG.db_path: + raise ValueError("DB reset is only allowed in dev with localhost DB.") + + async with async_session() as session: + # As a safety measure, reset DB only if it has less than `init_max_users` users + # Just in case we accidentally tried to nuke a prod DB + user_table_exists = ( + await session.exec( + text( + ( + f"SELECT EXISTS (" + f"SELECT FROM information_schema.tables WHERE table_schema = '{SCHEMA}' AND table_name = 'User'" + ");" + ) + ) + ) + ).scalar() + if user_table_exists: + user_count = (await session.exec(select(func.count(User.id)))).one() + if user_count >= reset_max_users: + logger.info( + f"Found {user_count:,d} users, abort database reset (>= {reset_max_users} users)." + ) + return + + # Delete all tables + logger.warning(f'Resetting database (dropping schema "{SCHEMA}")...') + await session.exec(text(f"DROP SCHEMA IF EXISTS {SCHEMA} CASCADE")) + await session.exec(text(f"CREATE SCHEMA {SCHEMA}")) + # Reapply default privileges for the new schema OID + await _grant_auditor_privilege(await create_db_engine_async()) + await session.commit() + stmt = """ + SELECT schema_name + FROM information_schema.schemata + WHERE schema_name ~ '^proj_.*(_action|_knowledge|_chat)$'; + """ + schemas = [r[0] for r in (await session.exec(text(stmt))).all()] + logger.warning(f'Dropping Generative Table schemas: "{schemas}"') + for schema in schemas: + await session.exec(text(f"DROP SCHEMA {schema} CASCADE")) + await session.commit() + conn = await session.connection() + await conn.run_sync(JamaiSQLModel.metadata.create_all) + await conn.commit() + logger.success("All application tables dropped and recreated.") + + +async def _create_schema(engine: AsyncEngine) -> bool: + async with engine.begin() as conn: + await conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {SCHEMA}")) + await conn.commit() + return False + + +async def _create_tables(engine: AsyncEngine) -> bool: try: - db_class.metadata.create_all(engine) + async with engine.begin() as conn: + await conn.run_sync(JamaiSQLModel.metadata.create_all) + await conn.commit() except Exception as e: logger.exception(f"Failed to create DB tables: {e}") if not isinstance(e, OperationalError): raise + return False -@lru_cache(maxsize=1000) -def cached_text(query: str): - return text(query) +async def _create_pg_functions(engine: AsyncEngine) -> bool: + async with engine.connect() as conn: + await conn.execute( + text(f""" + CREATE OR REPLACE FUNCTION {SCHEMA}.deduct_cost( + organization_id TEXT, + cost NUMERIC(21, 12) + ) + RETURNS {SCHEMA}."Organization" AS $$ + DECLARE + updated_org {SCHEMA}."Organization"%ROWTYPE; + BEGIN + -- Ensure the cost is a positive number to prevent misuse + IF cost < 0 THEN + RAISE EXCEPTION 'Cost must be a non-negative number.'; + END IF; + + UPDATE {SCHEMA}."Organization" + SET + -- Logic for credit_grant column + credit_grant = CASE + -- If grant is enough to cover the cost, deduct from grant + WHEN credit_grant >= cost THEN credit_grant - cost + -- Otherwise, the grant is fully used up + ELSE 0 + END, + -- Logic for credit column + credit = CASE + -- If grant is enough, credit is unchanged + WHEN credit_grant >= cost THEN credit + -- Otherwise, deduct the remainder of the cost from credit + ELSE credit - (cost - credit_grant) + END + WHERE id = organization_id + RETURNING * INTO updated_org; -- Capture the updated row into a variable + + RETURN updated_org; + END; + $$ LANGUAGE plpgsql; + """) + ) + await conn.execute( + text(f""" + CREATE OR REPLACE FUNCTION {SCHEMA}.add_credit_grant( + organization_id TEXT, + grant_to_add NUMERIC(21, 12) + ) + RETURNS {SCHEMA}."Organization" AS $$ + DECLARE + updated_org {SCHEMA}."Organization"%ROWTYPE; + BEGIN + -- Treat negative grant amounts as zero + grant_to_add := GREATEST(grant_to_add, 0); + + -- Atomically update the organization's credits. + UPDATE {SCHEMA}."Organization" + SET + credit_grant = GREATEST(credit_grant + grant_to_add + LEAST(credit, 0), 0), + credit = CASE + -- Case 1: No debt. Credit is unchanged. + WHEN credit >= 0 THEN credit + + -- Case 2: Debt exists + ELSE LEAST(credit + credit_grant + grant_to_add, 0) + END + WHERE id = organization_id + RETURNING * INTO updated_org; + + RETURN updated_org; + END; + $$ LANGUAGE plpgsql; + """) + ) + await conn.commit() + return False + + +async def _check_column_exists( + conn: Connection, + table_name: str, + column_name: str, +) -> bool: + sql = text(f""" + SELECT 1 + FROM information_schema.columns + WHERE table_schema = '{SCHEMA}' AND table_name = '{table_name}' AND column_name = '{column_name}' + LIMIT 1; + """) + exists = (await conn.execute(sql)).scalar() + if exists: + logger.info(f'Column "{column_name}" found in "{table_name}" table.') + return True + return False + + +async def _add_egress_updated_at_column(engine: AsyncEngine) -> bool: + async with engine.connect() as conn: + if await _check_column_exists(conn, "Organization", "egress_usage_updated_at"): + return False + await conn.execute( + text(f""" + ALTER TABLE {SCHEMA}."Organization" + ADD COLUMN egress_usage_updated_at TIMESTAMPTZ DEFAULT NOW(); + """) + ) + await conn.commit() + return True + + +async def _add_project_description_column(engine: AsyncEngine) -> bool: + """ + Add project description column. + """ + table_name = "Project" + column_name = "description" + + async with engine.connect() as conn: + # Check if the column already exists + if await _check_column_exists(conn, table_name, column_name): + return False + await conn.execute( + text( + f"""ALTER TABLE {SCHEMA}."{table_name}" ADD COLUMN {column_name} TEXT DEFAULT ''""" + ) + ) + await conn.commit() + logger.success(f'Successfully added column "{column_name}" to "{table_name}".') + return True + + +async def _grant_auditor_privilege(engine: AsyncEngine) -> bool: + """ + Apply the necessary grants to allow the auditor role to audit the database. + """ + auditor_role = "jamaibase_auditor" + audit_statement = "UPDATE, DELETE" + async with engine.connect() as conn: + role_exists = await conn.scalar( + text(f"SELECT 1 FROM pg_roles WHERE rolname = '{auditor_role}'") + ) + if role_exists is None: + return False + + # alter default privileges for FUTURE tables + await conn.execute( + text( + f'ALTER DEFAULT PRIVILEGES IN SCHEMA "{SCHEMA}" ' + f"GRANT {audit_statement} ON TABLES TO {auditor_role};" + ) + ) + + # grant privileges for existing tables right now + await conn.exec_driver_sql( + f'GRANT {audit_statement} ON ALL TABLES IN SCHEMA "{SCHEMA}" TO {auditor_role};' + ) + await conn.commit() + return False + + +async def _migrate_verification_codes(engine: AsyncEngine) -> bool: + """ + - Add columns: + - `purpose`: str | None + - `used_at`: DatetimeUTC | None + - `revoked_at`: DatetimeUTC | None + - If `meta` JSONB contains "purpose" key, update `purpose` column and delete "purpose" key + """ + if ENV_CONFIG.is_oss: + return False + + table_name = "VerificationCode" + async with engine.connect() as conn: + if ( + await _check_column_exists(conn, table_name, "purpose") + and await _check_column_exists(conn, table_name, "revoked_at") + and await _check_column_exists(conn, table_name, "used_at") + ): + return False + async with engine.begin() as conn: + await conn.execute(text(f'LOCK TABLE {SCHEMA}."{table_name}" IN SHARE MODE;')) + # Add columns + await conn.execute( + text( + f""" + ALTER TABLE {SCHEMA}."{table_name}" + ADD COLUMN IF NOT EXISTS purpose TEXT DEFAULT NULL, + ADD COLUMN IF NOT EXISTS used_at TIMESTAMPTZ DEFAULT NULL, + ADD COLUMN IF NOT EXISTS revoked_at TIMESTAMPTZ DEFAULT NULL; + """ + ) + ) + # If `meta` JSONB contains "purpose" key, update `purpose` column and delete "purpose" key + await conn.execute( + text( + f""" + UPDATE {SCHEMA}."{table_name}" SET purpose = meta ->> 'purpose' WHERE meta ->> 'purpose' IS NOT NULL; + UPDATE {SCHEMA}."{table_name}" SET meta = meta - 'purpose' WHERE meta ->> 'purpose' IS NOT NULL; + """ + ) + ) + logger.info(f'Successfully migrated "{table_name}".') + return True + + +async def migrate_db(): + engine = await create_db_engine_async() + migrated = [ + await _create_schema(engine), + await _grant_auditor_privilege(engine), + await _create_tables(engine), + await _create_pg_functions(engine), + await _add_egress_updated_at_column(engine), + await _add_project_description_column(engine), + await _migrate_verification_codes(engine), + ] + if any(migrated): + logger.success("DB migrations performed.") + else: + logger.success("No DB migrations performed.") + # Clean up connection pool + # https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork + await engine.dispose() + # Always clear cache + await CACHE.clear_all_async() + await CACHE.aclose() + + +async def init_db(*, init_max_users: int = 3): + from fastapi import Request + from sqlmodel import func, select + from starlette.datastructures import URL, Headers + + from owl.db.models import ModelConfig, Organization, User + from owl.routers import models + from owl.routers.organizations import oss as organizations_oss + from owl.routers.projects import oss as projects_oss + from owl.routers.users import oss as users_oss + from owl.types import OrganizationRead, UserRead + from owl.utils.exceptions import ResourceNotFoundError + from owl.utils.test import ( + GPT_41_NANO_CONFIG, + TEXT_EMBEDDING_3_SMALL_CONFIG, + RERANK_ENGLISH_v3_SMALL_CONFIG, + ) + + async with async_session() as session: + # As a safety measure, init DB only if it has less than `init_max_users` users + # Just in case we accidentally tried to nuke a prod DB + user_count = (await session.exec(select(func.count(User.id)))).one() + if user_count >= init_max_users: + logger.info( + f"Found {user_count:,d} users, abort database initialisation (>= {init_max_users} users)." + ) + return + + # Only enforce OSS check if db_init=False + if ENV_CONFIG.is_oss and user_count != 0: + logger.info("OSS mode: Skipping initialization (non-empty DB).") + return + + logger.info("Initialising database...") + + # Create a mock Request object + request = Request( + { + "type": "http", + "method": "POST", + "headers": Headers({"content-type": "application/json"}).raw, + "url": URL("/v2/users"), + "state": {"id": uuid7_str()}, + } + ) + + # User + try: + user = await User.get(session, "0") + except ResourceNotFoundError: + await users_oss.create_user( + request=request, + token="", + session=session, + body=users_oss.UserCreate( + name="Admin user", + email="user@local.com", + password="jambubu", + ), + ) + user = await User.get(session, "0", populate_existing=True) + + # Manually verify email + user.email_verified = True + session.add(user) + await session.commit() + await session.refresh(user) + user = UserRead.model_validate(user) + + # Organization + if await session.get(Organization, "0") is None: + await organizations_oss.create_organization( + request=request, + user=user, + session=session, + body=organizations_oss.OrganizationCreate( + name="Admin org", + external_keys={ + "anthropic": ENV_CONFIG.anthropic_api_key_plain, + "azure": ENV_CONFIG.azure_api_key_plain, + "azure_ai": ENV_CONFIG.azure_ai_api_key_plain, + "bedrock": ENV_CONFIG.bedrock_api_key_plain, + "cerebras": ENV_CONFIG.cerebras_api_key_plain, + "cohere": ENV_CONFIG.cohere_api_key_plain, + "deepseek": ENV_CONFIG.deepseek_api_key_plain, + "gemini": ENV_CONFIG.gemini_api_key_plain, + "groq": ENV_CONFIG.groq_api_key_plain, + "hyperbolic": ENV_CONFIG.hyperbolic_api_key_plain, + "jina_ai": ENV_CONFIG.jina_ai_api_key_plain, + "openai": ENV_CONFIG.openai_api_key_plain, + "openrouter": ENV_CONFIG.openrouter_api_key_plain, + "sagemaker": ENV_CONFIG.sagemaker_api_key_plain, + "sambanova": ENV_CONFIG.sambanova_api_key_plain, + "together_ai": ENV_CONFIG.together_ai_api_key_plain, + "vertex_ai": ENV_CONFIG.vertex_ai_api_key_plain, + "voyage": ENV_CONFIG.voyage_api_key_plain, + }, + ), + ) + if ENV_CONFIG.is_oss: + return + # Continue creating sample data for Cloud mode + user = UserRead.model_validate(await User.get(session, user.id, populate_existing=True)) + # Add credit grant + org = await session.get(Organization, "0", populate_existing=True) + org.credit_grant = 150.0 + session.add(org) + await session.commit() + await session.refresh(org) + org = OrganizationRead.model_validate(org) -MAIN_ENGINE = create_sqlite_engine( - f"sqlite:///{ENV_CONFIG.owl_db_dir}/main.db", - # https://github.com/bluesky/tiled/issues/663 - poolclass=QueuePool, - pool_pre_ping=True, - pool_size=ENV_CONFIG.owl_max_concurrency, - max_overflow=ENV_CONFIG.owl_max_concurrency, - pool_timeout=30, - pool_recycle=300, -) + # Project + await projects_oss.create_project( + request=request, + user=user, + session=session, + body=projects_oss.ProjectCreate( + organization_id=org.id, + name="Admin project", + ), + project_id="proj_bee957b5881f35e120909510", + ) + model_count = (await session.exec(select(func.count(ModelConfig.id)))).one() + model_list: list[models.ModelConfig] = [] + if model_count == 0: + # Chat models + model_list.append( + await models.create_model_config( + request=request, + user=user, + session=session, + body=GPT_41_NANO_CONFIG, + ) + ) + # Embedding model + model_list.append( + await models.create_model_config( + request=request, + user=user, + session=session, + body=TEXT_EMBEDDING_3_SMALL_CONFIG, + ) + ) + # Reranking model + model_list.append( + await models.create_model_config( + request=request, + user=user, + session=session, + body=RERANK_ENGLISH_v3_SMALL_CONFIG, + ) + ) -class UserSQLModel(SQLModel): - metadata = MetaData() + # Model Deployments + for model in model_list: + provider = model.id.split("/")[0] + # We need to deploy non-standard models manually + if provider not in models.CloudProvider: + continue + await models.create_deployment( + request=request, + user=user, + session=session, + body=models.DeploymentCreate( + model_id=model.id, + name=f"{model.name} deployment 1", + provider=provider, + routing_id=model.id, + api_base="", + ), + ) diff --git a/services/api/src/owl/db/gen_executor.py b/services/api/src/owl/db/gen_executor.py index 493ce07..79112ae 100644 --- a/services/api/src/owl/db/gen_executor.py +++ b/services/api/src/owl/db/gen_executor.py @@ -1,163 +1,139 @@ -import asyncio import base64 import re -from dataclasses import dataclass -from os.path import splitext -from time import time +from asyncio import Queue, TaskGroup +from os.path import basename, splitext +from time import perf_counter, time from typing import Any, AsyncGenerator, Literal import numpy as np +from async_lru import alru_cache from fastapi import Request from fastapi.exceptions import RequestValidationError from loguru import logger - -from jamaibase.exceptions import BadInputError, JamaiException, ResourceNotFoundError -from owl.db.gen_table import GenerativeTable -from owl.llm import LLMEngine -from owl.models import CloudEmbedder -from owl.protocol import ( +from pydantic import BaseModel + +from owl.configs import ENV_CONFIG +from owl.db.gen_table import GenerativeTableCore, KnowledgeTable +from owl.docparse import GeneralDocLoader +from owl.types import ( + AUDIO_FILE_EXTENSIONS, + DOCUMENT_FILE_EXTENSIONS, GEN_CONFIG_VAR_PATTERN, - ChatCompletionChoiceDelta, - ChatCompletionChunk, + IMAGE_FILE_EXTENSIONS, + AudioContent, + AudioContentData, + CellCompletionResponse, + CellReferencesResponse, + ChatCompletionChoice, + ChatCompletionMessage, + ChatCompletionResponse, + ChatCompletionUsage, ChatEntry, ChatRequest, + ChatRole, + ChatThreadEntry, + Chunk, CodeGenConfig, + ColumnDtype, + DiscriminatedGenConfig, EmbedGenConfig, - ExternalKeys, - GenTableChatCompletionChunks, - GenTableRowsChatCompletionChunks, - GenTableStreamChatCompletionChunk, - GenTableStreamReferences, + ImageContent, + ImageContentData, LLMGenConfig, + MultiRowAddRequest, + MultiRowCompletionResponse, + MultiRowRegenRequest, + OrganizationRead, + ProjectRead, + PythonGenConfig, + References, RegenStrategy, RowAdd, - RowAddRequest, + RowCompletionResponse, RowRegen, - RowRegenRequest, - TableMeta, + TextContent, ) from owl.utils import mask_string, uuid7_draft2_str +from owl.utils.billing import BillingManager from owl.utils.code import code_executor +from owl.utils.exceptions import ( + BadInputError, + JamaiException, + ResourceNotFoundError, + UpStreamError, +) from owl.utils.io import open_uri_async +from owl.utils.lm import LMEngine -@dataclass(slots=True) -class Task: - type: Literal["embed", "chat", "code"] +class Task(BaseModel, validate_assignment=True): output_column_name: str - body: ChatRequest | EmbedGenConfig | CodeGenConfig dtype: str + body: DiscriminatedGenConfig + status: Literal["pending", "running", "done"] = "pending" + + +class Result(BaseModel, validate_assignment=True): + row_id: str + + +class TaskResult(Result): + response: CellReferencesResponse | CellCompletionResponse | ChatCompletionResponse + output_column_name: str + + +class RowResult(Result): + data: dict[str, Any] + +ResultT = TaskResult | RowResult -class MultiRowsGenExecutor: + +class _Executor: def __init__( self, *, - table: GenerativeTable, - meta: TableMeta, request: Request, - body: RowAddRequest | RowRegenRequest, - rows_batch_size: int, - cols_batch_size: int, - max_write_batch_size: int, + table: GenerativeTableCore, + organization: OrganizationRead, + project: ProjectRead, + body: MultiRowAddRequest | MultiRowRegenRequest | RowAdd | RowRegen, ) -> None: - self.table = table - self.meta = meta self.request = request + self._request_id: str = request.state.id + self.table = table + self._table_id = table.table_id + self._col_map = {c.column_id: c for c in self.table.column_metadata} + self.organization = organization + self.project = project + if body.table_id != table.table_id: + raise ValueError(f"{body.table_id=} but {table.table_id=}") self.body = body - self.is_regen = isinstance(body, RowRegenRequest) - self.bodies = ( - [ - RowAdd( - table_id=self.body.table_id, - data=row_data, - stream=self.body.stream, - concurrent=self.body.concurrent, - ) - for row_data in self.body.data - ] - if isinstance(body, RowAddRequest) - else [ - RowRegen( - table_id=body.table_id, - row_id=row_id, - regen_strategy=body.regen_strategy, - output_column_id=body.output_column_id, - stream=body.stream, - concurrent=self.body.concurrent, - ) - for row_id in body.row_ids - ] + self._stream = self.body.stream + # Determine batch sizes + self._multi_turn = ( + sum(getattr(col.gen_config, "multi_turn", False) for col in table.column_metadata) > 0 ) - self.rows_batch_size = rows_batch_size - self.cols_batch_size = cols_batch_size - self.max_write_batch_size = max_write_batch_size - self.external_keys: ExternalKeys = request.state.external_keys - - # Accumulated rows for batch write - self.batch_rows = [] - self.write_batch_size = self.optimal_write_batch_size() - - def _log_exception(self, exc: Exception, error_message: str): - if not isinstance(exc, (JamaiException, RequestValidationError)): - logger.exception(f"{self.request.state.id} - {error_message}") - - def _create_executor(self, body_: RowAdd | RowRegen): - self.executor = GenExecutor( - table=self.table, - meta=self.meta, - request=self.request, - body=body_, - cols_batch_size=self.cols_batch_size, - ) - - async def _execute( - self, body_, tmp_id=None - ) -> Any | tuple[GenTableChatCompletionChunks, dict]: - self._create_executor(body_) - if self.body.stream: - try: - async for chunk in await self.executor.gen_row(): - await self.queue.put(chunk) - except Exception as e: - self._log_exception(e, f'Error executing task "{tmp_id}" with body: {body_}') - await self.queue.put("data: [DONE]\n\n") + self._col_batch_size = ENV_CONFIG.concurrent_cols_batch_size if body.concurrent else 1 + self._row_batch_size = 1 if self._multi_turn else ENV_CONFIG.concurrent_rows_batch_size + + @classmethod + def _log(cls, msg: str, level: str = "INFO", request_id: str = "", **kwargs): + _log = f"{cls.__name__}: {msg}" + if request_id: + _log = f"{request_id} - {_log}" + logger.log(level, _log, **kwargs) + + def log(self, msg: str, level: str = "INFO", **kwargs): + self._log(msg, level, request_id=self._request_id, **kwargs) + + def log_exception(self, message: str, exc: Exception, **kwargs) -> None: + if isinstance(exc, (JamaiException, RequestValidationError)): + logger.info(f"{self._request_id} - {self.__class__.__name__}: {message}", **kwargs) else: - return await self.executor.gen_row() - - async def _gen_stream_rows(self): - content_length = 0 - self.queue = asyncio.Queue() - for i in range(0, len(self.bodies), self.rows_batch_size): - batch_bodies = self.bodies[i : i + self.rows_batch_size] - # Accumulate rows within the row batch - for j, body_ in enumerate(batch_bodies): - asyncio.create_task(self._execute(body_, j)) - - done_row_count = 0 - while done_row_count < len(batch_bodies): - chunk = await self.queue.get() - if isinstance(chunk, dict) or isinstance(chunk, tuple): - # Accumulate complete row - self.batch_rows.append(chunk) - if len(self.batch_rows) >= self.write_batch_size: - await self._write_rows_to_table() - else: - if chunk == "data: [DONE]\n\n": - done_row_count += 1 - else: - content_length += len(chunk.encode("utf-8")) - yield chunk - - # Write the remaining rows to table - if len(self.batch_rows) > 0: - await self._write_rows_to_table() - # Final yield after writing is done - chunk = "data: [DONE]\n\n" - yield chunk - content_length += len(chunk.encode("utf-8")) - - self.request.state.billing.create_egress_events(content_length / (1024**3)) + logger.exception( + f"{self._request_id} - {self.__class__.__name__}: {message}", **kwargs + ) @staticmethod def _log_item(x: Any) -> str: @@ -168,905 +144,1091 @@ def _log_item(x: Any) -> str: else: return f"type={type(x)}" - def optimal_write_batch_size(self): - """ - Dynamically adjust batch size for progress updates, capped at `max_write_batch_size`. - """ - total_rows = len(self.bodies) - - # Aim for 5-10 batches, but ensure at least 10 rows per batch - target_batches = min(max(5, total_rows // 10), 10) - write_batch_size = max(total_rows // target_batches, 10) - - # Cap at max_write_batch_size - write_batch_size = min(write_batch_size, self.max_write_batch_size) - # Handle edge cases for small datasets - if total_rows <= self.max_write_batch_size: - write_batch_size = total_rows - logger.info(f"Write to table: {total_rows} row(s) at once.") +class MultiRowGenExecutor(_Executor): + def __init__( + self, + *, + request: Request, + table: GenerativeTableCore, + organization: OrganizationRead, + project: ProjectRead, + body: MultiRowAddRequest | MultiRowRegenRequest, + ) -> None: + _kwargs = dict(request=request, table=table, organization=organization, project=project) + super().__init__(body=body, **_kwargs) + # Executors + if isinstance(body, MultiRowAddRequest): + self._is_regen = False + self._executors = [ + GenExecutor( + body=RowAdd( + table_id=body.table_id, + data=row_data, + stream=body.stream, + concurrent=body.concurrent, + ), + **_kwargs, + ) + for row_data in body.data + ] else: - full_batches = total_rows // write_batch_size - remainder = total_rows % write_batch_size - - logger.info( - f"Write to table: {full_batches} batches with {write_batch_size} row(s) each." - ) - - if remainder: - logger.info(f"Write to table: 1 additional batch with {remainder} row(s).") + self._is_regen = True + self._executors = [ + GenExecutor( + body=RowRegen( + table_id=body.table_id, + row_id=row_id, + regen_strategy=body.regen_strategy, + output_column_id=body.output_column_id, + stream=body.stream, + concurrent=body.concurrent, + ), + **_kwargs, + ) + for row_id in body.row_ids + ] + # Determine write batch size + if self._multi_turn: + self._write_batch_size = 1 + else: + # Write batch size will be [10, max_write_batch_size] + _bs = len(self._executors) // 10 + self._write_batch_size = max(min(_bs, ENV_CONFIG.max_write_batch_size), 10) + # Task result queue + self._queue: Queue[ResultT | None] = Queue() + # Accumulated rows for batch write + self._batch_rows: list[dict[str, Any]] = [] + # Billing + self.content_length = 0 + self._billing: BillingManager = self.request.state.billing + + async def generate(self) -> AsyncGenerator[str, None] | MultiRowCompletionResponse: + if self._stream: + return self._generate() + else: + return await anext(self._generate()) - return write_batch_size + async def _generate(self) -> AsyncGenerator[str | MultiRowCompletionResponse, None]: + rows = { + exe.row_id: RowCompletionResponse(columns={}, row_id=exe.row_id) + for exe in self._executors + } + async with TaskGroup() as tg: + pending_executors = [exe for exe in self._executors] + while len(pending_executors) > 0: + _execs = pending_executors[: self._row_batch_size] + for exe in _execs: + tg.create_task(exe.generate(self._queue)) + done_rows = 0 + while done_rows < len(_execs): + res = await self._queue.get() + self.log( + "len(_execs)={a} done_rows={b} res={c}", + "DEBUG", + a=len(_execs), + b=done_rows, + c=res, + ) + if res is None: + pass + elif isinstance(res, TaskResult): + # logger.debug(f"{res.response.content=}") + if self._stream: + _sse = f"data: {res.response.model_dump_json()}\n\n" + self.content_length += len(_sse.encode("utf-8")) + yield _sse + else: + rows[res.row_id].columns[res.output_column_name] = res.response + else: + self._batch_rows.append(res.data) + if len(self._batch_rows) >= self._write_batch_size: + await self._write_rows_to_table() + done_rows += 1 + pending_executors = pending_executors[self._row_batch_size :] + # Write any remaining rows + await self._write_rows_to_table() + # End of all tasks + if self._stream: + _sse = "data: [DONE]\n\n" + yield _sse + self.content_length += len(_sse.encode("utf-8")) + self._billing.create_egress_events(self.content_length / (1024**3)) + else: + yield MultiRowCompletionResponse(rows=list(rows.values())) - async def _write_rows_to_table(self): + async def _write_rows_to_table(self) -> None: """ Writes accumulated rows to the table in batches. """ - with self.table.create_session() as session: - if not self.is_regen: - logger.info( - f"{self.request.state.id} - Writing {len(self.batch_rows)} rows to table '{self.body.table_id}'" + if len(self._batch_rows) == 0: + return + if self._is_regen: + self.log(f'Table "{self._table_id}": Updating {len(self._batch_rows):,d} rows.') + try: + await self.table.update_rows( + {row["ID"]: row for row in self._batch_rows}, ignore_state_columns=False + ) + except Exception as e: + _data = [ + {k: self._log_item(v) for k, v in row.items()} for row in self._batch_rows + ] + self.log_exception( + f'Table "{self._table_id}": Failed to update {len(self._batch_rows):,d} rows: {_data}', + e, ) - try: - await self.table.add_rows(session, self.body.table_id, self.batch_rows) - except Exception as e: - _data = [ - {k: self._log_item(v) for k, v in row.items()} for row in self.batch_rows - ] - self._log_exception(e, f"Error adding {len(self.batch_rows)} rows: {_data}") - else: - # Updating existing rows - for row_id, row in self.batch_rows: - _data = {k: self._log_item(v) for k, v in row.items()} - logger.info( - f"{self.request.state.id} - Updating row with ID '{row_id}' in table '{self.body.table_id}': " - f"{_data}" - ) - try: - self.table.update_rows( - session, self.body.table_id, where=f"`ID` = '{row_id}'", values=row - ) - except Exception as e: - self._log_exception(e, f'Error updating row "{row_id}" with values: {row}') - self.batch_rows.clear() - - async def _gen_nonstream_rows(self): - rows: list[GenTableChatCompletionChunks] = [] - for i in range(0, len(self.bodies), self.rows_batch_size): - batched_bodies = self.bodies[i : i + self.rows_batch_size] - rows_and_column_dicts = await asyncio.gather( - *[self._execute(body_) for body_ in batched_bodies] - ) - # Accumulate generated rows - for rows_, column_dict in rows_and_column_dicts: - rows.append(rows_) - - if self.is_regen: - self.batch_rows.append((rows_.row_id, column_dict)) - else: - self.batch_rows.append(column_dict) - - if len(self.batch_rows) >= self.write_batch_size: - await self._write_rows_to_table() - - # Write the reminding rows to table - if len(self.batch_rows) > 0: - await self._write_rows_to_table() - - return GenTableRowsChatCompletionChunks(rows=rows) - - async def gen_rows(self) -> Any | GenTableChatCompletionChunks: - if self.body.stream: - return self._gen_stream_rows() else: - return await self._gen_nonstream_rows() + self.log(f'Table "{self._table_id}": Writing {len(self._batch_rows):,d} rows.') + try: + await self.table.add_rows( + self._batch_rows, ignore_info_columns=False, ignore_state_columns=False + ) + except Exception as e: + _data = [ + {k: self._log_item(v) for k, v in row.items()} for row in self._batch_rows + ] + self.log_exception( + f'Table "{self._table_id}": Failed to add {len(self._batch_rows):,d} rows: {_data}', + e, + ) + self._batch_rows.clear() -class GenExecutor: +class GenExecutor(_Executor): def __init__( self, *, - table: GenerativeTable, - meta: TableMeta, request: Request, + table: GenerativeTableCore, + organization: OrganizationRead, + project: ProjectRead, body: RowAdd | RowRegen, - cols_batch_size: int, ) -> None: - self.table = table - self.meta = meta - self.body = body - self.is_row_add = isinstance(self.body, RowAdd) - self.column_dict = {} - self.regen_column_dict = {} - self.tasks = [] - self.table_id = body.table_id - self.request = request - if isinstance(body, RowAdd): - body.data["ID"] = body.data.get("ID", uuid7_draft2_str()) - self.row_id = body.data["ID"] - else: - self.row_id = body.row_id - self.cols_batch_size = cols_batch_size if self.body.concurrent else 1 - self.external_keys: ExternalKeys = request.state.external_keys - self.llm = LLMEngine(request=request) - self.error_columns = [] - self.tag_regen_columns = [] - self.skip_regen_columns = [] - self.image_columns = [] - self.audio_columns = [] - self.audio_gen_columns = [] - self.image_column_dict = {} - self.document_column_dict = {} - self.audio_column_dict = {} - - def _log_exception(self, exc: Exception, error_message: str): - if not isinstance(exc, (JamaiException, RequestValidationError)): - logger.exception(f"{self.request.state.id} - {error_message}") - - async def _get_file_binary(self, uri: str) -> bytes: - async with open_uri_async(uri) as file_handle: - return await file_handle.read() - - # TODO: resolve duplicated code - async def _convert_uri_to_base64(self, uri: str, col_id: str) -> tuple[dict, bool]: - """ - Converts a URI to a base64-encoded string with the appropriate prefix and determines the file type. - - Args: - uri (str): The URI of the file. - col_id (str): The column ID for error context. - - Returns: - tuple: A tuple containing: - - dict: A dictionary with the base64-encoded data and its prefix. - - bool: A boolean indicating whether the file is audio. - - Raises: - ValueError: If the file format is unsupported. - """ - if not uri.startswith(("file://", "s3://")): - raise ValueError( - f"Invalid URI format for column {col_id}. URI must start with 'file://' or 's3://'" - ) - - # uri -> file binary -> base64 - file_binary = await self._get_file_binary(uri) - base64_data = self._binary_to_base64(file_binary) - - # uri -> file extension -> prefix - extension = splitext(uri)[1].lower() - - if extension in [".mp3", ".wav"]: - prefix = f"data:audio/{"mpeg" if extension == ".mp3" else "x-wav"};base64," - return { - "data": base64_data, - "format": extension[1:], - "url": prefix + base64_data, - }, True - elif extension in [".jpeg", ".jpg", ".png", ".gif", ".webp"]: - extension = ".jpeg" if extension == ".jpg" else extension - prefix = f"data:image/{extension[1:]};base64," - return {"url": prefix + base64_data}, False + super().__init__( + request=request, table=table, organization=organization, project=project, body=body + ) + # Engines + self.lm = LMEngine(organization=organization, project=project, request=request) + # Tasks + self._tasks: list[Task] = [] + if isinstance(self.body, RowAdd): + self.body.data["ID"] = uuid7_draft2_str() + self.body.data.pop("Updated at", None) + self._row_id = self.body.data["ID"] + self._regen_strategy = None else: - raise ValueError( - "Unsupported file type. Supported formats are: " - "['jpeg/jpg', 'png', 'gif', 'webp'] for images and ['mp3', 'wav'] for audio." - ) - - async def gen_row(self) -> Any | tuple[GenTableChatCompletionChunks, dict]: - cols = self.meta.cols_schema - col_ids = set(c.id for c in cols) - if self.is_row_add: - self.column_dict = {k: v for k, v in self.body.data.items() if k in col_ids} + self._row_id = self.body.row_id + self._regen_strategy = self.body.regen_strategy + if not self.body.output_column_id: + if self._regen_strategy != RegenStrategy.RUN_ALL: + raise BadInputError( + f'`output_column_id` is required when `regen_strategy` is not "{str(RegenStrategy.RUN_ALL)}".' + ) + else: + output_column_ids = [ + col.column_id for col in self.table.column_metadata if col.is_output_column + ] + if self.body.output_column_id not in output_column_ids: + output_column_ids = [f'"{c}"' for c in output_column_ids] + raise ResourceNotFoundError( + ( + f'Column "{self.body.output_column_id}" not found in table "{self._table_id}". ' + f"Available output columns: {output_column_ids}" + ) + ) + self._column_dict: dict[str, Any] = {} + self._error_columns: list[str] = [] + self._task_signal: Queue[None] = Queue() + + @property + def row_id(self) -> str: + return self._row_id + + @property + def tasks(self) -> list[Task]: + return self._tasks + + @property + def column_dict(self) -> dict[str, Any]: + return self._column_dict + + # @property + # def done(self) -> bool: + # return all(task.status == "done" for task in self._tasks) + + async def _setup_tasks(self) -> None: + cols = self.table.column_metadata + # Process inputs and dependencies + if self._regen_strategy is None: + _body: RowAdd = self.body + self._column_dict = {k: v for k, v in _body.data.items() if k in self._col_map} else: - self.column_dict = self.table.get_row(self.table_id, self.row_id) + _body: RowRegen = self.body + _row = await self.table.get_row(self._row_id) + match self._regen_strategy: + case RegenStrategy.RUN_ALL: + # Keep all input columns + self._column_dict = { + k: v + for k, v in _row.items() + if not ( + self._col_map[k].is_output_column + or self._col_map[k.rstrip("_")].is_output_column + ) + } + case RegenStrategy.RUN_SELECTED: + # Keep all columns except the one being generated + self._column_dict = { + k: v + for k, v in _row.items() + if k not in (_body.output_column_id, f"{_body.output_column_id}_") + } + case RegenStrategy.RUN_BEFORE | RegenStrategy.RUN_AFTER: + _cols = [col.column_id for col in cols if col.is_output_column] + try: + idx = _cols.index(_body.output_column_id) + except ValueError as e: + raise BadInputError( + f'Column "{_body.output_column_id}" not found in table "{self._table_id}".' + ) from e + # Keep columns that are not being generated + if self._regen_strategy == RegenStrategy.RUN_BEFORE: + _cols = _cols[idx + 1 :] + else: + _cols = _cols[:idx] + _cols += [f"{c}_" for c in _cols] + _cols += [col.column_id for col in cols if not col.is_output_column] + self._column_dict = { + k: v for k, v in _row.items() if k in _cols or k.lower() == "id" + } + case _: + raise BadInputError(f'Invalid regen strategy: "{str(self._regen_strategy)}".') + # # Filter out state columns + # self._column_dict = {k: v for k, v in self._column_dict.items() if not k.endswith("_")} + self.log("self._column_dict={column_dict}", "DEBUG", column_dict=self._column_dict) - self.tasks = [] + self._tasks = [] for col in cols: - # Skip info columns - if col.id.lower() in ("id", "updated at"): + # Skip info and state columns + if col.is_info_column or col.is_state_column: continue - # Skip state column - if col.id.endswith("_"): - continue - # If user provides value, skip - if self.is_row_add and col.id in self.column_dict: - continue - # If gen_config not defined, set None and skip + # Create task if col.gen_config is None: - if self.is_row_add: - self.column_dict[col.id] = None + # Default value for missing column during row add + # Even though this is also handled by `GenerativeTableCore`, + # we need this to avoid hanging tasks due to missing inputs + self._column_dict[col.column_id] = self._column_dict.get(col.column_id) continue - if isinstance(col.gen_config, EmbedGenConfig): - task_type = "embed" - if col.vlen <= 0: - raise ValueError( - f'"gen_config" is EmbedGenConfig but `col.vlen` is {col.vlen}' - ) - gen_config = col.gen_config - elif isinstance(col.gen_config, LLMGenConfig): - task_type = "chat" - if col.gen_config.multi_turn: - messages = self.table.get_conversation_thread( - table_id=self.table_id, - column_id=col.id, - row_id="" if self.is_row_add else self.row_id, - include=False, - ).thread - user_message = col.gen_config.prompt - messages.append(ChatEntry.user(content=user_message if user_message else ".")) - if len(messages) == 0: - continue - else: - messages = [ - ChatEntry.system(col.gen_config.system_prompt), - ChatEntry.user(col.gen_config.prompt), - ] - gen_config = ChatRequest( - id=self.request.state.id, messages=messages, **col.gen_config.model_dump() + if col.column_id in self._column_dict: + self.log(f'Skipped generation for column "{col.column_id}".') + continue + self._tasks.append( + Task( + output_column_name=col.column_id, + dtype=col.dtype, + body=col.gen_config, ) - if gen_config.model != "": - model_config = self.request.state.all_models.get_llm_model_info( - gen_config.model + ) + self.log("self._tasks={tasks}", "DEBUG", tasks=self._tasks) + column_dict_keys = set(self._column_dict.keys()) + col_ids = set(self._col_map.keys()) + if len(column_dict_keys - col_ids) > 0: + logger.warning( + f'Table "{self._table_id}": There are unexpected columns: {column_dict_keys - col_ids}' + ) + self.log(f"Prepared {len(self._tasks):,d} tasks.", "DEBUG") + + async def generate(self, q: Queue[ResultT | None]) -> None: + await self._setup_tasks() + async with TaskGroup() as tg: + pending_tasks = [task for task in self._tasks if task.status == "pending"] + self.log("Pending tasks: {pending_tasks}", "DEBUG", pending_tasks=pending_tasks) + while len(pending_tasks) > 0: + # Go through pending tasks + ready_tasks = [task for task in pending_tasks if self._is_task_ready(task)] + for task in ready_tasks[: self._col_batch_size]: + if not self._is_task_ready(task): + continue + task.status = "running" + tg.create_task(self._execute_task(task, q)) + # Wait for a task to complete + await self._task_signal.get() + pending_tasks = [task for task in self._tasks if task.status == "pending"] + self.log("Pending tasks: {pending_tasks}", "DEBUG", pending_tasks=pending_tasks) + # Put row data + await q.put(RowResult(data=self._column_dict, row_id=self._row_id)) + self.log("All tasks completed.", "DEBUG") + + def _is_task_ready(self, task: Task) -> bool: + match task.body: + case LLMGenConfig(): + inputs = self._extract_upstream_columns(task.body.prompt) + case EmbedGenConfig() | CodeGenConfig(): + inputs = [task.body.source_column] + case PythonGenConfig(): + inputs = self._extract_all_upstream_columns(task.output_column_name) + case _: + raise ValueError(f'Table "{self._table_id}": Unexpected task type: {task.body}') + # Only consider input references that exist in table + inputs = [i for i in inputs if i in self._col_map] + task_ready = all(col in self._column_dict for col in inputs) + return task_ready + + async def _execute_task(self, task: Task, q: Queue[ResultT | None]) -> None: + logger.debug(f"Processing column: {task.output_column_name}") + match task.body: + case LLMGenConfig(): + await self._execute_chat_task(task, q) + case EmbedGenConfig(): + await self._execute_embed_task(task, q) + case CodeGenConfig(): + await self._execute_code_task(task, q) + case PythonGenConfig(): + await self._execute_python_task(task, q) + case _: + raise ValueError(f'Table "{self._table_id}": Unexpected task type: {task.body}') + + async def _execute_chat_task(self, task: Task, q: Queue[ResultT | None]) -> None: + output_column = task.output_column_name + body: LLMGenConfig = task.body + # Check if a value is provided + try: + # TODO: Perhaps we need to emit references too + result = self._column_dict[output_column] + # response_kwargs = dict( + # id=self._request_id, + # created=int(time()), + # model="", + # usage=ChatCompletionUsage(), + # choices=[ + # ChatCompletionChoice( + # message=ChatCompletionMessage(content=result), + # index=0, + # ) + # ], + # ) + # if self._stream: + # response = CellCompletionResponse( + # **response_kwargs, + # output_column_name=output_column, + # row_id=self._row_id, + # ) + # else: + # response = ChatCompletionResponse(**response_kwargs) + # self.log(f'Skipped completion for column "{output_column}".') + # if self._regen_strategy is not None: + # # TODO: Perhaps we should always emit column value even if it is provided? + # await q.put( + # TaskResult( + # response=response, + # output_column_name=output_column, + # row_id=self._row_id, + # ) + # ) + await q.put(None) + await self._signal_task_completion(task, result) + return + except KeyError: + pass + + # Perform completion + result = "" + references = None + try: + # Error circuit breaker + self._check_upstream_error(self._extract_upstream_columns(body.prompt)) + # Form the request body + if body.multi_turn: + messages = ( + await self.table.get_conversation_thread( + column_id=output_column, + row_id="" if self._regen_strategy is None else self._row_id, + include_row=False, ) - if ( - "audio" in model_config.capabilities - and model_config.deployments[0].provider == "openai" - ): - self.audio_gen_columns.append(col.id) - elif isinstance(col.gen_config, CodeGenConfig): - task_type = "code" - gen_config = col.gen_config + ).thread else: - raise ValueError(f'Unexpected "gen_config" type: {type(col.gen_config)}') - self.tasks.append( - Task(type=task_type, output_column_name=col.id, body=gen_config, dtype=col.dtype) + messages = [ChatThreadEntry.system(body.system_prompt)] + messages.append( + ChatThreadEntry.user( + content=self.table.interpolate_column( + body.prompt if body.prompt else ".", self._column_dict + ) + ) ) - - self.image_columns = [col.id for col in cols if col.dtype == "image"] - self.audio_columns = [col.id for col in cols if col.dtype == "audio"] - for col_id in self.image_columns + self.audio_columns: - if self.column_dict.get(col_id, None) is not None: - uri = self.column_dict[col_id] - b64, is_audio = await self._convert_uri_to_base64(uri, col_id) - - if is_audio: - if col_id not in self.audio_columns: - raise ValueError( - f"Column {col_id} is not marked as an audio column but contains audio data." + # Load files for each user message + messages = [await self._load_files(m) for m in messages] + req = ChatRequest( + id=self._request_id, + messages=[ChatEntry.model_validate(m.model_dump()) for m in messages], + **body.model_dump(), + ) + req, references = await self._setup_rag(req) + if self._stream: + reasoning = "" + result = "" + if references is not None: + ref = CellReferencesResponse( + **references.model_dump(exclude=["object"]), + output_column_name=output_column, + row_id=self._row_id, + ) + await q.put( + TaskResult( + response=ref, + output_column_name=output_column, + row_id=self._row_id, ) - self.audio_column_dict[col_id] = ( - { - "data": b64["data"], - "format": b64["format"], - }, # for audio gen model - {"url": b64["url"]}, # for audio model ) - else: - if col_id not in self.image_columns: - raise ValueError( - f"Column {col_id} is not marked as a file column but contains image data." + async for chunk in self.lm.chat_completion_stream( + messages=req.messages, + **req.hyperparams, + ): + reasoning += chunk.reasoning_content + result += chunk.content + # if chunk.content is None and chunk.usage is None: + # continue + chunk = CellCompletionResponse( + **chunk.model_dump(exclude={"object"}), + output_column_name=output_column, + row_id=self._row_id, + ) + await q.put( + TaskResult( + response=chunk, + output_column_name=output_column, + row_id=self._row_id, ) - self.image_column_dict[col_id] = b64 - - column_dict_keys = set(self.column_dict.keys()) - if len(column_dict_keys - col_ids) > 0: - raise ValueError(f"There are unexpected columns: {column_dict_keys - col_ids}") - - if self.body.stream: - return self._stream_concurrent_execution() - else: - return await self._nonstream_concurrent_execution() - - async def _run_embed_tasks(self): - """ - Executes embedding tasks sequentially. - """ - embed_tasks = [task for task in self.tasks if task.type == "embed"] - for task in embed_tasks: - output_column_name = task.output_column_name - body: EmbedGenConfig = task.body - embedding_model = body.embedding_model - embedder = CloudEmbedder(request=self.request) - source = self.column_dict[body.source_column] - embedding = await embedder.embed_documents( - embedding_model, texts=["." if source is None else source] - ) - embedding = np.asarray(embedding.data[0].embedding, dtype=task.dtype) - embedding = embedding / np.linalg.norm(embedding) - self.column_dict[output_column_name] = embedding - self.regen_column_dict[output_column_name] = embedding - - def _extract_upstream_columns(self, text: str) -> list[str]: - matches = re.findall(GEN_CONFIG_VAR_PATTERN, text) - # return the content inside ${...} - return matches - - def _extract_upstream_image_columns(self, text: str) -> list[str]: - matches = re.findall(GEN_CONFIG_VAR_PATTERN, text) - # return the content inside ${...} - return [match for match in matches if self.llm_tasks[matches].dtype == "img"] - - def _binary_to_base64(self, binary_data: bytes) -> str: - return base64.b64encode(binary_data).decode("utf-8") - - def _interpolate_column(self, prompt: str, base_column_name: str) -> str | dict[str, Any]: - """ - Replaces / interpolates column references in the prompt with their contents. - - Args: - prompt (str): The original prompt with zero or more column references. - - Returns: - new_prompt (str | dict[str, Any]): The prompt with column references replaced. - """ - - image_column_names = [] - audio_column_names = [] + ) + if chunk.finish_reason == "error": + self._error_columns.append(output_column) + else: + response = await self.lm.chat_completion( + messages=req.messages, + **req.hyperparams, + ) + response.references = references + await q.put( + TaskResult( + response=response, + output_column_name=output_column, + row_id=self._row_id, + ) + ) + result = response.content + reasoning = response.reasoning_content - def replace_match(match): - column_name = match.group(1) # Extract the column_name from the match - try: - if column_name in self.image_column_dict: - image_column_names.append(column_name) - return "" - elif column_name in self.audio_column_dict: - audio_column_names.append(column_name) - if base_column_name in self.audio_gen_columns: - return "" # follow the content type - else: - return "" - elif column_name in self.document_column_dict: - return self.document_column_dict[column_name] - return str(self.column_dict[column_name]) # Data can be non-string - except KeyError as e: - raise BadInputError(f"Requested column '{column_name}' is not found.") from e - - content_ = re.sub(GEN_CONFIG_VAR_PATTERN, replace_match, prompt) - content = [{"type": "text", "text": content_}] - - if len(image_column_names) > 0 and len(audio_column_names) > 0: - raise BadInputError("Either image or audio is supported per completion.") - - if len(image_column_names) > 0: - if len(image_column_names) > 1: - raise BadInputError("Only one image is supported per completion.") - - content.append( - { - "type": "image_url", - "image_url": self.image_column_dict[image_column_names[0]], - } + except Exception as e: + response_kwargs = dict( + id=self._request_id, + created=int(time()), + model="", + usage=ChatCompletionUsage(), + choices=[ + ChatCompletionChoice( + message=ChatCompletionMessage(content=f"[ERROR] {str(e)}"), + index=0, + finish_reason="error", + ) + ], ) - return content - elif len(audio_column_names) > 0: - if len(audio_column_names) > 1: - raise BadInputError("Only one audio is supported per completion.") - - if base_column_name in self.audio_gen_columns: - content.append( - { - "type": "input_audio", - "input_audio": self.audio_column_dict[audio_column_names[0]][0], - } + if self._stream: + response = CellCompletionResponse( + **response_kwargs, + output_column_name=output_column, + row_id=self._row_id, ) else: - content.append( - { - "type": "audio_url", - "audio_url": self.audio_column_dict[audio_column_names[0]][1], - } + response = ChatCompletionResponse(**response_kwargs) + await q.put( + TaskResult( + response=response, + output_column_name=output_column, + row_id=self._row_id, ) - return content - else: - return content_ - - def _check_upstream_error_chunk(self, content: str) -> None: - matches = re.findall(GEN_CONFIG_VAR_PATTERN, content) - if any([match in self.error_columns for match in matches]): - raise Exception - - def _validate_model(self, body: LLMGenConfig, output_column_name: str): - for input_column_name in self.dependencies[output_column_name]: - if input_column_name in self.image_column_dict: - try: - body.model = self.llm.validate_model_id(body.model, ["image"]) - break - except ResourceNotFoundError as e: - raise BadInputError( - f'Column "{output_column_name}" referred to image file input but using a chat model ' - f'"{self.llm.get_model_name(body.model) if self.llm.is_browser else body.model}", ' - "select image model instead.", - ) from e - if input_column_name in self.audio_column_dict: - try: - body.model = self.llm.validate_model_id(body.model, ["audio"]) - break - except ResourceNotFoundError as e: - raise BadInputError( - f'Column "{output_column_name}" referred to audio file input but using a chat model ' - f'"{self.llm.get_model_name(body.model) if self.llm.is_browser else body.model}", ' - "select audio model instead.", - ) from e - - async def _execute_code(self, task: Task) -> str: - output_column_name = task.output_column_name - body: CodeGenConfig = task.body - dtype = task.dtype - source_code = self.column_dict[body.source_column] - + ) + result = response.content + reasoning = response.reasoning_content + self._error_columns.append(output_column) + self.log_exception( + f'Table "{self._table_id}": Failed to generate completion for column "{output_column}": {repr(e)}', + e, + ) + finally: + await q.put(None) + state_col = f"{task.output_column_name}_" + state = self._column_dict.get(state_col, {}) + if references is not None: + state["references"] = references.model_dump(mode="json") + if reasoning: + state["reasoning"] = reasoning + self._column_dict[state_col] = state + await self._signal_task_completion(task, result) + self.log(f'Streamed completion for column "{output_column}": <{mask_string(result)}>.') + + async def _execute_embed_task(self, task: Task, q: Queue[ResultT | None]) -> None: + output_column = task.output_column_name + # Check if a value is provided try: - new_column_value = await code_executor(source_code, dtype, self.request) - except Exception as e: - new_column_value = f"[ERROR] {str(e)}" - self._log_exception(e, f'Error executing code for column "{output_column_name}": {e}') - - if dtype == "image" and new_column_value is not None: + embedding = self._column_dict[output_column] + if isinstance(embedding, np.ndarray): + pass + elif isinstance(embedding, list): + embedding = np.asarray(embedding) + else: + raise TypeError( + f"Unexpected embedding type, expected `np.ndarray` or `list`, got `{type(embedding)}`." + ) + # Perform embedding + except (KeyError, TypeError): + body: EmbedGenConfig = task.body try: - ( - self.image_column_dict[output_column_name], - _, - ) = await self._convert_uri_to_base64(new_column_value, output_column_name) - except ValueError as e: - self._log_exception(e, f"Invalid file path for column '{output_column_name}'") - new_column_value = None - - return new_column_value + # Error circuit breaker + self._check_upstream_error([body.source_column]) + # TODO: We can find a way to batch embedding tasks + source = self._column_dict.get(body.source_column, None) + embedding = await self.lm.embed_documents( + model=body.embedding_model, + texts=["." if source is None else source], + ) + embedding = np.asarray(embedding.data[0].embedding, dtype=task.dtype) + embedding = embedding / np.linalg.norm(embedding) + except Exception as e: + self.log_exception( + f'Table "{self._table_id}": Failed to embed for column "{output_column}": {repr(e)}', + e, + ) + embedding = None + # TODO: Perhaps we need to emit embeddings + await q.put(None) + await self._signal_task_completion(task, embedding) - async def _execute_task_stream(self, task: Task) -> AsyncGenerator[str, None]: - """ - Executes a single task in a streaming manner, returning an asynchronous generator of chunks. - """ - output_column_name = task.output_column_name - body: ChatRequest = task.body + async def _execute_code_task(self, task: Task, q: Queue[ResultT | None]) -> None: + output_column = task.output_column_name + body: CodeGenConfig = task.body + # Check if a value is provided try: - logger.debug(f"Processing column: {output_column_name}") + result = self._column_dict[output_column] + # response_kwargs = dict( + # id=self._request_id, + # created=int(time()), + # model="code_execution", + # usage=ChatCompletionUsage(), + # choices=[ + # ChatCompletionChoice( + # message=ChatCompletionMessage(content=result), + # index=0, + # ) + # ], + # ) + # if self._stream: + # response = CellCompletionResponse( + # **response_kwargs, + # output_column_name=output_column, + # row_id=self._row_id, + # ) + # else: + # response = ChatCompletionResponse(**response_kwargs) + + # self.log(f'Skipped code execution for column "{output_column}".') + # if self._regen_strategy is not None: + # await q.put( + # TaskResult( + # response=response, + # output_column_name=output_column, + # row_id=self._row_id, + # ) + # ) + await q.put(None) + await self._signal_task_completion(task, result) + return + except KeyError: + pass - if output_column_name in self.skip_regen_columns: - new_column_value = self.column_dict[output_column_name] - logger.debug( - f"Skipped regen for `{output_column_name}`, value: {new_column_value}" + # Perform code execution + result = "" + try: + # Error circuit breaker + self._check_upstream_error([body.source_column]) + source_code = self._column_dict.get(body.source_column, "") + + # Extract bytes from ColumnDtype.AUDIO and ColumnDtype.IMAGE and put it into a dictionary + row_data = self._column_dict.copy() + self.table.postprocess_rows([row_data], include_state=False) + for k, v in row_data.items(): + col = next((col for col in self.table.column_metadata if col.column_id == k), None) + if col and (col.dtype == ColumnDtype.AUDIO or col.dtype == ColumnDtype.IMAGE): + row_data[k] = await _load_uri_as_bytes(v) + + if source_code and row_data: + result = await code_executor( + request=self.request, + organization_id=self.organization.id, + project_id=self.project.id, + source_code=source_code, + output_column=output_column, + row_data=row_data, + dtype=task.dtype, ) + else: + result = "" - elif isinstance(body, CodeGenConfig): - new_column_value = await self._execute_code(task) - logger.info(f"Executed Code Execution Column: '{output_column_name}'") - chunk = GenTableStreamChatCompletionChunk( - id=self.request.state.id, - object="gen_table.completion.chunk", - created=int(time()), - model="code_execution", - usage=None, - choices=[ - ChatCompletionChoiceDelta( - message=ChatEntry.assistant(new_column_value), - index=0, - ) - ], - output_column_name=output_column_name, - row_id=self.row_id, + response_kwargs = dict( + id=self._request_id, + created=int(time()), + model="code_execution", + usage=ChatCompletionUsage(), + choices=[ + ChatCompletionChoice( + message=ChatCompletionMessage(content=result), + index=0, + ) + ], + ) + if self._stream: + response = CellCompletionResponse( + **response_kwargs, + output_column_name=output_column, + row_id=self._row_id, ) - yield f"data: {chunk.model_dump_json()}\n\n" + else: + response = ChatCompletionResponse(**response_kwargs) - elif isinstance(body, ChatRequest): - self._check_upstream_error_chunk(body.messages[-1].content) - body.messages[-1].content = self._interpolate_column( - body.messages[-1].content, output_column_name + await q.put( + TaskResult( + response=response, + output_column_name=output_column, + row_id=self._row_id, ) + ) - if isinstance(body.messages[-1].content, list): - self._validate_model(body, output_column_name) - - if output_column_name in self.image_columns + self.audio_columns: - new_column_value = None - logger.info( - f"Identified output column `{output_column_name}` as image / audio type, set value to {new_column_value}" - ) - chunk = GenTableStreamChatCompletionChunk( - id=self.request.state.id, - object="gen_table.completion.chunk", - created=int(time()), - model="", - usage=None, - choices=[ - ChatCompletionChoiceDelta( - message=ChatEntry.assistant(new_column_value), - index=0, - ) - ], - output_column_name=output_column_name, - row_id=self.row_id, - ) - yield f"data: {chunk.model_dump_json()}\n\n" - else: - new_column_value = "" - kwargs = body.model_dump() - messages, references = await self.llm.retrieve_references( - messages=kwargs.pop("messages"), - rag_params=kwargs.pop("rag_params", None), - **kwargs, - ) - if references is not None: - ref = GenTableStreamReferences( - **references.model_dump(exclude=["object"]), - output_column_name=output_column_name, - ) - yield f"data: {ref.model_dump_json()}\n\n" - async for chunk in self.llm.generate_stream(messages=messages, **kwargs): - new_column_value += chunk.text - chunk = GenTableStreamChatCompletionChunk( - **chunk.model_dump(exclude=["object"]), - output_column_name=output_column_name, - row_id=self.row_id, - ) - yield f"data: {chunk.model_dump_json()}\n\n" - if chunk.finish_reason == "error": - self.error_columns.append(output_column_name) - else: - raise ValueError(f"Unsupported task type: {type(body)}") + self.log(f'Executed code for column "{output_column}": <{mask_string(result)}>.') except Exception as e: - error_chunk = GenTableStreamChatCompletionChunk( - id=self.request.state.id, - object="gen_table.completion.chunk", + response_kwargs = dict( + id=self._request_id, created=int(time()), - model="", - usage=None, + model="code_execution", + usage=ChatCompletionUsage(), choices=[ - ChatCompletionChoiceDelta( - message=ChatEntry.assistant(f"[ERROR] {e}"), + ChatCompletionChoice( + message=ChatCompletionMessage(content=f"[ERROR] {str(e)}"), index=0, finish_reason="error", ) ], - output_column_name=output_column_name, - row_id=self.row_id, ) - yield f"data: {error_chunk.model_dump_json()}\n\n" - new_column_value = error_chunk.text - self.error_columns.append(output_column_name) - self._log_exception( - e, f'Error generating completion for column "{output_column_name}": {e}' + response = ( + CellCompletionResponse( + **response_kwargs, output_column_name=output_column, row_id=self._row_id + ) + if self._stream + else ChatCompletionResponse(**response_kwargs) ) - finally: - # Append new column data for subsequent tasks - self.column_dict[output_column_name] = new_column_value - self.regen_column_dict[output_column_name] = new_column_value - logger.info( - f"{self.request.state.id} - Streamed completion for " - f"{output_column_name}: <{mask_string(new_column_value)}>" + + await q.put( + TaskResult( + response=response, + output_column_name=output_column, + row_id=self._row_id, + ) + ) + result = response.content + self._error_columns.append(output_column) + self.log_exception( + f'Table "{self._table_id}": Failed to execute code for column "{output_column}": {repr(e)}', + e, ) + finally: + await q.put(None) + await self._signal_task_completion(task, result) - async def _execute_task_nonstream(self, task: Task): - """ - Executes a single task in a non-streaming manner. - """ - output_column_name = task.output_column_name - body: ChatRequest = task.body + async def _execute_python_task(self, task: Task, q: Queue[ResultT | None]) -> None: + output_column = task.output_column_name + body: PythonGenConfig = task.body + # Check if a value is provided try: - if output_column_name in self.skip_regen_columns: - new_column_value = self.column_dict[output_column_name] - response = ChatCompletionChunk( - id=self.request.state.id, - object="chat.completion.chunk", - created=int(time()), - model="", - usage=None, - choices=[ - ChatCompletionChoiceDelta( - message=ChatEntry.assistant(new_column_value), - index=0, - ) - ], - ) - logger.debug( - f"Skipped regen for `{output_column_name}`, value: {new_column_value}" - ) + result = self._column_dict[output_column] + # response_kwargs = dict( + # id=self._request_id, + # created=int(time()), + # model="python_fixed_function", + # usage=ChatCompletionUsage(), + # choices=[ + # ChatCompletionChoice( + # message=ChatCompletionMessage(content=result), + # index=0, + # ) + # ], + # ) + # if self._stream: + # response = CellCompletionResponse( + # **response_kwargs, + # output_column_name=output_column, + # row_id=self._row_id, + # ) + # else: + # response = ChatCompletionResponse(**response_kwargs) + + # self.log(f'Skipped python fixed function execution for column "{output_column}".') + # if self._regen_strategy is not None: + # await q.put( + # TaskResult( + # response=response, + # output_column_name=output_column, + # row_id=self._row_id, + # ) + # ) + await q.put(None) + await self._signal_task_completion(task, result) + return + except KeyError: + pass - elif isinstance(body, CodeGenConfig): - new_column_value = await self._execute_code(task) - response = ChatCompletionChunk( - id=self.request.state.id, - object="chat.completion.chunk", - created=int(time()), - model="code_execution", - usage=None, - choices=[ - ChatCompletionChoiceDelta( - index=0, - message=ChatEntry.assistant(new_column_value), - ) - ], - ) - logger.debug( - f"Identified as Code Execution Column: {task.output_column_name}, executing code." + # Perform python fixed function execution + result = "" + try: + # Error circuit breaker + # Extract all columns to the left and check for upstream errors + self._check_upstream_error(self._extract_all_upstream_columns(output_column)) + + # Extract bytes from ColumnDtype.AUDIO and ColumnDtype.IMAGE and put it into a dictionary + row_data = self._column_dict.copy() + self.table.postprocess_rows([row_data], include_state=False) + for k, v in row_data.items(): + col = next((col for col in self.table.column_metadata if col.column_id == k), None) + if col and (col.dtype == ColumnDtype.AUDIO or col.dtype == ColumnDtype.IMAGE): + row_data[k] = await _load_uri_as_bytes(v) + + if body.python_code and row_data: + result = await code_executor( + request=self.request, + organization_id=self.organization.id, + project_id=self.project.id, + source_code=body.python_code, + output_column=output_column, + row_data=row_data, + dtype=task.dtype, ) - elif isinstance(body, ChatRequest): - self._check_upstream_error_chunk(body.messages[-1].content) - try: - body.messages[-1].content = self._interpolate_column( - body.messages[-1].content, output_column_name - ) - except IndexError: - pass - - if isinstance(body.messages[-1].content, list): - self._validate_model(body, output_column_name) - - if output_column_name in self.image_columns + self.audio_columns: - new_column_value = None - response = ChatCompletionChunk( - id=self.request.state.id, - object="chat.completion.chunk", - created=int(time()), - model="", - usage=None, - choices=[ - ChatCompletionChoiceDelta( - message=ChatEntry.assistant(new_column_value), - index=0, - ) - ], - ) - logger.debug( - f"Identified output column `{output_column_name}` as image / audio type, set value to {new_column_value}" + + response_kwargs = dict( + id=self._request_id, + created=int(time()), + model="python_fixed_function", + usage=ChatCompletionUsage(), + choices=[ + ChatCompletionChoice( + message=ChatCompletionMessage(content=result), + index=0, ) - else: - response = await self.llm.rag(**body.model_dump()) - new_column_value = response.text - else: - raise ValueError(f"Unsupported task type: {type(body)}") - - # append new column data for subsequence tasks - self.column_dict[output_column_name] = new_column_value - self.regen_column_dict[output_column_name] = new_column_value - logger.info( - ( - f"{self.request.state.id} - Generated completion for {output_column_name}: " - f"<{mask_string(new_column_value)}>" + ], + ) + response = ( + CellCompletionResponse( + **response_kwargs, output_column_name=output_column, row_id=self._row_id + ) + if self._stream + else ChatCompletionResponse(**response_kwargs) + ) + + await q.put( + TaskResult( + response=response, + output_column_name=output_column, + row_id=self._row_id, ) ) - return response + + self.log( + f'Executed python code for column "{output_column}": <{mask_string(result)}>.' + ) except Exception as e: - error_chunk = ChatCompletionChunk( - id=self.request.state.id, - object="gen_table.completion.chunk", + response_kwargs = dict( + id=self._request_id, created=int(time()), - model="", - usage=None, + model="python_fixed_function", + usage=ChatCompletionUsage(), choices=[ - ChatCompletionChoiceDelta( - message=ChatEntry.assistant( - f'[ERROR] Column "{output_column_name}" referred to image file input but using a chat model ' - f'"{self.llm.get_model_name(body.model) if self.llm.is_browser else body.model}", ' - "select image model instead.", - ) - if isinstance(e, ResourceNotFoundError) - else ChatEntry.assistant(f"[ERROR] {e}"), + ChatCompletionChoice( + message=ChatCompletionMessage(content=f"[ERROR] {str(e)}"), index=0, finish_reason="error", ) ], ) - new_column_value = error_chunk.text - self.column_dict[output_column_name] = new_column_value - self.regen_column_dict[output_column_name] = new_column_value - self._log_exception( - e, f'Error generating completion for column "{output_column_name}": {e}' + response = ( + CellCompletionResponse( + **response_kwargs, output_column_name=output_column, row_id=self._row_id + ) + if self._stream + else ChatCompletionResponse(**response_kwargs) ) - return error_chunk - - def _setup_dependencies(self) -> None: - """ - Sets up dependencies for the tasks. - - This method initializes the dependencies for the tasks that need to be executed. It creates a dictionary - called `llm_tasks` where the keys are the output column names of the tasks and the values are the tasks themselves. - It also creates a dictionary called `dependencies` where the keys are the output column names of the tasks and the - values are the dependencies for each task. The dependencies are extracted from the content of the last message in the task's body. - - Examples: - ```python - # Example usage of _setup_dependencies method - llm_tasks = { - "task1_output": Task(...), - "task2_output": Task(...), - # ... - } - dependencies = { - "task1_output": self._extract_upstream_columns(task1_body["messages"][-1]["content"]), - "task2_output": self._extract_upstream_columns(task2_body["messages"][-1]["content"]), - # ... - } - ``` - """ - self.llm_tasks = { - task.output_column_name: task for task in self.tasks if task.type == "chat" - } - self.code_tasks = { - task.output_column_name: task for task in self.tasks if task.type == "code" - } - self.dependencies = { - task.output_column_name: self._extract_upstream_columns(task.body.messages[-1].content) - for task in self.llm_tasks.values() - } - self.dependencies.update( - { - task.output_column_name: [task.body.source_column] - for task in self.code_tasks.values() - } - ) - logger.debug(f"Initial dependencies: {self.dependencies}") - - self.input_column_names = [ - key - for key in self.column_dict.keys() - if key not in self.llm_tasks.keys() and key not in self.code_tasks.keys() - ] - - def _mark_regen_columns(self) -> None: - """ - Tag columns to regenerate based on the chosen regeneration strategy. - """ - if self.is_row_add: - return - - # Get the current column order from the table metadata - cols = self.meta.cols_schema - col_ids = [col.id for col in cols] - - if self.body.regen_strategy == RegenStrategy.RUN_ALL: - self.tag_regen_columns = set(self.llm_tasks.keys()).union(self.code_tasks.keys()) - - elif self.body.regen_strategy == RegenStrategy.RUN_SELECTED: - self.tag_regen_columns.append(self.body.output_column_id) - - elif self.body.regen_strategy in ( - RegenStrategy.RUN_BEFORE, - RegenStrategy.RUN_AFTER, - ): - if self.body.regen_strategy == RegenStrategy.RUN_BEFORE: - for column_name in col_ids: - self.tag_regen_columns.append(column_name) - if column_name == self.body.output_column_id: - break - else: # RegenStrategy.RUN_AFTER - reached_column = False - for column_name in col_ids: - if column_name == self.body.output_column_id: - reached_column = True - if reached_column: - self.tag_regen_columns.append(column_name) + await q.put( + TaskResult( + response=response, + output_column_name=output_column, + row_id=self._row_id, + ) + ) + result = response.content + self._error_columns.append(output_column) + self.log_exception( + f'Table "{self._table_id}": Failed to execute python code for column "{output_column}": {repr(e)}', + e, + ) + finally: + await q.put(None) + await self._signal_task_completion(task, result) + + async def _signal_task_completion(self, task: Task, result: Any) -> None: + self._column_dict[task.output_column_name] = result + task.status = "done" + await self._task_signal.put(None) + + async def _load_files(self, message: ChatThreadEntry) -> ChatThreadEntry | ChatEntry: + if not isinstance(message, ChatThreadEntry): + raise TypeError(f"Unexpected message type: {type(message)}") + if message.role != ChatRole.USER: + return message + ### Text-only + if isinstance(message.content, str): + # logger.error(f"{message.content=}") + return ChatEntry.user(content=message.content.strip()) else: - raise ValueError(f"Invalid regeneration strategy: {self.body.regen_strategy}") + content = message.content + ### Multi-modal + contents: list[TextContent, ImageContent, AudioContent] = [] + replacements: dict[str, str] = {} + # Load file + # logger.error(f"{content=}") + for c in content: + if isinstance(c, TextContent): + contents.append(c) + else: + data = await _load_uri_as_base64(str(c.uri)) + if getattr(self._col_map.get(c.column_name, None), "is_document_column", False): + # Document (data could be None) + replacements[c.column_name] = str(data) + # prompt = re.sub(_regex, str(data), prompt) + else: + # Image or audio + if isinstance(data, (ImageContent, AudioContent)): + contents.append(data) + replacements[c.column_name] = "" + # prompt = re.sub(_regex, "", prompt) + # Replace column references + for c in contents: + if not isinstance(c, TextContent): + continue + for col_name, data in replacements.items(): + _regex = r"(? list[str]: + col_ids = re.findall(GEN_CONFIG_VAR_PATTERN, prompt) + # return the content inside ${...} + return col_ids - self.skip_regen_columns = [ - column_name for column_name in col_ids if column_name not in self.tag_regen_columns + def _extract_all_upstream_columns(self, output_column_name: str) -> list[str]: + cols = self.table.column_metadata + try: + idx = next(i for i, c in enumerate(cols) if c.column_id == output_column_name) + except StopIteration: + return [] + return [ + c.column_id + for c in cols[:idx] + if not (c.is_info_column or c.is_state_column or c.is_vector_column) ] - async def _nonstream_concurrent_execution(self) -> tuple[GenTableChatCompletionChunks, dict]: - """ - Executes tasks in concurrent in a non-streaming manner, respecting dependencies. - """ - self._setup_dependencies() - self._mark_regen_columns() - - completed = set(self.input_column_names) - tasks_in_progress = set() - responses = {} - - async def execute_task(task_name): - try: - task = self.llm_tasks[task_name] - except Exception: - task = self.code_tasks[task_name] - - try: - responses[task_name] = await self._execute_task_nonstream(task) - except Exception as e: - self._log_exception(e, f'Error executing task "{task_name}": {e}') - finally: - completed.add(task_name) - tasks_in_progress.remove(task_name) - - while len(completed) < ( - len(self.llm_tasks) + len(self.code_tasks) + len(self.input_column_names) - ): - ready_tasks = [ - task_name - for task_name, deps in self.dependencies.items() - if all(dep in completed for dep in deps) - and task_name not in completed - and task_name not in tasks_in_progress - ] + def _check_upstream_error(self, upstream_cols: list[str]) -> None: + if not isinstance(upstream_cols, list): + raise TypeError(f"`upstream_cols` must be a list, got: {type(upstream_cols)}") + error_cols = [f'"{col}"' for col in upstream_cols if col in self._error_columns] + if len(error_cols) > 0: + raise UpStreamError(f"Upstream columns errored out: {', '.join(error_cols)}") - # Process tasks in batches - for i in range(0, len(ready_tasks), self.cols_batch_size): - batched_tasks = ready_tasks[i : i + self.cols_batch_size] - exe_tasks = [execute_task(task) for task in batched_tasks] - tasks_in_progress.update(batched_tasks) - await asyncio.gather(*exe_tasks) - completed.update(batched_tasks) - tasks_in_progress.difference_update(batched_tasks) - - # Post-execution steps - await self._run_embed_tasks() - - return ( - GenTableChatCompletionChunks(columns=responses, row_id=self.row_id), - self.column_dict if self.is_row_add else self.regen_column_dict, + @classmethod + async def setup_rag( + cls, + *, + project: ProjectRead, + lm: LMEngine, + body: ChatRequest, + request_id: str = "", + ) -> tuple[ChatRequest, References | None]: + if body.rag_params is None: + return body, None + kt_id = body.rag_params.table_id.strip() + if kt_id == "": + raise BadInputError( + "`rag_params.table_id` is required when `rag_params` is specified." + ) + kt = await KnowledgeTable.open_table( + project_id=project.id, table_id=kt_id, request_id=request_id ) - - async def _stream_concurrent_execution(self) -> AsyncGenerator[str, None]: - """ - Executes tasks concurrently in a streaming manner, yielding individual chunks. - """ - self._setup_dependencies() - self._mark_regen_columns() - - completed = set(self.input_column_names) - queue = asyncio.Queue() - tasks_in_progress = set() - - ready_tasks = [ - task_name - for task_name, deps in self.dependencies.items() - if all(dep in completed for dep in deps) - and task_name not in completed - and task_name not in tasks_in_progress + kt_cols = {c.column_id for c in kt.column_metadata if not c.is_state_column} + t0 = perf_counter() + fts_query, vs_query = await lm.generate_search_query( + messages=body.messages, + rag_params=body.rag_params, + **body.hyperparams, + ) + cls._log( + f'Query rewrite using "{body.model}" took t={(perf_counter() - t0) * 1e3:,.2f} ms.', + request_id=request_id, + ) + rows = await kt.hybrid_search( + fts_query=fts_query, + vs_query=vs_query, + embedding_fn=lm.embed_query_as_vector, + vector_column_names=None, + limit=body.rag_params.k, + offset=0, + remove_state_cols=True, + ) + chunks = [ + Chunk( + text=row.get("Text", "") or "", # could be None + title=row.get("Title", "") or "", # could be None + page=row.get("Page", None), + document_id=row.get("File ID", "") or "", # could be None + chunk_id=str(row.get("ID", "")), + # Context will contain extra columns + context={ + k: str(v) + for k, v in row.items() + if k not in kt.FIXED_COLUMN_IDS and k in kt_cols + }, + # Metadata will contain things like RRF score + metadata={ + k: str(v) + for k, v in row.items() + if k not in kt.FIXED_COLUMN_IDS and k not in kt_cols + }, + ) + for row in rows ] + # Add project and table ID + for chunk in chunks: + chunk.metadata["project_id"] = project.id + chunk.metadata["table_id"] = body.rag_params.table_id + if len(rows) > 0 and body.rag_params.reranking_model is not None: + order = ( + await lm.rerank_documents( + model=body.rag_params.reranking_model, + query=vs_query, + documents=kt.rows_to_documents(rows), + ) + ).results + chunks = [chunks[i.index] for i in order] + chunks = chunks[: body.rag_params.k] + references = References(chunks=chunks, search_query=vs_query) + if body.messages[-1].role == ChatRole.USER: + replacement_idx = -1 + elif body.messages[-2].role == ChatRole.USER: + replacement_idx = -2 + else: + raise BadInputError("The message list should end with user or assistant message.") + rag_prompt = await lm.generate_rag_prompt( + messages=body.messages, + references=references, + inline_citations=body.rag_params.inline_citations, + ) + body.messages[replacement_idx].content = rag_prompt + return body, references + + async def _setup_rag(self, body: ChatRequest) -> tuple[ChatRequest, References | None]: + return await self.setup_rag( + project=self.project, + lm=self.lm, + body=body, + request_id=self._request_id, + ) - async def execute_task(task_name): - try: - task = self.llm_tasks[task_name] - except Exception: - task = self.code_tasks[task_name] - - try: - async for chunk in self._execute_task_stream(task): - await queue.put((task_name, chunk)) - except Exception as e: - self._log_exception(e, f'Error executing task "{task_name}": {e}') - finally: - completed.add(task_name) - await queue.put((task_name, None)) - tasks_in_progress.remove(task_name) - - while len(completed) < ( - len(self.llm_tasks) + len(self.code_tasks) + len(self.input_column_names) - ): - ready_tasks = [ - task_name - for task_name, deps in self.dependencies.items() - if all(dep in completed for dep in deps) - and task_name not in completed - and task_name not in tasks_in_progress - ] - # Process tasks in batches - for i in range(0, len(ready_tasks), self.cols_batch_size): - batch_tasks = ready_tasks[i : i + self.cols_batch_size] - for task in batch_tasks: - tasks_in_progress.add(task) - asyncio.create_task(execute_task(task)) - - none_count = 0 - while none_count < len(batch_tasks): - task_name, chunk = await queue.get() - if chunk is None: - none_count += 1 - continue - yield chunk +@alru_cache(maxsize=ENV_CONFIG.max_file_cache_size, ttl=ENV_CONFIG.document_loader_cache_ttl_sec) +async def _load_uri_as_base64(uri: str | None) -> str | AudioContent | ImageContent | None: + """ + Loads a file from URI for LLM inference. - # Post-execution steps - await self._run_embed_tasks() + Args: + uri (str | None): The URI of the file. - # Return the complete row for accumulation in MultiRowsGenExecutor - yield self.column_dict if self.is_row_add else (self.body.row_id, self.regen_column_dict) + Returns: + content (str | AudioContent | ImageContent): The file content. - # Signal the end of stream for a row - yield "data: [DONE]\n\n" + Raises: + BadInputError: If the file format is unsupported. + """ + if not uri: + return None + try: + extension = splitext(uri)[1].lower() + async with open_uri_async(uri) as (file_handle, _): + file_binary = await file_handle.read() + except BadInputError: + raise + except Exception as e: + logger.warning(f'Failed to load file "{uri}" due to error: {repr(e)}') + return None + try: + # Load as document + if extension in DOCUMENT_FILE_EXTENSIONS: + return await GeneralDocLoader().load_document(basename(uri), file_binary) + # Load as audio or image + else: + base64_data = base64.b64encode(file_binary).decode("utf-8") + if extension in AUDIO_FILE_EXTENSIONS: + return AudioContent( + input_audio=AudioContentData(data=base64_data, format=extension[1:]) + ) + elif extension in IMAGE_FILE_EXTENSIONS: + extension = ".jpeg" if extension == ".jpg" else extension + prefix = f"data:image/{extension[1:]};base64," + return ImageContent(image_url=ImageContentData(url=prefix + base64_data)) + else: + raise BadInputError( + ( + "Unsupported file type. Supported formats are: " + f"{', '.join(DOCUMENT_FILE_EXTENSIONS + AUDIO_FILE_EXTENSIONS + IMAGE_FILE_EXTENSIONS)}" + ) + ) + except BadInputError: + raise + except Exception as e: + logger.warning(f'Failed to parse file "{uri}" due to error: {repr(e)}') + return None + + +@alru_cache(maxsize=ENV_CONFIG.max_file_cache_size, ttl=ENV_CONFIG.document_loader_cache_ttl_sec) +async def _load_uri_as_bytes(uri: str | None) -> bytes | None: + """ + Loads a file from URI as raw bytes. + Args: + uri (str): The URI of the file. + Returns: + content (bytes | None): The raw file content as bytes, or None if loading fails. + Raises: + BadInputError: If the URI is invalid or file cannot be accessed. + """ + if not uri: + return None + + try: + async with open_uri_async(str(uri)) as (file_handle, _): + file_binary = await file_handle.read() + return file_binary + except BadInputError: + raise + except Exception as e: + logger.warning(f'Failed to load file "{uri}" due to error: {repr(e)}') + return None diff --git a/services/api/src/owl/db/gen_table.py b/services/api/src/owl/db/gen_table.py index 49e32ad..9af611e 100644 --- a/services/api/src/owl/db/gen_table.py +++ b/services/api/src/owl/db/gen_table.py @@ -1,1398 +1,3934 @@ -import os +import asyncio +import contextlib import re +from asyncio import Semaphore from collections import defaultdict -from copy import deepcopy -from datetime import datetime, timedelta -from os import listdir -from os.path import exists, isdir, join +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from functools import lru_cache +from inspect import iscoroutinefunction from pathlib import Path -from shutil import copytree, ignore_patterns, move, rmtree -from time import perf_counter, sleep -from typing import Any, BinaryIO, Literal, override +from time import perf_counter +from typing import ( + Any, + AsyncIterator, + Awaitable, + BinaryIO, + Callable, + ClassVar, + Literal, + Self, + Type, + override, +) +from uuid import UUID -import filetype -import lancedb +import asyncpg +import bm25s +import nltk import numpy as np +import orjson import pandas as pd import pyarrow as pa -from filelock import FileLock -from lancedb.table import LanceTable +import pyarrow.parquet as pq +from asyncpg import Connection, Pool +from asyncpg.exceptions import ( + DataError, + DuplicateColumnError, + DuplicateTableError, + InvalidParameterValueError, + PostgresSyntaxError, + UndefinedColumnError, + UndefinedFunctionError, + UndefinedTableError, + UniqueViolationError, +) from loguru import logger -from sqlmodel import Session, select -from tenacity import retry, stop_after_attempt, wait_exponential -from typing_extensions import Self - -from jamaibase.exceptions import ( - BadInputError, - ResourceExistsError, - ResourceNotFoundError, - TableSchemaFixedError, - make_validation_error, +from numpy import array, ndarray +from pgvector.asyncpg import register_vector +from pydantic import ( + BaseModel, + Field, + GetCoreSchemaHandler, + ValidationError, + create_model, + field_validator, + model_validator, ) -from jamaibase.utils.io import df_to_csv, json_loads -from owl.configs.manager import ENV_CONFIG -from owl.db import cached_text, create_sql_tables, create_sqlite_engine -from owl.models import CloudEmbedder, CloudReranker -from owl.protocol import ( - COL_NAME_PATTERN, +from pydantic_core import core_schema + +from owl.configs import CACHE, ENV_CONFIG +from owl.db import async_session +from owl.db.models.oss import ModelConfig, Project +from owl.types import ( GEN_CONFIG_VAR_PATTERN, - AddChatColumnSchema, - AddKnowledgeColumnSchema, - ChatEntry, - ChatTableSchemaCreate, - ChatThread, - Chunk, + ChatThreadEntry, + ChatThreadResponse, + CodeGenConfig, ColName, ColumnDtype, ColumnSchema, CSVDelimiter, + DatetimeUTC, + DiscriminatedGenConfig, EmbedGenConfig, - GenConfig, - GenConfigUpdateRequest, - GenTableOrderBy, - KnowledgeTableSchemaCreate, - ModelListConfig, + LLMGenConfig, + ModelCapability, + ModelConfig_, + ModelConfigRead, + Page, PositiveInt, - RowAddData, - RowUpdateData, + ProgressState, + Project_, + PythonGenConfig, + S3Content, + SanitisedNonEmptyStr, + SanitisedStr, + TableImportProgress, TableMeta, TableMetaResponse, TableName, - TableSchema, - TableSchemaCreate, - TableSQLModel, TableType, + TextContent, +) +from owl.utils import merge_dict, uuid7_draft2_str, validate_where_expr +from owl.utils.crypt import hash_string_blake2b as blake2b_hash +from owl.utils.dates import now, utc_datetime_from_iso +from owl.utils.exceptions import ( + BadInputError, + JamaiException, + ModelCapabilityError, + ResourceExistsError, + ResourceNotFoundError, +) +from owl.utils.io import ( + df_to_csv, + guess_mime, + json_dumps, + json_loads, + open_uri_async, + s3_upload, ) -from owl.utils import datetime_now_iso, uuid7_draft2_str -from owl.utils.io import open_uri_sync, upload_file_to_s3 - -# Lance only support null values in string column -_py_type_default = { - "int": 0, - "int8": 0, - "float": 0.0, - "float32": 0.0, - "float16": 0.0, - "bool": False, - "str": "''", - "image": "''", - "audio": "''", -} - - -class GenerativeTable: +from owl.version import __version__ as owl_version + +# Regex for tokenization +digits = r"([0-9]+)" +letters = r"([a-zA-Z]+)" +hanzi = r"([\u4e00-\u9fff])" +# Other non-whitespace, non-letter, non-digit, non-hanzi characters +other = r"([^\s0-9a-zA-Z\u4e00-\u9fff])" +# Combine patterns with OR (|) +TOKEN_PATTERN = re.compile(f"{digits}|{letters}|{hanzi}|{other}") +stemmer = nltk.stem.SnowballStemmer("english") + + +""" +Postgres has limitation for identifier length at 63 characters. + +We need to support up to 100. + +But we cannot set the limit at 63 since Postgres will add suffix like `_id_pkey` ("ID" column as primary key). + +Solution is to use a mapping from `id` (len <= 46) to `table_id` (len <= 100). + +Consumers of a table will use `table_id`, `id` is for internal use. + +1. `len(table_id) <= 100`: + - `id` will be a truncated version of `table_id`: + 1. If `len(table_id) <= 29`: `id = table_id` + 2. If `len(table_id) > 29`: `id = f"{table_id[:29]}-{blake2b_hash(table_id, 16)}"` where the hash is 16 characters. + - During table duplication with auto-naming: + 1. `len(table_id) <= 70`: Suffix will be appended `{table_id} 2025-10-06-22-03-18 (9999)` + 2. `len(table_id) > 70`: `table_id` will be truncated as `f"{table_id[:53]}-{blake2b_hash(table_id, 16)}"` before appending suffix + 3. In both cases, `id` will be a truncated version of `table_id` as usual +2. `len(table_id) > 100`: + - Raise validation error + +Column ID works the same way with a mapping from `id` (len <= 46) to `column_id` (len <= 100), but care has to be taken for state column IDs. + +Index naming: + +1. FTS index: `f"{table_id[:25]}_{blake2b_hash(table_id, 24)}_fts_idx"` +2. Vector index: `f"{short_table_id[:25]}_{blake2b_hash(f"{short_table_id}_{short_column_id}", 24)}_vec_idx"` +""" + + +TABLE_ID_DST_MAX_ITER = 9_999 +IMPORT_BATCH_SIZE = 100 +S3_MAX_CONCURRENCY = 20 + + +def get_internal_id(long_id: str) -> str: + is_file_col = long_id.endswith("__") + is_state_col = long_id.endswith("_") + if is_file_col: + long_id = long_id[:-2] + elif is_state_col: + long_id = long_id[:-1] + else: + pass + if len(long_id) <= 29: + short_id = long_id + else: + short_id = f"{long_id[:29]}-{blake2b_hash(long_id, 16)}" + if is_file_col: + short_id = f"{short_id}__" + elif is_state_col: + short_id = f"{short_id}_" + else: + pass + return short_id + + +def truncate_table_id(table_id: str) -> str: + if len(table_id) <= 70: + return table_id + return f"{table_id[:53]}-{blake2b_hash(table_id, 16)}" + + +def fts_index_id(table_id: str) -> str: + return f"{table_id[:25]}_{blake2b_hash(table_id, 24)}_fts_idx" + + +def vector_index_id(table_id: str, col_id: str) -> str: + return f"{table_id[:25]}_{blake2b_hash(f'{table_id}_{col_id}', 24)}_vec_idx" + + +class NumpyArray: + """Wrapper class for numpy arrays with Pydantic schema support""" + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + return core_schema.no_info_after_validator_function( + cls.validate, + core_schema.union_schema( + [ + core_schema.is_instance_schema(ndarray), + core_schema.list_schema(core_schema.float_schema()), + ] + ), + ) + + @staticmethod + def validate(value: Any) -> ndarray: + if isinstance(value, list): + # Convert list of floats to a NumPy array + return array(value, dtype=float) + elif isinstance(value, ndarray): + return value + else: + raise ValueError("Value must be a numpy array or a list of floats") + + +class _TableBase(BaseModel): + version: str = Field( + default=owl_version, + description="Table version, following owl version.", + ) + meta: dict[str, Any] = Field( + default={}, + description="Additional metadata about the table.", + ) + + +class TableMetadata(_TableBase): + """ + Table metadata + - Primary key: table_id + - Data table name: table_id + * Remember to update the SQL when making changes to this model """ - Smart Table class. - Note that by default, this class assumes that each method uses a new LanceDB connection. - Otherwise, consider passing in `read_consistency_interval=timedelta(seconds=0)` during init. + table_id: TableName = Field( + description="Table name.", + ) + short_id: SanitisedNonEmptyStr = Field( + "", + description="Internal short table ID derived from `table_id`.", + ) + title: SanitisedStr = Field( + "", + description='Chat title. Defaults to "".', + ) + parent_id: TableName | None = Field( + None, + description="The parent table ID. If None (default), it means this is a parent table.", + ) + created_by: SanitisedNonEmptyStr | None = Field( + None, + description="ID of the user that created this table. Defaults to None.", + ) + updated_at: DatetimeUTC = Field( + default_factory=now, + description="Table last update datetime (UTC).", + ) + + @model_validator(mode="after") + def generate_internal_id(self) -> Self: + if not self.short_id: + self.short_id = get_internal_id(self.table_id) + return self + + @staticmethod + def sql_create(schema_id: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS "{schema_id}"."TableMetadata" ( + table_id TEXT PRIMARY KEY, + short_id TEXT UNIQUE NOT NULL, + title TEXT NOT NULL, + parent_id TEXT, + created_by TEXT, + updated_at TIMESTAMPTZ NOT NULL, + version TEXT NOT NULL, + meta JSONB NOT NULL + ); + """ + + @classmethod + @lru_cache(maxsize=1) + def str_cols(cls) -> list[str]: + """Return every column name that is a string.""" + return [k for k, v in cls.model_fields.items() if v.annotation is str] + + +class ColumnMetadata(_TableBase): + """ + Column metadata + - Primary key: table_id, column_id + - Foreign key: table_id + * Remember to update the SQL when making changes to this model """ - FIXED_COLUMN_IDS = [] + INFO_COLUMNS: ClassVar[set[str]] = {"id", "updated at"} - def __init__( - self, - db_url: str, - vector_db_url: str, - *, - read_consistency_interval: timedelta | None = None, - create_sqlite_tables: bool = True, - ) -> None: - self.db_url = Path(db_url) - self.vector_db_url = Path(vector_db_url) - self.read_consistency_interval = read_consistency_interval - self.sqlite_engine = create_sqlite_engine(db_url) - if create_sqlite_tables: - create_sql_tables(TableSQLModel, self.sqlite_engine) - self._lance_db = None - self.organization_id = db_url.split(os.sep)[-3] - self.project_id = db_url.split(os.sep)[-2] - # Thread and process safe lock - self.lock_name_prefix = vector_db_url - self.locks = {} + table_id: TableName = Field( + description="Associated Table name.", + ) + column_id: str = Field( + pattern=r"^[A-Za-z0-9]([A-Za-z0-9.?!@#$%^&*_()\- ]*[A-Za-z0-9.?!()\-])?_*$", + min_length=1, + max_length=101, + description="Column name.", + ) + short_table_id: SanitisedNonEmptyStr = Field( + "", + description="Internal short table ID derived from `table_id`.", + ) + short_id: SanitisedNonEmptyStr = Field( + "", + description="Internal short column ID derived from `column_id`.", + ) + dtype: ColumnDtype = Field( + ColumnDtype.STR, + description=f"Column data type, one of {list(map(str, ColumnDtype))}.", + ) + vlen: PositiveInt = Field( + 0, + description=( + "_Optional_. vector length. If provided, then this column will be a VECTOR column type." + "ex: embedding size." + ), + examples=[1024], + ) + gen_config: DiscriminatedGenConfig | None = Field( + None, + description=( + '_Optional_. Generation config. If provided, then this column will be an "Output Column". ' + "Table columns on its left can be referenced by `${column-name}`." + ), + ) + column_order: int = Field( + 0, + description="Order of the column in the table. Usually you don't need to set this.", + examples=[0, 1], + ) + + @model_validator(mode="after") + def generate_internal_id(self) -> Self: + if not self.short_table_id: + self.short_table_id = get_internal_id(self.table_id) + if not self.short_id: + self.short_id = get_internal_id(self.column_id) + return self + @field_validator("dtype", mode="before") @classmethod - def from_ids( - cls, - org_id: str, - project_id: str, - table_type: str | TableType, - ) -> Self: - lance_path = join(ENV_CONFIG.owl_db_dir, org_id, project_id, table_type) - sqlite_path = f"sqlite:///{lance_path}.db" - read_consistency_interval = timedelta(seconds=0) - if table_type == TableType.ACTION: - return ActionTable( - sqlite_path, - lance_path, - read_consistency_interval=read_consistency_interval, - ) - elif table_type == TableType.KNOWLEDGE: - return KnowledgeTable( - sqlite_path, - lance_path, - read_consistency_interval=read_consistency_interval, - ) - else: - return ChatTable( - sqlite_path, - lance_path, - read_consistency_interval=read_consistency_interval, - ) + def validate_dtype(cls, value: Any) -> str: + """ + Handles some special cases for dtype. + """ + if value in ["float32", "float16"]: + return ColumnDtype.FLOAT + if value == "int8": + return ColumnDtype.INT + return value @property - def lance_db(self): - if self._lance_db is None: - self._lance_db = lancedb.connect( - self.vector_db_url, read_consistency_interval=self.read_consistency_interval - ) - return self._lance_db + def is_output_column(self) -> bool: + return self.gen_config is not None - def lock(self, name: str, timeout: int = ENV_CONFIG.owl_table_lock_timeout_sec): - name = join(self.lock_name_prefix, f"{name}.lock") - self.locks[name] = self.locks.get(name, FileLock(name, timeout=timeout)) - return self.locks[name] + @property + def is_text_column(self) -> bool: + return self.dtype == ColumnDtype.STR and self.column_id.lower() not in self.INFO_COLUMNS - def create_session(self): - return Session(self.sqlite_engine) + @property + def is_chat_column(self) -> bool: + return getattr(self.gen_config, "multi_turn", False) - def has_info_col_names(self, names: list[str]) -> bool: - return sum(n.lower() in ("id", "updated at") for n in names) > 0 + @property + def is_vector_column(self) -> bool: + return self.dtype in (ColumnDtype.FLOAT,) and self.vlen > 0 - def has_state_col_names(self, names: list[str]) -> bool: - return any(n.endswith("_") for n in names) + @property + def is_image_column(self) -> bool: + return self.dtype == ColumnDtype.IMAGE - def num_output_columns(self, meta: TableMeta) -> int: - return len( - [col for col in meta.cols if col["gen_config"] is not None and col["vlen"] == 0] - ) + @property + def is_audio_column(self) -> bool: + return self.dtype == ColumnDtype.AUDIO - def _create_table( - self, - session: Session, - schema: TableSchemaCreate, - remove_state_cols: bool = False, - add_info_state_cols: bool = True, - ) -> tuple[LanceTable, TableMeta]: - table_id = schema.id - with self.lock(table_id): - meta = session.get(TableMeta, table_id) - if meta is None: - # Add metadata - if add_info_state_cols: - schema = schema.add_info_cols().add_state_cols() - meta = TableMeta( - id=table_id, - parent_id=None, - cols=[c.model_dump() for c in schema.cols], - ) - session.add(meta) - session.commit() - session.refresh(meta) - # Create Lance table - table = self.lance_db.create_table(table_id, schema=schema.pyarrow) - else: - raise ResourceExistsError(f'Table "{table_id}" already exists.') - if remove_state_cols: - meta.cols = [c for c in meta.cols if not c["id"].endswith("_")] - return table, meta + @property + def is_document_column(self) -> bool: + return self.dtype == ColumnDtype.DOCUMENT - def create_table( - self, - session: Session, - schema: TableSchemaCreate, - remove_state_cols: bool = False, - add_info_state_cols: bool = True, - ) -> tuple[LanceTable, TableMeta]: - if not isinstance(schema, TableSchema): - raise TypeError("`schema` must be an instance of `TableSchema`.") - fixed_cols = set(c.lower() for c in self.FIXED_COLUMN_IDS) - if len(fixed_cols.intersection(set(c.id.lower() for c in schema.cols))) != len(fixed_cols): - raise BadInputError(f"Schema must contain fixed columns: {self.FIXED_COLUMN_IDS}") - return self._create_table( - session=session, - schema=schema, - remove_state_cols=remove_state_cols, - add_info_state_cols=add_info_state_cols, - ) + @property + def is_file_column(self) -> bool: + return self.dtype in (ColumnDtype.IMAGE, ColumnDtype.AUDIO, ColumnDtype.DOCUMENT) - def open_table(self, table_id: TableName) -> LanceTable: - try: - table = self.lance_db.open_table(table_id) - except FileNotFoundError as e: - raise ResourceNotFoundError(f'Table "{table_id}" is not found.') from e - return table + @property + def is_info_column(self) -> bool: + return self.column_id.lower() in self.INFO_COLUMNS - def open_meta( - self, - session: Session, - table_id: TableName, - remove_state_cols: bool = False, - ) -> TableMeta: - meta = session.get(TableMeta, table_id) - if meta is None: - raise ResourceNotFoundError(f'Table "{table_id}" is not found.') - if remove_state_cols: - meta.cols = [c for c in meta.cols if not c["id"].endswith("_")] - return meta + @property + def is_state_column(self) -> bool: + return self.column_id.endswith("_") - def open_table_meta( - self, - session: Session, - table_id: TableName, - remove_state_cols: bool = False, - ) -> tuple[LanceTable, TableMeta]: - meta = self.open_meta(session, table_id, remove_state_cols=remove_state_cols) - table = self.open_table(table_id) - return table, meta + @staticmethod + def sql_create(schema_id: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS "{schema_id}"."ColumnMetadata" ( + table_id TEXT NOT NULL, + column_id TEXT NOT NULL, + short_table_id TEXT NOT NULL, + short_id TEXT NOT NULL, + dtype TEXT NOT NULL, + vlen INT DEFAULT 0 NOT NULL, + gen_config JSONB, + column_order INT NOT NULL, + version TEXT, + meta JSONB NOT NULL, + PRIMARY KEY (table_id, column_id), + UNIQUE (short_table_id, short_id), + CONSTRAINT "fk_ColumnMetadataTable_table_id" + FOREIGN KEY (table_id) + REFERENCES "{schema_id}"."TableMetadata" (table_id) + ON UPDATE CASCADE + ON DELETE CASCADE, + CONSTRAINT "fk_ColumnMetadataTable_short_id" + FOREIGN KEY (short_table_id) + REFERENCES "{schema_id}"."TableMetadata" (short_id) + ON UPDATE CASCADE + ); + """ - def list_meta( - self, - session: Session, + +class DataTableRow(BaseModel, coerce_numbers_to_str=True): + @classmethod + def get_column_ids( + cls, *, - offset: int, - limit: int, - parent_id: str | None = None, - search_query: str = "", - order_by: str = GenTableOrderBy.UPDATED_AT, - order_descending: bool = True, - count_rows: bool = False, - remove_state_cols: bool = False, - ) -> tuple[list[TableMetaResponse], int]: - t0 = perf_counter() - search_query = search_query.strip() - if parent_id is None: - selection = select(TableMeta) - elif parent_id.lower() == "_agent_": - selection = select(TableMeta).where(TableMeta.parent_id == None) # noqa - elif parent_id.lower() == "_chat_": - selection = select(TableMeta).where(TableMeta.parent_id != None) # noqa - else: - selection = select(TableMeta).where(TableMeta.parent_id == parent_id) - if search_query != "": - selection = selection.where(TableMeta.id.ilike(f"%{search_query}%")) - total = len(session.exec(selection).all()) - metas = session.exec( - selection.order_by( - cached_text(f"{order_by} DESC" if order_descending else f"{order_by} ASC") - ) - .offset(offset) - .limit(limit) - ).all() - t1 = perf_counter() - meta_responses = [] - for meta in metas: - try: - num_rows = self.count_rows(meta.id) if count_rows else -1 - except Exception: - table_path = self.vector_db_url / f"{meta.id}.lance" - if exists(table_path) and len(listdir(table_path)) > 0: - logger.error(f"Lance table FAILED to be opened: {meta.id}") - else: - logger.warning(f"Lance table MISSING, removing metadata: {meta.id}") - session.delete(meta) - continue - meta_responses.append( - TableMetaResponse.model_validate(meta, update={"num_rows": num_rows}) - ) - t2 = perf_counter() - num_metas = len(metas) - time_per_table = (t2 - t1) * 1000 / num_metas if num_metas > 0 else 0.0 - logger.info( - ( - f"Listing {num_metas:,d} table metas took: {(t2 - t0) * 1000:.2f} ms " - f"SQLite query = {(t1 - t0) * 1000:.2f} ms " - f"Count rows (total) = {(t2 - t1) * 1000:.2f} ms " - f"Count rows (per table) = {time_per_table:.2f} ms" + exclude_info: bool = False, + exclude_state: bool = False, + ) -> list[str]: + columns = list(cls.model_fields.keys()) + if exclude_info: + columns = [c for c in columns if c.lower() not in ("id", "updated at")] + if exclude_state: + columns = [c for c in columns if not c.endswith("_")] + return columns + + +class DBengine: + _instance = None + _conn_pool: Pool = None + _initialized = False + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + async def get_conn_pool(self) -> Pool: + """Get or create a PostgreSQL connection pool with proper configuration.""" + if self._conn_pool is None: + self._conn_pool = await asyncpg.create_pool( + dsn=re.sub(r"\+\w+", "", ENV_CONFIG.db_path), + min_size=2, + max_size=5, + max_inactive_connection_lifetime=300.0, + timeout=30.0, + command_timeout=60.0, + max_queries=1000, + # Do not cache statement plan since Generative Table's schema can change + statement_cache_size=0, + init=self._setup_connection, ) + self._initialized = True + return self._conn_pool + + async def close(self): + """Close the connection pool.""" + if self._conn_pool and not self._conn_pool._closed: + await self._conn_pool.close() + self._conn_pool = None + self._initialized = False + + async def _setup_connection(self, conn: Connection) -> None: + """Configure a new connection with required settings.""" + # Remember to update the InitApplicationSQL in the yaml for extension creation + await conn.execute("CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE;") + await conn.execute("CREATE EXTENSION IF NOT EXISTS pgroonga;") + # If `transaction_timeout` <= `idle_in_transaction_session_timeout` or `statement_timeout` + # then the longer timeout is ignored. + await conn.execute("SET statement_timeout = 20000") + await conn.execute("SET transaction_timeout = 20000") + await conn.execute("SET idle_in_transaction_session_timeout = 20000") + await register_vector(conn) + await conn.set_type_codec( + "jsonb", + encoder=lambda obj: orjson.dumps(obj).decode("utf-8"), + decoder=orjson.loads, + schema="pg_catalog", ) - if remove_state_cols: - for meta in meta_responses: - meta.cols = [c for c in meta.cols if not c.id.endswith("_")] - return meta_responses, total - def count_rows(self, table_id: TableName, filter: str | None = None) -> int: - return self.open_table(table_id).count_rows(filter) + @contextlib.asynccontextmanager + async def transaction(self, schema_id: str = None) -> AsyncIterator[Connection]: + """Provide a transactional scope for a series of operations.""" + async with (await self.get_conn_pool()).acquire() as conn: + async with conn.transaction(): + try: + if schema_id: + await conn.execute(f'SET search_path TO "{schema_id}"') + yield conn + except JamaiException: + # No need to log these errors + raise + except Exception as e: + logger.error(f"Transaction failed: {e}") + raise + + +GENTABLE_ENGINE = DBengine() - def duplicate_table( + +class GenerativeTableCore: + """ + Core class for managing generative tables in PostgreSQL with schema-based organization. + Devs should use `GenerativeTable` instead. + """ + + INFO_COLUMNS = {"id", "updated at"} + FIXED_COLUMN_IDS = ["ID", "Updated at"] + + def __init__( self, - session: Session, - table_id_src: TableName, - table_id_dst: TableName, - include_data: bool = True, - create_as_child: bool = False, - ) -> TableMeta: - dst_meta = session.get(TableMeta, table_id_dst) - if dst_meta is not None: - raise ResourceExistsError(f'Table "{table_id_dst}" already exists.') - # Duplicate metadata - with self.lock(table_id_src): - meta = self.open_meta(session, table_id_src) - new_meta = TableMeta.model_validate( - meta, - update={ - "id": table_id_dst, - "parent_id": table_id_src if create_as_child else None, - }, - ) - session.add(new_meta) - session.commit() - session.refresh(new_meta) - # Duplicate LanceTable - if include_data: - copytree( - self.vector_db_url / f"{table_id_src}.lance", - self.vector_db_url / f"{table_id_dst}.lance", - ignore=ignore_patterns("_indices"), + *, + # TODO: We should directly pass in `Project_` instead of fetching it again + project_id: str, + table_type: TableType, + table_metadata: TableMetadata, + column_metadata_list: list[ColumnMetadata], + num_rows: int = -1, + request_id: str = "", + ) -> None: + self.project_id = project_id + self.table_type = table_type + self.table_metadata = table_metadata + self.column_metadata = column_metadata_list + self.num_rows = num_rows + self.request_id = request_id + self.schema_id = f"{project_id}_{table_type}" + self.data_table_model = self._create_data_table_row_model( + table_metadata.table_id, column_metadata_list + ) + self.text_column_names = [ + col.column_id for col in self.column_metadata if col.is_text_column + ] + self.vector_column_names = [ + col.column_id for col in self.column_metadata if col.is_vector_column + ] + self.map_to_short_col_id = {c.column_id: c.short_id for c in column_metadata_list} + self.map_to_long_col_id = {c.short_id: c.column_id for c in column_metadata_list} + + @property + def table_id(self) -> str: + return self.table_metadata.table_id + + @table_id.setter + def table_id(self, value: str) -> None: + if not isinstance(value, str): + raise TypeError("`table_id` must be a string.") + short_id = get_internal_id(value) + self.table_metadata.table_id = value + self.table_metadata.short_id = short_id + for col in self.column_metadata: + col.table_id = value + col.short_table_id = short_id + + @property + def short_table_id(self) -> str: + return self.table_metadata.short_id + + @property + def v1_meta(self) -> TableMeta: + meta = TableMeta( + id=self.table_id, + version=self.table_metadata.version, + meta=self.table_metadata.meta, + cols=[ + ColumnSchema( + id=col.column_id, + dtype=col.dtype, + vlen=col.vlen, + gen_config=col.gen_config, ) - with self.create_session() as session: - self.create_indexes(session, table_id_dst, force=True) - else: - schema = TableSchema.model_validate(new_meta) - self.lance_db.create_table(table_id_dst, schema=schema.pyarrow) - return new_meta + for col in self.column_metadata + ], + parent_id=self.table_metadata.parent_id, + created_by=self.table_metadata.created_by, + title=self.table_metadata.title, + updated_at=self.table_metadata.updated_at, + num_rows=self.num_rows, + ) + return meta - def rename_table( - self, - session: Session, - table_id_src: TableName, - table_id_dst: TableName, - ) -> TableMeta: - # Check - dst_meta = session.get(TableMeta, table_id_dst) - if dst_meta is not None: - raise ResourceExistsError(f'Table "{table_id_dst}" already exists.') - # Rename metadata - with self.lock(table_id_src): - meta = self.open_meta(session, table_id_src) - meta.id = table_id_dst - meta.updated_at = datetime_now_iso() - session.add(meta) - # Rename all parent IDs - session.exec( - cached_text( - f"UPDATE TableMeta SET parent_id = '{table_id_dst}' WHERE parent_id = '{table_id_src}'" + @property + def v1_meta_response(self) -> TableMetaResponse: + meta = TableMetaResponse( + id=self.table_id, + version=self.table_metadata.version, + meta=self.table_metadata.meta, + cols=[ + ColumnSchema( + id=col.column_id, + dtype=col.dtype, + vlen=col.vlen, + gen_config=col.gen_config, ) - ) - session.commit() - session.refresh(meta) - # Rename LanceTable - # self.lance_db.rename_table(table_id_src, table_id_dst) # Seems like not implemented - move( - self.vector_db_url / f"{table_id_src}.lance", - self.vector_db_url / f"{table_id_dst}.lance", - ) + for col in self.column_metadata + ], + parent_id=self.table_metadata.parent_id, + created_by=self.table_metadata.created_by, + title=self.table_metadata.title, + updated_at=self.table_metadata.updated_at, + num_rows=self.num_rows, + ) return meta - def delete_table(self, session: Session, table_id: TableName) -> None: - with self.lock(table_id): - # Delete LanceTable - for _ in range(10): - # Try 10 times - try: - rmtree(self.vector_db_url / f"{table_id}.lance") - except FileNotFoundError as e: - raise ResourceNotFoundError(f'Table "{table_id}" is not found.') from e - except Exception: - # There might be ongoing operations - sleep(0.5) - else: - break - # Delete metadata - meta = session.get(TableMeta, table_id) - if meta is None: - raise ResourceNotFoundError(f'Table "{table_id}" is not found.') - session.delete(meta) - session.commit() - return - - def update_gen_config(self, session: Session, updates: GenConfigUpdateRequest) -> TableMeta: - table_id = updates.table_id - meta = session.get(TableMeta, table_id) - if meta is None: - raise ResourceNotFoundError(f'Table "{table_id}" is not found.') - meta_col_ids = set(c["id"] for c in meta.cols) - update_col_ids = set(updates.column_map.keys()) - if len(update_col_ids - meta_col_ids) > 0: - raise make_validation_error( - ValueError( - f"Some columns are not found in the table: {update_col_ids - meta_col_ids}" - ), - loc=("body", "column_map"), + def _log(self, msg: str, level: str = "INFO"): + _log = f"{self.__class__.__name__}: {msg}" + if self.request_id: + _log = f"{self.request_id} - {_log}" + logger.log(level, _log) + + @staticmethod + async def _fetch_project(project_id: str) -> Project_: + async with async_session() as session: + project = await Project.get(session, project_id) + return Project_.model_validate(project) + + @staticmethod + async def _fetch_model(model: str, organization_id: str) -> ModelConfigRead: + async with async_session() as session: + cfg = await ModelConfig.get(session, model) + cfg = ModelConfigRead.model_validate(cfg) + if (not cfg.is_active) or (not cfg.allowed(organization_id)): + raise ResourceNotFoundError(f'Model "{model}" is not found.') + return cfg + + @staticmethod + async def _fetch_model_with_capabilities( + *, + capabilities: list[ModelCapability], + organization_id: str, + ) -> ModelConfig_: + from owl.utils.lm import LMEngine + + async with async_session() as session: + models = ( + await ModelConfig.list_( + session=session, + return_type=ModelConfig_, + organization_id=organization_id, + capabilities=capabilities, + exclude_inactive=True, + ) + ).items + if len(models) == 0: + raise ModelCapabilityError( + f"No model found with capabilities: {list(map(str, capabilities))}" ) - cols = deepcopy(meta.cols) - for c in cols: - # Validate and update - gen_config = updates.column_map.get(c["id"], c["gen_config"]) - c["gen_config"] = ( - gen_config.model_dump() if isinstance(gen_config, GenConfig) else gen_config + model = LMEngine.pick_best_model(models, capabilities) + return model + + @classmethod + async def _check_columns( + cls, + conn: Connection, + *, + project_id: str, + table_type: TableType, + table_metadata: TableMetadata, + column_metadata_list: list[ColumnMetadata], + set_default_prompts: bool, + replace_unavailable_models: bool, + allow_nonexistent_refs: bool = False, + ) -> list[ColumnMetadata]: + del table_type # Not used for now + table_id = table_metadata.table_id + if len(set(c.column_id.lower() for c in column_metadata_list)) != len( + column_metadata_list + ): + raise BadInputError( + f'Table "{table_id}": There are repeated column names (case-insensitive).' ) - meta.cols = [c.model_dump() for c in TableSchema(id=meta.id, cols=cols).cols] - session.add(meta) - session.commit() - session.refresh(meta) - return meta + project = await cls._fetch_project(project_id) + column_map = {c.column_id: c for c in column_metadata_list} + for i, col in enumerate(column_metadata_list): + gen_config = col.gen_config + if gen_config is None: + continue + available_cols = [ + c + for c in column_metadata_list[:i] + if not (c.is_info_column or c.is_state_column or c.is_vector_column) + ] + valid_col_ids = [c.column_id for c in available_cols] + if isinstance(gen_config, EmbedGenConfig): + if not col.is_vector_column: + raise BadInputError( + f'Table "{table_id}": ' + f'Embedding column "{col.column_id}" must be a vector column with float data type.' + ) + if (not allow_nonexistent_refs) and ( + gen_config.source_column not in valid_col_ids + ): + raise BadInputError( + ( + f'Table "{table_id}": ' + f'Embedding config of column "{col.column_id}" referenced ' + f'an invalid source column "{gen_config.source_column}". ' + "Make sure you only reference non-vector columns on its left. " + f"Available columns: {valid_col_ids}." + ) + ) + # Validate and assign default model + embedding_model = gen_config.embedding_model.strip() + if embedding_model: + # Validate model capabilities + try: + model = await cls._fetch_model(embedding_model, project.organization_id) + if ModelCapability.EMBED not in model.capabilities: + raise ModelCapabilityError( + ( + f'Table "{table_id}": Model "{model.id}" used in Embedding column "{col.column_id}" ' + f"does not support embedding." + ) + ) + except ModelCapabilityError: + # Embedding model is not interchangeable + raise + except ResourceNotFoundError as e: + # Embedding model is not interchangeable + raise BadInputError( + f'Table "{table_id}": ' + f'Embedding model "{embedding_model}" used by column "{col.column_id}" is not found.' + ) from e + # Do not use `elif` here + if not embedding_model: + # Assign default model + try: + model = await cls._fetch_model_with_capabilities( + capabilities=[ModelCapability.EMBED], + organization_id=project.organization_id, + ) + except ModelCapabilityError as e: + raise ModelCapabilityError(f'Table "{table_id}": {e}') from e + gen_config.embedding_model = model.id + elif isinstance(gen_config, LLMGenConfig): + if col.is_vector_column: + raise BadInputError( + f'Table "{table_id}": ' + f'LLM column "{col.column_id}" must not be a vector column.' + ) + if not col.is_text_column: + raise BadInputError( + f'Table "{table_id}": ' + f'LLM column "{col.column_id}" must be a string (text) column.' + ) + # Insert default prompts if needed + if set_default_prompts: + # We only put input columns into default prompt + _input_cols = [c for c in available_cols if c.gen_config is None] + _text_cols = "\n\n".join( + f"{c.column_id}: ${{{c.column_id}}}" + for c in _input_cols + if not (c.is_image_column or c.is_audio_column) + ) + _image_audio_cols = " ".join( + f"${{{c.column_id}}}" + for c in _input_cols + if (c.is_image_column or c.is_audio_column) + ) + # We place image and audio columns first, which will then be replaced with "" and stripped out + if gen_config.multi_turn: + default_system_prompt = ( + f'You are an agent named "{table_id}". Be helpful. ' + "Ensure that your reply is easy to understand and is accessible to all users. " + "Provide answers based on the information given. " + "Be factual and do not hallucinate." + ).strip() + default_user_prompt = f"{_image_audio_cols}\n\n{_text_cols}".strip() + else: + default_system_prompt = ( + "You are a versatile data generator. " + "Your task is to process information from input data and generate appropriate responses " + "based on the specified column name and input data. " + "Adapt your output format and content according to the column name provided." + ).strip() + if _text_cols: + _text_cols = f"{_text_cols}\n\n" + default_user_prompt = ( + f"{_image_audio_cols}\n\n" + f'Table name: "{table_id}"\n\n' + f"{_text_cols}" + "Based on the available information, " + f'provide an appropriate response for the column "{col.column_id}".\n' + "Be factual and do not hallucinate. " + "Remember to act as a cell in a spreadsheet and provide concise, " + "relevant information without explanations unless specifically requested." + ).strip() + if not gen_config.system_prompt: + gen_config.system_prompt = default_system_prompt + if not gen_config.prompt: + gen_config.prompt = default_user_prompt + # Check references + ref_cols = re.findall(GEN_CONFIG_VAR_PATTERN, gen_config.prompt) + if allow_nonexistent_refs: + ref_cols = [c for c in ref_cols if c in column_map] + if len(invalid_cols := [c for c in ref_cols if c not in valid_col_ids]) > 0: + raise BadInputError( + ( + f'Table "{table_id}": ' + f'LLM Generation prompt of column "{col.column_id}" referenced ' + f"invalid source columns: {invalid_cols}. " + "Make sure you only reference non-vector columns on its left. " + f"Available columns: {valid_col_ids}." + ) + ) + # Validate and assign default model + ref_image_cols = [c for c in ref_cols if column_map[c].is_image_column] + ref_audio_cols = [c for c in ref_cols if column_map[c].is_audio_column] + capabilities = [ModelCapability.CHAT] + if len(ref_image_cols) > 0: + capabilities.append(ModelCapability.IMAGE) + if len(ref_audio_cols) > 0: + capabilities.append(ModelCapability.AUDIO) + chat_model = gen_config.model.strip() + if chat_model: + # Validate model capabilities + try: + model = await cls._fetch_model(gen_config.model, project.organization_id) + unsupported = list(set(capabilities) - set(model.capabilities)) + if len(unsupported) > 0: + raise ModelCapabilityError( + ( + f'Table "{table_id}": Model "{model.id}" used in LLM column "{col.column_id}" ' + f"lack these capabilities: {', '.join(unsupported)}." + ) + ) + except ModelCapabilityError: + if replace_unavailable_models: + # We replace the unavailable model with a default model below + chat_model = "" + else: + raise + except ResourceNotFoundError as e: + if replace_unavailable_models: + # We replace the unavailable model with a default model below + chat_model = "" + else: + raise BadInputError( + f'Table "{table_id}": ' + f'LLM model "{gen_config.model}" used by column "{col.column_id}" is not found.' + ) from e + # Do not use `elif` here + if not chat_model: + # Assign default model + try: + model = await cls._fetch_model_with_capabilities( + capabilities=capabilities, + organization_id=project.organization_id, + ) + except ModelCapabilityError as e: + raise ModelCapabilityError(f'Table "{table_id}": {e}') from e + gen_config.model = model.id + # Check RAG params + if gen_config.rag_params is not None: + kt_id = gen_config.rag_params.table_id + if not allow_nonexistent_refs: + if kt_id.strip() == "": + raise BadInputError( + ( + f'Table "{table_id}": Column "{col.column_id}" ' + f"referenced a Knowledge Table with an empty ID." + ) + ) + kt_metadata = await conn.fetchrow( + f'SELECT * FROM "{project_id}_knowledge"."TableMetadata" WHERE table_id = $1', + kt_id, + ) + if kt_metadata is None: + raise BadInputError( + ( + f'Table "{table_id}": Column "{col.column_id}" ' + f'referenced a Knowledge Table "{kt_id}" that does not exist.' + ) + ) + # Validate and assign default Reranking Model + reranking_model = gen_config.rag_params.reranking_model + if reranking_model is not None: + reranking_model = reranking_model.strip() + if reranking_model: + # Validate model capabilities + try: + model = await cls._fetch_model( + reranking_model, project.organization_id + ) + if ModelCapability.RERANK not in model.capabilities: + raise ModelCapabilityError( + ( + f'Table "{table_id}": Model "{reranking_model}" ' + f'used in LLM column "{col.column_id}" ' + f"does not support reranking." + ) + ) + except ModelCapabilityError: + if replace_unavailable_models: + # We replace the unavailable model with a default model below + reranking_model = "" + else: + raise + except ResourceNotFoundError as e: + if replace_unavailable_models: + # We replace the unavailable model with a default model below + reranking_model = "" + else: + raise BadInputError( + f'Table "{table_id}": ' + f'Reranking model "{gen_config.model}" used by column "{col.column_id}" is not found.' + ) from e + # Do not use `elif` here + if not reranking_model: + model = await cls._fetch_model_with_capabilities( + capabilities=[str(ModelCapability.RERANK)], + organization_id=project.organization_id, + ) + gen_config.rag_params.reranking_model = model.id + elif isinstance(gen_config, CodeGenConfig): + if col.is_vector_column: + raise BadInputError( + f'Table "{table_id}": ' + f'Code Execution column "{col.column_id}" must not be a vector column.' + ) + if col.dtype not in (ColumnDtype.STR, ColumnDtype.IMAGE, ColumnDtype.AUDIO): + raise BadInputError( + f'Table "{table_id}": ' + f'Code Execution column "{col.column_id}" must be a string (text) or image column or audio column.' + ) + valid_col_ids = [c.column_id for c in available_cols if c.dtype == ColumnDtype.STR] + if (not allow_nonexistent_refs) and ( + gen_config.source_column not in valid_col_ids + ): + raise BadInputError( + ( + f'Table "{table_id}": ' + f'Code Execution config of column "{col.column_id}" referenced ' + f'an invalid source column "{gen_config.source_column}". ' + "Make sure you only reference string (text) columns on its left. " + f"Available columns: {valid_col_ids}." + ) + ) + elif isinstance(gen_config, PythonGenConfig): + if col.is_vector_column: + raise BadInputError( + f'Table "{table_id}": ' + f'Python Function column "{col.column_id}" must not be a vector column.' + ) + if col.dtype not in (ColumnDtype.STR, ColumnDtype.IMAGE, ColumnDtype.AUDIO): + raise BadInputError( + f'Table "{table_id}": ' + f'Python Function column "{col.column_id}" must be a string (text) or image column or audio column.' + ) - def add_columns( - self, session: Session, schema: TableSchemaCreate - ) -> tuple[LanceTable, TableMeta]: + return column_metadata_list + + @classmethod + async def _create_table( + cls, + *, + project_id: str, + table_type: TableType, + table_metadata: TableMetadata, + column_metadata_list: list[ColumnMetadata], + request_id: str = "", + set_default_prompts: bool = True, + replace_unavailable_models: bool = False, + allow_nonexistent_refs: bool = False, + create_indexes: bool = True, + ) -> Self: """ - Adds one or more input or output column. + Create a new table. + This method is created so that the public method `create_table` can be overridden + without affecting table creation logic. Args: - session (Session): SQLAlchemy session. - schema (TableSchemaCreate): Schema of the columns to be added. - - Raises: - ResourceNotFoundError: If the table is not found. - ValueError: If any of the columns exists. + project_id (str): Project ID. + table_type (str): Table type. + table_metadata (TableMetadata): Table metadata. + column_metadata_list (list[ColumnMetadata]): List of column metadata. + request_id (str, optional): Request ID for logging. Defaults to "". + set_default_prompts (bool, optional): Set default prompts. + Useful when importing table which does not need to set prompts. Defaults to True. + replace_unavailable_models (bool, optional): Replace unavailable models with default models. + Useful when importing old tables. Defaults to False. + allow_nonexistent_refs (bool, optional): Ignore non-existent column and Knowledge Table references. + Otherwise will raise an error. Useful when importing old tables and performing maintenance. + Defaults to False. + create_indexes (bool, optional): Create indexes for the table. + Setting to False can be useful when importing tables + where you want to create indexes after all rows are added. + Defaults to True. Returns: - table (LanceTable): Lance table. - meta (TableMeta): Table metadata. - """ - if not isinstance(schema, TableSchema): - raise TypeError("`schema` must be an instance of `TableSchema`.") - table_id = schema.id - # Check - meta = self.open_meta(session, table_id) - schema = schema.add_state_cols() - cols = meta.cols_schema + schema.cols - if len(set(c.id for c in cols)) != len(cols): - raise make_validation_error( - ValueError("Schema and table contain overlapping column names."), - loc=("body", "cols"), + self (GenerativeTableCore): The table instance. + """ + schema_id = f"{project_id}_{table_type}" + + ### --- VALIDATIONS --- ### + # Override info and state columns + column_metadata_list = [ + col for col in column_metadata_list if not (col.is_info_column or col.is_state_column) + ] + state_columns = [ + ColumnMetadata( + table_id=table_metadata.table_id, + column_id=f"{col.column_id}_", + dtype=ColumnDtype.JSON, ) - meta.cols = [ - c.model_dump() - for c in TableSchema(id=meta.id, cols=[c.model_dump() for c in cols]).cols + for col in column_metadata_list ] + info_columns = [ + ColumnMetadata( + table_id=table_metadata.table_id, + column_id="ID", + dtype=ColumnDtype.STR, + ), + ColumnMetadata( + table_id=table_metadata.table_id, + column_id="Updated at", + dtype=ColumnDtype.DATE_TIME, + ), + ] + column_metadata_list = info_columns + column_metadata_list + state_columns - with self.lock(table_id): - # Add columns to LanceDB - table = self.open_table(table_id) - # Non-vector columns can be added using SQL statement - # TODO: Investigate adding vector columns using BatchUDF - cols_to_add = { - c.id: f"{_py_type_default[c.dtype]}" for c in schema.cols if c.vlen == 0 - } - if len(cols_to_add) > 0: - table.add_columns(cols_to_add) - # Add vector columns to Lance Table using merge op (this is very slow) - vectors = [ - [np.zeros(shape=[c.vlen], dtype=c.dtype)] for c in schema.cols if c.vlen > 0 - ] - if len(vectors) > 0: - _id = table.search().limit(1).to_list() - _id = _id[0]["ID"] if len(_id) > 0 else "0" - vec_schema = schema.pa_vec_schema - vec_schema = vec_schema.insert(0, table.schema.field("ID")) - pa_table = pa.table([[_id]] + vectors, schema=vec_schema) - table.merge(pa_table, left_on="ID") - - # Add Table Metadata - meta.updated_at = datetime_now_iso() - session.add(meta) - session.commit() - session.refresh(meta) - return table, meta - - def _drop_columns( - self, - session: Session, - table_id: TableName, - col_names: list[ColName], - ) -> tuple[LanceTable, TableMeta]: - """ - NOTE: This is broken until lance issue is resolved - https://github.com/lancedb/lancedb/pull/1227 + ### --- Create metadata tables --- ### + await cls.create_schemas(project_id) + async with GENTABLE_ENGINE.transaction() as conn: + # Validate column metadata + await cls._check_columns( + conn=conn, + project_id=project_id, + table_type=table_type, + table_metadata=table_metadata, + column_metadata_list=column_metadata_list, + set_default_prompts=set_default_prompts, + replace_unavailable_models=replace_unavailable_models, + allow_nonexistent_refs=allow_nonexistent_refs, + ) + # Override column order + for i, col_meta in enumerate(column_metadata_list): + col_meta.column_order = i + ### --- Create data table --- ### + # Create the data table + await cls._create_data_table( + conn=conn, + schema_id=schema_id, + table_metadata=table_metadata, + column_metadata_list=column_metadata_list, + create_indexes=create_indexes, + ) + # Create metadata entries + await cls._upsert_table_metadata(conn, schema_id, table_metadata) + for col_metadata in column_metadata_list: + await cls._upsert_column_metadata(conn, schema_id, col_metadata) + # Reload table + async with GENTABLE_ENGINE.transaction() as conn: + return await cls._open_table( + conn=conn, + project_id=project_id, + table_type=table_type, + table_id=table_metadata.table_id, + request_id=request_id, + ) - Drops one or more input or output column. + async def _count_rows(self, conn: Connection) -> int: + """ + Count the number of rows. Args: - session (Session): SQLAlchemy session. - table_id (str): Table ID. - col_names (list[str]): List of column ID to drop. - - Raises: - TypeError: If `col_names` is not a list. - ResourceNotFoundError: If the table is not found. - ResourceNotFoundError: If any of the columns is not found. + conn (Connection): PostgreSQL connection. Returns: - table (LanceTable): Lance table. - meta (TableMeta): Table metadata. - """ - if not isinstance(col_names, list): - raise TypeError("`col_names` must be a list.") - if self.has_state_col_names(col_names): - raise make_validation_error( - ValueError("Cannot drop state columns."), - loc=("body", "column_names"), - ) - if self.has_info_col_names(col_names): - raise make_validation_error( - ValueError('Cannot drop "ID" or "Updated at".'), - loc=("body", "column_names"), + num_rows (int): Number of rows in the table. + """ + # If we don't need a 100% exact count and a very fast, rough estimate is good enough + # SELECT reltuples::bigint AS estimate FROM pg_class WHERE relname = 'your_table'; + try: + self.num_rows = await conn.fetchval( + f'SELECT COUNT("ID") FROM "{self.schema_id}"."{self.short_table_id}"' ) - with self.lock(table_id): - meta = self.open_meta(session, table_id) - col_names += [f"{n}_" for n in col_names] - table = self.open_table(table_id) + except (UndefinedTableError, UndefinedColumnError) as e: + logger.error( + ( + f'Data table `"{self.schema_id}"."{self.short_table_id}"` ' + "is not found but table and column metadata exist !!! " + f"Error: {repr(e)}" + ) + ) + raise ResourceNotFoundError( + f'Table "{self.table_id}" is not found. Please contact support if this is unexpected.' + ) from e + # await conn.fetch("SET LOCAL enable_seqscan = off;") + return self.num_rows + + @classmethod + async def _open_table( + cls, + conn: Connection, + *, + project_id: str, + table_type: TableType, + table_id: str, + request_id: str = "", + ) -> Self: + """ + Open an existing table. + + Args: + conn (Connection): PostgreSQL connection. + project_id (str): Project ID. + table_type (str): Table type. + table_id (str): Name of the table. + request_id (str, optional): Request ID for logging. Defaults to "". + + Returns: + self (GenerativeTableCore): The table instance. + """ + schema_id = f"{project_id}_{table_type}" + + ### --- Read table and column metadata --- ### + try: + table_metadata = await conn.fetchrow( + f'SELECT * FROM "{schema_id}"."TableMetadata" WHERE table_id = $1', table_id + ) + except UndefinedTableError as e: + raise ResourceNotFoundError(f'Table "{table_id}" is not found.') from e + except Exception as e: + raise BadInputError(e) from e + if table_metadata is None: + raise ResourceNotFoundError(f'Table metadata for "{table_id}" is not found.') + try: + column_metadata = await conn.fetch( + f'SELECT * FROM "{schema_id}"."ColumnMetadata" WHERE table_id = $1 ORDER BY column_order ASC', + table_id, + ) + except UndefinedTableError as e: + raise ResourceNotFoundError(f'Table "{table_id}" is not found.') from e + except Exception as e: + raise BadInputError(e) from e + if len(column_metadata) == 0: + raise ResourceNotFoundError(f'Column metadata for "{table_id}" is not found.') + self = cls( + project_id=project_id, + table_type=table_type, + table_metadata=TableMetadata.model_validate(dict(table_metadata)), + column_metadata_list=[ + ColumnMetadata.model_validate(dict(col)) for col in column_metadata + ], + request_id=request_id, + ) + await self._count_rows(conn) + return self + + async def _reload_table(self, conn: Connection) -> Self: + self = await self._open_table( + conn=conn, + project_id=self.project_id, + table_type=self.table_type, + table_id=self.table_id, + request_id=self.request_id, + ) + await self._check_columns( + conn=conn, + project_id=self.project_id, + table_type=self.table_type, + table_metadata=self.table_metadata, + column_metadata_list=self.column_metadata, + set_default_prompts=False, + replace_unavailable_models=False, + ) + return self + + @staticmethod + async def _recreate_fts_index( + conn: Connection, + *, + schema_id: str, + table_id: str, + columns: list[str], + ) -> None: + if len(columns) == 0: + return + index_id = fts_index_id(table_id) + await conn.execute(f'DROP INDEX IF EXISTS "{schema_id}"."{index_id}"') + await conn.execute( + f""" + CREATE INDEX "{index_id}" + ON "{schema_id}"."{get_internal_id(table_id)}" + USING pgroonga ((ARRAY[{", ".join(f'"{get_internal_id(col)}"' for col in columns)}])); + """, + timeout=300.0, + ) + + @staticmethod + async def _recreate_vector_index( + conn: Connection, + *, + schema_id: str, + table_id: str, + columns: list[str], + ) -> None: + if len(columns) == 0: + return + # pgvector doesn't support multi-column index, as of: 2025-03-04 + for col in columns: + index_id = vector_index_id(table_id, col) + await conn.execute(f'DROP INDEX IF EXISTS "{schema_id}"."{index_id}"') + await conn.execute( + f""" + CREATE INDEX "{index_id}" + ON "{schema_id}"."{get_internal_id(table_id)}" + USING diskann ("{get_internal_id(col)}" vector_cosine_ops); + """, + timeout=600.0, + ) + + @staticmethod + def _state_column_sql(short_column_id: str) -> str: + return f""""{short_column_id}_" JSONB NOT NULL DEFAULT '{{}}'::JSONB""" + + @classmethod + async def _create_data_table( + cls, + conn: Connection, + *, + schema_id: str, + table_metadata: TableMetadata, + column_metadata_list: list[ColumnMetadata], + create_indexes: bool = True, + ) -> None: + table_id = table_metadata.table_id + # All data table have "ID" and "Updated at" columns + column_defs = [] + column_defs.append('"ID" UUID PRIMARY KEY') + column_defs.append('"Updated at" TIMESTAMPTZ') + + # Generate the SQL column definitions for the CREATE TABLE statement + text_cols = [] + vec_cols = [] + for col in column_metadata_list: + if col.is_info_column or col.is_state_column: + continue + dtype = col.dtype + if col.is_vector_column: + dtype = f"VECTOR({col.vlen})" + vec_cols.append(col.column_id) + else: + dtype = dtype.to_postgres_type() + if col.is_text_column: + text_cols.append(col.column_id) + column_defs.append(f'"{col.short_id}" {dtype}') + column_defs.append(cls._state_column_sql(col.short_id)) + try: + # Create the table in the database + await conn.execute(f""" + CREATE TABLE "{schema_id}"."{table_metadata.short_id}" ( + {", ".join(column_defs)} + ); + """) + if create_indexes: + await cls._recreate_fts_index( + conn, + schema_id=schema_id, + table_id=table_id, + columns=text_cols, + ) + await cls._recreate_vector_index( + conn, + schema_id=schema_id, + table_id=table_id, + columns=vec_cols, + ) + except DuplicateTableError as e: + raise ResourceExistsError(f'Table "{table_id}" already exists.') from e + + @staticmethod + async def _upsert_table_metadata( + conn: Connection, + schema_id: str, + table_metadata: TableMetadata, + ) -> None: + query = f""" + INSERT INTO "{schema_id}"."TableMetadata" ( + table_id, short_id, title, parent_id, created_by, updated_at, version, meta + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (table_id) DO UPDATE SET + title = COALESCE(EXCLUDED.title, "TableMetadata".title), + parent_id = EXCLUDED.parent_id, + created_by = EXCLUDED.created_by, + updated_at = COALESCE(EXCLUDED.updated_at, "TableMetadata".updated_at), + version = COALESCE(EXCLUDED.version, "TableMetadata".version), + meta = COALESCE(EXCLUDED.meta, "TableMetadata".meta); + """ + values = [ + table_metadata.table_id, + table_metadata.short_id, + table_metadata.title, + table_metadata.parent_id, + table_metadata.created_by, + table_metadata.updated_at, + table_metadata.version, + table_metadata.meta, + ] + await conn.execute(query, *values) + + @staticmethod + async def _upsert_column_metadata( + conn: Connection, + schema_id: str, + column_metadata: ColumnMetadata, + ) -> None: + query = f""" + INSERT INTO "{schema_id}"."ColumnMetadata" ( + table_id, column_id, short_table_id, short_id, dtype, vlen, gen_config, column_order, version, meta + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ON CONFLICT (table_id, column_id) DO UPDATE SET + dtype = COALESCE(EXCLUDED.dtype, "ColumnMetadata".dtype), + vlen = COALESCE(EXCLUDED.vlen, "ColumnMetadata".vlen), + gen_config = EXCLUDED.gen_config, + column_order = EXCLUDED.column_order, + version = COALESCE(EXCLUDED.version, "ColumnMetadata".version), + meta = COALESCE(EXCLUDED.meta, "ColumnMetadata".meta); + """ + values = [ + column_metadata.table_id, + column_metadata.column_id, + column_metadata.short_table_id, + column_metadata.short_id, + column_metadata.dtype, + column_metadata.vlen, + column_metadata.gen_config.model_dump() if column_metadata.gen_config else None, + column_metadata.column_order, + column_metadata.version, + column_metadata.meta, + ] + await conn.execute(query, *values) + + async def _set_updated_at( + self, + conn: Connection, + updated_at: datetime | None = None, + ) -> None: + if updated_at is None: + updated_at = now() + stmt = f'UPDATE "{self.schema_id}"."TableMetadata" SET "updated_at" = $1 WHERE "table_id" = $2;' + await conn.execute(stmt, updated_at, self.table_id) + self.table_metadata.updated_at = updated_at + + @staticmethod + def _create_data_table_row_model( + table_id: str, + columns: list["ColumnMetadata"], + ) -> Type[DataTableRow]: + """ + Dynamically creates the Pydantic model class for a data table row. + + Args: + table_id (str): Table ID. + columns (list[ColumnMetadata]): List of column metadata. + + Returns: + model_cls (Type[DataTableRow]): The Pydantic model class. + """ + + @field_validator("ID", mode="before") + @classmethod + def id_validator(cls, v: Any): + if isinstance(v, UUID): + return str(v) + return v + + field_definitions = { + "ID": ( + str, + Field(default_factory=uuid7_draft2_str, description="Row ID."), + ), + "Updated at": ( + DatetimeUTC, + Field(default_factory=now, description="Last updated timestamp."), + ), + } + validators = { + "validate_id": id_validator, + } + + for col in columns: + if col.is_info_column or col.is_state_column: + continue + if col.is_vector_column: + # Create vector validator + def create_vector_validator(col: ColumnMetadata): + @field_validator(col.column_id, mode="after") + @classmethod + def vector_validator(cls, v: np.ndarray | None): + if v is not None and len(v) != col.vlen: + raise ValueError( + f"Array input for column {col.column_id} must have length {col.vlen}" + ) + return v + + return vector_validator + + validators[f"validate_{col.column_id}"] = create_vector_validator(col) + field_definitions[col.column_id] = (NumpyArray | None, Field(default=None)) + else: + # Get the Python type from ColumnDtype + py_type = col.dtype.to_python_type() + field_definitions[col.column_id] = (py_type | None, Field(default=None)) + # Add state column (ending with '_') + state_col_id = f"{col.column_id}_" + field_definitions[state_col_id] = ( + dict[str, Any], + Field(default={}, description=f"State of {col.column_id} column."), + ) + + return create_model( + table_id, + **field_definitions, + __base__=DataTableRow, + __validators__=validators, + ) + + @classmethod + async def create_schemas(cls, project_id: str) -> None: + """ + Create the project's schemas and metadata tables. + """ + try: + async with GENTABLE_ENGINE.transaction() as conn: + for table_type in TableType: + schema_id = f"{project_id}_{table_type}" + await conn.execute(f'CREATE SCHEMA IF NOT EXISTS "{schema_id}"') + await conn.execute(TableMetadata.sql_create(schema_id)) + await conn.execute(ColumnMetadata.sql_create(schema_id)) + except (UniqueViolationError, DuplicateTableError): + # Just to be safe, even though catching `UniqueViolationError` is sufficient + return + + @classmethod + async def drop_schemas(cls, project_id: str) -> None: + """ + Drops the project's schemas along with all metadata and data tables. + """ + async with GENTABLE_ENGINE.transaction() as conn: + for table_type in TableType: + schema_id = f"{project_id}_{table_type}" + await conn.execute(f'DROP SCHEMA IF EXISTS "{schema_id}" CASCADE') + + @classmethod + async def drop_schema( + cls, + *, + project_id: str, + table_type: TableType, + ) -> None: + """ + Drops the project's schema along with all metadata and data tables. + """ + schema_id = f"{project_id}_{table_type}" + async with GENTABLE_ENGINE.transaction() as conn: + await conn.execute(f'DROP SCHEMA IF EXISTS "{schema_id}" CASCADE') + + ### --- Table CRUD --- ### + + # Table Create Ops + @classmethod + async def create_table( + cls, + *, + project_id: str, + table_type: TableType, + table_metadata: TableMetadata, + column_metadata_list: list[ColumnMetadata], + set_default_prompts: bool = True, + ) -> Self: + """ + Create a new table. + + Args: + project_id (str): Project ID. + table_type (str): Table type. + table_metadata (TableMetadata): Table metadata. + column_metadata_list (list[ColumnMetadata]): List of column metadata. + set_default_prompts (bool, optional): If True, set default prompts. + Useful when importing table which does not need to set prompts. Defaults to True. + + Returns: + self (GenerativeTableCore): The table instance. + """ + return await cls._create_table( + project_id=project_id, + table_type=table_type, + table_metadata=table_metadata, + column_metadata_list=column_metadata_list, + set_default_prompts=set_default_prompts, + ) + + @classmethod + async def duplicate_table( + cls, + *, + project_id: str, + table_type: TableType, + table_id_src: str, + table_id_dst: str | None = None, + include_data: bool = True, + create_as_child: bool = False, + created_by: str | None = None, + request_id: str = "", + ) -> Self: + """ + Duplicate an existing table including schema, data and metadata. + + Args: + project_id (str): Project ID. + table_type (str): Table type. + table_id_src (str): Name of the table to be duplicated. + table_id_dst (str | None, optional): Name for the new table. + Defaults to None (automatically find the next available table name). + include_data (bool, optional): If True, include data. Defaults to True. + create_as_child (bool, optional): If True, create the new table as a child of the source table. + Defaults to False. + created_by (str | None, optional): User ID of the user who created the table. + Defaults to None. + request_id (str, optional): Request ID for logging. Defaults to "". + + Raises: + BadInputError: If `table_id_dst` is not None or a non-empty string. + ResourceNotFoundError: If table or column metadata cannot be found. + + Returns: + self (GenerativeTableCore): The duplicated table instance. + """ + schema_id = f"{project_id}_{table_type}" + if create_as_child: + include_data = True + if isinstance(table_id_dst, str): + table_id_dst = table_id_dst.strip() + async with GENTABLE_ENGINE.transaction() as conn: + try: + if table_id_dst: + try: + table_metadata = await conn.fetchrow( + f'SELECT * FROM "{schema_id}"."TableMetadata" WHERE table_id = $1', + table_id_dst, + ) + except UndefinedTableError as e: + # TableMetadata does not exist, meaning this schema is empty + raise ResourceNotFoundError(f'Table "{table_id_src}" not found.') from e + if table_metadata is not None: + raise ResourceExistsError(f'Table "{table_id_dst}" already exists.') + else: + # Might need to truncate table name + now_str = now().strftime("%Y-%m-%d-%H-%M-%S") + base_name = f"{truncate_table_id(table_id_src)} {now_str}" + # Automatically find the next available table name + # The function will raise UndefinedTableError if the table does not exist + await conn.execute( + f""" + CREATE OR REPLACE FUNCTION duplicate_table() + RETURNS TEXT AS $$ + DECLARE + new_table_name TEXT; + suffix INTEGER := 1; + max_iterations INTEGER := {TABLE_ID_DST_MAX_ITER}; + BEGIN + -- Loop to find the next available table name + WHILE suffix <= max_iterations LOOP + new_table_name := format('%s (%s)', '{base_name}', suffix); + -- Check if the new table name already exists + IF NOT EXISTS ( + SELECT 1 FROM "{schema_id}"."TableMetadata" + WHERE table_id = new_table_name + ) THEN + RETURN new_table_name; -- Return the new table name + END IF; + suffix := suffix + 1; + END LOOP; + -- If we've reached the maximum number of iterations without finding an available name + RETURN NULL; -- Return NULL to indicate failure + END; + $$ LANGUAGE plpgsql; + """, + ) + table_id_dst: str | None = await conn.fetchval("SELECT duplicate_table();") + if table_id_dst is None: + raise ResourceExistsError( + f'Could not find a name for table "{table_id_src}" after {TABLE_ID_DST_MAX_ITER:,d} attempts.' + ) + # Create the data table + # Exclude indexes to set our own index name + short_id_src = get_internal_id(table_id_src) + short_id_dst = get_internal_id(table_id_dst) + if include_data: + await conn.execute( + f'CREATE TABLE "{schema_id}"."{short_id_dst}" AS TABLE "{schema_id}"."{short_id_src}"' + ) + else: + await conn.execute( + ( + f'CREATE TABLE "{schema_id}"."{short_id_dst}" ' + f'(LIKE "{schema_id}"."{short_id_src}" INCLUDING ALL EXCLUDING INDEXES)' + ) + ) + + # It's required to explicitly add primary key + await conn.execute( + f'ALTER TABLE "{schema_id}"."{short_id_dst}" ADD PRIMARY KEY ("ID")' + ) + # Copy metadata + table_meta = await conn.fetchrow( + f'SELECT * FROM "{schema_id}"."TableMetadata" WHERE table_id = $1', + table_id_src, + ) + if table_meta is None: + raise ResourceNotFoundError( + f'Table metadata for "{table_id_src}" is not found.' + ) + table_meta = dict(table_meta) + table_meta["table_id"] = table_id_dst + table_meta.pop("short_id", None) + table_meta["created_by"] = created_by + if create_as_child: + table_meta["parent_id"] = table_id_src + table_meta = TableMetadata.model_validate(table_meta) + await cls._upsert_table_metadata(conn, schema_id, table_meta) + + # Copy column metadata + column_metas = await conn.fetch( + f'SELECT * FROM "{schema_id}"."ColumnMetadata" WHERE table_id = $1', + table_id_src, + ) + if len(column_metas) == 0: + raise ResourceNotFoundError( + f'Column metadata for "{table_id_src}" is not found.' + ) + column_metas = [ColumnMetadata.model_validate(dict(m)) for m in column_metas] + for meta in column_metas: + meta.table_id = table_meta.table_id + meta.short_table_id = table_meta.short_id + await cls._upsert_column_metadata(conn, schema_id, meta) + + # Recreate indexes + text_cols = [col.column_id for col in column_metas if col.is_text_column] + vector_cols = [col.column_id for col in column_metas if col.is_vector_column] + await cls._recreate_fts_index( + conn, schema_id=schema_id, table_id=table_id_dst, columns=text_cols + ) + await cls._recreate_vector_index( + conn, schema_id=schema_id, table_id=table_id_dst, columns=vector_cols + ) + + return await cls._open_table( + conn=conn, + project_id=project_id, + table_type=table_type, + table_id=table_id_dst, + request_id=request_id, + ) + + except UndefinedTableError as e: + raise ResourceNotFoundError(f'Table "{table_id_src}" is not found.') from e + except DuplicateTableError as e: + raise ResourceExistsError(f'Table "{table_id_dst}" already exists.') from e + except ValidationError as e: + raise BadInputError(str(e)) from e + + # Table Create Ops + @classmethod + async def open_table( + cls, + *, + project_id: str, + table_type: TableType, + table_id: str, + created_by: str | None = None, + request_id: str = "", + ) -> Self: + """ + Open an existing table. + + Args: + project_id (str): Project ID. + table_type (str): Table type. + table_id (str): Name of the table. + created_by (str | None, optional): User who created the table. + If provided, will check if the table was created by the user. Defaults to None (any user). + request_id (str, optional): Request ID for logging. Defaults to "". + + Returns: + self (GenerativeTableCore): The table instance. + """ + async with GENTABLE_ENGINE.transaction() as conn: + table = await cls._open_table( + conn=conn, + project_id=project_id, + table_type=table_type, + table_id=table_id, + request_id=request_id, + ) + if created_by is not None and table.table_metadata.created_by != created_by: + raise ResourceNotFoundError(f'Table "{table_id}" not found.') + return table + + @classmethod + async def list_tables( + cls, + *, + project_id: str, + table_type: TableType, + limit: int | None = 100, + offset: int = 0, + order_by: Literal["id", "updated_at"] = "updated_at", + order_ascending: bool = True, + created_by: str | None = None, + parent_id: str | None = None, + search_query: str = "", + search_columns: list[str] = None, + count_rows: bool = False, + ) -> Page[TableMetaResponse]: + """ + List tables. + + Args: + project_id (str): Project ID. + limit (int | None, optional): Maximum number of tables to return. + Defaults to 100. Pass None to return all tables. + offset (int, optional): Offset for pagination. Defaults to 0. + order_by (Literal["id", "updated_at"], optional): Sort tables by this attribute. + Defaults to "updated_at". + order_ascending (bool, optional): Whether to sort by ascending order. + Defaults to True. + created_by (str | None, optional): Return tables created by this user. + Defaults to None (return all tables). + parent_id (str | None, optional): Parent ID of tables to return. + Defaults to None (no parent ID filtering). + Additionally for Chat Table, you can list: + (1) all chat agents by passing in "_agent_"; or + (2) all chats by passing in "_chat_". + search_query (str, optional): A string to search for within table names. + The string is interpreted as both POSIX regular expression and literal string. + Defaults to "". + search_columns (list[str], optional): List of columns to search within. + Defaults to None (search table ID). + count_rows (bool, optional): Whether to count the rows of the tables. + Defaults to False. + + Returns: + tables (Page[TableMetaResponse]): List of tables. + """ + schema_id = f"{project_id}_{table_type}" + search_query = search_query.strip() + filters = [] + params = [] + if search_columns is None: + search_columns = ["table_id"] + if created_by: + params.append(str(created_by)) + filters.append(f"(created_by = ${len(params)})") + if parent_id: + if parent_id == "_agent_": + filters.append("(parent_id IS NULL)") + elif parent_id == "_chat_": + filters.append("(parent_id IS NOT NULL)") + else: + params.append(parent_id) + filters.append(f"(parent_id = ${len(params)})") + if search_query: + search_filters = [] + for search_column in search_columns: + search_column = "table_id" if search_column == "id" else search_column + # Literal (escaped) search + params.append(re.escape(search_query)) + literal_expr = f"({search_column}::text ~* ${len(params)})" + # Regex search + params.append(search_query) + regex_expr = f"({search_column}::text ~* ${len(params)})" + search_filters.append(f"({literal_expr} OR {regex_expr})") + filters.append("(" + " OR ".join(search_filters) + ")") + if order_by == "id": + order_by = "table_id" + if order_by in TableMetadata.str_cols(): + order_by = f'LOWER("{order_by}")' + else: + order_by = f'"{order_by}"' + order_direction = "ASC" if order_ascending else "DESC" + where = f"WHERE {' AND '.join(filters)}" if len(filters) > 0 else "" + async with GENTABLE_ENGINE.transaction() as conn: + try: + total = await conn.fetchval( + f'SELECT COUNT(*) FROM "{schema_id}"."TableMetadata" {where}', + *params, + ) + sql = f""" + SELECT * FROM "{schema_id}"."TableMetadata" {where} + ORDER BY {order_by} {order_direction} + """ + if limit is not None: + params.append(limit) + sql += f" LIMIT ${len(params)}" + table_metas = await conn.fetch(f"{sql} OFFSET ${len(params) + 1}", *params, offset) + except UndefinedColumnError as e: + # raise ResourceNotFoundError(f'Attribute "{order_by}" is not found.') from e + raise ResourceNotFoundError(str(e)) from e + except UndefinedTableError: + total = 0 + return Page[TableMetaResponse]( + items=[], + offset=offset, + limit=total if limit is None else limit, + total=total, + ) + meta_responses = [] + for table_meta in table_metas: + table_meta = TableMetadata.model_validate(dict(table_meta)) + column_metas = await conn.fetch( + f""" + SELECT * FROM "{schema_id}"."ColumnMetadata" + WHERE table_id = $1 ORDER BY column_order ASC + """, + table_meta.table_id, + ) + column_metas = [ColumnMetadata.model_validate(dict(col)) for col in column_metas] + if count_rows: + num_rows = await conn.fetchval( + f'SELECT COUNT("ID") FROM "{schema_id}"."{table_meta.short_id}"' + ) + else: + num_rows = -1 + meta_responses.append( + TableMetaResponse( + id=table_meta.table_id, + cols=[ + ColumnSchema( + id=col.column_id, + dtype=col.dtype, + vlen=col.vlen, + gen_config=col.gen_config, + ) + for col in column_metas + ], + parent_id=table_meta.parent_id, + title=table_meta.title, + created_by=table_meta.created_by, + updated_at=table_meta.updated_at.isoformat(), + num_rows=num_rows, + version=table_meta.version, + meta=table_meta.meta, + ) + ) + return Page[TableMetaResponse]( + items=meta_responses, + offset=offset, + limit=total if limit is None else limit, + total=total, + ) + + async def count_rows(self) -> int: + """ + Count the number of rows. + + Returns: + num_rows (int): Number of rows in the table. + """ + async with GENTABLE_ENGINE.transaction() as conn: + return await self._count_rows(conn) + return self.num_rows + + # Table Update Ops + async def rename_table(self, table_id_dst: TableName) -> Self: + """ + Rename a table. + + Args: + table_id_dst (str): New name for the table. + + Raises: + ResourceNotFoundError: If the table is not found. + ResourceExistsError: If the table already exists. + + Returns: + self (GenerativeTableCore): The renamed table instance. + """ + table_id_src = self.table_id + short_id_src = self.short_table_id + short_id_dst = get_internal_id(table_id_dst) + async with GENTABLE_ENGINE.transaction() as conn: + try: + # Rename data table + await conn.execute( + f'ALTER TABLE "{self.schema_id}"."{short_id_src}" RENAME TO "{short_id_dst}"' + ) + # Rename primary key index (only for consistency purposes, no operational impact even without rename) + await conn.execute( + f""" + ALTER TABLE "{self.schema_id}"."{short_id_dst}" + RENAME CONSTRAINT "{short_id_src}_pkey" TO "{short_id_dst}_pkey" + """ + ) + # Rename indexes + await conn.execute( + f""" + ALTER INDEX "{self.schema_id}"."{fts_index_id(table_id_src)}" + RENAME TO "{fts_index_id(table_id_dst)}" + """ + ) + for col in self.vector_column_names: + await conn.execute( + f""" + ALTER INDEX "{self.schema_id}"."{vector_index_id(table_id_src, col)}" + RENAME TO "{vector_index_id(table_id_dst, col)}" + """ + ) + # Update table metadata entry + await conn.execute( + f""" + UPDATE "{self.schema_id}"."TableMetadata" + SET table_id = $1, short_id = $2 WHERE table_id = $3 + """, + table_id_dst, + short_id_dst, + table_id_src, + ) + # Update any child tables' parent_id references + await conn.execute( + f'UPDATE "{self.schema_id}"."TableMetadata" SET parent_id = $1 WHERE parent_id = $2', + table_id_dst, + table_id_src, + ) + self.table_id = table_id_dst + # Set updated at time + await self._set_updated_at(conn) + return self + except UndefinedTableError as e: + # Index or table not found + raise ResourceNotFoundError(f'Table "{table_id_src}" is not found.') from e + except DuplicateTableError as e: + raise ResourceExistsError(f'Table "{table_id_dst}" already exists.') from e + + async def update_table_title(self, title: str) -> Self: + """ + Update the table title. + """ + updated_at = now() + query = f""" + UPDATE "{self.schema_id}"."TableMetadata" + SET title = $1, updated_at = $2 + WHERE table_id = $3; + """ + async with GENTABLE_ENGINE.transaction() as conn: + await conn.execute(query, title, updated_at, self.table_id) + self.table_metadata.title = title + self.table_metadata.updated_at = updated_at + return self + + # Table Delete Ops + async def drop_table(self) -> None: + """ + Drop the table. + + Raises: + ResourceNotFoundError: If the table is not found. + """ + async with GENTABLE_ENGINE.transaction() as conn: + try: + # Drop the data table + await conn.execute( + f'DROP TABLE IF EXISTS "{self.schema_id}"."{self.short_table_id}" CASCADE' + ) + # Drop row from table metadata, this will cascade to the associated column metadata + await conn.execute( + f'DELETE FROM "{self.schema_id}"."TableMetadata" WHERE table_id = $1', + self.table_id, + ) + except UndefinedTableError as e: + raise ResourceNotFoundError(f'Table "{self.table_id}" is not found.') from e + + @staticmethod + def _coerce_column_to_pa_dtype( + data: list[Any], + dtype: pa.DataType, + ) -> pa.Array: + """Convert column data to appropriate Arrow array type""" + if len(data) == 0: + return pa.array([], dtype) + if isinstance(data[0], UUID): + data = [str(d) for d in data] + elif isinstance(data[0], dict): + data = [json_dumps(d) for d in data] + return pa.array(data, dtype) + + # Table Import Export Ops + async def export_table( + self, + dest: str | Path | BinaryIO, + *, + compression: Literal["NONE", "ZSTD", "LZ4", "SNAPPY"] = "ZSTD", + verbose: bool = False, + ) -> None: + """ + Export a table's data and metadata to a specified output path. + + Args: + output_path (str | Path): Path to save the exported data. + compression (str, optional): Compression type for the output file. + Options are "NONE", "ZSTD", "LZ4", or "SNAPPY". Defaults to "ZSTD". + verbose (bool, optional): If True, will produce verbose logging messages. + Defaults to False. + + Raises: + ResourceNotFoundError: If the output path is invalid. + """ + log_level = "INFO" if verbose else "DEBUG" + if isinstance(dest, (str, Path)): + dest = Path(dest) + if dest.is_dir(): + dest = dest / f"{self.table_id}.parquet" + else: + if (suffix := Path(dest).suffix) != ".parquet": + raise BadInputError(f'Output extension "{suffix}" is invalid.') + rows: list[dict[str, Any]] = ( + await self.list_rows( + limit=None, + offset=0, + order_by=["ID"], + order_ascending=True, + columns=None, + remove_state_cols=False, + ) + ).items + col_dtype_map = { + col.column_id: pa.list_(pa.float32()) + if col.is_vector_column + else col.dtype.to_pyarrow_type() + for col in self.column_metadata + } + + # Add file data into Arrow Table + async def _download(uri: str) -> tuple[str, bytes, str]: + async with semaphore: + try: + async with open_uri_async(uri) as (f, mime): + return (uri, await f.read(), mime) + except ResourceNotFoundError: + return (uri, b"", "") + + async def _download_files(col_ids: list[str]) -> dict[str, tuple[bytes, str]]: + download_coros = [] + _uri_bytes: dict[str, tuple[bytes, str]] = {} + for col_id in col_ids: + if f"{col_id}__" in col_dtype_map: + raise BadInputError(f'Table "{self.table_id}" has bad column "{col_id}__".') + for row in rows: + uri = row[col_id] + if uri in _uri_bytes: + continue + # Create the coroutine + download_coros.append(_download(uri)) + _uri_bytes[uri] = (b"", "") + self._log( + ( + f'Importing table "{self.table_id}": ' + f"Downloading {len(download_coros):,d} files " + f"with concurrency limit of {S3_MAX_CONCURRENCY}." + ), + log_level, + ) + for fut in asyncio.as_completed(download_coros): + uri, content, mime = await fut + _uri_bytes[uri] = (content, mime) + return _uri_bytes + + semaphore = Semaphore(S3_MAX_CONCURRENCY) + pa_file_columns = [] + self._log( + f'Importing table "{self.table_id}": Downloading files in file columns.', + log_level, + ) + file_col_ids = [col.column_id for col in self.column_metadata if col.is_file_column] + uri_bytes = await _download_files(file_col_ids) + uris_seen = set() + for col_id in file_col_ids: + col_bytes = [] + for row in rows: + uri = row[col_id] + if uri in uris_seen: + col_bytes.append(b"") + continue + content, mime = uri_bytes.get(uri, (b"", "")) + col_bytes.append(content) + if mime: + row[f"{col_id}_"].update({"_mime_type": mime}) + uris_seen.add(uri) + if len(col_bytes) > 0: + pa_file_columns.append((pa.field(f"{col_id}__", pa.binary()), [col_bytes])) + + # Add Knowledge Table file data + if self.table_type == TableType.KNOWLEDGE: + self._log( + f'Importing table "{self.table_id}": Downloading Knowledge Table files.', + log_level, + ) + file_col_ids = ["File ID"] + uri_bytes = await _download_files(file_col_ids) + uris_seen = set() + for col_id in file_col_ids: + col_bytes = [] + for row in rows: + uri = row[col_id] + if uri in uris_seen: + col_bytes.append(b"") + continue + content, mime = uri_bytes.get(uri, (b"", "")) + col_bytes.append(content) + if mime: + row[f"{col_id}_"].update({"_mime_type": mime}) + uris_seen.add(uri) + if len(col_bytes) > 0: + pa_file_columns.append((pa.field(f"{col_id}__", pa.binary()), [col_bytes])) + # Create Parquet table + self._log(f'Importing table "{self.table_id}": Creating Parquet table.', log_level) + pa_table = pa.table( + { + col.column_id: self._coerce_column_to_pa_dtype( + [row[col.column_id] for row in rows], col_dtype_map[col.column_id] + ) + for col in self.column_metadata + }, + metadata=dict(gen_table_meta=self.v1_meta.model_dump_json()), + ) + # Append byte column + for pa_col in pa_file_columns: + pa_table = pa_table.append_column(*pa_col) + # Write to Parquet + self._log(f'Importing table "{self.table_id}": Writing Parquet table.', log_level) + try: + pq.write_table(pa_table, dest, compression=compression) + except (FileNotFoundError, OSError) as e: + raise ResourceNotFoundError(f'Output path "{dest}" is invalid.') from e + self._log(f'Importing table "{self.table_id}": Export completed.', log_level) + + @classmethod + async def _import_table( + cls, + *, + project_id: str, + table_type: TableType, + source: str | Path | BinaryIO, + table_id_dst: str | None, + reupload_files: bool = True, + progress_key: str = "", + verbose: bool = False, + ) -> Self: + def _measure_ram() -> str: + import psutil + + GiB = 1024**3 + mem = psutil.virtual_memory() + return f"RAM usage: {mem.used / GiB:,.2f} / {mem.total / GiB:,.2f} GiB ({mem.percent:.1f} %)" + + # Check if project exists + project = await cls._fetch_project(project_id) + organization_id = project.organization_id + + # Load Parquet file + filename = source if isinstance(source, str) else getattr(source, "name", "") + try: + pa_table: pa.Table = pq.read_table( + source, columns=None, use_threads=False, memory_map=True + ) + except FileNotFoundError as e: + raise ResourceNotFoundError(f'Parquet file "{filename}" is not found.') from e + except Exception as e: + logger.info(f'Parquet file "{filename}" contains bad data: {repr(e)}') + raise BadInputError(f'Parquet file "{filename}" contains bad data.') from e + try: + pa_meta = TableMeta.model_validate_json(pa_table.schema.metadata[b"gen_table_meta"]) + except KeyError as e: + raise BadInputError("Missing table metadata in the Parquet file.") from e + except Exception as e: + logger.warning(f"Invalid table metadata in the Parquet file: {repr(e)}") + raise BadInputError("Invalid table metadata in the Parquet file.") from e + # Check for existing table + if table_id_dst is None: + table_id_dst = pa_meta.id + if verbose: + logger.info( + f'Importing table "{table_id_dst}": Parquet data loaded successfully. {_measure_ram()}' + ) + prog = TableImportProgress(key=progress_key) + if not (await CACHE.set_progress(prog, nx=True)): + raise ResourceExistsError( + f'There is an in-progress import for table "{table_id_dst}".' + ) + prog.data["table_id_dst"] = table_id_dst + prog.load_data.progress = 100 + await CACHE.set_progress(prog) + + async with GENTABLE_ENGINE.transaction() as conn: + schema_id = f"{project_id}_{table_type}" + try: + table_metadata = await conn.fetchrow( + f'SELECT * FROM "{schema_id}"."TableMetadata" WHERE table_id = $1', + table_id_dst, + ) + except UndefinedTableError: + table_metadata = None + if table_metadata is not None: + raise ResourceExistsError(f'Table "{table_id_dst}" already exists.') + # Check for required columns + pa_meta_cols = {c.id for c in pa_meta.cols} + # Sometimes Chat Table has "user" instead of "User" + if table_type == TableType.CHAT and "user" in pa_meta_cols and "User" not in pa_meta_cols: + for col in pa_meta.cols: + if col.id == "user": + col.id = "User" + break + pa_meta_cols = {c.id for c in pa_meta.cols} + required_columns = set(cls.FIXED_COLUMN_IDS) + if len(required_columns - pa_meta_cols) > 0: + raise BadInputError( + f"Missing table columns in the Parquet file: {list(required_columns - pa_meta_cols)}." + ) + # Recreate table and column metadata + table_metadata = TableMetadata( + table_id=table_id_dst, + title=pa_meta.title, + parent_id=pa_meta.parent_id, + updated_at=pa_meta.updated_at, + ) + column_metadata = [] + for col in pa_meta.cols: + if isinstance(col.gen_config, LLMGenConfig): + # LLM columns are always string typed + col.dtype = ColumnDtype.STR + # Handle RAG params + if col.gen_config.rag_params: + params = col.gen_config.rag_params.model_dump(exclude_unset=True) + col.gen_config.rag_params.inline_citations = params.get( + "inline_citations", False + ) + column_metadata.append( + ColumnMetadata( + table_id=table_id_dst, + column_id=col.id, + dtype=col.dtype, + vlen=col.vlen, + gen_config=col.gen_config, + ) + ) + + # Create the new table + if verbose: + logger.info( + f'Importing table "{table_id_dst}": Creating Generative Table. {_measure_ram()}' + ) + prog.parse_data.progress = 50 + await CACHE.set_progress(prog) + self = await cls._create_table( + project_id=project_id, + table_type=table_type, + table_metadata=table_metadata, + column_metadata_list=column_metadata, + set_default_prompts=False, + replace_unavailable_models=True, # Old tables may have deprecated models + allow_nonexistent_refs=True, # Old tables may have non-existent columns + create_indexes=False, + ) + + # Load data + if verbose: + logger.info( + f'Importing table "{self.table_id}": Pre-processing Parquet data. {_measure_ram()}' + ) + rows: list[dict[str, Any]] = pa_table.to_pylist() + # Process state JSON + for row in rows: + for col_id in row: + if col_id.endswith("__"): + # File byte column + continue + if not col_id.endswith("_"): + # Regular column + continue + state = json_loads(row[col_id] or "{}") + # Legacy attribute + if state.pop("is_null", False): + row[col_id[:-1]] = None + row[col_id] = state + + # Upload files to S3 + if verbose: + if reupload_files: + logger.info(f'Importing table "{self.table_id}": Uploading files to S3.') + else: + logger.info(f'Importing table "{self.table_id}": Skipped S3 upload.') + prog.parse_data.progress = 100 + await CACHE.set_progress(prog) + + async def _upload( + old_uri: str, + content: bytes, + content_type: str, + filename: str, + ) -> tuple[str, str]: + async with semaphore: + new_uri = await s3_upload( + organization_id, + project_id, + content, + content_type=content_type, + filename=filename, + ) + return (old_uri, new_uri) + + uris_seen: dict[str, str] = {} # Old URI to new URI + semaphore = Semaphore(S3_MAX_CONCURRENCY) + upload_coros = [] + for row in rows: + file_byte_cols = [c for c in row.keys() if c.endswith("__")] + for col_id in file_byte_cols: + state_col_id = col_id[:-1] + uri_col_id = col_id[:-2] + uri: str = row[uri_col_id] + if uri in uris_seen: + continue + if not reupload_files: + uris_seen[uri] = uri + continue + file_bytes = row[col_id] + if len(file_bytes) == 0: + # Could be file download error or duplicate URI + continue + mime_type = row[state_col_id].pop("_mime_type", None) + # Attempt MIME type detection based on URI + if mime_type is None: + mime_type = guess_mime(uri) + # Attempt MIME type detection based on file content + if mime_type is None: + mime_type = guess_mime(file_bytes) + # Create the coroutine + upload_coros.append(_upload(uri, file_bytes, mime_type, uri.split("/")[-1])) + # Set to old URI for now + uris_seen[uri] = uri + total, completed = len(upload_coros), 0 + if verbose: + logger.info( + ( + f'Importing table "{self.table_id}": Uploading {total:,d} files ' + f"with concurrency limit of {S3_MAX_CONCURRENCY}. {_measure_ram()}" + ) + ) + for fut in asyncio.as_completed(upload_coros): + old_uri, new_uri = await fut + uris_seen[old_uri] = new_uri + completed += 1 + prog.upload_files.progress = int((completed / total) * 100) + await CACHE.set_progress(prog) + # Set new URI and remove file byte column from row + for row in rows: + file_byte_cols = [c for c in row.keys() if c.endswith("__")] + for col_id in file_byte_cols: + uri_col_id = col_id[:-2] + row[uri_col_id] = uris_seen.get(row[uri_col_id], None) + state_col_id = col_id[:-1] + row[state_col_id].pop("_mime_type", None) + row.pop(col_id, None) + prog.upload_files.progress = 100 + await CACHE.set_progress(prog) + + # Add data to table batch by batch + n = len(rows) + if verbose: + logger.info(f'Importing table "{self.table_id}": Adding {n:,d} rows. {_measure_ram()}') + for i in range(0, n, IMPORT_BATCH_SIZE): + j = min(i + IMPORT_BATCH_SIZE, n) + self = await self.add_rows( + rows[i:j], + ignore_info_columns=False, + ignore_state_columns=False, + set_updated_at=False, + ) + if verbose: + logger.info( + f'Importing table "{self.table_id}": Added {j:,d} / {n:,d} rows. {_measure_ram()}' + ) + prog.add_rows.progress = int((j / n) * 100) + await CACHE.set_progress(prog) + prog.add_rows.progress = 100 + # Perform indexing + async with GENTABLE_ENGINE.transaction() as conn: + await self._recreate_fts_index( + conn, + schema_id=self.schema_id, + table_id=self.table_id, + columns=self.text_column_names, + ) + logger.info(f'Importing table "{self.table_id}": Created FTS index.') + async with GENTABLE_ENGINE.transaction() as conn: + await self._recreate_vector_index( + conn, + schema_id=self.schema_id, + table_id=self.table_id, + columns=self.vector_column_names, + ) + logger.info(f'Importing table "{self.table_id}": Created vector index.') + prog.index.progress = 100 + prog.state = ProgressState.COMPLETED + prog.data["table_meta"] = self.v1_meta_response.model_dump(mode="json") + await CACHE.set_progress(prog) + return self + + @classmethod + async def import_table( + cls, + *, + project_id: str, + table_type: TableType, + source: str | Path | BinaryIO, + table_id_dst: TableName | None, + reupload_files: bool = True, + progress_key: str = "", + verbose: bool = False, + ) -> Self: + """ + Recreate a table (data and metadata) from a Parquet file. + + Args: + project_id (str): Project ID. + table_type (str): Table type. + input_path (str | Path): The path to the import file. + table_id_dst (TableName): Name or ID of the new table. + If None, the table ID in the Parquet metadata will be used. + reupload_files (bool, optional): If True, will reupload files to S3 with new URI. + Otherwise skip reupload and keep the original S3 paths for file columns. + Defaults to True. + progress_key (str, optional): Progress publish key. Defaults to "" (disabled). + verbose (bool, optional): If True, will produce verbose logging messages. + Defaults to False. + + Raises: + ResourceExistsError: If the table already exists. + + Returns: + self (GenerativeTableCore): The table instance. + """ + try: + self = await cls._import_table( + project_id=project_id, + table_type=table_type, + source=source, + table_id_dst=table_id_dst, + reupload_files=reupload_files, + progress_key=progress_key, + verbose=verbose, + ) + except Exception as e: + if not isinstance(e, JamaiException): + logger.exception(repr(e)) + try: + prog = await CACHE.get_progress(progress_key, TableImportProgress) + if table_id := (prog.data.get("table_id_dst", None)): + # Might need to clean up + async with GENTABLE_ENGINE.transaction() as conn: + try: + schema_id = f"{project_id}_{table_type}" + # Drop the data table + await conn.execute(f'DROP TABLE IF EXISTS "{schema_id}"."{table_id}"') + # Drop row from table metadata, this will automatically drop the associated column metadata + await conn.execute( + f'DELETE FROM "{schema_id}"."TableMetadata" WHERE table_id = $1', + table_id, + ) + except Exception as e: + logger.info( + f'Encountered error cleaning up table "{table_id}" after failed import: {repr(e)}' + ) + prog.state = ProgressState.FAILED + prog.error = repr(e) + await CACHE.set_progress(prog) + except Exception as e: + logger.error(f"Encountered error setting progress after failed import: {repr(e)}") + logger.error(repr(e)) + raise + return self + + def _filter_columns( + self, + columns: list[str] | None, + *, + exclude_state: bool, + ) -> list[str]: + data_columns = self.data_table_model.get_column_ids(exclude_state=exclude_state) + if columns: + if not exclude_state: + columns += [f"{c}_" for c in columns] + columns = [c for c in data_columns if c in columns] + if "Updated at" not in columns: + columns.insert(0, "Updated at") + if "ID" not in columns: + columns.insert(0, "ID") + else: + columns = data_columns + return columns + + async def export_data( + self, + output_path: str | Path, + *, + columns: list[str] | None = None, + where: str = "", + limit: int | None = None, + offset: int = 0, + delimiter: CSVDelimiter = CSVDelimiter.COMMA, + ) -> None: + """ + Export table data to CSV file. + + Args: + output_path (str | Path): Path to save the CSV file. + columns (list[str] | None, optional): A list of column names to include in the returned rows. + Defaults to None (return all columns). + where (str, optional): SQL WHERE clause to filter rows. Defaults to "". + limit (int | None, optional): Maximum number of rows to export. Defaults to None. + offset (int | None, optional): Offset for pagination. Defaults to None. + delimiter (str, optional): CSV delimiter, either "," or "\\t". Defaults to ",". + + Raises: + BadInputError: If the delimiter is invalid. + ResourceNotFoundError: If the table is not found. + """ + if delimiter not in CSVDelimiter: + raise BadInputError(f"Invalid delimiter: {delimiter}") + columns = self._filter_columns(columns, exclude_state=True) + # Get table data + rows = ( + await self.list_rows( + limit=limit, + offset=offset, + order_by=["ID"], + order_ascending=True, + columns=columns, + where=where, + remove_state_cols=True, + ) + ).items + try: + df = pd.DataFrame(rows, columns=columns) + # Convert special types + col_meta_map = {col.column_id: col for col in self.column_metadata} + dtype = {} + for col in columns: + if col_meta_map[col].dtype == ColumnDtype.DATE_TIME: + df[col] = df[col].apply(lambda x: x.isoformat()) + dtype[col] = pd.StringDtype() + elif col_meta_map[col].is_vector_column: + df[col] = df[col].apply(lambda x: x.tolist()) + dtype[col] = pd.StringDtype() + else: + dtype[col] = col_meta_map[col].dtype.to_pandas_type() + df = df.astype(dtype, errors="raise") + except Exception as e: + raise BadInputError( + f'Failed to export table "{self.table_id}" due to error: {e}' + ) from e + try: + df_to_csv(df=df, file_path=output_path, sep=delimiter) + except (FileNotFoundError, OSError) as e: + raise BadInputError(f'Output path "{output_path}" is not found.') from e + + async def read_csv( + self, + input_path: str | Path | BinaryIO, + *, + column_id_mapping: dict[str, str] | None = None, + delimiter: CSVDelimiter = CSVDelimiter.COMMA, + ignore_info_columns: bool = True, + ) -> Self: + col_meta_map = {col.column_id: col for col in self.column_metadata} + dtype = { + col.column_id: pd.StringDtype() if col.is_vector_column else col.dtype.to_pandas_type() + for col in self.column_metadata + } + # Read CSV file + try: + df = pd.read_csv(input_path, dtype=dtype, delimiter=delimiter, keep_default_na=True) + except FileNotFoundError as e: + raise ResourceNotFoundError(f'Input file "{input_path}" is not found.') from e + except pd.errors.EmptyDataError as e: + raise BadInputError(f'Input file "{input_path}" is empty.') from e + if len(df) == 0: + raise BadInputError(f'Input file "{input_path}" has no rows.') + try: + # Apply column mapping if provided + if column_id_mapping: + df = df.rename(columns=column_id_mapping) + # Remove "ID" and "Updated at" columns if needed + if ignore_info_columns: + df = df[[col for col in df.columns if col.lower() not in self.INFO_COLUMNS]] + + # Create a mapping of column names to their metadata for faster lookup + col_meta_map = {col.column_id: col for col in self.column_metadata} + # Keep only valid columns + df = df[[col for col in df.columns if col in col_meta_map]] + # Convert special types + for col in df.columns: + if col_meta_map[col].dtype == ColumnDtype.DATE_TIME: + df[col] = df[col].apply(lambda x: utc_datetime_from_iso(x)) + elif col_meta_map[col].is_vector_column: + df[col] = df[col].apply(json_loads) + # Check vector length + array_lengths = df[col].apply(len) + if array_lengths.nunique() != 1: + raise BadInputError("All vectors must have the same length.") + array_length = int(array_lengths[0]) + if array_length != col_meta_map[col].vlen: + raise BadInputError( + ( + f'Vector column "{col}" expects vectors of length {col_meta_map[col].vlen:,d} ' + f"but got vectors of length {array_length:,d}." + ) + ) + # Convert to list of dicts + rows = df.to_dict(orient="records") + except Exception as e: + raise BadInputError( + f'Failed to import data into table "{self.table_id}" due to error: {e}' + ) from e + return rows + + async def import_data( + self, + input_path: str | Path, + *, + column_id_mapping: dict[str, str] | None = None, + delimiter: CSVDelimiter = CSVDelimiter.COMMA, + ignore_info_columns: bool = True, + verbose: bool = False, + ) -> Self: + """ + Import data into the Generative Table from a CSV file. + + Args: + input_path (str | Path): Path to the CSV file. + column_id_mapping (dict[str, str] | None, optional): Mapping of CSV column ID to table column ID. + Defaults to None. + delimiter (str, optional): CSV delimiter, either "," or "\\t". Defaults to ",". + ignore_info_columns (bool, optional): Whether to ignore "ID" and "Updated at" columns. + Defaults to True. + verbose (bool, optional): If True, will produce verbose logging messages. + Defaults to False. + + Raises: + ResourceNotFoundError: If the file or table is not found. + + Returns: + self (GenerativeTableCore): The table instance. + """ + rows = await self.read_csv( + input_path=input_path, + column_id_mapping=column_id_mapping, + delimiter=delimiter, + ignore_info_columns=ignore_info_columns, + ) + if verbose: + self._log(f'Importing table "{self.table_id}": Import data loaded successfully.') + # Insert rows + n = len(rows) + if verbose: + self._log(f'Importing table "{self.table_id}": Adding {n:,d} rows.') + for i in range(0, n, IMPORT_BATCH_SIZE): + j = min(i + IMPORT_BATCH_SIZE, n) + self = await self.add_rows(rows[i:j]) + if verbose: + self._log(f'Importing table "{self.table_id}": Added {j:,d} / {n:,d} rows.') + return self + + ### --- Column CRUD --- ### + + # Column Create Ops + async def add_column( + self, + metadata: ColumnMetadata, + request_id: str = "", + ) -> Self: + """ + Add a new column to the table. + + Args: + metadata (ColumnMetadata): Metadata for the new column. + request_id (str, optional): Request ID for logging. Defaults to "". + + Raises: + BadInputError: If the column is a state column. + ResourceNotFoundError: If table cannot be found. + ResourceExistsError: If the column already exists in the table. + + Returns: + self (GenerativeTableCore): The table instance. + """ + if self.table_metadata.parent_id is not None: + # TODO: Test this + raise BadInputError(f'Table "{self.table_id}": Cannot add column to a child table.') + if metadata.is_state_column: + # TODO: Test this + raise BadInputError(f'Table "{self.table_id}": Cannot add state column.') + async with GENTABLE_ENGINE.transaction() as conn: + column_metadata_list = await self._check_columns( + conn=conn, + project_id=self.project_id, + table_type=self.table_type, + table_metadata=self.table_metadata, + column_metadata_list=self.column_metadata + [metadata], + set_default_prompts=True, + replace_unavailable_models=False, + ) + metadata = column_metadata_list[-1] + # Define column definition + if metadata.is_vector_column: + column_def = f'"{metadata.short_id}" VECTOR({metadata.vlen})' + else: + column_def = f'"{metadata.short_id}" {metadata.dtype.to_postgres_type()}' + # Add new and state column to the data table + try: + await conn.execute( + f""" + ALTER TABLE "{self.schema_id}"."{self.short_table_id}" + ADD COLUMN {column_def}, + ADD COLUMN {self._state_column_sql(metadata.short_id)}; + """ + ) + except UndefinedTableError as e: + raise ResourceNotFoundError(f'Table "{self.table_id}" is not found.') from e + except DuplicateColumnError as e: + raise ResourceExistsError( + f"Column {metadata.column_id} already exists in table {self.table_id}" + ) from e + # Add column metadata + metadata.column_order = len(self.column_metadata) + await self._upsert_column_metadata(conn, self.schema_id, metadata) + state_meta = ColumnMetadata( + table_id=self.table_id, + column_id=f"{metadata.column_id}_", + dtype=ColumnDtype.JSON, + column_order=len(self.column_metadata) + 1, + ) + await self._upsert_column_metadata(conn, self.schema_id, state_meta) + # Set updated at time + await self._set_updated_at(conn) + # Reload table + self = await self._open_table( + conn=conn, + project_id=self.project_id, + table_type=self.table_type, + table_id=self.table_id, + request_id=request_id, + ) + if metadata.is_text_column: + await self._recreate_fts_index( + conn, + schema_id=self.schema_id, + table_id=self.table_id, + columns=self.text_column_names, + ) + elif metadata.is_vector_column: + await self._recreate_vector_index( + conn, + schema_id=self.schema_id, + table_id=self.table_id, + columns=self.vector_column_names, + ) + return self + + # Column Read ops are implemented as table ops + # Column Update Ops + async def rename_columns( + self, + column_map: dict[str, ColName], + ) -> Self: + """ + Rename columns of the Generative Table. + + Args: + column_map (dict[str, str]): Mapping of old column names to new column names. + + Raises: + ResourceNotFoundError: If the table or any of the columns cannot be found. + ResourceExistsError: If any of the new column names already exists in the table. + + Returns: + self (GenerativeTableCore): The table instance. + """ + if self.table_metadata.parent_id is not None: + # TODO: Test this + raise BadInputError( + f'Table "{self.table_id}": Cannot rename columns of a child table.' + ) + fixed_cols = {c.lower() for c in self.FIXED_COLUMN_IDS} + if invalid_cols := {c.lower() for c in column_map}.intersection(fixed_cols): + # TODO: Test this especially for Knowledge Table + raise BadInputError( + f'Table "{self.table_id}": Cannot rename fixed columns: {list(invalid_cols)}' + ) + if invalid_cols := [c for c in column_map if c.endswith("_")]: + # TODO: Test this + raise BadInputError( + f'Table "{self.table_id}": Cannot rename state columns: {invalid_cols}' + ) + async with GENTABLE_ENGINE.transaction() as conn: + for col_id_src, col_id_dst in column_map.items(): + col_meta = next( + (col for col in self.column_metadata if col.column_id == col_id_src), None + ) + if col_meta is None: + continue + # Rename data and state columns + short_table_id = self.short_table_id + short_id_src = get_internal_id(col_id_src) + short_id_dst = get_internal_id(col_id_dst) + try: + await conn.execute( + f""" + ALTER TABLE "{self.schema_id}"."{short_table_id}" + RENAME COLUMN "{short_id_src}" TO "{short_id_dst}" + """ + ) + await conn.execute( + f""" + ALTER TABLE "{self.schema_id}"."{short_table_id}" + RENAME COLUMN "{short_id_src}_" TO "{short_id_dst}_" + """ + ) + # Rename vector index + if col_meta.is_vector_column: + await conn.execute( + ( + f'ALTER INDEX "{self.schema_id}"."{vector_index_id(self.table_id, col_id_src)}" ' + f'RENAME TO "{vector_index_id(self.table_id, col_id_dst)}"' + ) + ) + except UndefinedTableError as e: + # Index or table not found + raise ResourceNotFoundError(f'Table "{self.table_id}" is not found.') from e + except (UndefinedColumnError, IndexError) as e: + raise ResourceNotFoundError( + f'Column "{col_id_src}" is not found in table "{self.table_id}".' + ) from e + except DuplicateColumnError as e: + raise ResourceExistsError( + f'Column "{col_id_dst}" already exists in table "{self.table_id}".' + ) from e + # Update column metadata entries + await conn.execute( + f""" + UPDATE "{self.schema_id}"."ColumnMetadata" + SET column_id = $1, short_id = $2 + WHERE table_id = $3 AND column_id = $4 + """, + col_id_dst, + short_id_dst, + self.table_id, + col_id_src, + ) + await conn.execute( + f""" + UPDATE "{self.schema_id}"."ColumnMetadata" + SET column_id = $1, short_id = $2 + WHERE table_id = $3 AND column_id = $4 + """, + f"{col_id_dst}_", + f"{short_id_dst}_", + self.table_id, + f"{col_id_src}_", + ) + # Update gen config references + for col in self.column_metadata: + if col.column_id == col_id_dst or col.column_id == col_id_src: + continue + if not isinstance(col.gen_config, LLMGenConfig): + continue + for k in ("system_prompt", "prompt"): + setattr( + col.gen_config, + k, + re.sub( + GEN_CONFIG_VAR_PATTERN, + lambda m: f"${{{column_map.get(m.group(1), m.group(1))}}}", + getattr(col.gen_config, k), + ), + ) + await conn.execute( + f""" + UPDATE "{self.schema_id}"."ColumnMetadata" SET gen_config = $1 + WHERE table_id = $2 AND column_id = $3 + """, + col.gen_config.model_dump(), + self.table_id, + col.column_id, + ) + # Set updated at time + await self._set_updated_at(conn) + return await self._reload_table(conn) + + async def update_gen_config( + self, + update_mapping: dict[str, DiscriminatedGenConfig | None], + *, + allow_nonexistent_refs: bool = False, + request_id: str = "", + ) -> Self: + """ + Update the generation configuration for a column. + + Args: + update_mapping (dict[str, DiscriminatedGenConfig]): Mapping of column IDs to new generation configurations. + allow_nonexistent_refs (bool, optional): Ignore non-existent column and Knowledge Table references. + Otherwise will raise an error. Useful when importing old tables and performing maintenance. + Defaults to False. + request_id (str, optional): Request ID for logging. Defaults to "". + + Raises: + ResourceNotFoundError: If the column is not found. + + Returns: + self (GenerativeTableCore): The table instance. + """ + # Verify column exists + columns_to_update = [] + async with GENTABLE_ENGINE.transaction() as conn: + for column_id, config in update_mapping.items(): + column = next( + (col for col in self.column_metadata if col.column_id == column_id), None + ) + if not column: + # TODO: Test this + raise ResourceNotFoundError( + f'Column "{column_id}" is not found in table "{self.table_id}".' + ) + if column.is_state_column: + # TODO: Test this + raise BadInputError( + f'Column "{column_id}" is a state column and cannot be updated.' + ) + # Disallow update of vector column if the table has data + has_data: bool = await conn.fetchval( + f'SELECT EXISTS (SELECT 1 FROM "{self.schema_id}"."{self.short_table_id}" LIMIT 1)' + ) + if column.is_vector_column and has_data: + # TODO: Test this + raise BadInputError( + f'Column "{column_id}" contains data thus its Embedding config cannot be updated.' + ) + # Update column metadata in-place + if config is None or column.gen_config is None: + column.gen_config = config + else: + column.gen_config = type(column.gen_config).model_validate( + merge_dict( + column.gen_config.model_dump(), + config.model_dump(exclude_unset=True), + ) + ) + columns_to_update.append(column) + # Validate + await self._check_columns( + conn=conn, + project_id=self.project_id, + table_type=self.table_type, + table_metadata=self.table_metadata, + column_metadata_list=self.column_metadata, + set_default_prompts=False, + replace_unavailable_models=False, + allow_nonexistent_refs=allow_nonexistent_refs, + ) + for column in columns_to_update: + await self._upsert_column_metadata(conn, self.schema_id, column) + # Set updated at time + await self._set_updated_at(conn) + self = await self._open_table( + conn=conn, + project_id=self.project_id, + table_type=self.table_type, + table_id=self.table_id, + request_id=request_id, + ) + return self + + async def reorder_columns( + self, + column_names: list[str], + ) -> Self: + """ + Reorder columns in the table. + + Args: + column_names (list[str]): List of column name in the desired order. + + Raises: + BadInputError: If the list of columns to reorder does not match the table columns. + + Returns: + self (GenerativeTableCore): The table instance. + """ + if column_names[0].lower() != "id": + raise BadInputError('First column must be "ID".') + if column_names[1].lower() != "updated at": + raise BadInputError('Second column must be "Updated at".') + if len(set(n.lower() for n in column_names)) != len(column_names): + raise BadInputError("Column names must be unique (case-insensitive).") + columns = self.data_table_model.get_column_ids(exclude_state=True) + if set(column_names) != set(columns): + raise BadInputError("The list of columns to reorder does not match the table columns.") + state_columns = [f"{col}_" for col in column_names if col.lower() not in self.INFO_COLUMNS] + async with GENTABLE_ENGINE.transaction() as conn: + # Update column order + for idx, column_id in enumerate(column_names + state_columns): + await conn.execute( + f""" + UPDATE "{self.schema_id}"."ColumnMetadata" + SET column_order = $1 + WHERE table_id = $2 AND column_id = $3 + """, + idx, + self.table_id, + column_id, + ) + # Set updated at time + await self._set_updated_at(conn) + return await self._reload_table(conn) + + # Column Delete Ops + async def drop_columns( + self, + column_ids: list[str], + ) -> Self: + """ + Drop columns from the Generative Table. + + Args: + column_ids (list[str]): List of column IDs to drop. + + Raises: + ResourceNotFoundError: If any of the columns is not found. + """ + if self.table_metadata.parent_id is not None: + # TODO: Test this + raise BadInputError(f'Table "{self.table_id}": Cannot drop column from a child table.') + fixed_cols = {c.lower() for c in self.FIXED_COLUMN_IDS} + if invalid_cols := {c.lower() for c in column_ids}.intersection(fixed_cols): + # TODO: Test this especially for Knowledge Table + raise BadInputError( + f'Table "{self.table_id}": Cannot drop fixed columns: {list(invalid_cols)}' + ) + if len(invalid_cols := [c for c in column_ids if c.endswith("_")]) > 0: + # TODO: Test this + raise BadInputError( + f'Table "{self.table_id}": Cannot drop state columns: {invalid_cols}' + ) + async with GENTABLE_ENGINE.transaction() as conn: + short_table_id = self.short_table_id + for column_id in column_ids: + # Drop column and state column + short_id = get_internal_id(column_id) + try: + await conn.execute( + f'ALTER TABLE "{self.schema_id}"."{short_table_id}" DROP COLUMN "{short_id}"' + ) + await conn.execute( + f'ALTER TABLE "{self.schema_id}"."{short_table_id}" DROP COLUMN "{short_id}_"' + ) + except UndefinedColumnError as e: + raise ResourceNotFoundError( + f'Column "{column_id}" is not found in table "{self.table_id}".' + ) from e + except Exception as e: + raise ResourceNotFoundError( + f'Column "{column_id}" is not found in table "{self.table_id}".' + ) from e + # Remove column metadata and the associated state column + await conn.execute( + f'DELETE FROM "{self.schema_id}"."ColumnMetadata" WHERE table_id = $1 AND column_id = $2', + self.table_id, + column_id, + ) + await conn.execute( + f'DELETE FROM "{self.schema_id}"."ColumnMetadata" WHERE table_id = $1 AND column_id = $2', + self.table_id, + f"{column_id}_", + ) + # Update column order + columns = self.data_table_model.get_column_ids(exclude_state=False) + columns = [col for col in columns if col not in column_ids] + for idx, column_id in enumerate(columns): + await conn.execute( + f""" + UPDATE "{self.schema_id}"."ColumnMetadata" + SET column_order = $1 + WHERE table_id = $2 AND column_id = $3 + """, + idx, + self.table_id, + column_id, + ) + # Set updated at time + await self._set_updated_at(conn) + # Rebuild indexes if needed + if any(c.is_text_column for c in self.column_metadata if c.column_id in column_ids): + await self._recreate_fts_index( + conn, + schema_id=self.schema_id, + table_id=self.table_id, + columns=[ + c.column_id + for c in self.column_metadata + if c.column_id not in column_ids and c.is_text_column + ], + ) + if any(c.is_vector_column for c in self.column_metadata if c.column_id in column_ids): + await self._recreate_vector_index( + conn, + schema_id=self.schema_id, + table_id=self.table_id, + columns=[ + c.column_id + for c in self.column_metadata + if c.column_id not in column_ids and c.is_vector_column + ], + ) + return await self._reload_table(conn) + + ### --- Row CRUD --- ### + @staticmethod + def _jsonify(x: Any) -> Any: + return x.tolist() if isinstance(x, np.ndarray) else x + + def _validate_row_data(self, data: dict[str, Any]) -> DataTableRow: + try: + row = self.data_table_model.model_validate(data, strict=False) + except ValidationError as e: + # Set invalid value to None, and save original value to state + for error in e.errors(): + if len(error["loc"]) > 1: + raise BadInputError(f"Input data contains errors: {e}") from e + col = error["loc"][0] + state = data.get(f"{col}_", {}) + data[col], data[f"{col}_"] = ( + None, + {"original": self._jsonify(data[col]), "error": error.get("msg", ""), **state}, + ) + # Try validating again try: - table.drop_columns(col_names) - except ValueError as e: - raise ResourceNotFoundError(e) from e - meta.cols = [c.model_dump() for c in meta.cols_schema if c.id not in col_names] - meta.updated_at = datetime_now_iso() - session.add(meta) - session.commit() - session.refresh(meta) - return table, meta - - # Look at this instead !! - def drop_columns( + row = self.data_table_model.model_validate(data, strict=False) + except ValidationError as e: + raise BadInputError(f"Input data contains errors: {e}") from e + return row + + # Row Create Ops + async def add_rows( self, - session: Session, - table_id: TableName, - column_names: list[ColName], - ) -> tuple[LanceTable, TableMeta]: + data_list: list[dict[str, Any]], + *, + ignore_info_columns: bool = True, + ignore_state_columns: bool = True, + set_updated_at: bool = True, + ) -> Self: """ - Drops one or more input or output column. + Add multiple rows to the Generative Table. Args: - session (Session): SQLAlchemy session. - table_id (str): Table ID. - column_names (list[str]): List of column ID to drop. + data_list (list[dict[str, Any]]): List of row data dictionaries. + ignore_info_columns (bool, optional): Whether to ignore "ID" and "Updated at" columns. + Defaults to True. + ignore_state_columns (bool, optional): Whether to ignore state columns. + Defaults to True. + set_updated_at (bool, optional): Whether to set the "Updated at" time to now. + Defaults to True. Raises: - TypeError: If `column_names` is not a list. + TypeError: If the data is not a list of dictionaries. ResourceNotFoundError: If the table is not found. - ResourceNotFoundError: If any of the columns is not found. Returns: - table (LanceTable): Lance table. - meta (TableMeta): Table metadata. - """ - if not isinstance(column_names, list): - raise TypeError("`column_names` must be a list.") - if self.has_state_col_names(column_names): - raise BadInputError("Cannot drop state columns.") - if self.has_info_col_names(column_names): - raise BadInputError('Cannot drop "ID" or "Updated at".') - fixed_cols = set(c.lower() for c in self.FIXED_COLUMN_IDS) - if len(fixed_cols.intersection(set(c.lower() for c in column_names))) > 0: - raise BadInputError(f"Cannot drop fixed columns: {self.FIXED_COLUMN_IDS}") - - with self.lock(table_id): - # Get table metadata - meta = self.open_meta(session, table_id) - # Create new table with dropped columns - new_table_id = f"{table_id}_dropped_{uuid7_draft2_str()}" - column_names += [f"{col_name}_" for col_name in column_names] - new_schema = TableSchema( - id=new_table_id, - cols=[c for c in meta.cols_schema if c.id not in column_names], - ) - new_table, new_meta = self._create_table( - session, new_schema, add_info_state_cols=False - ) - - # Copy data from old table to new table - old_table = self.open_table(table_id) - if old_table.count_rows() > 0: - data = old_table._dataset.to_table( - columns=[c.id for c in new_schema.cols] - ).to_pylist() - new_table.add(data) - - # Delete old table and rename - self.delete_table(session, table_id) - new_meta = self.rename_table(session, new_table_id, table_id) - new_table = self.open_table(table_id) - return new_table, new_meta - - def rename_columns( - self, - session: Session, - table_id: TableName, - column_map: dict[ColName, ColName], - ) -> TableMeta: - new_col_names = set(column_map.values()) - if self.has_state_col_names(column_map.keys()): - raise BadInputError("Cannot rename state columns.") - if self.has_info_col_names(column_map.keys()): - raise BadInputError('Cannot rename "ID" or "Updated at".') - fixed_cols = set(c.lower() for c in self.FIXED_COLUMN_IDS) - if len(fixed_cols.intersection(set(c.lower() for c in column_map))) > 0: - raise BadInputError(f"Cannot rename fixed columns: {self.FIXED_COLUMN_IDS}") - if len(new_col_names) != len(column_map): - raise BadInputError("`column_map` contains repeated new column names.") - if not all(re.match(COL_NAME_PATTERN, v) for v in column_map.values()): - raise BadInputError("`column_map` contains invalid new column names.") - meta = self.open_meta(session, table_id) - col_names = set(c.id for c in meta.cols_schema) - overlap_col_names = col_names.intersection(new_col_names) - if len(overlap_col_names) > 0: - raise BadInputError( - ( - "`column_map` contains new column names that " - f"overlap with existing column names: {overlap_col_names}" - ) + self (GenerativeTableCore): The table instance. + """ + if not (isinstance(data_list, list) and all(isinstance(row, dict) for row in data_list)): + # We raise TypeError here since this is a programming error + raise TypeError("`data_list` must be a list of dicts.") + # Filter out non-existent fields + columns = set( + self.data_table_model.get_column_ids( + exclude_info=ignore_info_columns, + exclude_state=ignore_state_columns, ) - not_found = set(column_map.keys()) - col_names - if len(not_found) > 0: - raise ResourceNotFoundError(f"Some columns are not found: {list(not_found)}.") - # Add state columns - for k in list(column_map.keys()): - column_map[f"{k}_"] = f"{column_map[k]}_" - # Modify metadata - cols = [] - for col in meta.cols: - col = deepcopy(col) - _id = col["id"] - col["id"] = column_map.get(_id, _id) - if ( - col["gen_config"] is not None - and col["gen_config"].get("object", "") == "gen_config.llm" - ): - for k in ("system_prompt", "prompt"): - col["gen_config"][k] = re.sub( - GEN_CONFIG_VAR_PATTERN, - lambda m: f"${{{column_map.get(m.group(1), m.group(1))}}}", - col["gen_config"][k], - ) - cols.append(col) - with self.lock(table_id): - meta.cols = cols - meta.updated_at = datetime_now_iso() - session.add(meta) - session.commit() - session.refresh(meta) - # Modify LanceTable - alterations = [{"path": k, "name": v} for k, v in column_map.items()] - table = self.open_table(table_id) - table.alter_columns(*alterations) - return meta - - def reorder_columns( - self, - session: Session, - table_id: TableName, - column_names: list[ColName], - ) -> TableMeta: - column_names_low = [n.lower() for n in column_names] - if len(set(column_names_low)) != len(column_names): - raise BadInputError("Column names must be unique (case-insensitive).") - if self.has_state_col_names(column_names): - raise BadInputError("Cannot reorder state columns.") - if self.has_info_col_names(column_names) and column_names_low[:2] != ["id", "updated at"]: - raise BadInputError('Cannot reorder "ID" or "Updated at".') - order = ["ID", "Updated at"] - for c in column_names: - order += [c, f"{c}_"] - meta = self.open_meta(session, table_id) - try: - meta.cols = [ - c.model_dump() for c in sorted(meta.cols_schema, key=lambda x: order.index(x.id)) - ] - except ValueError as e: - raise ResourceNotFoundError(e) from e - meta.updated_at = datetime_now_iso() - # Validate changes - TableSchema.model_validate(meta.model_dump()) - session.add(meta) - session.commit() - session.refresh(meta) - return meta - - async def add_rows( - self, - session: Session, - table_id: TableName, - data: list[dict[ColName, Any]], - errors: list[list[str]] | None = None, - ) -> Self: - if not isinstance(data, list): - raise TypeError("`data` must be a list.") - with self.lock(table_id): - with await lancedb.connect_async( - uri=self.vector_db_url, - read_consistency_interval=self.read_consistency_interval, - ) as db: + ) + data_list = [{k: v for k, v in row.items() if k in columns} for row in data_list] + data_list = [row for row in data_list if len(row) > 0] + if len(data_list) == 0: + return self + rows = [self._validate_row_data(data) for data in data_list] + # Build SQL statement + all_columns = self.data_table_model.get_column_ids() + _sql_cols = [f'"{self.map_to_short_col_id[c]}"' for c in all_columns] + stmt = ( + f'INSERT INTO "{self.schema_id}"."{self.short_table_id}" ({", ".join(_sql_cols)}) ' + f"VALUES ({', '.join(f'${i + 1}' for i in range(len(all_columns)))})" + ) + values = [[getattr(row, c) for c in all_columns] for row in rows] + async with GENTABLE_ENGINE.transaction() as conn: + # Insert rows with retries + for _ in range(3): try: - with await db.open_table(table_id) as table: - meta = self.open_meta(session, table_id) - # Validate data and generate ID & timestamp under write lock - data = RowAddData(table_meta=meta, data=data, errors=errors).set_id().data - # Add to Lance Table - await table.add(data) - # Update metadata - meta.updated_at = datetime_now_iso() - session.add(meta) - session.commit() - except FileNotFoundError as e: - raise ResourceNotFoundError(f'Table "{table_id}" is not found.') from e + # Use executemany for batch operations + await conn.executemany(stmt, values) + break + except UndefinedTableError as e: + raise ResourceNotFoundError(f'Table "{self.table_id}" is not found.') from e + except DataError as e: + self._log( + f"Failed to insert {len(rows):,d} rows due to: {repr(e)}.\nSQL:\n{stmt}\nValues:\n{values}", + "WARNING", + ) + if isinstance(e, InvalidParameterValueError) and "pgroonga" in str(e): + pass + else: + raise BadInputError(f"Bad input: {e}") from e + # Set updated at time + if set_updated_at: + await self._set_updated_at(conn) return self - def update_rows( + # Row Read Ops + async def list_rows( self, - session: Session, - table_id: TableName, *, - where: str | None, - values: dict[str, Any], - errors: list[str] | None = None, - ) -> Self: - with self.lock(table_id): - table = self.open_table(table_id) - meta = self.open_meta(session, table_id) - # Validate data and generate ID & timestamp under write lock - values = RowUpdateData( - table_meta=meta, - data=[values], - errors=None if errors is None else [errors], - ) - values = values.sql_escape().data[0] - # TODO: Vector column update seems to be broken - values = {k: v for k, v in values.items() if not isinstance(v, np.ndarray)} - table.update(where=where, values=values) - # Update metadata - meta.updated_at = datetime_now_iso() - session.add(meta) - session.commit() - return self - - @staticmethod - def _filter_col( - col_id: str, + limit: int | None = None, + offset: int = 0, + order_by: list[str] | None = None, + order_ascending: bool = True, columns: list[str] | None = None, + where: str = "", + search_query: str = "", + search_columns: list[str] | None = None, remove_state_cols: bool = False, - ) -> bool: - if remove_state_cols and col_id.endswith("_"): - return False - # Hybrid search distance and match scores - if col_id.startswith("_"): - return False - if columns is not None: - columns = {"id", "updated at"} | {c.lower() for c in columns} - return col_id.lower() in columns - return True - - @staticmethod - def _process_cell( - row: dict[str, Any], - col_id: str, - convert_null: bool, - include_original: bool, - float_decimals: int, - vec_decimals: int, - ): - state_id = f"{col_id}_" - data = row[col_id] - if state_id not in row: - # Some columns like "ID", "Updated at" do not have state cols - return data - # Process precision - if float_decimals > 0 and isinstance(data, float): - data = round(data, float_decimals) - elif vec_decimals > 0 and isinstance(data, list): - data = np.asarray(data).round(vec_decimals).tolist() - state = row[state_id] - if state == "" or state is None: - data = None if convert_null else data - return {"value": data} if include_original else data - state = json_loads(state) - data = None if convert_null and state["is_null"] else data - if include_original: - ret = {"value": data} - if "original" in state: - ret["original"] = state["original"] - # if "error" in state: - # ret["error"] = state["error"] - return ret - else: - return data + ) -> Page[dict[str, Any]]: + """ + List rows with filtering and sorting. - @staticmethod - def _post_process_rows( - rows: list[dict[str, Any]], - *, - columns: list[str] | None = None, - convert_null: bool = True, - remove_state_cols: bool = False, - json_safe: bool = False, - include_original: bool = False, - float_decimals: int = 0, - vec_decimals: int = 0, - ): - if json_safe: - rows = [ - {k: v.isoformat() if isinstance(v, datetime) else v for k, v in row.items()} - for row in rows - ] - rows = [ - { - k: GenerativeTable._process_cell( - row, - k, - convert_null=convert_null, - include_original=include_original, - float_decimals=float_decimals, - vec_decimals=vec_decimals, - ) - for k in row - if not (vec_decimals < 0 and isinstance(row[k], list)) - } - for row in rows - ] - rows = [ - { - k: v - for k, v in row.items() - if GenerativeTable._filter_col( - k, columns=columns, remove_state_cols=remove_state_cols - ) - } - for row in rows - ] - return rows + Args: + limit (int | None, optional): Maximum number of rows to return. Defaults to None. + offset (int, optional): Offset for pagination. Defaults to 0. + order_by (list[str] | None, optional): Order the rows by these columns. Defaults to None (order by row ID). + order_ascending (bool, optional): Order the rows in ascending order. Defaults to True. + columns (list[str] | None, optional): A list of column names to include in the returned rows. + Defaults to None (return all columns). + where (str, optional): SQL where clause. Defaults to "" (no filter). + It will be combined other filters using `AND`. + search_query (str, optional): A string to search for within row data. + The string is interpreted as both POSIX regular expression and literal string. + Defaults to "". + search_columns (list[str] | None, optional): A list of column names to search for search_query. + Defaults to None (search all columns). + remove_state_cols (bool, optional): If True, remove state columns. Defaults to False. - @staticmethod - def _post_process_rows_df( - df: pd.DataFrame, - *, - columns: list[str] | None = None, - convert_null: bool = True, - remove_state_cols: bool = False, - json_safe: bool = False, - include_original: bool = False, - float_decimals: int = 0, - vec_decimals: int = 0, - ): - dt_columns = set(df.select_dtypes(include="datetimetz").columns.to_list()) - float_columns = set(df.select_dtypes(include="float").columns.to_list()) + Raises: + ResourceNotFoundError: If the table or column(s) is not found. - def _process_row(row: pd.Series): - for col_id in row.index.to_list(): - state_id = f"{col_id}_" - try: - data = row[col_id] - except KeyError: - # The column is dropped - continue - if json_safe and col_id in dt_columns: - row[col_id] = data.isoformat() - if state_id not in row: - # Some columns like "ID", "Updated at" do not have state cols - # State cols also do not have their state cols + Returns: + rows (Page[dict[str, Any]]): A page of row data dictionaries. + """ + columns = self._filter_columns(columns, exclude_state=remove_state_cols) + # Build SQL query + params = [] + query = f""" + SELECT {",".join([f'"{self.map_to_short_col_id[c]}"' for c in columns])} + FROM "{self.schema_id}"."{self.short_table_id}" + """ + total = f'SELECT COUNT("ID") FROM "{self.schema_id}"."{self.short_table_id}"' + filters = [] + where = where.strip() + if where: + try: + where = f"({validate_where_expr(where, id_map=self.map_to_short_col_id)})" + except Exception as e: + raise BadInputError(str(e)) from e + filters.append(where) + if search_query: + _cols = search_columns or [ + col.column_id + for col in self.column_metadata + if not ( + col.is_info_column + or col.is_file_column + or col.is_vector_column + or col.is_state_column + ) + ] + search_filters = [] + for c in _cols: + c = self.map_to_short_col_id.get(c, None) + if c is None: continue - state = row[state_id] - # Process precision - if isinstance(data, np.ndarray): - if vec_decimals < 0: - row.drop([col_id, state_id], inplace=True) + # Literal (escaped) search + params.append(re.escape(search_query)) + literal_expr = f'("{c}"::text ~* ${len(params)})' + # Regex search + params.append(search_query) + regex_expr = f'("{c}"::text ~* ${len(params)})' + search_filters.append(f"({literal_expr} OR {regex_expr})") + filters.append(f"({' OR '.join(search_filters)})") + if filters: + query += f" WHERE {' AND '.join(filters)}" + total += f" WHERE {' AND '.join(filters)}" + async with GENTABLE_ENGINE.transaction() as conn: + # Row count + try: + total = await conn.fetchval(total, *params) + except UndefinedColumnError as e: + raise ResourceNotFoundError( + f'One or more columns is not found in table "{self.table_id}".' + ) from e + except UndefinedTableError as e: + raise ResourceNotFoundError(f'Table "{self.table_id}" is not found.') from e + except (PostgresSyntaxError, UndefinedFunctionError) as e: + raise BadInputError(f"Bad SQL statement: `{query}`") from e + # Sorting + order_direction = "ASC" if order_ascending else "DESC" + order_clauses = [] + if order_by: + for c in order_by: + cs = self.map_to_short_col_id.get(c, None) + if cs is None: continue - elif vec_decimals == 0: - if json_safe: - data = data.tolist() - elif vec_decimals > 0: - if json_safe: - data = [round(d, vec_decimals) for d in data.tolist()] - else: - data = data.round(vec_decimals) - elif float_decimals > 0 and col_id in float_columns: - row[col_id] = round(data, float_decimals) - # Convert null - if state == "" or state is None: - data = None if convert_null else data - row[col_id] = {"value": data} if include_original else data - continue - state = json_loads(state) - data = None if convert_null and state["is_null"] else data - if include_original: - ret = {"value": data} - if "original" in state: - ret["original"] = state["original"] - # if "error" in state: - # ret["error"] = state["error"] - row[col_id] = ret - else: - row[col_id] = data - return row - - df = df.apply(_process_row, axis=1) - # Remove hybrid search distance and match score columns - keep_cols = [c for c in df.columns.to_list() if not c.startswith("_")] - # Remove state columns - if remove_state_cols: - keep_cols = [c for c in keep_cols if not c.endswith("_")] - # Column selection - if columns is not None: - columns = {"id", "updated at"} | {c.lower() for c in columns} - keep_cols = [c for c in keep_cols if c.lower() in columns] - df = df[keep_cols] - return df - - def get_row( + if c in self.text_column_names: + order_clauses.append(f'LOWER("{cs}") {order_direction}') + else: + order_clauses.append(f'"{cs}" {order_direction}') + order_clauses.append(f'"ID" {order_direction}') + query += " ORDER BY " + ", ".join(order_clauses) + # Pagination + if limit: + params.append(limit) + query += f" LIMIT ${len(params)}" + if offset: + params.append(offset) + query += f" OFFSET ${len(params)}" + # Execute query + try: + rows = await conn.fetch(query, *params) + except UndefinedColumnError as e: + raise ResourceNotFoundError( + f'One or more columns is not found in table "{self.table_id}".' + ) from e + except UndefinedTableError as e: + raise ResourceNotFoundError(f'Table "{self.table_id}" is not found.') from e + except (PostgresSyntaxError, UndefinedFunctionError) as e: + raise BadInputError(f"Bad SQL statement: `{query}`") from e + # Map short column IDs back to long column IDs + rows = [{self.map_to_long_col_id[k]: v for k, v in dict(row).items()} for row in rows] + return Page[dict[str, Any]]( + items=rows, + offset=offset, + limit=total if limit is None else limit, + total=total, + ) + + async def get_row( self, - table_id: TableName, row_id: str, *, columns: list[str] | None = None, - convert_null: bool = True, remove_state_cols: bool = False, - json_safe: bool = False, - include_original: bool = False, - float_decimals: int = 0, - vec_decimals: int = 0, ) -> dict[str, Any]: - table = self.open_table(table_id) - rows = table.search().where(where=f"`ID` = '{row_id}'", prefilter=True).to_list() - if len(rows) == 0: - raise ResourceNotFoundError(f'Row "{row_id}" is not found.') - elif len(rows) > 1: - logger.warning(f"More than one row in table {table_id} with ID {row_id}") - rows = self._post_process_rows( - rows, - columns=columns, - convert_null=convert_null, - remove_state_cols=remove_state_cols, - json_safe=json_safe, - include_original=include_original, - float_decimals=float_decimals, - vec_decimals=vec_decimals, - ) - return rows[0] + """ + Get a single row by its row ID. - @staticmethod - def _count_rows_query(table_name: str) -> str: - return f"SELECT COUNT(*) FROM '{table_name}'" + Args: + row_id (str): ID of the row to be retrieved. + columns (list[str] | None, optional): A list of column names to include in the returned rows. + Defaults to None (return all columns). + remove_state_cols (bool, optional): If True, remove state columns. Defaults to False. - @staticmethod - def _list_rows_query( - table_name: str, - *, - sort_by: str, - sort_order: Literal["ASC", "DESC"] = "ASC", - starting_after: str | int | None = None, - id_column: str = "ID", - offset: int = 0, - limit: int = 100, - ) -> str: - if starting_after is None: - query = ( - f"""SELECT * FROM '{table_name}' ORDER BY "{sort_by}" {sort_order} LIMIT {limit}""" - ) - else: - query = f""" - WITH sorted_rows AS ( - SELECT - *, - ROW_NUMBER() OVER ( - ORDER BY "{sort_by}" {sort_order} - ) AS _row_num - FROM '{table_name}' - ), - cursor_position AS ( - SELECT _row_num - FROM sorted_rows - WHERE "{id_column}" = '{starting_after}' - ) - SELECT sr.* - FROM sorted_rows sr, cursor_position cp - WHERE sr._row_num > cp._row_num OR cp._row_num IS NULL - ORDER BY sr._row_num - OFFSET {offset} - LIMIT {limit} - """ - return query + Raises: + ResourceNotFoundError: If the table or row is not found. + + Returns: + row (dict[str, Any]): The row data dictionary. + """ + columns = self._filter_columns(columns, exclude_state=remove_state_cols) + query = f""" + SELECT {",".join([f'"{self.map_to_short_col_id[c]}"' for c in columns])} + FROM "{self.schema_id}"."{self.short_table_id}" + """ + # Get row + row = None + async with GENTABLE_ENGINE.transaction() as conn: + try: + row = await conn.fetchrow(f'{query} WHERE "ID" = $1', row_id) + except UndefinedTableError as e: + raise ResourceNotFoundError(f'Table "{self.table_id}" is not found.') from e + if not row: + raise ResourceNotFoundError( + f'Row "{row_id}" is not found in table "{self.table_id}".' + ) + # Map short column ID back to long column ID + row = {self.map_to_long_col_id[k]: v for k, v in dict(row).items()} + return row - def list_rows( + def postprocess_rows( self, - table_id: TableName, + rows: list[dict[str, Any]], *, - offset: int = 0, - limit: int = 1_000, - columns: list[ColName] | None = None, - convert_null: bool = True, - remove_state_cols: bool = False, - json_safe: bool = False, - include_original: bool = False, float_decimals: int = 0, vec_decimals: int = 0, - order_descending: bool = True, - ) -> tuple[list[dict[str, Any]], int]: - try: - table = self.open_table(table_id) - total = self.count_rows(table_id) - except ValueError as e: - raise ResourceNotFoundError(f'Table "{table_id}" is not found.') from e - offset, limit = max(0, offset), max(1, limit) - if offset >= total: - rows = [] - else: - if offset + limit > total: - limit = total - offset - if order_descending: - offset = max(0, total - limit - offset) - if columns is not None: - if "ID" not in columns: - columns.insert(0, "ID") - if "Updated at" not in columns: - columns.insert(1, "Updated at") - rows = table._dataset.to_table(columns=columns, offset=offset, limit=limit).to_pylist() - rows = sorted(rows, reverse=order_descending, key=lambda r: r["ID"]) - rows = self._post_process_rows( - rows, - columns=columns, - convert_null=convert_null, - remove_state_cols=remove_state_cols, - json_safe=json_safe, - include_original=include_original, - float_decimals=float_decimals, - vec_decimals=vec_decimals, + include_state: bool = True, + ) -> list[dict[str, Any]]: + if not (isinstance(rows, list) and all(isinstance(r, dict) for r in rows)): + # We raise TypeError here since this is a programming error + raise TypeError("`rows` must be a list of dicts.") + for row in rows: + columns = list(row.keys()) + # Process data + for col_name in columns: + if col_name.endswith("_"): + continue + col_value = row[col_name] + # Process UUID and datetime + if isinstance(col_value, UUID): + col_value = str(col_value) + elif isinstance(col_value, datetime): + col_value = col_value.isoformat() + else: + # Rounding logic + if float_decimals > 0 and isinstance(col_value, float): + col_value = round(col_value, float_decimals) + if isinstance(col_value, np.ndarray): + if vec_decimals < 0: + del row[col_name] + continue + if vec_decimals > 0: + col_value = [round(v, vec_decimals) for v in col_value.tolist()] + else: + col_value = col_value.tolist() + # Process state + state = row.get(f"{col_name}_", None) + if state is None: + # Columns like "ID", "Updated at" do not have state + row[col_name] = col_value + continue + try: + state.pop("is_null", None) # Legacy attribute + except Exception as e: + self._log( + f'Failed to process state of column "{col_name}" due to {repr(e)} {type(state)=} {state=}', + "WARNING", + ) + row[col_name] = {"value": col_value, **state} if include_state else col_value + # Remove state + for col_name in columns: + if col_name.endswith("_"): + del row[col_name] + return rows + + def check_multiturn_column(self, column_id: str) -> LLMGenConfig: + cols = {c.column_id: c for c in self.column_metadata} + multiturn_cols = [c.column_id for c in self.column_metadata if c.is_chat_column] + column = cols.get(column_id, None) + if column is None: + raise ResourceNotFoundError( + ( + f'Table "{self.table_id}": Column "{column_id}" is not found. ' + f"Available multi-turn columns: {multiturn_cols}" + ) ) - return rows, total - - def delete_row(self, session: Session, table_id: TableName, row_id: str) -> Self: - with self.lock(table_id): - table = self.open_table(table_id) - table.delete(f"`ID` = '{row_id}'") - # Update metadata - meta = self.open_meta(session, table_id) - meta.updated_at = datetime_now_iso() - session.add(meta) - session.commit() - return self + gen_config = column.gen_config + if not (isinstance(gen_config, LLMGenConfig) and gen_config.multi_turn): + raise ResourceNotFoundError( + ( + f'Table "{self.table_id}": Column "{column_id}" is not a multi-turn LLM column. ' + f"Available multi-turn columns: {multiturn_cols}" + ) + ) + return gen_config - def delete_rows( + def interpolate_column( self, - session: Session, - table_id: TableName, - row_ids: list[str] | None = None, - where: str | None = "", - ) -> Self: - if row_ids is None: - row_ids = [] - with self.lock(table_id): - table = self.open_table(table_id) - for row_id in row_ids: - table.delete(f"`ID` = '{row_id}'") - if where: - table.delete(where) - # Update metadata - meta = self.open_meta(session, table_id) - meta.updated_at = datetime_now_iso() - session.add(meta) - session.commit() - return self - - @staticmethod - def _interpolate_column( prompt: str, - column_dtypes: dict[str, str], - column_contents: dict[str, Any], - ) -> str: + row: dict[str, Any], + ) -> str | list[TextContent | S3Content]: """ Replaces / interpolates column references in the prompt with their contents. Args: prompt (str): The original prompt with zero or more column references. + row (dict[str, Any]): The row data containing column values. + content_injection (bool, optional): If True, injects column content in the prompt. + If False, user prompt will be unchanged. Defaults to True. Returns: - new_prompt (str): The prompt with column references replaced. + content (str | list[TextContent | S3Content]): Message content with column references replaced. """ + column_map = {c.column_id: c for c in self.column_metadata} + s3_contents: list[S3Content] = [] - def replace_match(match): + def _replace(match: re.Match) -> str: col_id = match.group(1) try: - if column_dtypes[col_id] == "image": - return "" - elif column_dtypes[col_id] == "audio": - return "" - return str(column_contents[col_id]) - except KeyError as e: - raise KeyError(f'Referenced column "{col_id}" is not found.') from e + # Referenced column is found + col = column_map[col_id] + col_data = row.get(col_id, None) + if col.is_file_column: + # File references will be loaded and interpolated in `GenExecutor` + if col_data is None: + # If file URI is None, we treat it as no content injection + return "" + else: + # Return URI and retain column reference for downstream interpolation + s3_contents.append(S3Content(uri=row[col_id], column_name=col_id)) + return f"${{{col_id}}}" + # Non-file references can interpolate directly + return str(col_data) + except KeyError: + # Referenced column is not found + # Maybe injected contents accidentally contain references + # We escape it here just in case + return f"\\${{{col_id}}}" - return re.sub(GEN_CONFIG_VAR_PATTERN, replace_match, prompt) + prompt = re.sub(GEN_CONFIG_VAR_PATTERN, _replace, prompt).strip() + if len(s3_contents) == 0: + return prompt + return s3_contents + [TextContent(text=prompt)] - def get_conversation_thread( + async def get_conversation_thread( self, - table_id: TableName, + *, column_id: str, row_id: str = "", - include: bool = True, - ) -> ChatThread: - with self.create_session() as session: - meta = self.open_meta(session, table_id) - cols = {c.id: c for c in meta.cols_schema} - chat_cols = {c.id: c for c in cols.values() if getattr(c.gen_config, "multi_turn", False)} - try: - gen_config = chat_cols[column_id].gen_config - except KeyError as e: - raise ResourceNotFoundError( - f'Column "{column_id}" is not found. Available chat columns: {list(chat_cols.keys())}' - ) from e + include_row: bool = True, + ) -> ChatThreadResponse: + """ + Get a conversation thread for a multi-turn LLM column. + + Args: + column_id (str): ID of the multi-turn LLM column. + row_id (str, optional): ID of the last row in the thread. + Defaults to "" (export all rows).. + include_row (bool, optional): Whether to include the row specified by `row_id`. + Defaults to True. + + Returns: + response (ChatThreadResponse): _description_ + """ + gen_config = self.check_multiturn_column(column_id) ref_col_ids = re.findall(GEN_CONFIG_VAR_PATTERN, gen_config.prompt) - rows, _ = self.list_rows( - table_id=table_id, - offset=0, - limit=1_000_000, - columns=ref_col_ids + [column_id], - convert_null=True, - remove_state_cols=True, - json_safe=True, - float_decimals=0, - vec_decimals=0, - order_descending=False, - ) + columns = ref_col_ids + [column_id] if row_id: - row_ids = [r["ID"] for r in rows] - try: - rows = rows[: row_ids.index(row_id) + (1 if include else 0)] - except ValueError as e: - raise make_validation_error( - ValueError(f'Row ID "{row_id}" is not found in table "{table_id}".'), - loc=("body", "row_id"), - ) from e + where = '"ID" ' + (f"<= '{row_id}'" if include_row else f"< '{row_id}'") + else: + where = "" + rows = ( + await self.list_rows( + limit=None, + offset=0, + order_by=None, + order_ascending=True, + columns=columns, + where=where, + remove_state_cols=False, + ) + ).items + ref_cols = set(re.findall(GEN_CONFIG_VAR_PATTERN, gen_config.prompt)) + has_user_prompt = "User" in ref_cols thread = [] if gen_config.system_prompt: - thread.append(ChatEntry.system(gen_config.system_prompt)) + thread.append(ChatThreadEntry.system(gen_config.system_prompt)) for row in rows: + if has_user_prompt: + user_prompt = row.get("User", None) or None # Map "" to None + else: + user_prompt = None + row_id = str(row["ID"]) thread.append( - ChatEntry.user( - self._interpolate_column( - gen_config.prompt, - {c.id: c.dtype for c in cols.values()}, - row, - ) + ChatThreadEntry.user( + self.interpolate_column(gen_config.prompt, row), + user_prompt=user_prompt, + row_id=row_id, ) ) - thread.append(ChatEntry.assistant(row[column_id])) - return ChatThread(thread=thread) + thread.append( + ChatThreadEntry.assistant( + row[column_id], + references=row.get(f"{column_id}_", {}).get("references", None), + row_id=row_id, + ) + ) + return ChatThreadResponse(thread=thread, column_id=column_id) - def export_csv( - self, - table_id: TableName, - columns: list[ColName], - file_path: str = "", - delimiter: CSVDelimiter | str = ",", - ) -> pd.DataFrame: - if isinstance(delimiter, str): - try: - delimiter = CSVDelimiter[delimiter] - except KeyError as e: - raise make_validation_error( - ValueError(f'Delimiter can only be "," or "\\t", received: {delimiter}'), - loc=("body", "delimiter"), - ) from e - rows, total = self.list_rows( - table_id=table_id, - offset=0, - limit=self.count_rows(table_id), - columns=columns, - convert_null=True, - remove_state_cols=True, - json_safe=True, - include_original=False, - float_decimals=0, - vec_decimals=0, - order_descending=False, + @staticmethod + def _tokenize_regex_simple(text): + tokens = [] + for match in TOKEN_PATTERN.finditer(text): + # Figure out which group matched to determine category and get the string + if match.group(1): # Digits + token_str = match.group(1) + tokens.append(token_str) + elif match.group(2): # Letters + token_str = match.group(2).lower() # Lowercase letters + tokens.append(token_str) + elif match.group(3): # Hanzi + token_str = match.group(3) + tokens.append(token_str) # Append Hanzi directly + elif match.group(4): # Other + token_str = match.group(4) + tokens.append(token_str) # Append other char directly + return tokens + + @staticmethod + def _bm25_ranking( + fts_results: list[dict[str, Any]], + *, + query: str, + text_column_names: list[str], + weights: list[int] | None = None, + ascending: bool = False, + ) -> list[dict[str, Any]]: + corpus = [res[col] for res in fts_results for col in text_column_names] + tokenizer = bm25s.tokenization.Tokenizer( + splitter=GenerativeTableCore._tokenize_regex_simple, + stopwords=[ + "english", + ], + stemmer=stemmer.stem, ) - df = pd.DataFrame.from_dict(rows, orient="columns", dtype=None, columns=None) - if len(df) != total: - logger.error( - f"Table {table_id} has {total:,d} rows but exported DF has {len(df):,d} rows !!!" - ) - if file_path == "": - return df - if delimiter == CSVDelimiter.COMMA and not file_path.endswith(".csv"): - file_path = f"{file_path}.csv" - elif delimiter == CSVDelimiter.TAB and not file_path.endswith(".tsv"): - file_path = f"{file_path}.tsv" - df_to_csv(df, file_path, sep=delimiter.value) - return df - - def dump_parquet( + corpus = ["" if c is None else c for c in corpus] + corpus_tokens = tokenizer.tokenize(corpus, show_progress=False) + retriever = bm25s.BM25(backend="numpy") + retriever.index(corpus_tokens, show_progress=False) + query_tokens = tokenizer.tokenize([query], show_progress=False) + results, scores = retriever.retrieve( + query_tokens, k=len(corpus), show_progress=False, n_threads=1, sorted=True + ) + # Reshape scores into (n_docs, n_columns) and apply weights + scores_reshaped = scores[0, results.argsort()].reshape(-1, len(text_column_names)) + if weights: + scores_reshaped *= np.array(weights) + # Sum scores across columns + doc_scores = scores_reshaped.sum(axis=1) + + # Get sorted indices (ascending or descending) + sorted_indices = np.argsort(doc_scores) + if not ascending: + sorted_indices = sorted_indices[::-1] # Reverse for descending + + # Build sorted results with scores + ranked_results = [fts_results[i] for i in sorted_indices] + + for res, score in zip(ranked_results, doc_scores[sorted_indices], strict=True): + res["score"] = float(score) # Convert numpy.float32 to native Python float + return ranked_results + + async def fts_search( self, - session: Session, - table_id: TableName, - dest: str | BinaryIO, + query: str, *, - compression: Literal["NONE", "ZSTD", "LZ4", "SNAPPY"] = "ZSTD", - ) -> None: - from pyarrow.parquet import write_table - - with self.lock(table_id): - meta = self.open_meta(session, table_id) - table = self.open_table(table_id) - # Convert into Arrow Table - pa_table = table._dataset.to_table(offset=None, limit=None) - # Add file data into Arrow Table - file_col_ids = [col.id for col in meta.cols_schema if col.dtype in ["image", "audio"]] - for col_id in file_col_ids: - file_bytes = [] - for uri in pa_table.column(col_id).to_pylist(): - if not uri: - file_bytes.append(b"") - continue - with open_uri_sync(uri) as f: - file_bytes.append(f.read()) - # Append byte column - pa_table = pa_table.append_column( - pa.field(f"{col_id}__", pa.binary()), [file_bytes] - ) - # Add Generative Table metadata - pa_meta = pa_table.schema.metadata or {} - pa_table = pa_table.replace_schema_metadata( - {"gen_table_meta": meta.model_dump_json(), **pa_meta} - ) - if isinstance(dest, str): - if isdir(dest): - dest = join(dest, f"{table_id}.parquet") - elif not dest.endswith(".parquet"): - dest = f"{dest}.parquet" - write_table(pa_table, dest, compression=compression) - - async def import_parquet( - self, - session: Session, - source: str | BinaryIO, - table_id_dst: str | None, - ) -> tuple[LanceTable, TableMeta]: - from pyarrow.parquet import read_table + weights: dict[str, int] | None = None, + limit: int = 100, + offset: int = 0, + remove_state_cols: bool = False, + force_use_index: bool = False, + use_bm25_ranking: bool = False, + explain: bool = False, + ) -> list[dict[str, Any]]: + """ + Perform full-text search across all text columns using pgroonga. - # Check metadata - pa_table = read_table(source, columns=None, use_threads=False, memory_map=True) - try: - meta = TableMeta.model_validate_json(pa_table.schema.metadata[b"gen_table_meta"]) - except KeyError as e: - raise BadInputError("Missing table metadata in the Parquet file.") from e - except Exception as e: - raise BadInputError("Invalid table metadata in the Parquet file.") from e - # Check for required columns - required_columns = set(self.FIXED_COLUMN_IDS) - meta_cols = {c.id for c in meta.cols_schema} - if len(required_columns - meta_cols) > 0: - raise BadInputError( - f"Missing columns in table metadata: {list(required_columns - meta_cols)}." - ) - # Table ID must not exist - if table_id_dst is None: - table_id_dst = meta.id - with self.lock(table_id_dst): - if session.get(TableMeta, table_id_dst) is not None: - raise ResourceExistsError(f'Table "{table_id_dst}" already exists.') - # Upload files - file_col_ids = [col.id for col in meta.cols_schema if col.dtype in ["image", "audio"]] - for col_id in file_col_ids: - new_uris = [] - for old_uri, content in zip( - pa_table.column(col_id).to_pylist(), - pa_table.column(f"{col_id}__").to_pylist(), - strict=True, - ): - if len(content) == 0: - new_uris.append(None) - continue - mime_type = filetype.guess(content).mime - if mime_type is None: - mime_type = "application/octet-stream" - uri = await upload_file_to_s3( - self.organization_id, - self.project_id, - content, - mime_type, - old_uri.split("/")[-1], - ) - new_uris.append(uri) - # Drop old columns - pa_table = pa_table.drop_columns([col_id, f"{col_id}__"]) - # Append new column - pa_table = pa_table.append_column(pa.field(col_id, pa.utf8()), [new_uris]) - # Import Generative Table - meta.id = table_id_dst - session.add(meta) - session.commit() - session.refresh(meta) - table = self.lance_db.create_table(meta.id, data=pa_table, schema=pa_table.schema) - self.create_indexes( - session=session, - table_id=meta.id, - force=True, + Args: + query (str): Search query string. + limit (int, optional): Maximum number of rows to return. Defaults to 100. + offset (int, optional): Offset for pagination. Defaults to 0. + remove_state_cols (bool, optional): If True, remove state columns. Defaults to False. + force_use_index (bool, optional): If True, force using pgroonga index. Defaults to False. + use_bm25_ranking (bool, optional): If True, use BM25 ranking. Defaults to False. + explain (bool, optional): If True, return explain query. Defaults to False. + + Raises: + ResourceNotFoundError: If the table or column(s) is not found. + + Returns: + rows (list[dict[str, Any]]): List of row data dictionaries. + """ + t0 = perf_counter() + if weights is None: + weights = [1 for _ in self.text_column_names] + else: + weights = [weights.get(n, 1) for n in self.text_column_names] + if len(weights) == 0: # if no text columns fts return empty list + return [] + # Build query + select_cols = self.data_table_model.get_column_ids(exclude_state=remove_state_cols) + # Do not enforce idx like: ($1, ARRAY{weights}, '{fts_index_id(self.table_id)}')::pgroonga_full_text_search_condition + # Pg planner will choose the best plan to run the query efficiently (for smaller number of rows might just use seq scan) + # for duplicated table with CTAS, if number of rows is small it might always use seq scan regardless, so forcing the index will fail + # tested a simple 3 col table, if number rows is 1000 then even with NULL index will be used. + index_name = f"'{fts_index_id(self.table_id)}'" if force_use_index else "NULL" + stmt = f""" + SELECT + {",".join(f'"{self.map_to_short_col_id[c]}"' for c in select_cols)}, + pgroonga_score(tableoid, ctid) AS score + FROM + "{self.schema_id}"."{self.short_table_id}" + WHERE + ARRAY[{", ".join(f'"{self.map_to_short_col_id[n]}"' for n in self.text_column_names)}] &@~ + ($1, ARRAY{weights}, {index_name})::pgroonga_full_text_search_condition + ORDER BY score DESC + LIMIT $2 OFFSET $3 + """ + if explain: + stmt = f"EXPLAIN ANALYZE {stmt}" + async with GENTABLE_ENGINE.transaction() as conn: + # Execute query + try: + rows = await conn.fetch(stmt, query, limit, offset) + except UndefinedColumnError as e: + raise ResourceNotFoundError( + f'One or more columns is not found in table "{self.table_id}".' + ) from e + except UndefinedTableError as e: + raise ResourceNotFoundError(f'Table "{self.table_id}" is not found.') from e + except DataError as e: + raise BadInputError(f"Bad input: {e}") from e + # Map short column IDs back to long column IDs + # Keys contain non-column IDs like "score" + results = [ + {self.map_to_long_col_id.get(k, k): v for k, v in dict(row).items()} for row in rows + ] + if len(results) > 0 and use_bm25_ranking: + results = self._bm25_ranking( + fts_results=results, + query=query, + text_column_names=self.text_column_names, + weights=weights, + ascending=False, ) - session.refresh(meta) - return table, meta + self._log(f"FTS search took t={(perf_counter() - t0) * 1e3:,.2f} ms.") + return results - @retry( - wait=wait_exponential(multiplier=1, min=2, max=10), - stop=stop_after_attempt(4), - reraise=True, - ) - def _run_query( + async def vector_search( self, - session: Session, - table_id: TableName, - table: LanceTable, - query: np.ndarray | list | str | None = None, - column_name: str | None = None, - where: str | None = None, - limit: PositiveInt = 10_000, - metric: str = "cosine", - nprobes: PositiveInt = 50, - refine_factor: PositiveInt = 20, + query: str, + *, + embedding_fn: Callable[[str, str], list[float] | Awaitable[list[float]]], + vector_column_names: list[str] | None = None, + limit: int = 100, + offset: int = 0, + remove_state_cols: bool = False, + explain: bool = False, ) -> list[dict[str, Any]]: - is_vector = isinstance(query, (list, np.ndarray)) - if query is None: - column_name = None - query_type = "auto" - elif is_vector: - query_type = "vector" - elif isinstance(query, str): - query = re.sub(r"[\W\s]", " ", query.lower()) - query_type = "fts" + """Perform vector similarity search using cosine distance. + + Args: + query (str): Search query string. + embedding_fn (Callable[[str, str], list[float] | Awaitable[list[float]]]): Embedding function that + takes two string parameters (`str`, `str`) and returns a list of floats. + Can be either synchronous or asynchronous. + The first argument is the model ID and the second argument is the query, ie `embedding_fn(model, query)`. + vector_column_names (list[str] | None, optional): List of vector column name to search. + Defaults to None (all vector columns are used). + limit (int, optional): Maximum number of rows to return. Defaults to 100. + offset (int, optional): Offset for pagination. Defaults to 0. + remove_state_cols (bool, optional): If True, remove state columns. Defaults to False. + explain (bool, optional): If True, return explain query. Defaults to False. + + Raises: + TypeError: If `vector_column_names` is not a list of strings. + BadInputError: If not all columns are vector columns. + ResourceNotFoundError: If the table or column(s) is not found. + + Returns: + rows (list[dict[str, Any]]): List of row data dictionaries. + """ + t0 = perf_counter() + if vector_column_names is None: + vector_column_names = self.vector_column_names else: - raise TypeError("`query` must be one of [np.ndarray | list | str | None].") - query_builder = table.search( - query=query, - vector_column_name=column_name, - query_type=query_type, + if not ( + isinstance(vector_column_names, list) + and all(isinstance(n, str) for n in vector_column_names) + ): + # We raise TypeError here since this is a programming error + raise TypeError("`vector_column_names` must be a list of strings.") + # Ensure all columns are vector columns + if len(invalid_cols := set(vector_column_names) - set(self.vector_column_names)) > 0: + raise BadInputError( + ( + f'Table "{self.table_id}": All columns to be searched must be vector columns. ' + f"Invalid columns: {list(invalid_cols)}" + ) + ) + if len(vector_column_names) == 0: + return [] + # Get query vectors + models: list[str] = list( + { + getattr(c.gen_config, "embedding_model", "") + for c in self.column_metadata + if c.column_id in vector_column_names + } ) - if is_vector: - query_builder = ( - query_builder.metric(metric).nprobes(nprobes).refine_factor(refine_factor) - ) - if where: - query_builder = query_builder.where(where, prefilter=True) - try: - results = query_builder.limit(limit).to_list() - except ValueError: - logger.exception( - f'Failed to perform search on table "{table_id}" !!! Attempting index rebuild ...' - ) - index_ok = self.create_indexes(session, table_id, force=True) - if index_ok: - logger.warning(f'Reindex table "{table_id}" OK, retrying search ...') - else: - logger.error( - f'Failed to reindex table "{table_id}" !!! Retrying search anyway ...' + self._log(f"Embedding using models: {models}") + if iscoroutinefunction(embedding_fn): + query_vectors = await asyncio.gather(*[embedding_fn(m, query) for m in models]) + else: + with ThreadPoolExecutor() as executor: + query_vectors = list(executor.map(embedding_fn, models, [query] * len(models))) + query_vectors = {m: v for m, v in zip(models, query_vectors, strict=True)} + self._log(f"Embedding using {models} took t={(perf_counter() - t0) * 1e3:,.2f} ms.") + + t0 = perf_counter() + columns = [] + for c in self.column_metadata: + if c.column_id not in vector_column_names: + continue + vec = query_vectors[getattr(c.gen_config, "embedding_model", "")] + if len(vec) != c.vlen: + raise BadInputError( + f"Vector length mismatch for column {c.column_id}. Expected {c.vlen}, got {len(vec)}." ) - results = query_builder.limit(limit).to_list() + columns.append((self.map_to_short_col_id[c.column_id], vec)) + if len(columns) == 0: + return [] + # CTE query + # https://learn.microsoft.com/en-us/answers/questions/2118689/how-to-search-across-multiple-vector-indexes-in-po + subqueries = [ + f""" + "{col_id}_results" AS ( + SELECT + "ID", ("{col_id}" <=> ${i + 1}) AS score + FROM + "{self.schema_id}"."{self.short_table_id}" + ORDER BY + score ASC + ) + """ + for i, (col_id, _) in enumerate(columns) + ] + select_cols = self.data_table_model.get_column_ids(exclude_state=remove_state_cols) + selects = [f't."{self.map_to_short_col_id[col]}"' for col in select_cols] + joins = [ + f'JOIN "{col_id}_results" ON "{columns[0][0]}_results"."ID" = "{col_id}_results"."ID"' + for col_id, _ in columns[1:] + ] + join_expr = "\n".join(joins) + stmt = f""" + WITH + {", ".join(subqueries)} + SELECT + {", ".join(selects)}, + {" + ".join(f'"{col_id}_results".score' for col_id, _ in columns)} AS score + FROM + "{columns[0][0]}_results" + {join_expr} + JOIN + "{self.schema_id}"."{self.short_table_id}" t + ON + t."ID" = "{columns[0][0]}_results"."ID" + ORDER BY + score ASC + LIMIT ${len(columns) + 1} OFFSET ${len(columns) + 2}; + """ + if explain: + stmt = f"EXPLAIN ANALYZE {stmt}" + async with GENTABLE_ENGINE.transaction() as conn: + # Execute query + try: + rows = await conn.fetch(stmt, *[vec for _, vec in columns], limit, offset) + except UndefinedColumnError as e: + raise ResourceNotFoundError( + f'One or more columns is not found in table "{self.table_id}".' + ) from e + except UndefinedTableError as e: + raise ResourceNotFoundError(f'Table "{self.table_id}" is not found.') from e + except DataError as e: + raise BadInputError(f"Bad input: {e}") from e + # Map short column IDs back to long column IDs + # Keys contain non-column IDs like "score" + results = [ + {self.map_to_long_col_id.get(k, k): v for k, v in dict(row).items()} for row in rows + ] + self._log(f"Vector search took t={(perf_counter() - t0) * 1e3:,.2f} ms.") return results @staticmethod def _reciprocal_rank_fusion( - search_results: list[list[dict]], result_key: str = "ID", K: int = 60 - ): + search_results: list[list[dict]], + result_key: str = "ID", + K: int = 60, + ) -> list[dict]: """ - Perform reciprocal rank fusion to merge the rank of the search results (arbitrary number of results and can be varying in length) + Perform reciprocal rank fusion to merge the rank of the search results + (arbitrary number of results and can vary in length). + Args: - search_results: list of search results from lance query, search result is a sorted list of dict (descending order of closeness) - result_key: dict key of the search result - K: const (def=60) for reciprocal rank fusion + search_results (list[list[dict]]): List of search results, + where each result is a sorted list of dict (descending order of closeness). + result_key (str, optional): Dictionary key of each item's ID. Defaults to "ID". + K (int, optional): Const for reciprocal rank fusion. Defaults to 60. Return: - A list of dict of original result with the rrf scores (higher scores, higher ranking) + rows (list[dict]): A list of dict of original result with the rrf scores (higher scores, higher ranking). """ rrf_scores = defaultdict(lambda: {"rrf_score": 0.0}) for search_result in search_results: @@ -1403,593 +3939,701 @@ def _reciprocal_rank_fusion( sorted_rrf = sorted(rrf_scores.values(), key=lambda x: x["rrf_score"], reverse=True) return sorted_rrf - def regex_search( + async def hybrid_search( self, - session: Session, - table_id: TableName, - query: str | None, + fts_query: str, + vs_query: str, *, - columns: list[ColName] | None = None, - convert_null: bool = True, + embedding_fn: Callable[[str, str], Awaitable[list[float] | np.ndarray]], + vector_column_names: list[str] | None = None, + limit: int = 100, + offset: int = 0, + use_bm25_ranking: bool = True, remove_state_cols: bool = False, - json_safe: bool = False, - include_original: bool = False, - float_decimals: int = 0, - vec_decimals: int = 0, - order_descending: bool = True, ) -> list[dict[str, Any]]: - table, meta = self.open_table_meta(session, table_id) - if self.count_rows(table_id) == 0: - return [] - if not isinstance(query, str): - raise TypeError(f"`query` must be string, received: {type(query)}") - rows = [] + """ + Perform vector similarity search using cosine distance. + + Args: + fts_query (str): FTS search query string. + vs_query (str): Vector search query string. + embedding_fn (Callable[[str, str], Awaitable[list[float] | np.ndarray]]): Async embedding function that + takes two string parameters (`str`, `str`) and returns a NumPy array or a list of floats. + The first argument is the model ID and the second argument is the query, ie `embedding_fn(model, query)`. + The returned NumPy array should be one-dimensional (ie a single vector). + vector_column_names (list[str] | None, optional): List of vector column name to search. + Defaults to None (all vector columns are used). + limit (int, optional): Maximum number of rows to return from FTS and vector searches. + Note that this means that hybrid search can return more than `limit` rows. Defaults to 100. + offset (int, optional): Offset for pagination. Defaults to 0. + use_bm25_ranking (bool, optional): If True, use BM25 ranking. Defaults to True. + remove_state_cols (bool, optional): If True, remove state columns. Defaults to False. + + Raises: + BadInputError: If not all columns are vector columns. + ResourceNotFoundError: If the table or column(s) is not found. + + Returns: + rows (list[dict[str, Any]]): List of row data dictionaries. + """ t0 = perf_counter() - cols = self.fts_cols(meta) - for col in cols: - rows += ( - table.search() - .where(f"regexp_match(`{col.id}`, '{query}')") - .limit(table.count_rows()) - .to_list() - ) - logger.info(f"Regex search timings ({len(cols)} cols): {perf_counter() - t0:,.3f}") - # De-duplicate and sort - rows = {r["ID"]: r for r in rows}.values() - rows = sorted(rows, reverse=order_descending, key=lambda r: r["ID"]) - rows = self._post_process_rows( - rows, - columns=columns, - convert_null=convert_null, + fts_task = self.fts_search( + query=fts_query, + limit=limit, + offset=offset, + use_bm25_ranking=use_bm25_ranking, + remove_state_cols=remove_state_cols, + ) + vs_task = self.vector_search( + query=vs_query, + embedding_fn=embedding_fn, + vector_column_names=vector_column_names, + limit=limit, + offset=offset, remove_state_cols=remove_state_cols, - json_safe=json_safe, - include_original=include_original, - float_decimals=float_decimals, - vec_decimals=vec_decimals, ) + # Run both tasks concurrently and wait for them to complete + # asyncio.gather returns results in the order the tasks were passed + fts_result, vs_result = await asyncio.gather(fts_task, vs_task) + search_results = [fts_result, vs_result] + # RRF + rows = self._reciprocal_rank_fusion(search_results) + self._log(f"Hybrid search took t={(perf_counter() - t0) * 1e3:,.2f} ms.") return rows - async def hybrid_search( + def rows_to_documents(self, rows: list[dict[str, Any]]) -> list[str]: + cols = {c.column_id for c in self.column_metadata if not c.is_state_column} + documents = [ + ( + f"Title: {r.get('Title', '')}\nContent: {r.get('Text', '')}\n" + + "\n".join( + f"{k}: {v}" + for k, v in r.items() + if k not in self.FIXED_COLUMN_IDS and k in cols + ) + ) + for r in rows + ] + return documents + + # Row Update Ops + async def update_rows( self, - session: Session, - table_id: TableName, - query: str | None, + updates: dict[str, dict[str, Any]], *, - where: str | None = None, - limit: PositiveInt = 100, - columns: list[ColName] | None = None, - metric: str = "cosine", - nprobes: PositiveInt = 50, - refine_factor: PositiveInt = 20, - embedder: CloudEmbedder | None = None, - reranker: CloudReranker | None = None, - reranking_model: str | None = None, - convert_null: bool = True, - remove_state_cols: bool = False, - json_safe: bool = False, - include_original: bool = False, - float_decimals: int = 0, - vec_decimals: int = 0, - ) -> list[dict[str, Any]]: - if not (isinstance(limit, int) and limit > 0): - # TODO: Currently LanceDB is bugged, limit in theory can be None or 0 or negative - # https://github.com/lancedb/lancedb/issues/1151 - raise TypeError("`limit` must be a positive non-zero integer.") - t0 = perf_counter() - table, meta = self.open_table_meta(session, table_id) - if self.count_rows(table_id) == 0: - return [] - timings = {} - if query is None: - t1 = perf_counter() - rows = self._run_query( - session=session, - table_id=table_id, - table=table, - query=None, - column_name=None, - where=where, - limit=limit, + ignore_state_columns: bool = True, + ) -> None: + """ + Update multiple rows in the Generative Table. + + Args: + updates (dict[str, dict[str, Any]]): A dictionary mapping row ID to update data. + Each update data is a dictionary of column name to value. + ignore_state_columns (bool, optional): Whether to ignore state columns. Defaults to True. + + Raises: + TypeError: If the data is not a list of dictionaries. + BadInputError: If any row does not have an "ID" field. + ResourceNotFoundError: If the table is not found. + + Returns: + self (GenerativeTableCore): The table instance. + """ + if not ( + isinstance(updates, dict) and all(isinstance(row, dict) for row in updates.values()) + ): + # We raise TypeError here since this is a programming error + raise TypeError("`updates` must be a dict of dicts.") + # Filter out non-existent fields + columns = set( + self.data_table_model.get_column_ids( + exclude_info=True, + exclude_state=ignore_state_columns, ) - timings["no_query"] = perf_counter() - t1 - else: - if not isinstance(query, str): - raise TypeError(f"`query` must be string, received: {type(query)}") - search_results = [] - # 2024-06 (BUG?): lance fts works on all indexed cols at once (can't specify the col to be searched) - # Thus no need to loop through indexed col one by one - if len(self.fts_cols(meta)) > 0: - t1 = perf_counter() - fts_result = self._run_query( - session=session, - table_id=table_id, - table=table, - query=query, - # column_name=c.id, - where=where, - limit=limit, - metric=metric, - nprobes=nprobes, - refine_factor=refine_factor, - ) - timings["FTS:"] = perf_counter() - t1 - search_results.append(fts_result) - for c in self.embedding_cols(meta): - t1 = perf_counter() - embedding = await embedder.embed_queries( - c.gen_config.embedding_model, texts=[query] - ) - # TODO: Benchmark this - # Searching using float16 seems to be faster on float32 and float16 indexes - # 2024-05-21, lance 0.6.13, pylance 0.10.12 - embedding = np.asarray(embedding.data[0].embedding, dtype=np.float16) - embedding = embedding / np.linalg.norm(embedding) - timings[f"Embed ({c.gen_config.embedding_model}): {c.id}"] = perf_counter() - t1 - t1 = perf_counter() - sub_rows = self._run_query( - session=session, - table_id=table_id, - table=table, - query=embedding, - column_name=c.id, - where=where, - limit=limit, - metric=metric, - nprobes=nprobes, - refine_factor=refine_factor, - ) - # vector_score from lance is 1.0 - cosine similarity (0. exact match) - search_results.append(sub_rows) - timings[f"VS: {c.id}"] = perf_counter() - t1 - # list of search results with rrf_score - rows = self._reciprocal_rank_fusion(search_results) - if reranker is None: - # No longer do a linear combination for hybrid scores, use RRF score instead. - _scores = [(f'(RRF_score={r["rrf_score"]:.1f}, ') for r in rows] - logger.info(f"Hybrid search scores: {_scores}") - else: - t1 = perf_counter() - chunks = await reranker.rerank_chunks( - reranking_model, - chunks=[ - Chunk( - text="" if row["Text"] is None else row["Text"], - title="" if row["Title"] is None else row["Title"], - ) - for row in rows - ], - query=query, - ) - rerank_order = [c[2] for c in chunks] - rows = [rows[idx] for idx in rerank_order] - timings[f"Rerank ({reranking_model})"] = perf_counter() - t1 - rows = rows[:limit] - rows = self._post_process_rows( - rows, - columns=columns, - convert_null=convert_null, - remove_state_cols=remove_state_cols, - json_safe=json_safe, - include_original=include_original, - float_decimals=float_decimals, - vec_decimals=vec_decimals, ) - timings["Total"] = perf_counter() - t0 - timings = {k: f"{v:,.3f}" for k, v in timings.items()} - logger.info(f"Hybrid search timings: {timings}") - return rows + # Validate and convert all rows + try: + updates = { + row_id: self._validate_row_data( + {k: v for k, v in row.items() if k in columns and k.lower() != "id"} + ).model_dump(exclude_unset=True) + for row_id, row in updates.items() + } + except ValidationError as e: + raise BadInputError(f"Input data contains errors: {e}") from e + async with GENTABLE_ENGINE.transaction() as conn: + try: + for row_id, update in updates.items(): + if len(update) == 0: + continue + _cols = [k for k in update.keys()] + # Build SQL statement + set_expr = ", ".join( + f'"{self.map_to_short_col_id[col]}" = ${i + 1}' + for i, col in enumerate(_cols) + ) + query = ( + f'UPDATE "{self.schema_id}"."{self.short_table_id}" ' + f'SET "Updated at" = statement_timestamp(), {set_expr} ' + f'WHERE "ID" = ${len(_cols) + 1}' + ) + # Update rows + await conn.execute(query, *(update[col] for col in _cols), row_id) + # Set updated at time + await self._set_updated_at(conn) + except UndefinedTableError as e: + raise ResourceNotFoundError(f'Table "{self.table_id}" is not found.') from e + except DataError as e: + raise BadInputError(f"Bad input: {e}") from e - def scalar_cols(self, meta: TableMeta) -> list[ColumnSchema]: - return [c for c in meta.cols_schema if c.id.lower() in ("id", "updated at")] + # Row Delete Ops + async def delete_rows( + self, + *, + row_ids: list[str] | None = None, + where: str = "", + ) -> Self: + """ + Delete one or more rows from the Generative Table. - def embedding_cols(self, meta: TableMeta) -> list[ColumnSchema]: - return [c for c in meta.cols_schema if c.vlen > 0] + Args: + row_ids (list[str] | None, optional): List of row IDs to be deleted. + Defaults to None (match rows using `where`). + where (str, optional): SQL where clause. Defaults to "" (no filter). + It will be combined with `row_ids` using `AND`. - def fts_cols(self, meta: TableMeta) -> list[ColumnSchema]: - return [c for c in meta.cols_schema if c.dtype == ColumnDtype.STR and c.id.lower() != "id"] + Raises: + ResourceNotFoundError: If the table is not found. - def create_fts_index( - self, - session: Session, - table_id: TableName, + Returns: + self (GenerativeTableCore): The table instance. + """ + if row_ids is None: + row_ids = [] + if not (isinstance(row_ids, list) and all(isinstance(i, (str, UUID)) for i in row_ids)): + # We raise TypeError here since this is a programming error + raise TypeError("`row_ids` must be a list of strings.") + + # Build SQL query + filters = [] + if row_ids: + filters.append('("ID" = $1)') + row_ids = [(row_id,) for row_id in row_ids] + where = where.strip() + if where: + try: + where = f"({validate_where_expr(where, id_map=self.map_to_short_col_id)})" + except Exception as e: + raise BadInputError(str(e)) from e + filters.append(where) + if len(filters) == 0: + raise BadInputError("Either `row_ids` or `where` must be provided.") + async with GENTABLE_ENGINE.transaction() as conn: + try: + sql = f'DELETE FROM "{self.schema_id}"."{self.short_table_id}" WHERE {" AND ".join(filters)}' + if row_ids: + await conn.executemany(sql, row_ids) + else: + await conn.execute(sql) + # Set updated at time + await self._set_updated_at(conn) + except UndefinedTableError as e: + raise ResourceNotFoundError(f'Table "{self.table_id}" is not found.') from e + except PostgresSyntaxError as e: + raise BadInputError(f"Bad SQL statement: `{sql}`") from e + return self + + +class ActionTable(GenerativeTableCore): + TABLE_TYPE = TableType.ACTION + + @override + @classmethod + async def drop_schema( + cls, *, - force: bool = False, - ) -> bool: - table, meta = self.open_table_meta(session, table_id) - fts_cols = [c.id for c in self.fts_cols(meta)] - # Maybe can skip reindexing - if ( - (not force) - and meta.indexed_at_fts is not None - and meta.indexed_at_fts > meta.updated_at - ): - return False - num_rows = table.count_rows() - if num_rows == 0: - return False - if len(fts_cols) == 0: - return False - index_datetime = datetime_now_iso() - table.create_fts_index(fts_cols, replace=True) - # Update metadata - meta.indexed_at_fts = index_datetime - session.add(meta) - session.commit() - return True - - def create_scalar_index( - self, - session: Session, - table_id: TableName, + project_id: str, + ) -> None: + """ + Drops the project's schema along with all data tables. + """ + return await super().drop_schema( + project_id=project_id, + table_type=cls.TABLE_TYPE, + ) + + @override + @classmethod + async def create_table( + cls, *, - force: bool = False, - ) -> bool: - table, meta = self.open_table_meta(session, table_id) - # Maybe can skip reindexing - if ( - (not force) - and meta.indexed_at_sca is not None - and meta.indexed_at_sca > meta.updated_at - ): - return False - num_rows = table.count_rows() - if num_rows == 0: - return False - index_datetime = datetime_now_iso() - for c in self.scalar_cols(meta): - table.create_scalar_index(c.id, replace=True) - # Update metadata - meta.indexed_at_sca = index_datetime - session.add(meta) - session.commit() - return True - - def create_vector_index( - self, - session: Session, - table_id: TableName, - force: bool = False, + project_id: str, + table_metadata: TableMetadata, + column_metadata_list: list[ColumnMetadata], + ) -> Self: + """ + Create a new Action Table with default prompts (if prompts are not provided). + + Args: + project_id (str): Project ID. + table_metadata (TableMetadata): Table metadata. + column_metadata_list (list[ColumnMetadata]): List of column metadata. + + Returns: + self (GenerativeTableCore): The table instance. + """ + return await cls._create_table( + project_id=project_id, + table_type=cls.TABLE_TYPE, + table_metadata=table_metadata, + column_metadata_list=column_metadata_list, + set_default_prompts=True, + ) + + @classmethod + async def duplicate_table( + cls, *, - metric: str = "cosine", - num_partitions: int | None = None, - num_sub_vectors: int | None = None, - accelerator: str | None = None, - index_cache_size: int | None = None, - ) -> bool: + project_id: str, + table_id_src: str, + table_id_dst: TableName | None = None, + include_data: bool = True, + create_as_child: bool = False, + created_by: str | None = None, + ) -> Self: """ - Creates a vector IVF-PQ index for each vector column. Existing indexes will be replaced. - This is a no-op if number of rows is less than 1,000. + Duplicate an existing table including schema, data and metadata. Args: - session (Session): SQLAlchemy session. - table_id (TableName): Table ID. - force (bool, optional): If True, force reindex. Defaults to False. - metric (str, optional): The distance metric type. - "L2" (alias to "euclidean"), "cosine" or "dot" (dot product). Defaults to "dot". - num_partitions (int, optional): The number of IVF partitions to create. - By default the number of partitions is the square root of the number of rows. - for example the square root of the number of rows. Defaults to None. - num_sub_vectors (int, optional): Number of sub-vectors of PQ. - This value controls how much the vector is compressed during the quantization step. - The more sub vectors there are the less the vector is compressed. - The default is the dimension of the vector divided by 16. - If the dimension is not evenly divisible by 16 we use the dimension divided by 8. - The above two cases are highly preferred. - Having 8 or 16 values per subvector allows us to use efficient SIMD instructions. - If the dimension is not visible by 8 then we use 1 subvector. - This is not ideal and will likely result in poor performance. + project_id (str): Project ID. + table_id_src (str): Name of the table to be duplicated. + table_id_dst (str | None, optional): Name for the new table. + Defaults to None (automatically find the next available table name). + include_data (bool, optional): If True, include data. Defaults to True. + create_as_child (bool, optional): If True, create the new table as a child of the source table. + Defaults to False. + created_by (str | None, optional): User ID of the user who created the table. Defaults to None. - accelerator (str | None, optional): str or `torch.Device`, optional. - If set, use an accelerator to speed up the index training process. - Accepted accelerator: "cuda" (Nvidia GPU) and "mps" (Apple Silicon GPU). - If not set, use the CPU. Defaults to None. - index_cache_size (int | None, optional): The size of the index cache in number of entries. Defaults to None. - index_cache_size (int | None, optional): The size of the index cache in number of entries. Defaults to None. + + Raises: + BadInputError: If `table_id_dst` is not None or a non-empty string. + ResourceNotFoundError: If table or column metadata cannot be found. Returns: - reindexed (bool): Whether the reindex operation is performed. - """ - table, meta = self.open_table_meta(session, table_id) - # Maybe can skip reindexing - if ( - (not force) - and meta.indexed_at_vec is not None - and meta.indexed_at_vec > meta.updated_at - ): - return False - num_rows = table.count_rows() - if num_rows < 10_000: - return False - index_datetime = datetime_now_iso() - num_partitions = num_partitions or max(1, int(np.sqrt(num_rows))) - for c in self.embedding_cols(meta): - if num_sub_vectors is None: - if c.vlen % 16 == 0: - num_sub_vectors = c.vlen // 16 - elif c.vlen % 8 == 0: - num_sub_vectors = c.vlen // 8 - else: - num_sub_vectors = 1 - table.create_index( - vector_column_name=c.id, - replace=True, - metric=metric, - num_partitions=num_partitions, - num_sub_vectors=num_sub_vectors, - accelerator=accelerator, - index_cache_size=index_cache_size, - ) - # Update metadata - meta.indexed_at_vec = index_datetime - session.add(meta) - session.commit() - return True + self (GenerativeTableCore): The duplicated table instance. + """ + return await super().duplicate_table( + project_id=project_id, + table_type=cls.TABLE_TYPE, + table_id_src=table_id_src, + table_id_dst=table_id_dst, + include_data=include_data, + create_as_child=create_as_child, + created_by=created_by, + ) - def create_indexes( - self, - session: Session, - table_id: TableName, + # Read + @classmethod + async def open_table( + cls, *, - force: bool = False, - ) -> bool: - """Creates scalar, vector, FTS indexes. + project_id: str, + table_id: str, + created_by: str | None = None, + request_id: str = "", + ) -> Self: + """ + Open an existing table. Args: - session (Session): SQLAlchemy session. - table_id (TableName): Table ID. - force (bool, optional): If True, force reindex. Defaults to False. + project_id (str): Project ID. + table_id (str): Name of the table. + created_by (str | None, optional): User who created the table. + If provided, will check if the table was created by the user. Defaults to None (any user). + request_id (str, optional): Request ID for logging. Defaults to "". Returns: - index_ok (bool): Whether at least one reindexing operation is performed. + self (GenerativeTableCore): The table instance. """ - t0 = perf_counter() - sca_reindexed = self.create_scalar_index(session, table_id, force=force) - t1 = perf_counter() - fts_reindexed = self.create_fts_index(session, table_id, force=force) - t2 = perf_counter() - vec_reindexed = self.create_vector_index(session, table_id, force=force) - t3 = perf_counter() - timings = [] - if sca_reindexed: - timings.append(f"scalar={t1-t0:,.2f} s") - if fts_reindexed: - timings.append(f"FTS={t2-t1:,.2f} s") - if vec_reindexed: - timings.append(f"vector={t3-t2:,.2f} s") - if len(timings) > 0: - timings = ", ".join(timings) - num_rows = self.open_table(table_id).count_rows() - logger.info( - ( - f'Index creation for table "{table_id}" with {num_rows:,d} rows took {t3-t0:,.2f} s ' - f"({timings})." - ) - ) - return len(timings) > 0 - - def compact_files(self, table_id: TableName, *args, **kwargs) -> bool: - with self.lock(table_id): - table = self.open_table(table_id) - num_rows = table.count_rows() - if num_rows < 10: - return False - table.compact_files(*args, **kwargs) - return True - - def cleanup_old_versions( - self, - table_id: TableName, - older_than: timedelta | None = None, - delete_unverified: bool = False, - ) -> bool: - with self.lock(table_id): - table = self.open_table(table_id) - num_rows = table.count_rows() - if num_rows < 3: - return False - table.cleanup_old_versions(older_than=older_than, delete_unverified=delete_unverified) - return True + return await super().open_table( + project_id=project_id, + table_type=cls.TABLE_TYPE, + table_id=table_id, + created_by=created_by, + request_id=request_id, + ) + + @classmethod + async def list_tables( + cls, + *, + project_id: str, + limit: int | None = 100, + offset: int = 0, + order_by: Literal["id", "updated_at"] = "updated_at", + order_ascending: bool = True, + created_by: str | None = None, + parent_id: str | None = None, + search_query: str = "", + search_columns: list[str] = None, + count_rows: bool = False, + ) -> Page[TableMetaResponse]: + """ + List tables. + + Args: + project_id (str): Project ID. + limit (int | None, optional): Maximum number of tables to return. + Defaults to 100. Pass None to return all tables. + offset (int, optional): Offset for pagination. Defaults to 0. + order_by (Literal["id", "updated_at"], optional): Sort tables by this attribute. + Defaults to "updated_at". + order_ascending (bool, optional): Whether to sort by ascending order. + Defaults to True. + created_by (str | None, optional): Return tables created by this user. + Defaults to None (return all tables). + parent_id (str | None, optional): Parent ID of tables to return. + Defaults to None (no parent ID filtering). + Additionally for Chat Table, you can list: + (1) all chat agents by passing in "_agent_"; or + (2) all chats by passing in "_chat_". + search_query (str, optional): A string to search for within table names. + The string is interpreted as both POSIX regular expression and literal string. + Defaults to "". + search_columns (list[str], optional): List of columns to search within. + Defaults to None (search table ID). + count_rows (bool, optional): Whether to count the rows of the tables. + Defaults to False. + + Returns: + tables (Page[TableMetaResponse]): List of tables. + """ + return await super().list_tables( + project_id=project_id, + table_type=cls.TABLE_TYPE, + limit=limit, + offset=offset, + order_by=order_by, + order_ascending=order_ascending, + created_by=created_by, + parent_id=parent_id, + search_query=search_query, + search_columns=search_columns, + count_rows=count_rows, + ) + + @classmethod + async def import_table( + cls, + *, + project_id: str, + source: str | Path | BinaryIO, + table_id_dst: TableName | None, + reupload_files: bool = True, + progress_key: str = "", + verbose: bool = False, + ) -> Self: + """ + Recreate a table (data and metadata) from a Parquet file. - def update_title(self, session: Session, table_id: TableName, title: str): - meta = self.open_meta(session, table_id) - meta.title = title - session.add(meta) - session.commit() + Args: + project_id (str): Project ID. + input_path (str | Path): The path to the import file. + table_id_dst (TableName | None): Name or ID of the new table. + If None, the table ID in the Parquet metadata will be used. + reupload_files (bool, optional): If True, will reupload files to S3 with new URI. + Otherwise skip reupload and keep the original S3 paths for file columns. + Defaults to True. + progress_key (str, optional): Progress publish key. Defaults to "" (disabled). + verbose (bool, optional): If True, will produce verbose logging messages. + Defaults to False. + Raises: + ResourceExistsError: If the table already exists. -class ActionTable(GenerativeTable): - pass + Returns: + self (GenerativeTableCore): The table instance. + """ + return await super().import_table( + project_id=project_id, + table_type=cls.TABLE_TYPE, + source=source, + table_id_dst=table_id_dst, + reupload_files=reupload_files, + progress_key=progress_key, + verbose=verbose, + ) -class KnowledgeTable(GenerativeTable): - FIXED_COLUMN_IDS = ["Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"] +class KnowledgeTable(ActionTable): + TABLE_TYPE = TableType.KNOWLEDGE + FIXED_COLUMN_IDS = [ + "ID", + "Updated at", + "Title", + "Title Embed", + "Text", + "Text Embed", + "File ID", + "Page", + ] @override - def create_table( - self, - session: Session, - schema: KnowledgeTableSchemaCreate, - model_list: ModelListConfig, - remove_state_cols: bool = False, - add_info_state_cols: bool = True, - ) -> tuple[LanceTable, TableMeta]: - if not isinstance(schema, KnowledgeTableSchemaCreate): - raise TypeError("`schema` must be an instance of `KnowledgeTableSchemaCreate`.") - schema = TableSchema( - id=schema.id, - cols=[ - ColumnSchema(id="Title", dtype=ColumnDtype.STR), - ColumnSchema( - id="Title Embed", - # TODO: Benchmark this - # float32 index creation is 2x faster than float16 - # float32 vector search is 10% to 50% faster than float16 - # 2024-05-21, lance 0.6.13, pylance 0.10.12 - # https://github.com/lancedb/lancedb/issues/1312 - dtype=ColumnDtype.FLOAT32, - vlen=model_list.get_embed_model_info(schema.embedding_model).embedding_size, - gen_config=EmbedGenConfig( - embedding_model=schema.embedding_model, - source_column="Title", - ), + @classmethod + async def create_table( + cls, + *, + project_id: str, + table_metadata: TableMetadata, + column_metadata_list: list[ColumnMetadata], + embedding_model: str, + ) -> Self: + """ + Create a new Knowledge Table with default prompts (if prompts are not provided). + + Args: + project_id (str): Project ID. + table_type (str): Table type. + table_metadata (TableMetadata): Table metadata. + column_metadata_list (list[ColumnMetadata]): List of column metadata. + embedding_model (str): ID of the embedding model. + + Returns: + self (GenerativeTableCore): The table instance. + """ + table_id = table_metadata.table_id + # Fetch model config + project = await cls._fetch_project(project_id) + try: + # If model is empty string, select a model based on capabilities + if embedding_model.strip() == "": + model = await cls._fetch_model_with_capabilities( + capabilities=[str(ModelCapability.EMBED)], + organization_id=project.organization_id, + ) + else: + model = await cls._fetch_model(embedding_model, project.organization_id) + except ResourceNotFoundError as e: + raise BadInputError( + f'Table "{table_id}": Model "{embedding_model}" is not found.' + ) from e + # Use `dimensions` if specified; otherwise use `size` + embed_size = model.final_embedding_size + fixed_columns = [ + ColumnMetadata( + table_id=table_id, + column_id="Title", + dtype=ColumnDtype.STR, + ), + ColumnMetadata( + table_id=table_id, + column_id="Title Embed", + dtype=ColumnDtype.FLOAT, + vlen=embed_size, + gen_config=EmbedGenConfig( + embedding_model=model.id, + source_column="Title", ), - ColumnSchema(id="Text", dtype=ColumnDtype.STR), - ColumnSchema( - id="Text Embed", - dtype=ColumnDtype.FLOAT32, - vlen=model_list.get_embed_model_info(schema.embedding_model).embedding_size, - gen_config=EmbedGenConfig( - embedding_model=schema.embedding_model, - source_column="Text", - ), + ), + ColumnMetadata( + table_id=table_id, + column_id="Text", + dtype=ColumnDtype.STR, + ), + ColumnMetadata( + table_id=table_id, + column_id="Text Embed", + dtype=ColumnDtype.FLOAT, + vlen=embed_size, + gen_config=EmbedGenConfig( + embedding_model=model.id, + source_column="Text", ), - ColumnSchema(id="File ID", dtype=ColumnDtype.STR), - ColumnSchema(id="Page", dtype=ColumnDtype.INT), - ] - + schema.cols, + ), + ColumnMetadata( + table_id=table_id, + column_id="File ID", + dtype=ColumnDtype.STR, + ), + ColumnMetadata( + table_id=table_id, + column_id="Page", + dtype=ColumnDtype.INT, + ), + ] + return await cls._create_table( + project_id=project_id, + table_type=cls.TABLE_TYPE, + table_metadata=table_metadata, + column_metadata_list=fixed_columns + column_metadata_list, + set_default_prompts=True, ) - return super().create_table(session, schema, remove_state_cols, add_info_state_cols) - @override - def update_gen_config( + async def update_gen_config( self, - session: Session, - updates: GenConfigUpdateRequest, - ) -> TableMeta: - with self.create_session() as session: - table, meta = self.open_table_meta(session, updates.table_id) - num_rows = table.count_rows() - id2col = {c["id"]: c for c in meta.cols} - for col_id in updates.column_map: - if num_rows > 0 and id2col[col_id]["vlen"] > 0: - raise TableSchemaFixedError( - "Knowledge Table contains data, cannot update embedding config." - ) - return super().update_gen_config(session, updates) + update_mapping: dict[str, DiscriminatedGenConfig | None], + *, + allow_nonexistent_refs: bool = False, + ) -> Self: + """ + Update the generation configuration for a column. - @override - def add_columns( + Args: + update_mapping (dict[str, DiscriminatedGenConfig]): Mapping of column IDs to new generation configurations. + allow_nonexistent_refs (bool, optional): Ignore non-existent column and Knowledge Table references. + Otherwise will raise an error. Useful when importing old tables and performing maintenance. + Defaults to False. + + Raises: + ResourceNotFoundError: If the column is not found. + + Returns: + self (GenerativeTableCore): The table instance. + """ + # "Title Embed" and "Text Embed" columns must always have gen config + filtered = { + column_id: config + for column_id, config in update_mapping.items() + if not ( + column_id.lower() in {"title embed", "text embed"} + and not isinstance(config, EmbedGenConfig) + ) + } + + if not filtered: + return self + + return await super().update_gen_config( + update_mapping=filtered, allow_nonexistent_refs=allow_nonexistent_refs + ) + + async def update_rows( self, - session: Session, - schema: AddKnowledgeColumnSchema, - ) -> tuple[LanceTable, TableMeta]: + updates: dict[str, dict[str, Any]], + *, + ignore_state_columns: bool = True, + ) -> None: """ - Adds one or more input or output column. + Update multiple rows in the Generative Table. Args: - session (Session): SQLAlchemy session. - schema (AddKnowledgeColumnSchema): Schema of the columns to be added. + updates (dict[str, dict[str, Any]]): A dictionary mapping row ID to update data. + Each update data is a dictionary of column name to value. + ignore_state_columns (bool, optional): Whether to ignore state columns. Defaults to True. Raises: + TypeError: If the data is not a list of dictionaries. + BadInputError: If any row does not have an "ID" field. ResourceNotFoundError: If the table is not found. - ValueError: If any of the columns exists. Returns: - table (LanceTable): Lance table. - meta (TableMeta): Table metadata. + self (GenerativeTableCore): The table instance. """ - if not isinstance(schema, AddKnowledgeColumnSchema): - raise TypeError("`schema` must be an instance of `AddKnowledgeColumnSchema`.") - # if self.open_table(schema.id).count_rows() > 0: - # raise TableSchemaFixedError("Knowledge Table contains data, cannot add columns.") - return super().add_columns(session, schema) + return await super().update_rows( + updates=updates, ignore_state_columns=ignore_state_columns + ) -class ChatTable(GenerativeTable): - FIXED_COLUMN_IDS = ["User"] +class ChatTable(ActionTable): + TABLE_TYPE = TableType.CHAT + FIXED_COLUMN_IDS = [ + "ID", + "Updated at", + "User", + ] @override - def create_table( - self, - session: Session, - schema: ChatTableSchemaCreate, - remove_state_cols: bool = False, - add_info_state_cols: bool = True, - ) -> tuple[LanceTable, TableMeta]: - if not isinstance(schema, ChatTableSchemaCreate): - raise TypeError("`schema` must be an instance of `ChatTableSchemaCreate`.") - num_chat_cols = len([c for c in schema.cols if c.gen_config and c.gen_config.multi_turn]) + @classmethod + async def create_table( + cls, + *, + project_id: str, + table_metadata: TableMetadata, + column_metadata_list: list[ColumnMetadata], + ) -> Self: + """ + Create a new Chat Table with default prompts (if prompts are not provided). + + Args: + project_id (str): Project ID. + table_type (str): Table type. + table_metadata (TableMetadata): Table metadata. + column_metadata_list (list[ColumnMetadata]): List of column metadata. + + Returns: + self (GenerativeTableCore): The table instance. + """ + table_id = table_metadata.table_id + for col in column_metadata_list: + if col.column_id.lower() == "ai": + if isinstance(col.gen_config, LLMGenConfig): + col.gen_config.multi_turn = True + else: + col.gen_config = LLMGenConfig(multi_turn=True) + num_chat_cols = len([c for c in column_metadata_list if c.is_chat_column]) if num_chat_cols == 0: - raise BadInputError("The table must have at least one multi-turn column.") - return super().create_table(session, schema, remove_state_cols, add_info_state_cols) + raise BadInputError( + f'Chat Table "{table_id}" must have at least one multi-turn column.' + ) + return await cls._create_table( + project_id=project_id, + table_type=cls.TABLE_TYPE, + table_metadata=table_metadata, + column_metadata_list=column_metadata_list, + set_default_prompts=True, + ) - @override - def add_columns( + async def update_gen_config( self, - session: Session, - schema: AddChatColumnSchema, - ) -> tuple[LanceTable, TableMeta]: + update_mapping: dict[str, DiscriminatedGenConfig | None], + *, + allow_nonexistent_refs: bool = False, + ) -> Self: """ - Adds one or more input or output column. + Update the generation configuration for a column. Args: - session (Session): SQLAlchemy session. - schema (AddChatColumnSchema): Schema of the columns to be added. + update_mapping (dict[str, DiscriminatedGenConfig]): Mapping of column IDs to new generation configurations. + allow_nonexistent_refs (bool, optional): Ignore non-existent column and Knowledge Table references. + Otherwise will raise an error. Useful when importing old tables and performing maintenance. + Defaults to False. Raises: - ResourceNotFoundError: If the table is not found. - ValueError: If any of the columns exists. + ResourceNotFoundError: If the column is not found. Returns: - table (LanceTable): Lance table. - meta (TableMeta): Table metadata. + self (GenerativeTableCore): The table instance. """ - if not isinstance(schema, AddChatColumnSchema): - raise TypeError("`schema` must be an instance of `AddChatColumnSchema`.") - with self.create_session() as session: - meta = self.open_meta(session, schema.id) - if meta.parent_id is not None: - raise TableSchemaFixedError("Unable to add columns to a conversation table.") - return super().add_columns(session, schema) + for column_id, config in update_mapping.items(): + if column_id.lower() == "ai" and isinstance(config, LLMGenConfig): + config.multi_turn = True # in-place mutation is fine + filtered = { + column_id: config + for column_id, config in update_mapping.items() + if not (column_id.lower() == "ai" and not isinstance(config, LLMGenConfig)) + } + return await super().update_gen_config( + update_mapping=filtered, allow_nonexistent_refs=allow_nonexistent_refs + ) - @override - def drop_columns( + async def drop_columns( self, - session: Session, - table_id: TableName, - column_names: list[ColName], - ) -> tuple[LanceTable, TableMeta]: + column_ids: list[str], + ) -> Self: """ - Drops one or more input or output column. + Drop columns from the Chat Table. Args: - session (Session): SQLAlchemy session. - table_id (str): The ID of the table. - column_names (list[str]): List of column ID to drop. + column_ids (list[str]): List of column IDs to drop. Raises: - TypeError: If `column_names` is not a list. - ResourceNotFoundError: If the table is not found. ResourceNotFoundError: If any of the columns is not found. - - Returns: - table (LanceTable): Lance table. - meta (TableMeta): Table metadata. """ - with self.create_session() as session: - meta = self.open_meta(session, table_id) - if meta.parent_id is not None: - raise TableSchemaFixedError("Unable to drop columns from a conversation table.") num_chat_cols = len( - [ - c - for c in meta.cols_schema - if c.id not in column_names and c.gen_config and c.gen_config.multi_turn - ] + [c for c in self.column_metadata if c.column_id not in column_ids and c.is_chat_column] ) if num_chat_cols == 0: - raise BadInputError("The table must have at least one multi-turn column.") - return super().drop_columns(session, table_id, column_names) - - @override - def rename_columns( - self, - session: Session, - table_id: TableName, - column_map: dict[ColName, ColName], - ) -> TableMeta: - with self.create_session() as session: - meta = self.open_meta(session, table_id) - if meta.parent_id is not None: - raise TableSchemaFixedError("Unable to rename columns of a conversation table.") - return super().rename_columns(session, table_id, column_map) + raise BadInputError( + f'Chat Table "{self.table_id}" must have at least one multi-turn column after column drop.' + ) + return await super().drop_columns(column_ids) diff --git a/services/api/src/owl/db/gen_table_v2.py b/services/api/src/owl/db/gen_table_v2.py deleted file mode 100644 index 0bd2066..0000000 --- a/services/api/src/owl/db/gen_table_v2.py +++ /dev/null @@ -1,126 +0,0 @@ -from typing import Self - -import numpy as np - - -class GenerativeTableCore: - ### --- Table CRUD --- ### - - # Create - @classmethod - async def create_table(cls, table_id: str) -> Self: - pass - - @classmethod - async def duplicate_table(cls, table_id: str) -> Self: - pass - - # Read - @classmethod - async def list_tables(cls, table_id: str) -> list[Self]: - pass - - @classmethod - async def get_table(cls, table_id: str) -> Self: - pass - - async def count_rows(self): - pass - - # Update - async def rename_table(self): - pass - - async def recreate_fts_index(self): - # Optional - pass - - async def recreate_vector_index(self): - # Optional - pass - - async def drop_fts_index(self): - # Optional - pass - - async def drop_vector_index(self): - # Optional - pass - - # Delete - async def drop_table(self): - pass - - # Import Export - async def export_table(self): - pass - - async def import_table(self): - pass - - async def export_data(self): - pass - - async def import_data(self): - pass - - ### --- Column CRUD --- ### - - # Create - async def add_column(self): - pass - - # Read ops are implemented as table ops - # Update - async def update_gen_config(self): - pass - - async def rename_column(self): - pass - - async def reorder_columns(self): - # Need to ensure that length of new order list matches the number of columns - pass - - # Delete - async def drop_column(self): - pass - - ### --- Row CRUD --- ### - - # Create - async def add_row(self): - pass - - async def add_rows(self): - # Optional, if batch operation is supported - pass - - # Read - async def list_rows(self): - pass - - async def get_row(self): - pass - - async def fts_search(self, query: str): - pass - - async def vector_search(self, query: list[float] | np.ndarray): - pass - - # Update - async def update_row(self): - pass - - async def update_rows(self): - # Optional, if batch operation is supported - pass - - # Delete - async def delete_row(self): - pass - - async def delete_rows(self): - # Optional, if batch operation is supported - pass diff --git a/services/api/src/owl/db/models/__init__.py b/services/api/src/owl/db/models/__init__.py new file mode 100644 index 0000000..eb135ea --- /dev/null +++ b/services/api/src/owl/db/models/__init__.py @@ -0,0 +1,22 @@ +from owl.db.models.oss import ( # noqa: F401 + BASE_PLAN_ID, + TEMPLATE_ORG_ID, + Deployment, + JamaiSQLModel, + ModelConfig, + ModelInfo, + Organization, + OrgMember, + PricePlan, + Project, + ProjectMember, + User, +) + +""" +Cloud-only models + +VerificationCode, +ProjectKey, +StripeEvent, +""" diff --git a/services/api/src/owl/db/models/oss.py b/services/api/src/owl/db/models/oss.py new file mode 100644 index 0000000..5f97165 --- /dev/null +++ b/services/api/src/owl/db/models/oss.py @@ -0,0 +1,1473 @@ +from base64 import urlsafe_b64decode, urlsafe_b64encode +from datetime import datetime +from decimal import Decimal +from functools import lru_cache +from typing import Any, Self, Type, TypeVar + +from pydantic import BaseModel, computed_field +from pydantic_extra_types.currency_code import ISO4217 +from pydantic_extra_types.timezone_name import TimeZoneName +from sqlalchemy.orm import declared_attr, selectinload +from sqlalchemy.sql.base import ExecutableOption +from sqlmodel import ( + VARCHAR, + AutoString, + Boolean, + DateTime, + ForeignKey, + Integer, + MetaData, + Numeric, + Relationship, + SQLModel, + String, + Unicode, + and_, + asc, + desc, + exists, + func, + literal, + nulls_first, + nulls_last, + or_, + select, + text, + tuple_, +) +from sqlmodel import Field as SqlField +from sqlmodel.ext.asyncio.session import AsyncSession +from sqlmodel.sql._expression_select_cls import SelectBase +from sqlmodel.sql.expression import SelectOfScalar + +from owl.configs import ENV_CONFIG +from owl.types import ( + DEFAULT_MUL_LANGUAGES, + CloudProvider, + DatetimeUTC, + LanguageCodeList, + ModelCapability, + ModelType, + Page, + PaymentState, + PositiveNonZeroInt, + Role, + SanitisedNonEmptyStr, + SanitisedStr, +) +from owl.utils import uuid7_str +from owl.utils.crypt import generate_key +from owl.utils.dates import now +from owl.utils.exceptions import ( + BadInputError, + InsufficientCreditsError, + NoTierError, + ResourceNotFoundError, +) +from owl.utils.io import json_dumps, json_loads +from owl.utils.types import JSON + +TEMPLATE_ORG_ID = "template" +BASE_PLAN_ID = "base" + + +def _encode_cursor(values: dict[str, Any]) -> str: + return urlsafe_b64encode(json_dumps(values).encode()).decode() + + +def _decode_cursor(token: str) -> dict[str, Any]: + raw = json_loads(urlsafe_b64decode(token.encode()).decode()) + if "created_at" in raw and isinstance(raw["created_at"], str): + raw["created_at"] = datetime.fromisoformat(raw["created_at"]) + elif "updated_at" in raw and isinstance(raw["updated_at"], str): + raw["updated_at"] = datetime.fromisoformat(raw["updated_at"]) + return raw + + +def _relationship( + back_populates: str | None = None, + link_model: Any | None = None, + *, + selectin: bool = True, + cascade: str | None = "all, delete-orphan", + sa_kwargs: dict[str, Any] | None = None, +): + sa_relationship_kwargs = dict(viewonly=True) + if isinstance(sa_kwargs, dict): + sa_relationship_kwargs.update(sa_kwargs) + if selectin: + sa_relationship_kwargs["lazy"] = "selectin" + if cascade: + sa_relationship_kwargs["cascade"] = cascade + return Relationship( + back_populates=back_populates, + link_model=link_model, + sa_relationship_kwargs=sa_relationship_kwargs, + ) + + +class JamaiSQLModel(SQLModel): + metadata = MetaData(schema="jamai") + + @declared_attr + def __tablename__(cls) -> str: + return cls.__name__ + + +ItemType = TypeVar("ItemType", bound=BaseModel) + + +class _TableBase(JamaiSQLModel, str_strip_whitespace=True): + meta: dict[str, Any] = SqlField( + {}, + sa_type=JSON, + description="Metadata.", + ) + created_at: DatetimeUTC = SqlField( + default_factory=now, + sa_type=DateTime(timezone=True), + description="Creation datetime (UTC).", + ) + updated_at: DatetimeUTC = SqlField( + default_factory=now, + sa_type=DateTime(timezone=True), + description="Update datetime (UTC).", + ) + + @classmethod + @lru_cache(maxsize=1) + def pk(cls) -> list[str]: + """Return every column name that is a primary key.""" + return [c.name for c in cls.__table__.primary_key] + + @classmethod + @lru_cache(maxsize=1) + def str_cols(cls) -> list[str]: + """Return every column name that is a string.""" + return [ + c.name + for c in cls.__table__.columns + if isinstance(c.type, (AutoString, VARCHAR, Unicode, String)) + ] + + @classmethod + @lru_cache(maxsize=1) + def nullable_cols(cls) -> list[str]: + """Return every column name that is nullable.""" + return [c.name for c in cls.__table__.columns if c.nullable] + + @classmethod + @lru_cache(maxsize=1) + def indexed_cols(cls) -> list[str]: + """ + Return every column name that participates in any declared index. + + Even though for Postgres, unique constraint creates an index automatically, + we still only list columns that explicitly declare an index. + https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#postgresql-index-reflection + """ + tbl = cls.__table__ + # flagged = {c.name for c in tbl.columns if c.index or c.unique or c.primary_key} + cols = {c.name for idx in tbl.indexes for c in idx.columns} + return [c.name for c in tbl.columns if c.name in cols] + + @classmethod + def _where_filter( + cls, + selection: SelectBase, + filters: dict[str, Any | list[Any]] | None, + ) -> SelectBase: + if filters: + selection = selection.where( + and_( + *[ + or_(*[getattr(cls, k) == vv for vv in v]) + if isinstance(v, list) + else getattr(cls, k) == v + for k, v in filters.items() + ] + ) + ) + return selection + + @classmethod + def _search_query_filter( + cls, + selection: SelectBase, + *, + search_query: str | None, + search_columns: list[str] | None, + ) -> SelectBase: + # Apply search filters + if search_query and search_columns: + search_conditions = [] + for column_name in search_columns: + if (column := getattr(cls, column_name, None)) is not None: + # Using case-insensitive regex match (~*) + search_conditions.append(column.op("~*")(search_query)) + if search_conditions: + selection = selection.where(or_(*search_conditions)) + return selection + + @classmethod + def _allow_block_list_filter( + cls, + selection: SelectBase, + filter_id: str, + *, + allow_list_attr: str = "allowed_orgs", + block_list_attr: str = "blocked_orgs", + ) -> SelectBase: + allow_list = getattr(cls, allow_list_attr) + block_list = getattr(cls, block_list_attr, None) + # Allow list + allow = or_(allow_list == [], allow_list.contains([filter_id])) + if block_list is None: + # No block list, just allow list + selection = selection.where(allow) + else: + # Block list + selection = selection.where(and_(allow, ~block_list.contains([filter_id]))) + return selection + + @classmethod + def _pagination( + cls, + selection: SelectBase, + *, + offset: int, + limit: int | None, + order_by: str, + order_ascending: bool, + after: str | None = None, + ) -> SelectBase: + # Apply ordering + order_col = getattr(cls, order_by, None) + if order_col is None: + raise BadInputError(f'Unable to order by column "{order_by}" as it does not exist.') + is_nullable = order_col.nullable + # Postgres index sorts nulls last (nulls are larger than non-null) + # But it is hard to get a string null coalesce value, so we sort null first + null_order_func = nulls_first if order_ascending else nulls_last + # Keyset pagination + # cursor = before or after + cursor = after + if cursor: + # if before: + # op = "__lt__" if order_ascending else "__gt__" + # else: + # op = "__gt__" if order_ascending else "__lt__" + op = "__gt__" if order_ascending else "__lt__" + try: + vals = _decode_cursor(cursor) + except Exception as e: + raise BadInputError(f'Pagination failed due to invalid cursor: "{cursor}"') from e + try: + pk_cols = tuple(getattr(cls, pk) for pk in cls.pk()) + pk_vals = tuple(vals[pk] for pk in cls.pk()) + cmp_val = vals[order_by] + except KeyError as e: + raise BadInputError( + f'Unable to order by column "{order_by}" as it is not found in the cursor.' + ) from e + if is_nullable: + # This is mainly for JamaiBase rather than TokenVisor + if isinstance(order_col.type, Integer): + coalesce_val = literal(-(2**31 - 1)) # Standard 32-bit signed integer + elif isinstance(order_col.type, Numeric): + coalesce_val = literal(float("-inf")) + elif isinstance(order_col.type, Boolean): + coalesce_val = False + else: + coalesce_val = "" + # else: + # raise BadInputError( + # f'Unable to order by nullable column "{order_by}" of type {order_col.type}.' + # ) + if cmp_val is None: + cmp_val = coalesce_val + order_by_expr = func.coalesce(order_col, coalesce_val) + else: + order_by_expr = order_col + filter_cond = or_( + getattr(order_by_expr, op)(cmp_val), + and_(order_by_expr == cmp_val, getattr(tuple_(*pk_cols), op)(pk_vals)), + ) + selection = selection.where(filter_cond) + else: + selection = selection.offset(offset) + # Postgres ordering on Linux seems to be case-insensitive by default + # https://dba.stackexchange.com/a/131471 + # Apply LOWER() on text columns + if order_by in cls.str_cols(): + order_col = func.lower(order_col) + # Determine order function based on sort direction + order_func = asc if order_ascending else desc + # Pagination + if is_nullable: + order_by_expr = null_order_func(order_func(order_col)) + else: + order_by_expr = order_func(order_col) + selection = selection.order_by( + order_by_expr, *(order_func(getattr(cls, pk)) for pk in cls.pk()) + ) + if limit is not None: + selection = selection.limit(limit) + return selection + + def _generate_cursor(self, order_by: str) -> str: + cursor_keys = [order_by, *self.pk()] + cursor_values = {k: getattr(self, k) for k in cursor_keys} + return _encode_cursor(cursor_values) + + @classmethod + def _list( + cls, + *, + offset: int, + limit: int | None, + order_by: str, + order_ascending: bool, + search_query: str | None, + search_columns: list[str] | None, + filters: dict[str, Any | list[Any]] | None = None, + options: list[ExecutableOption] | None = None, + after: str | None = None, + ) -> tuple[SelectOfScalar[Self], SelectOfScalar[int]]: + ### --- Main query --- ### + items = cls._search_query_filter( + cls._where_filter(select(cls), filters), + search_query=search_query, + search_columns=search_columns, + ) + if options: + items = items.options(*options) + items = cls._pagination( + items, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + after=after, + ) + ### --- Count --- ### + # Same filters but without pagination + total = cls._search_query_filter( + cls._where_filter(select(func.count(getattr(cls, cls.pk()[0]))), filters), + search_query=search_query, + search_columns=search_columns, + ) + return items, total + + @classmethod + async def _fetch_list_and_cursor( + cls, + session: AsyncSession, + items: SelectOfScalar[Self], + total: SelectOfScalar[int], + order_by: str, + ) -> tuple[list[Self], int, str | None]: + items: list[Self] = (await session.exec(items)).all() + total: int = (await session.exec(total)).one() + if items: + end_cursor = items[-1]._generate_cursor(order_by) + else: + end_cursor = None + return items, total, end_cursor + + @classmethod + async def create( + cls, + session: AsyncSession, + body: dict[str, Any] | BaseModel, + ) -> Self: + item = cls.model_validate(body) + session.add(item) + await session.commit() + await session.refresh(item) + return item + + @classmethod + async def list_( + cls, + session: AsyncSession, + return_type: Type[ItemType], + *, + offset: int = 0, + limit: int | None = None, + order_by: str | None = None, + order_ascending: bool = True, + search_query: str | None = None, + search_columns: list[str] | None = None, + filters: dict[str, Any | list[Any]] | None = None, + options: list[ExecutableOption] | None = None, + after: str | None = None, + ) -> Page[ItemType]: + if order_by is None: + order_by = cls.pk()[0] + items, total = cls._list( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + search_query=search_query, + search_columns=search_columns, + filters=filters, + options=options, + after=after, + ) + items, total, end_cursor = await cls._fetch_list_and_cursor( + session=session, + items=items, + total=total, + order_by=order_by, + ) + return Page[return_type]( + items=items, + offset=offset, + limit=total if limit is None else limit, + total=total, + end_cursor=end_cursor, + ) + + @classmethod + async def get( + cls, + session: AsyncSession, + item_id: str, + *, + name: str = "", + **kwargs, + ) -> Self: + item = await session.get(cls, item_id, **kwargs) + if item is None: + raise ResourceNotFoundError( + f'{name if name else cls.__name__} "{item_id}" is not found.' + ) + return item + + @classmethod + async def _update( + cls, + session: AsyncSession, + item_id: str, + updates: dict[str, Any], + *, + name: str = "", + ) -> Self: + item = await cls.get(session, item_id, name=name) + for key, value in updates.items(): + setattr(item, key, value) + item.updated_at = now() + session.add(item) + return item + + @classmethod + async def update( + cls, + session: AsyncSession, + item_id: str, + body: BaseModel, + *, + name: str = "", + ) -> tuple[Self, dict[str, Any]]: + updates = body.model_dump(exclude_unset=True) + item = await cls._update(session, item_id, updates, name=name) + await session.commit() + await session.refresh(item) + return item, updates + + @classmethod + async def delete( + cls, + session: AsyncSession, + item_id: str, + *, + name: str = "", + ) -> None: + item = await cls.get(session, item_id, name=name) + await session.delete(item) + await session.commit() + + +class PricePlan(_TableBase, table=True): + id: SanitisedNonEmptyStr = SqlField( + default_factory=lambda: generate_key(8, "plan_"), + primary_key=True, + description="Price plan ID.", + ) + name: str = SqlField( + unique=True, + description="Price plan name. Must be unique.", + ) + stripe_price_id_live: str = SqlField( + index=True, + unique=True, + description="Stripe price ID (live mode). Must be unique.", + ) + stripe_price_id_test: str = SqlField( + index=True, + unique=True, + description="Stripe price ID (test mode). Must be unique.", + ) + flat_cost: float = SqlField( + description="Base price for the entire tier (in USD decimal terms).", + ) + credit_grant: float = SqlField( + description="Credit amount included (in USD decimal terms).", + ) + max_users: int | None = SqlField( + description="Maximum number of users per organization. `None` means no limit.", + ) + products: dict[str, Any] = SqlField( + sa_type=JSON, + description="Mapping of product ID to product.", + ) + allowed_orgs: list[str] = SqlField( + [], + index=True, + sa_type=JSON, + description=( + "List of IDs of organizations allowed to use this price plan. " + "If empty, all orgs are allowed." + ), + ) + organizations: "Organization" = _relationship("price_plan", selectin=False) + + @computed_field(description="Stripe Price ID.") + @property + def stripe_price_id(self) -> str: + return ( + self.stripe_price_id_live + if ENV_CONFIG.stripe_api_key_plain.startswith("sk_live") + else self.stripe_price_id_test + ) + + @computed_field( + description="Whether this is a private price plan visible only to select organizations." + ) + @property + def is_private(self) -> bool: + return len(self.allowed_orgs) > 0 + + @classmethod + async def list_public( + cls, + session: AsyncSession, + return_type: Type[ItemType], + *, + offset: int, + limit: int | None, + order_by: str, + order_ascending: bool, + search_query: str | None, + search_columns: list[str] | None, + filters: dict[str, Any | list[Any]] | None = None, + after: str | None = None, + ) -> Page[ItemType]: + # List + items, total = cls._list( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + search_query=search_query, + search_columns=search_columns, + filters=filters, + after=after, + ) + # Filter + items = items.where(cls.allowed_orgs == []) + total = total.where(cls.allowed_orgs == []) + items, total, end_cursor = await cls._fetch_list_and_cursor( + session=session, + items=items, + total=total, + order_by=order_by, + ) + return Page[return_type]( + items=items, + offset=offset, + limit=total if limit is None else limit, + total=total, + end_cursor=end_cursor, + ) + + @classmethod + async def list_( + cls, + session: AsyncSession, + return_type: Type[ItemType], + *, + offset: int = 0, + limit: int | None = None, + order_by: str | None = None, + order_ascending: bool = True, + search_query: str | None = None, + search_columns: list[str] | None = None, + filters: dict[str, Any | list[Any]] | None = None, + after: str | None = None, + ) -> Page[ItemType]: + if order_by is None: + order_by = cls.pk()[0] + items, total = cls._list( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + search_query=search_query, + search_columns=search_columns, + filters=filters, + after=after, + ) + items, total, end_cursor = await cls._fetch_list_and_cursor( + session=session, + items=items, + total=total, + order_by=order_by, + ) + return Page[return_type]( + items=items, + offset=offset, + limit=total if limit is None else limit, + total=total, + end_cursor=end_cursor, + ) + + +class Deployment(_TableBase, table=True): + id: SanitisedNonEmptyStr = SqlField( + default_factory=uuid7_str, + primary_key=True, + description="Deployment ID.", + ) + model_id: str = SqlField( + sa_column_args=[ForeignKey("ModelConfig.id", ondelete="CASCADE", onupdate="CASCADE")], + index=True, + description="Model ID.", + ) + name: str = SqlField( + description="Name for the deployment.", + ) + routing_id: str = SqlField( + "", + description=( + "Model ID that the inference provider expects (whereas `model_id` is what the users will see). " + "OpenAI example: `model_id` CAN be `openai/gpt-5` but `routing_id` SHOULD be `gpt-5`." + ), + ) + api_base: str = SqlField( + "", + description=( + "(Optional) Hosting url. " + "Required for creating external cloud deployment using custom providers. " + "Example: `http://vllm-endpoint.xyz/v1`." + ), + ) + provider: str = SqlField( + "", + description=( + f"Inference provider of the model. " + f"Standard cloud providers are {CloudProvider.list_()}." + ), + ) + weight: int = SqlField( + 1, + ge=0, + description="Routing weight. Must be >= 0. A deployment is selected according to its relative weight.", + ) + cooldown_until: DatetimeUTC = SqlField( + default_factory=now, + sa_type=DateTime(timezone=True), + description="Cooldown until datetime (UTC).", + ) + model: "ModelConfig" = _relationship("deployments") + + +class ModelInfo(_TableBase): + id: str = SqlField( + primary_key=True, + description=( + "Unique identifier. " + "Users will specify this to select a model. " + "Must follow the following format: `{provider}/{model_id}`. " + "Examples=['openai/gpt-4o-mini', 'Qwen/Qwen2.5-0.5B']" + ), + ) + type: ModelType = SqlField( + ModelType.LLM, + description="Model type. Can be completion, llm, embed, or rerank.", + ) + name: str = SqlField( + "", + description="Model name that is more user friendly.", + ) + owned_by: str = SqlField( + "", + description="Model provider (usually organization that trained the model).", + ) + capabilities: list[ModelCapability] = SqlField( + [ModelCapability.CHAT], + sa_type=JSON, + description="List of capabilities of model.", + ) + context_length: int = SqlField( + 4096, + description="Context length of model.", + ) + languages: LanguageCodeList = SqlField( + ["en"], + sa_type=JSON, + description=f'List of languages which the model is well-versed in. "*" and "mul" resolves to {DEFAULT_MUL_LANGUAGES}.', + ) + max_output_tokens: int | None = SqlField( + None, + description="Maximum number of output tokens, if not specified, will be based on context length.", + # examples=[8192], + ) + + +class ModelConfig(ModelInfo, table=True): + # --- All models --- # + type: ModelType = SqlField( + description="Model type. Can be completion, chat, embed, or rerank.", + ) + name: str = SqlField( + description="Model name that is more user friendly.", + ) + context_length: int = SqlField( + description="Context length of model. Examples=[4096]", + ) + capabilities: list[ModelCapability] = SqlField( + sa_type=JSON, + description="List of capabilities of model.", + ) + owned_by: str = SqlField( + "", + description="Model provider (usually organization that trained the model).", + ) + timeout: float = SqlField( + 15 * 60, + gt=0, + nullable=False, + description="Timeout in seconds. Must be greater than 0. Defaults to 15 minutes.", + ) + priority: int = SqlField( + 0, + description="Priority for fallback model selection. The larger the number, the higher the priority.", + ) + allowed_orgs: list[str] = SqlField( + [], + index=True, + sa_type=JSON, + description=( + "List of IDs of organizations allowed to use this model. " + "If empty, all orgs are allowed. Allow list is applied first, followed by block list." + ), + ) + blocked_orgs: list[str] = SqlField( + [], + index=True, + sa_type=JSON, + description=( + "List of IDs of organizations NOT allowed to use this model. " + "If empty, no org is blocked. Allow list is applied first, followed by block list." + ), + ) + # --- Chat models --- # + llm_input_cost_per_mtoken: float = SqlField( + -1.0, + description=( + "Cost in USD per million (mega) input / prompt token. " + "Can be zero. Negative values will be overridden with a default value." + ), + ) + llm_output_cost_per_mtoken: float = SqlField( + -1.0, + description=( + "Cost in USD per million (mega) output / completion token. " + "Can be zero. Negative values will be overridden with a default value." + ), + ) + # --- Embedding models --- # + embedding_size: PositiveNonZeroInt | None = SqlField( + None, + description=( + "The default embedding size of the model. " + "For example: `openai/text-embedding-3-large` has `embedding_size` of 3072 " + "but can be shortened to `embedding_dimensions` of 256; " + "`cohere/embed-v4.0` has `embedding_size` of 1536 " + "but can be shortened to `embedding_dimensions` of 256." + ), + ) + # Matryoshka embedding dimension + embedding_dimensions: PositiveNonZeroInt | None = SqlField( + None, + description=( + "The number of dimensions the resulting output embeddings should have. " + "Can be overridden by `dimensions` for each request. " + "Defaults to None (no reduction). " + "Note that this parameter will only be used when using models that support Matryoshka embeddings. " + "For example: `openai/text-embedding-3-large` has `embedding_size` of 3072 " + "but can be shortened to `embedding_dimensions` of 256; " + "`cohere/embed-v4.0` has `embedding_size` of 1536 " + "but can be shortened to `embedding_dimensions` of 256." + ), + ) + # Most likely only useful for HuggingFace models + embedding_transform_query: str | None = SqlField( + None, + description="Transform query that might be needed, esp. for hf models", + ) + embedding_cost_per_mtoken: float = SqlField( + -1.0, + description=( + "Cost in USD per million (mega) embedding tokens. " + "Can be zero. Negative values will be overridden with a default value." + ), + ) + # --- Reranking models --- # + reranking_cost_per_ksearch: float = SqlField( + -1.0, + description=( + "Cost in USD per thousand (kilo) searches. " + "Can be zero. Negative values will be overridden with a default value." + ), + ) + deployments: list[Deployment] = _relationship("model") + + @computed_field( + description="Whether this is a private model visible only to select organizations." + ) + @property + def is_private(self) -> bool: + return len(self.allowed_orgs) > 0 or len(self.blocked_orgs) > 0 + + @computed_field(description="Whether this model is active and ready for inference.") + @property + def is_active(self) -> bool: + return len(self.deployments) > 0 + + @classmethod + async def list_( + cls, + session: AsyncSession, + return_type: Type[ItemType], + *, + organization_id: str | None, + offset: int = 0, + limit: int | None = None, + order_by: str | None = None, + order_ascending: bool = True, + search_query: str | None = None, + search_columns: list[str] | None = None, + filters: dict[str, Any | list[Any]] | None = None, + after: str | None = None, + capabilities: list[ModelCapability] | None = None, + exclude_inactive: bool = False, + ) -> Page[ItemType]: + if order_by is None: + order_by = cls.pk()[0] + items, total = cls._list( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + search_query=search_query, + search_columns=search_columns, + filters=filters, + after=after, + ) + # Filter + if organization_id: + items = cls._allow_block_list_filter(items, organization_id) + total = cls._allow_block_list_filter(total, organization_id) + # Filter by capability + if capabilities is not None: + items = items.where(cls.capabilities.contains(capabilities)) + total = total.where(cls.capabilities.contains(capabilities)) + if exclude_inactive: + subquery = select(Deployment).where(Deployment.model_id == cls.id) + items = items.where(exists(subquery)) + total = total.where(exists(subquery)) + items, total, end_cursor = await cls._fetch_list_and_cursor( + session=session, + items=items, + total=total, + order_by=order_by, + ) + return Page[return_type]( + items=items, + offset=offset, + limit=total if limit is None else limit, + total=total, + end_cursor=end_cursor, + ) + + +class OrgMember(_TableBase, table=True): + user_id: str = SqlField( + foreign_key="User.id", + primary_key=True, + ondelete="CASCADE", + description="User ID.", + ) + organization_id: str = SqlField( + foreign_key="Organization.id", + primary_key=True, + ondelete="CASCADE", + description="Organization ID.", + ) + role: Role = SqlField( + Role.GUEST, + description="Organization role.", + ) + user: "User" = _relationship("org_memberships") + organization: "Organization" = _relationship("members") + + +class ProjectMember(_TableBase, table=True): + user_id: str = SqlField( + foreign_key="User.id", + primary_key=True, + ondelete="CASCADE", + description="User ID.", + ) + project_id: str = SqlField( + foreign_key="Project.id", + primary_key=True, + ondelete="CASCADE", + description="Project ID.", + ) + role: Role = SqlField( + Role.GUEST, + description="Project role.", + ) + user: "User" = _relationship("proj_memberships") + project: "Project" = _relationship("members") + + +class User(_TableBase, table=True): + id: str = SqlField( + default_factory=uuid7_str, + primary_key=True, + description="User ID.", + ) + name: str = SqlField( + index=True, + description="User's preferred name.", + ) + email: str = SqlField( + unique=True, + index=True, + description="User's email.", + ) + email_verified: bool = SqlField( + False, + description="Whether the email address is verified.", + ) + password_hash: str | None = SqlField( + None, + index=True, + description="Password hash.", + ) + picture_url: str | None = SqlField( + None, + description="User picture URL.", + ) + refresh_counter: int = SqlField( + 0, + description="Counter used as refresh token version for invalidation.", + ) + google_id: str | None = SqlField( + None, + index=True, + description="Google user ID.", + ) + google_name: str | None = SqlField( + None, + description="Google user's preferred name.", + ) + google_username: str | None = SqlField( + None, + description="Google username.", + ) + google_email: str | None = SqlField( + None, + description="Google email.", + ) + google_picture_url: str | None = SqlField( + None, + description="Google user picture URL.", + ) + google_updated_at: DatetimeUTC | None = SqlField( + None, + sa_type=DateTime(timezone=True), + description="Google user info update datetime (UTC).", + ) + github_id: str | None = SqlField( + None, + index=True, + description="GitHub user ID.", + ) + github_name: str | None = SqlField( + None, + description="GitHub user's preferred name.", + ) + github_username: str | None = SqlField( + None, + description="GitHub username.", + ) + github_email: str | None = SqlField( + None, + description="GitHub email.", + ) + github_picture_url: str | None = SqlField( + None, + description="GitHub user picture URL.", + ) + github_updated_at: DatetimeUTC | None = SqlField( + None, + sa_type=DateTime(timezone=True), + description="GitHub user info update datetime (UTC).", + ) + org_memberships: list[OrgMember] = _relationship("user") + proj_memberships: list[ProjectMember] = _relationship("user") + organizations: list["Organization"] = _relationship(None, link_model=OrgMember, selectin=False) + projects: list["Project"] = _relationship(None, link_model=ProjectMember, selectin=False) + # keys: list["ProjectKey"] = _relationship("user") + + @computed_field(description="Name for display.") + @property + def preferred_name(self) -> str: + return self.name or self.google_name or self.github_name + + @computed_field(description="Email for display.") + @property + def preferred_email(self) -> str: + return self.email or self.google_email or self.github_email + + @computed_field(description="Picture URL for display.") + @property + def preferred_picture_url(self) -> str | None: + return self.picture_url or self.google_picture_url or self.github_picture_url + + @computed_field(description="Username for display.") + @property + def preferred_username(self) -> str | None: + return self.google_username or self.github_username + + @classmethod + async def list_( + cls, + session: AsyncSession, + return_type: Type[ItemType], + *, + offset: int = 0, + limit: int | None = None, + order_by: str | None = None, + order_ascending: bool = True, + search_query: str | None = None, + search_columns: list[str] | None = None, + filters: dict[str, Any | list[Any]] | None = None, + after: str | None = None, + ) -> Page[ItemType]: + return await super().list_( + session=session, + return_type=return_type, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + search_query=search_query, + search_columns=search_columns, + filters=filters, + options=[selectinload(cls.organizations), selectinload(cls.projects)], + after=after, + ) + + @classmethod + async def get( + cls, + session: AsyncSession, + item_id: str, + *, + name: str = "", + **kwargs, + ) -> Self: + where_expr = cls.id == item_id + if item_id.startswith("google-oauth2|"): + where_expr = or_(where_expr, cls.google_id == item_id.split("|")[1]) + elif item_id.startswith("github|"): + where_expr = or_(where_expr, cls.github_id == item_id.split("|")[1]) + item = ( + await session.exec( + select(User) + .where(where_expr) + .options(selectinload(cls.organizations), selectinload(cls.projects)), + execution_options=kwargs, + ) + ).one_or_none() + if item is None: + raise ResourceNotFoundError( + f'{name if name else cls.__name__} "{item_id}" is not found.' + ) + return item + + +class Organization(_TableBase, table=True): + id: SanitisedNonEmptyStr = SqlField( + default_factory=lambda: generate_key(24, "org_"), + primary_key=True, + description="Organization ID.", + ) + name: SanitisedStr = SqlField( + description="Organization name.", + ) + currency: ISO4217 = SqlField( + "USD", + description="Currency of the organization.", + ) + timezone: TimeZoneName | None = SqlField( + None, + description="Timezone specifier.", + ) + external_keys: dict[str, str] = SqlField( + {}, + sa_type=JSON, + description="Mapping of external service provider to its API key.", + ) + stripe_id: str | None = SqlField( + None, + index=True, + description="Stripe Customer ID.", + ) + # stripe_subscription_id: SanitisedIdStr | None = SqlField( + # None, + # description="Stripe Subscription ID.", + # ) + price_plan_id: str | None = SqlField( + None, + foreign_key="PricePlan.id", + index=True, + nullable=True, + description="Subscribed plan ID.", + ) + payment_state: PaymentState = SqlField( + PaymentState.NONE, + description=f"Payment state of the organization, one of {list(map(str, PaymentState))}.", + ) + last_subscription_payment_at: DatetimeUTC | None = SqlField( + None, + sa_type=DateTime(timezone=True), + description="Datetime of the last successful subscription payment (UTC).", + ) + quota_reset_at: DatetimeUTC = SqlField( + default_factory=now, + sa_type=DateTime(timezone=True), + description="Quota reset datetime (UTC).", + ) + credit: float = SqlField( + 0.0, + sa_type=Numeric(21, 12), + description=( + "Credit paid by the customer. " + "Unused credit will be carried forward to the next billing cycle. " + "Must be in the range [-999_999_999.0, 999_999_999.0] with up to 12 decimal places." + ), + ) + credit_grant: float = SqlField( + 0.0, + sa_type=Numeric(21, 12), + description=( + "Credit granted to the customer. " + "Unused credit will NOT be carried forward. " + "Must be in the range [-999_999_999.0, 999_999_999.0] with up to 12 decimal places." + ), + ) + llm_tokens_quota_mtok: float | None = SqlField( + 0.0, + description="LLM token quota in millions of tokens.", + ) + llm_tokens_usage_mtok: float = SqlField( + 0.0, + description="LLM token usage in millions of tokens.", + ) + llm_tokens_usage_updated_at: DatetimeUTC = SqlField( + default_factory=now, + sa_type=DateTime(timezone=True), + description="Datetime of the last successful LLM token usage update (UTC).", + ) + embedding_tokens_quota_mtok: float | None = SqlField( + 0.0, + description="Embedding token quota in millions of tokens.", + ) + embedding_tokens_usage_mtok: float = SqlField( + 0.0, + description="Embedding token quota in millions of tokens.", + ) + embedding_tokens_usage_updated_at: DatetimeUTC = SqlField( + default_factory=now, + sa_type=DateTime(timezone=True), + description="Datetime of the last successful Embedding token usage update (UTC).", + ) + reranker_quota_ksearch: float | None = SqlField( + 0.0, + description="Reranker quota for every thousand searches.", + ) + reranker_usage_ksearch: float = SqlField( + 0.0, + description="Reranker usage for every thousand searches.", + ) + reranker_usage_updated_at: DatetimeUTC = SqlField( + default_factory=now, + sa_type=DateTime(timezone=True), + description="Datetime of the last successful Reranker usage update (UTC).", + ) + db_quota_gib: float | None = SqlField( + 0.0, + description="DB storage quota in GiB.", + ) + db_usage_gib: float = SqlField( + 0.0, + description="DB storage usage in GiB.", + ) + db_usage_updated_at: DatetimeUTC = SqlField( + default_factory=now, + sa_type=DateTime(timezone=True), + description="Datetime of the last successful DB usage update (UTC).", + ) + file_quota_gib: float | None = SqlField( + 0.0, + description="File storage quota in GiB.", + ) + file_usage_gib: float = SqlField( + 0.0, + description="File storage usage in GiB.", + ) + file_usage_updated_at: DatetimeUTC = SqlField( + default_factory=now, + sa_type=DateTime(timezone=True), + description="Datetime of the last successful File usage update (UTC).", + ) + egress_quota_gib: float | None = SqlField( + 0.0, + description="Egress quota in GiB.", + ) + egress_usage_gib: float = SqlField( + 0.0, + description="Egress usage in GiB.", + ) + egress_usage_updated_at: DatetimeUTC = SqlField( + default_factory=now, + sa_type=DateTime(timezone=True), + description="Datetime of the last successful egress usage update (UTC).", + ) + created_by: str = SqlField( + description="ID of the user that created this organization.", + ) + owner: str = SqlField( + description="ID of the user that owns this organization.", + ) + users: list[User] = _relationship("organizations", link_model=OrgMember, selectin=False) + members: list[OrgMember] = _relationship("organization", selectin=False) + projects: list["Project"] = _relationship("organization", selectin=False) + price_plan: PricePlan | None = _relationship("organizations") + + @staticmethod + def status_check(org: "Organization", *, raise_error: bool = False) -> bool: + """Whether the organization's quota is active (paid).""" + if ENV_CONFIG.is_oss or ENV_CONFIG.disable_billing: + return True + if org.id in ("0", TEMPLATE_ORG_ID): + return True + if org.price_plan_id is None: + if raise_error: + raise NoTierError + else: + return False + if org.last_subscription_payment_at is None: + payment_on_time = False + else: + payment_on_time = ( + now() - org.last_subscription_payment_at + ).days <= ENV_CONFIG.payment_lapse_max_days + payment_ok = ( + org.payment_state in [PaymentState.SUCCESS, PaymentState.PROCESSING] or payment_on_time + ) + if payment_ok or (float(org.credit) + float(org.credit_grant)) > 0: + return True + elif raise_error: + raise InsufficientCreditsError + else: + return False + + @computed_field(description="Whether the organization's quota is active (paid).") + @property + def active(self) -> bool: + return self.status_check(self, raise_error=False) + + @computed_field(description="Quota snapshot.") + @property + def quotas(self) -> dict[str, dict[str, float | None]]: + return { + "llm_tokens": { + "quota": self.llm_tokens_quota_mtok, + "usage": self.llm_tokens_usage_mtok, + }, + "embedding_tokens": { + "quota": self.embedding_tokens_quota_mtok, + "usage": self.embedding_tokens_usage_mtok, + }, + "reranker_searches": { + "quota": self.reranker_quota_ksearch, + "usage": self.reranker_usage_ksearch, + }, + "db_storage": { + "quota": self.db_quota_gib, + "usage": self.db_usage_gib, + }, + "file_storage": { + "quota": self.file_quota_gib, + "usage": self.file_usage_gib, + }, + "egress": { + "quota": self.egress_quota_gib, + "usage": self.egress_usage_gib, + }, + } + + @classmethod + async def list_base_tier_orgs( + cls, + session: AsyncSession, + user_id: str, + ) -> list[Self]: + return ( + await session.exec( + select(cls).where( + cls.id != "0", # Internal org "0" is not counted against the limit + cls.price_plan_id == BASE_PLAN_ID, + exists( + select(OrgMember).where( + OrgMember.user_id == user_id, + OrgMember.organization_id == cls.id, + ) + ), + ) + ) + ).all() + + async def add_credit_grant( + self, + session: AsyncSession, + amount: float | Decimal, + ) -> None: + await session.exec( + text( + f""" + SELECT id FROM {JamaiSQLModel.metadata.schema}.add_credit_grant( + '{self.id}'::TEXT, + {amount:.12f}::NUMERIC(21, 12) + ); + """ + ) + ) + + +class Project(_TableBase, table=True): + id: str = SqlField( + default_factory=lambda: generate_key(24, "proj_"), + primary_key=True, + description="Project ID.", + ) + organization_id: str = SqlField( + foreign_key="Organization.id", + index=True, + description="Organization ID.", + ondelete="CASCADE", + ) + name: str = SqlField( + description="Project name.", + ) + description: str = SqlField( + description="Project description.", + ) + tags: list[str] = SqlField( + [], + sa_type=JSON, + description="Project tags.", + ) + profile_picture_url: str | None = SqlField( + None, + description="URL of the profile picture.", + ) + cover_picture_url: str | None = SqlField( + None, + description="URL of the cover picture.", + ) + created_by: str = SqlField( + description="ID of the user that created this project.", + ) + quotas: dict[str, Any] = SqlField( + {}, + sa_type=JSON, + description="Quotas allotted to this project.", + ) + owner: str = SqlField( + foreign_key="User.id", + description="ID of the user that owns this organization.", + ) + organization: Organization = _relationship("projects") + users: list[User] = _relationship("projects", link_model=ProjectMember, selectin=False) + members: list[ProjectMember] = _relationship("project", selectin=False) + # keys: list["ProjectKey"] = _relationship("project") + + @classmethod + async def list_( + cls, + session: AsyncSession, + return_type: Type[ItemType], + *, + offset: int = 0, + limit: int | None = None, + order_by: str | None = None, + order_ascending: bool = True, + search_query: str | None = None, + search_columns: list[str] | None = None, + filters: dict[str, Any | list[Any]] | None = None, + after: str | None = None, + filter_by_user: str = "", + ) -> Page[ItemType]: + if order_by is None: + order_by = cls.pk()[0] + items, total = cls._list( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + search_query=search_query, + search_columns=search_columns, + filters=filters, + after=after, + ) + if filter_by_user: + subquery = select(ProjectMember).where( + ProjectMember.user_id == filter_by_user, + ProjectMember.project_id == cls.id, + ) + items = items.where(exists(subquery)) + total = total.where(exists(subquery)) + items, total, end_cursor = await cls._fetch_list_and_cursor( + session=session, + items=items, + total=total, + order_by=order_by, + ) + return Page[return_type]( + items=items, + offset=offset, + limit=total if limit is None else limit, + total=total, + end_cursor=end_cursor, + ) diff --git a/services/api/src/owl/db/oss_admin.py b/services/api/src/owl/db/oss_admin.py deleted file mode 100644 index 6e6ed84..0000000 --- a/services/api/src/owl/db/oss_admin.py +++ /dev/null @@ -1,171 +0,0 @@ -from typing import Any - -from pydantic import BaseModel, Field, model_validator -from sqlmodel import JSON, Column, Relationship -from sqlmodel import Field as sql_Field -from typing_extensions import Self - -from owl.configs.manager import ENV_CONFIG -from owl.db import UserSQLModel -from owl.protocol import ExternalKeys, Name -from owl.utils import datetime_now_iso -from owl.utils.crypt import decrypt, generate_key - - -class _ProjectBase(UserSQLModel): - name: str = sql_Field( - description="Project name.", - ) - organization_id: str = sql_Field( - default="default", - foreign_key="organization.id", - index=True, - description="Organization ID.", - ) - - -class ProjectCreate(_ProjectBase): - name: Name = sql_Field( - description="Project name.", - ) - - -class ProjectUpdate(BaseModel): - id: str - """Project ID.""" - name: Name | None = sql_Field( - default=None, - description="Project name.", - ) - updated_at: str = sql_Field( - default_factory=datetime_now_iso, - description="Project update datetime (ISO 8601 UTC).", - ) - - -class Project(_ProjectBase, table=True): - id: str = sql_Field( - primary_key=True, - default_factory=lambda: generate_key(24, "proj_"), - description="Project ID.", - ) - created_at: str = sql_Field( - default_factory=datetime_now_iso, - description="Project creation datetime (ISO 8601 UTC).", - ) - updated_at: str = sql_Field( - default_factory=datetime_now_iso, - description="Project update datetime (ISO 8601 UTC).", - ) - organization: "Organization" = Relationship(back_populates="projects") - """Organization that this project is associated with.""" - - -class ProjectRead(_ProjectBase): - id: str = sql_Field( - description="Project ID.", - ) - created_at: str = sql_Field( - description="Project creation datetime (ISO 8601 UTC).", - ) - updated_at: str = sql_Field( - description="Project update datetime (ISO 8601 UTC).", - ) - organization: "OrganizationRead" = sql_Field( - description="Organization that this project is associated with.", - ) - - -class _OrganizationBase(UserSQLModel): - id: str = sql_Field( - default=ENV_CONFIG.default_org_id, - primary_key=True, - description="Organization ID.", - ) - name: str = sql_Field( - default="Personal", - description="Organization name.", - ) - external_keys: dict[str, str] = sql_Field( - default={}, - sa_column=Column(JSON), - description="Mapping of service provider to its API key.", - ) - timezone: str | None = sql_Field( - default=None, - description="Timezone specifier.", - ) - models: dict[str, Any] = sql_Field( - default={}, - sa_column=Column(JSON), - description="The organization's custom model list, in addition to the provided default list.", - ) - - @property - def members(self) -> list: - # OSS does not support user accounts - return [] - - -class OrganizationCreate(_OrganizationBase): - name: str = sql_Field( - default="Personal", - description="Organization name.", - ) - - @model_validator(mode="after") - def check_external_keys(self) -> Self: - self.external_keys = ExternalKeys.model_validate(self.external_keys).model_dump() - return self - - -class OrganizationRead(_OrganizationBase): - created_at: str = sql_Field( - description="Organization creation datetime (ISO 8601 UTC).", - ) - updated_at: str = sql_Field( - description="Organization update datetime (ISO 8601 UTC).", - ) - projects: list[Project] | None = sql_Field( - default=None, - description="List of projects.", - ) - - def decrypt(self, key: str) -> Self: - if self.external_keys is not None: - self.external_keys = {k: decrypt(v, key) for k, v in self.external_keys.items()} - return self - - -class Organization(_OrganizationBase, table=True): - created_at: str = sql_Field( - default_factory=datetime_now_iso, - description="Organization creation datetime (ISO 8601 UTC).", - ) - updated_at: str = sql_Field( - default_factory=datetime_now_iso, - description="Organization update datetime (ISO 8601 UTC).", - ) - projects: list[Project] = Relationship(back_populates="organization") - """List of projects.""" - - -class OrganizationUpdate(BaseModel): - id: str - """Organization ID.""" - name: str | None = None - """Organization name.""" - external_keys: dict[str, str] | None = Field( - default=None, - description="Mapping of service provider to its API key.", - ) - timezone: str | None = Field(default=None) - """ - Timezone specifier. - """ - - @model_validator(mode="after") - def check_external_keys(self) -> Self: - if self.external_keys is not None: - self.external_keys = ExternalKeys.model_validate(self.external_keys).model_dump() - return self diff --git a/services/api/src/owl/db/template.py b/services/api/src/owl/db/template.py deleted file mode 100644 index 5a3f738..0000000 --- a/services/api/src/owl/db/template.py +++ /dev/null @@ -1,55 +0,0 @@ -from sqlmodel import Field as sql_Field -from sqlmodel import MetaData, Relationship, SQLModel - -from owl.protocol import Name -from owl.utils import datetime_now_iso - - -class TemplateSQLModel(SQLModel): - metadata = MetaData() - - -class TagTemplateLink(TemplateSQLModel, table=True): - tag_id: str = sql_Field( - primary_key=True, - foreign_key="tag.id", - description="Tag ID.", - ) - template_id: str = sql_Field( - primary_key=True, - foreign_key="template.id", - description="Template ID.", - ) - - -class Tag(TemplateSQLModel, table=True): - id: str = sql_Field( - primary_key=True, - description="Tag ID.", - ) - templates: list["Template"] = Relationship(back_populates="tags", link_model=TagTemplateLink) - - -class _TemplateBase(TemplateSQLModel): - id: str = sql_Field( - primary_key=True, - description="Template ID.", - ) - name: Name = sql_Field( - description="Template name.", - ) - created_at: str = sql_Field( - default_factory=datetime_now_iso, - description="Template creation datetime (ISO 8601 UTC).", - ) - - -class Template(_TemplateBase, table=True): - tags: list[Tag] = Relationship( - back_populates="templates", - link_model=TagTemplateLink, - ) - - -class TemplateRead(_TemplateBase): - tags: list[Tag] diff --git a/services/api/src/owl/docio.py b/services/api/src/owl/docio.py deleted file mode 100644 index af95b61..0000000 --- a/services/api/src/owl/docio.py +++ /dev/null @@ -1,52 +0,0 @@ -from mimetypes import guess_type - -import httpx -from httpx import Timeout -from langchain.docstore.document import Document - -HTTP_CLIENT = httpx.Client(transport=httpx.HTTPTransport(retries=3), timeout=Timeout(5 * 60)) - - -class DocIOAPIFileLoader: - """Load files using docio API.""" - - def __init__( - self, - file_path: str, - url, - client: httpx.Client = HTTP_CLIENT, - ) -> None: - """Initialize with a file path.""" - self.url = url - self.file_path = file_path - self.client = client - - def load(self) -> list[Document]: - """Load file.""" - # Guess the MIME type of the file based on its extension - mime_type, _ = guess_type(self.file_path) - if mime_type is None: - mime_type = "application/octet-stream" # Default MIME type - - # Extract the filename from the file path - filename = self.file_path.split("/")[-1] - - # Return the response from the forwarded request - documents = [] - # Open the file in binary mode - with open(self.file_path, "rb") as f: - response = self.client.post( - f"{self.url}/v1/load_file", - files={ - "file": (filename, f, mime_type), - }, - timeout=None, - ) - if response.status_code != 200: - err_mssg = response.text - raise RuntimeError(err_mssg) - for doc in response.json(): - documents.append( - Document(page_content=doc["page_content"], metadata=doc["metadata"]) - ) - return documents diff --git a/services/api/src/owl/docparse.py b/services/api/src/owl/docparse.py new file mode 100644 index 0000000..f068e50 --- /dev/null +++ b/services/api/src/owl/docparse.py @@ -0,0 +1,602 @@ +import asyncio +import sys +from hashlib import blake2b +from io import BytesIO +from os.path import basename, splitext + +import httpx +import orjson +import pandas as pd +import xmltodict +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_core.documents.base import Document +from loguru import logger + +from owl.configs import CACHE, ENV_CONFIG +from owl.types import Chunk, SplitChunksParams, SplitChunksRequest +from owl.utils.exceptions import BadInputError, JamaiException, UnexpectedError +from owl.utils.io import get_bytes_size_mb, json_dumps, json_loads + +# Table mapping all non-printable characters to None +NOPRINT_TRANS_TABLE = { + i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable() and chr(i) != "\n" +} + + +def make_printable(s: str) -> str: + """ + Replace non-printable characters in a string using + `translate()` that removes characters that map to None. + + # https://stackoverflow.com/a/54451873 + """ + return s.translate(NOPRINT_TRANS_TABLE) + + +def format_chunks(documents: list[Document], file_name: str, page: int = None) -> list[Chunk]: + if page is not None: + for d in documents: + d.metadata["page"] = page + chunks = [ + # TODO: Probably can use regex for this + # Replace vertical tabs, form feed, Unicode replacement character + # page_content=d.page_content.replace("\x0c", " ") + # .replace("\x0b", " ") + # .replace("\uFFFD", ""), + # For now we use a more aggressive strategy + Chunk( + text=make_printable(d.page_content), + title=d.metadata.get("title", ""), + page=d.metadata.get("page", 0), + file_name=file_name, + file_path=file_name, + metadata=d.metadata, + ) + for d in documents + ] + return chunks + + +class BaseLoader: + """Base loader class for loading documents.""" + + def __init__(self, request_id: str = ""): + """ + Initialize the BaseLoader class. + + Args: + request_id (str, optional): Request ID for logging. Defaults to "". + """ + self.request_id = request_id + + def split_chunks( + self, request: SplitChunksRequest, page_break_placeholder: str | None = None + ) -> list[Chunk]: + """Split a list of chunks using RecursiveCharacterTextSplitter. + + Args: + request (SplitChunksRequest): Request containing chunks and splitting parameters. + page_break_placeholder (str | None): The string that signifies a page break. + + Returns: + list[Chunk]: A list of split chunks. + + Raises: + BadInputError: If the split method is not supported. + UnexpectedError: If chunk splitting fails. + """ + _id = request.id + logger.info(f"{_id} - Split documents request: {request.str_trunc()}") + if request.params.method == "RecursiveCharacterTextSplitter": + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=request.params.chunk_size, + chunk_overlap=request.params.chunk_overlap, + ) + else: + raise BadInputError(f"Split method not supported: {request.params.method}") + + # Pre-process chunks to handle page breaks before splitting by character count. + if page_break_placeholder is not None: + doc_chunks = [] + page_counter = 0 + for chunk in request.chunks: + texts_from_pages = chunk.text.split(page_break_placeholder) + for text in texts_from_pages: + page_counter += 1 + + new_metadata = chunk.metadata.copy() + new_metadata["page"] = page_counter + + doc_chunks.append( + Chunk( + text=text.strip(), + title=chunk.title, + page=page_counter, # Update page number + file_name=chunk.file_name, + file_path=chunk.file_name, + metadata=new_metadata, + ) + ) + else: + # If no page break handling is needed, use the chunks as they are. + doc_chunks = request.chunks + + try: + # Now, split the processed chunks (doc_chunks) by character count. + chunks = [] + for chunk in doc_chunks: + chunks += [ + Chunk( + text=d.page_content, + title=chunk.title, + page=chunk.page, + file_name=chunk.file_name, + file_path=chunk.file_name, + metadata=chunk.metadata, + ) + for d in text_splitter.split_documents([Document(page_content=chunk.text)]) + ] + logger.info( + f"{_id} - {len(request.chunks):,d} chunks split into {len(chunks):,d} chunks.", + ) + return chunks + except Exception as e: + logger.exception(f"{_id} - Failed to split chunks.") + raise UnexpectedError("Failed to split chunks.") from e + + +class GeneralDocLoader(BaseLoader): + """ + General document loader class supporting various file extensions. + + This loader intelligently handles different file types, using DoclingLoader for + formats it supports and falling back to other methods for text-based and structured + data formats like JSON, XML, CSV, and TSV. + """ + + def __init__(self, request_id: str = ""): + """ + Initialize the GeneralDocLoader class. + + Args: + request_id (str, optional): Request ID for logging. Defaults to "". + """ + super().__init__(request_id=request_id) + + async def load_document( + self, + file_name: str, + content: bytes, + ) -> str: + """ + Loads and processes a file, converting it to Markdown format. + + Supports file types: PDF, DOCX, PPTX, XLSX, HTML, MD, TXT, JSON, JSONL, XML, CSV, TSV. + - PDF, DOCX, PPTX, XLSX, HTML: Parsed into Markdown using `DoclingLoader`. + - MD, TXT: Read directly. + - JSON: Formatted as a string with 2-space indenting. + - JSONL: Converted into Markdown table format using `pandas`. + - XML: Formatted as a JSON string with 2-space indenting. + - CSV, TSV: Converted into Markdown table format using `pandas`. + + Args: + file_name (str): The name of the file. + content (bytes): The binary content of the file. + + Returns: + str: The document content in Markdown format, or JSON string for JSON/XML. + + Raises: + BadInputError: If the parsing fails due to unsupported type or other errors. + """ + if len(content) == 0: + raise BadInputError(f'Input file "{file_name}" is empty.') + # Check cache + cache_ttl = ENV_CONFIG.document_loader_cache_ttl_sec + cache_key = "" + if cache_ttl > 0: + content_len = len(content) + content_hash = blake2b(content).hexdigest() + cache_key = f"document:{basename(file_name)}:{content_hash}:{content_len}" + # If multiple rows reference the same file, this lock prevents concurrent parsing + # Only the first row will trigger parsing, the rest will read from cache + # The lock expires after 2 minutes automatically if not released + async with CACHE.alock(f"{cache_key}:lock", blocking=cache_ttl > 0, expire=120): + md = None + if cache_key != "": + md = await CACHE.get(cache_key) + if md is not None: + # Extend cache TTL + await CACHE._redis_async.expire( + cache_key, ENV_CONFIG.document_loader_cache_ttl_sec + ) + logger.info(f'File "{file_name}" loaded from cache (cache key="{cache_key}").') + return md + try: + ext = splitext(file_name)[1].lower() + if ext in [".pdf", ".docx", ".pptx", ".xlsx", ".html"]: + doc_loader = DoclingLoader(self.request_id) + md = await doc_loader.load_document(file_name=file_name, content=content) + elif ext in [".md", ".txt"]: + md = content.decode("utf-8") + elif ext in [".json"]: + md = json_dumps( + json_loads(content.decode("utf-8")), option=orjson.OPT_INDENT_2 + ) + elif ext in [".jsonl"]: + md = pd.read_json( + BytesIO(content), + lines=True, + ).to_markdown() + elif ext in [".xml"]: + md = json_dumps(xmltodict.parse(content), option=orjson.OPT_INDENT_2) + elif ext in [".csv", ".tsv"]: + md = pd.read_csv( + BytesIO(content), + sep="\t" if ext == ".tsv" else ",", + ).to_markdown() + else: + raise BadInputError(f'File type "{ext}" is not supported at the moment.') + if len(md.strip()) == 0: + raise BadInputError(f'Input file "{file_name}" is empty.') + # Set cache + if cache_ttl > 0: + await CACHE.set(cache_key, md, ex=cache_ttl) + logger.info( + f'File "{file_name}" successfully parsed into markdown. (cache key="{cache_key}")' + ) + else: + logger.info(f'File "{file_name}" successfully parsed into markdown.') + return md + + except JamaiException: + raise + except pd.errors.EmptyDataError as e: + raise BadInputError(f'Input file "{file_name}" is empty.') from e + except Exception as e: + logger.error(f'Failed to parse file "{file_name}": {repr(e)}') + raise BadInputError(f'Failed to parse file "{file_name}".') from e + + async def load_document_chunks( + self, + file_name: str, + content: bytes, + chunk_size: int = 1000, + chunk_overlap: int = 200, + ) -> list[Chunk]: + """ + Loads and processes a file, splitting it into chunks. + + Supports file types: PDF, DOCX, PPTX, XLSX, HTML, MD, TXT, JSON, JSONL, XML, CSV, TSV. + - PDF, DOCX, PPTX, XLSX, HTML: Parsed and chunked using `DoclingLoader`. + - MD, TXT: Read directly and chunked using `RecursiveCharacterTextSplitter`. + - JSON, JSONL: Each JSON is formatted as a chunk with 2-space indenting. + - CSV, TSV: Each row is parsed into a JSON and formatted as a chunk with 2-space indenting. + - XML: Each XML is formatted as a JSON chunk with 2-space indenting. + + Args: + file_name (str): The name of the file. + content (bytes): The binary content of the file. + chunk_size (int): The desired size of each chunk in tokens. + chunk_overlap (int): The number of tokens to overlap between chunks. + + Returns: + list[Chunk]: A list of Chunk objects representing the processed file content. + + Raises: + BadInputError: If the parsing and splitting fails due to unsupported type or other errors. + """ + if len(content) == 0: + raise BadInputError(f'Input file "{file_name}" is empty.') + # Check cache + cache_ttl = ENV_CONFIG.document_loader_cache_ttl_sec + cache_key = "" + if cache_ttl > 0: + content_len = len(content) + content_hash = blake2b(content).hexdigest() + cache_key = f"chunks:{basename(file_name)}:{content_hash}:{content_len}" + # If multiple rows reference the same file, this lock prevents concurrent parsing + # Only the first row will trigger parsing, the rest will read from cache + # The lock expires after 2 minutes automatically if not released + async with CACHE.alock(f"{cache_key}:lock", blocking=cache_ttl > 0, expire=120): + chunk_json_str = None + if cache_key != "": + chunk_json_str = await CACHE.get(cache_key) + if chunk_json_str is not None: + # Extend cache TTL + await CACHE._redis_async.expire(cache_key, cache_ttl) + logger.info( + f'File chunks "{file_name}" loaded from cache (cache key="{cache_key}").' + ) + return [Chunk.model_validate(chunk) for chunk in json_loads(chunk_json_str)] + try: + ext = splitext(file_name)[1].lower() + if ext in [".pdf", ".docx", ".pptx", ".xlsx", ".html"]: + if ext in [".pdf", ".pptx", ".xlsx"]: + doc_loader = DoclingLoader( + self.request_id, page_break_placeholder="=====Page===Break=====" + ) + else: + doc_loader = DoclingLoader(self.request_id, page_break_placeholder=None) + chunks = await doc_loader.load_document_chunks( + file_name=file_name, + content=content, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + elif ext in [".md", ".txt"]: + content = content.decode("utf-8") + if len(content.strip()) == 0: + raise BadInputError(f'Input file "{file_name}" is empty.') + chunks = format_chunks( + [Document(page_content=content, metadata={"page": 1})], + file_name, + ) + chunks = self.split_chunks( + SplitChunksRequest( + chunks=chunks, + params=SplitChunksParams( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ), + ) + ) + elif ext in [".json", ".jsonl", ".csv", ".tsv"]: + if ext in [".csv", ".tsv"]: + json_list = pd.read_csv( + BytesIO(content), + sep="\t" if ext == ".tsv" else ",", + ).to_dict(orient="records") + else: + content = content.decode("utf-8") + if ext == ".jsonl": + json_list = [ + json_loads(line) + for line in content.split("\n") + if line.strip() != "" + ] + else: + json_list = [json_loads(content)] + docs = [ + Document( + page_content=json_dumps(js, option=orjson.OPT_INDENT_2), + metadata={"page": 1, "row": i}, + ) + for i, js in enumerate(json_list) + ] + chunks = format_chunks(docs, file_name) + elif ext in [".xml"]: + chunks = format_chunks( + [ + Document( + page_content=json_dumps( + xmltodict.parse(content), option=orjson.OPT_INDENT_2 + ), + metadata={"page": 1}, + ) + ], + file_name, + ) + else: + raise BadInputError(f'File type "{ext}" is not supported at the moment.') + if len(chunks) == 0: + raise BadInputError(f'Input file "{file_name}" is empty.') + # Set cache + if cache_ttl > 0: + chunk_json_str = json_dumps([chunk.model_dump() for chunk in chunks]) + await CACHE.set(cache_key, chunk_json_str, ex=cache_ttl) + logger.info( + ( + f'File "{file_name}" successfully parsed and split into ' + f'{len(chunks):,d} chunks (cache key="{cache_key}").' + ) + ) + else: + logger.info( + ( + f'File "{file_name}" successfully parsed and split into ' + f"{len(chunks):,d} chunks." + ) + ) + return chunks + + except JamaiException: + raise + except pd.errors.EmptyDataError as e: + raise BadInputError(f'Input file "{file_name}" is empty.') from e + except Exception as e: + logger.error(f'Failed to parse and split file "{file_name}": {repr(e)}') + raise BadInputError(f'Failed to parse and split file "{file_name}".') from e + + +class DoclingLoader(BaseLoader): + """ + A class for loading and processing documents using Docling-Serve API. + """ + + def __init__( + self, + request_id: str = "", + docling_serve_url: str | None = None, + page_break_placeholder: str | None = None, + ): + """ + Initialize the DoclingLoader class. + + Args: + request_id (str, optional): Request ID for logging. Defaults to "". + """ + super().__init__(request_id=request_id) + self.http_aclient = httpx.AsyncClient( + timeout=60.0 * 10, + transport=httpx.AsyncHTTPTransport(retries=3), + ) + self.docling_serve_url = ( + ENV_CONFIG.docling_url if docling_serve_url is None else docling_serve_url + ) + self.page_break_placeholder = page_break_placeholder + + async def retrieve_document_content( + self, + file_name: str, + content: bytes, + ) -> dict: # Expecting JSON response from docling-serve + """ + Retrieves the content of a document file using Docling-Serve API (async pattern). + + Args: + file_path (str): Path to the document file to be parsed (local temp path). + file_name (str): Original file name. + content (bytes): Binary content of the file. + force_full_page_ocr (bool): Whether to force full-page OCR. + + Returns: + dict: The JSON response from docling-serve. + + Raises: + HTTPException: If the document conversion fails via docling-serve. + """ + logger.info(f'{self.request_id} - Calling Docling-Serve for file "{file_name}".') + + files = {"files": (file_name, content, "application/octet-stream")} + data = { + "to_formats": ["md"], + "image_export_mode": "placeholder", + "pipeline": "standard", + "ocr": True, + "force_ocr": False, + "ocr_engine": "easyocr", + "pdf_backend": "dlparse_v4", + "table_mode": "accurate", + "abort_on_error": False, + "return_as_file": False, + } + + if self.page_break_placeholder is not None: + data["md_page_break_placeholder"] = self.page_break_placeholder + + try: + # Step 1: Start async conversion + response = await self.http_aclient.post( + f"{self.docling_serve_url}/v1alpha/convert/file/async", files=files, data=data + ) + response.raise_for_status() + task_id_data = response.json() + task_id = task_id_data.get("task_id") + if not task_id: + raise UnexpectedError("Docling-Serve did not return a task_id.") + + # Step 2: Poll for completion + poll_url = f"{self.docling_serve_url}/v1alpha/status/poll/{task_id}" + time_slept = 0 + sleep_for = 1 + task_status = None + while time_slept < ENV_CONFIG.docling_timeout_sec: + try: + poll_resp = await self.http_aclient.get(poll_url, timeout=20) + poll_resp.raise_for_status() + status_data = poll_resp.json() + task_status = status_data.get("task_status") + except Exception as e: + logger.error(f"Polling API error: {e}") + + if task_status == "success": + break # Exit polling loop + elif task_status in ("failure", "revoked"): + error_info = status_data.get("task_result", {}).get("error", "Unknown error") + raise UnexpectedError(f"Docling-Serve task failed: {error_info}") + # If not success, failure, or revoked, it's still processing or in another state + await asyncio.sleep(sleep_for) + time_slept += sleep_for + else: # Executed if the while loop completes without a 'break' + logger.error( + f"{self.request_id} - Polling timed out for Docling-Serve task {task_id} after {time_slept} seconds." + ) + raise UnexpectedError( + f"Polling timed out for Docling-Serve task {task_id} after {time_slept} seconds." + ) + + # Step 3: Fetch result + result_url = f"{self.docling_serve_url}/v1alpha/result/{task_id}" + result_resp = await self.http_aclient.get(result_url, timeout=60) + result_resp.raise_for_status() + return result_resp.json() + + except httpx.TimeoutException as e: + logger.error(f"Docling-Serve API timeout error: {e}") + raise UnexpectedError(f"Docling-Serve API timeout error: {e}") from e + except httpx.HTTPError as e: + logger.error(f"Docling-Serve API error: {e}") + raise UnexpectedError(f"Docling-Serve API error: {e}") from e + except Exception as e: + raise UnexpectedError(f"Docling-Serve API error: {e}") from e + + async def convert_document_to_markdown(self, file_name: str, content: bytes) -> str: + """ + Converts a document to Markdown format using Docling-Serve. + """ + docling_response = await self.retrieve_document_content(file_name, content) + logger.info( + f"Converted `{file_name}` to Markdown in {docling_response.get('processing_time', '0'):.3f} seconds, " + f"{get_bytes_size_mb(content):.3f} MB." + ) + return docling_response.get("document", {}).get("md_content", "") + + async def convert_document_to_chunks( + self, file_name: str, content: bytes, chunk_size: int, chunk_overlap: int + ) -> list[Chunk]: + """ + Converts a document to chunks, respecting page and table boundaries, using Docling-Serve. + """ + docling_response = await self.retrieve_document_content(file_name, content) + md_content = docling_response.get("document", {}).get("md_content", "") + + documents = [Document(page_content=md_content, metadata={"page": 1})] + chunks = format_chunks(documents, file_name) + + chunks = self.split_chunks( + SplitChunksRequest( + chunks=chunks, + params=SplitChunksParams( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ), + ), + page_break_placeholder=self.page_break_placeholder, + ) + + return chunks + + async def load_document(self, file_name: str, content: bytes) -> str: + """ + Loads and processes a document file, converting it to Markdown format using Docling-Serve. + """ + try: + md = await self.convert_document_to_markdown(file_name, content) + logger.info(f'File "{file_name}" loaded as markdown.') + return md + + except Exception as e: + logger.error(f"Failed to process file: {e}") + raise UnexpectedError(f"Failed to process file: {e}") from e + + async def load_document_chunks( + self, + file_name: str, + content: bytes, + chunk_size: int = 1000, + chunk_overlap: int = 200, + ) -> list[Chunk]: + """ + Loads and processes a document file, splitting it into chunks using Docling-Serve. + """ + try: + chunks = await self.convert_document_to_chunks( + file_name, content, chunk_size, chunk_overlap + ) + logger.info(f'File "{file_name}" loaded and split into {len(chunks):,d} chunks.') + return chunks + + except Exception as e: + logger.error(f"Failed to process file: {e}") + raise UnexpectedError(f"Failed to process file: {e}") from e diff --git a/services/api/src/owl/entrypoints/api.py b/services/api/src/owl/entrypoints/api.py index d9f2678..7daae9b 100644 --- a/services/api/src/owl/entrypoints/api.py +++ b/services/api/src/owl/entrypoints/api.py @@ -1,75 +1,104 @@ -""" -API server. -""" +import asyncio +from asyncio.coroutines import iscoroutine +from collections import defaultdict +from contextlib import asynccontextmanager +from time import perf_counter -import os -from typing import Any - -from fastapi import BackgroundTasks, FastAPI, Request, status -from fastapi.exceptions import RequestValidationError, ResponseValidationError +from fastapi import BackgroundTasks, FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse -from filelock import Timeout +from gunicorn.app.base import BaseApplication from loguru import logger -from pydantic import BaseModel -from starlette.exceptions import HTTPException -from starlette.middleware.sessions import SessionMiddleware - -from jamaibase import JamAIAsync -from jamaibase.exceptions import ( - AuthorizationError, - BadInputError, - ContextOverflowError, - ExternalAuthError, - ForbiddenError, - InsufficientCreditsError, - ResourceExistsError, - ResourceNotFoundError, - ServerBusyError, - TableSchemaFixedError, - UnexpectedError, - UnsupportedMediaTypeError, - UpgradeTierError, +from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor +from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor +from opentelemetry.instrumentation.redis import RedisInstrumentor + +from owl.configs import CACHE, ENV_CONFIG +from owl.db import create_db_engine_async, init_db, migrate_db, reset_db +from owl.routers import ( + auth, + conversation, + file, + gen_table, + gen_table_v1, + meters, + models, + organizations, + projects, + serving, + tasks, + templates, + users, ) -from owl.billing import BillingManager -from owl.configs.manager import CONFIG, ENV_CONFIG -from owl.protocol import COL_NAME_PATTERN, TABLE_NAME_PATTERN, UserAgent -from owl.routers import file, gen_table, llm, org_admin, template +from owl.routers.projects import v1 as projects_v1 +from owl.types import UserAgent from owl.utils import uuid7_str +from owl.utils.billing import CLICKHOUSE_CLIENT, BillingManager +from owl.utils.exceptions import JamaiException +from owl.utils.handlers import exception_handler, make_request_log_str, path_not_found_handler from owl.utils.logging import setup_logger_sinks, suppress_logging_handlers -from owl.utils.responses import ( - bad_input_response, - forbidden_response, - internal_server_error_response, - make_request_log_str, - make_response, - resource_exists_response, - resource_not_found_response, - server_busy_response, - unauthorized_response, -) - -if ENV_CONFIG.is_oss: - from owl.routers import oss_admin as admin - - cloud_auth = None -else: - from owl.routers import cloud_admin as admin - from owl.routers import cloud_auth +from owl.utils.mcp import get_mcp_router +from owl.utils.mcp.server import MCP_TOOL_TAG - -NO_AUTH_ROUTES = {"health", "public", "favicon.ico"} - -client = JamAIAsync(token=ENV_CONFIG.service_key_plain, timeout=60.0) -logger.enable("owl") -setup_logger_sinks() +OVERHEAD_LOG_ROUTES = {r.path for r in serving.router.routes} +# logger.enable("owl") +setup_logger_sinks(None) # We purposely don't intercept uvicorn logs since it is typically not useful # We also don't intercept transformers logs # replace_logging_handlers(["uvicorn.access"], False) -suppress_logging_handlers(["uvicorn", "litellm", "openmeter", "azure"], True) +suppress_logging_handlers(["uvicorn", "litellm", "azure", "openmeter", "pottery"], True) + +# --- Setup DB --- # +# Maybe reset DB +if ENV_CONFIG.db_reset: + asyncio.run(reset_db(reset_max_users=ENV_CONFIG.db_init_max_users)) +# Migration +asyncio.run(migrate_db()) +# Maybe populate DB with demo data +# If OSS and first launch, init user, organization and project +if ENV_CONFIG.db_init: + asyncio.run(init_db(init_max_users=ENV_CONFIG.db_init_max_users)) +# Maybe reset cache +if ENV_CONFIG.cache_reset: + CACHE.purge() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Startup logic + logger.info(f"Using configuration: {ENV_CONFIG}") + yield + logger.info("Shutting down...") + + # Close DB connection + logger.info("Closing DB connection.") + try: + await (await create_db_engine_async()).dispose() + except Exception as e: + logger.warning(f"Failed to close DB connection: {repr(e)}") + + # Close Redis connection + logger.info("Closing Redis connection.") + try: + await CACHE.aclose() + except Exception as e: + logger.warning(f"Failed to close Redis connection: {repr(e)}") + + # Flush buffer + logger.info("Flushing redis buffer to database.") + try: + await CLICKHOUSE_CLIENT.flush_buffer() + except Exception as e: + logger.warning(f"Failed to flush buffer: {repr(e)}") + finally: + ret = CLICKHOUSE_CLIENT.client.close() + if iscoroutine(ret): + await ret + logger.info("Shutdown complete.") app = FastAPI( + title="JamAI Base API", logger=logger, default_response_class=ORJSONResponse, # Should be faster openapi_url="/api/public/openapi.json", @@ -80,37 +109,107 @@ "url": "https://www.apache.org/licenses/LICENSE-2.0.html", }, servers=[dict(url="https://api.jamaibase.com")], + lifespan=lifespan, ) -services = [ - (admin.router, ["Backend Admin"], "/api"), - (admin.public_router, ["Backend Admin"], "/api"), - (org_admin.router, ["Organization Admin"], "/api/admin/org"), - (template.router, ["Templates"], "/api"), - (template.public_router, ["Templates (Public)"], "/api"), - (llm.router, ["Large Language Model"], "/api"), - (gen_table.router, ["Generative Table"], "/api"), - (file.router, ["File"], "/api"), -] + +# Programmatic Instrumentation +FastAPIInstrumentor.instrument_app(app) +RedisInstrumentor().instrument() +HTTPXClientInstrumentor().instrument() # Mount -for router, tags, prefix in services: +internal_api_tag = "" if ENV_CONFIG.is_oss else " (Internal API)" +app.include_router( + models.router, + prefix="/api", + tags=["Models" + internal_api_tag], +) +app.include_router( + auth.router, + prefix="/api", + tags=["Authentication" + internal_api_tag], +) +app.include_router( + users.router, + prefix="/api", + tags=["Users" + internal_api_tag], +) +app.include_router( + organizations.router, + prefix="/api", + tags=["Organizations" + internal_api_tag], +) +app.include_router( + projects.router, + prefix="/api", + tags=["Projects"], +) +app.include_router( + projects_v1.router, + deprecated=True, + prefix="/api/admin/org", + tags=["Organization Admin (Legacy)"], +) +app.include_router( + templates.router, + prefix="/api", + tags=["Templates"], +) +app.include_router( + conversation.router, + prefix="/api", + tags=["Conversations"], +) +app.include_router( + gen_table.router, + prefix="/api", + tags=["Generative Table (V2)"], +) +app.include_router( + gen_table_v1.router, + prefix="/api", + tags=["Generative Table (V1)"], + deprecated=True, +) +app.include_router( + serving.router, + prefix="/api", + tags=["Serving"], +) +app.include_router( + file.router, + prefix="/api", + tags=["File"], +) +app.include_router( + tasks.router, + prefix="/api", + tags=["Tasks"], +) +app.include_router( + meters.router, + prefix="/api", + tags=["Meters" + internal_api_tag], +) +if ENV_CONFIG.is_cloud: + from owl.routers.cloud import logs, prices + app.include_router( - router, - prefix=prefix, - tags=tags, + prices.router, + prefix="/api", + tags=["Prices"], ) -if cloud_auth is not None: app.include_router( - cloud_auth.router, + logs.router, prefix="/api", - tags=["OAuth"], - ) - app.add_middleware( - SessionMiddleware, - secret_key=ENV_CONFIG.owl_session_secret_plain, - max_age=60 * 60 * 24 * 7, - https_only=ENV_CONFIG.owl_is_prod, + tags=["Logs (Internal Cloud-only API)"], ) +app.include_router( + get_mcp_router(app), + prefix="/api", + tags=["Model Context Protocol (MCP)"], +) + # Permissive CORS app.add_middleware( @@ -120,51 +219,9 @@ allow_methods=["*"], allow_headers=["*"], ) - - -@app.on_event("startup") -async def startup(): - # Router lifespan is broken as of fastapi==0.109.0 and starlette==0.35.1 - # https://github.com/tiangolo/fastapi/discussions/9664 - logger.info(f"Using configuration: {ENV_CONFIG}") - # Maybe purge Redis data - if ENV_CONFIG.owl_cache_purge: - CONFIG.purge() - if ENV_CONFIG.is_oss: - logger.opt(colors=True).info("Launching in OSS mode.") - from sqlalchemy import func - from sqlmodel import Session, select - - from owl.db import MAIN_ENGINE - from owl.db.oss_admin import Organization, Project - - with Session(MAIN_ENGINE) as session: - org = session.get(Organization, ENV_CONFIG.default_org_id) - if org is None: - org = Organization() - session.add(org) - session.commit() - session.refresh(org) - logger.info(f"Default organization created: {org}") - else: - logger.info(f"Default organization found: {org}") - # Default project could have been deleted - # As long as there is at least one project it's ok - project_count = session.exec(select(func.count(Project.id))).one() - if project_count == 0: - project = Project( - id=ENV_CONFIG.default_project_id, - name="Default", - organization_id=org.id, - ) - session.add(project) - session.commit() - session.refresh(project) - logger.info(f"Default project created: {project}") - else: - logger.info(f"{project_count:,d} projects found.") - else: - logger.opt(colors=True).info("Launching in Cloud mode.") +app.add_exception_handler(JamaiException, exception_handler) # Suppress starlette traceback +app.add_exception_handler(Exception, exception_handler) +app.add_exception_handler(404, path_not_found_handler) @app.middleware("http") @@ -178,35 +235,48 @@ async def log_request(request: Request, call_next): Returns: response (Response): Response of the path operation. """ + request.state.request_start_time = perf_counter() # Set request state - request.state.id = uuid7_str() + request_id = request.headers.get("x-request-id", uuid7_str()) + request.state.id = request_id request.state.user_agent = UserAgent.from_user_agent_string( request.headers.get("user-agent", "") ) - request.state.billing = BillingManager(request=request) - - # OPTIONS are always allowed for CORS preflight: - if request.method == "OPTIONS": - return await call_next(request) - # The following paths are always allowed: - path_components = [p for p in request.url.path.split("/") if p][:2] - if request.method in ("GET", "HEAD") and ( - len(path_components) == 0 or path_components[-1] in NO_AUTH_ROUTES - ): - return await call_next(request) + request.state.timing = defaultdict(float) # Call request + path = request.url.path + if "api/health" not in path: + logger.info(make_request_log_str(request)) response = await call_next(request) - logger.info(make_request_log_str(request, response.status_code)) - - # Add egress events - request.state.billing.create_egress_events( - float(response.headers.get("content-length", 0)) / (1024**3) - ) - # Process billing (this will run AFTER streaming responses are sent) - tasks = BackgroundTasks() - tasks.add_task(request.state.billing.process_all) - response.background = tasks + response.headers["x-request-id"] = request_id + if "api/health" not in path: + logger.info(make_request_log_str(request, response.status_code)) + + # Process billing (this will run BEFORE any responses are sent) + if hasattr(request.state, "billing"): + billing: BillingManager = request.state.billing + # Add egress events + # This does not include SSE egress, and will need to be captured separately + egress_bytes = float(response.headers.get("content-length", 0)) + if egress_bytes > 0: + billing.create_egress_events(egress_bytes / (1024**3)) + # Background tasks will run AFTER streaming responses are sent + tasks = BackgroundTasks() + tasks.add_task(billing.process_all) + response.background = tasks + # Log timing + model_start_time = getattr(request.state, "model_start_time", None) + if ( + ENV_CONFIG.log_timings + and model_start_time + and any(p for p in OVERHEAD_LOG_ROUTES if p in path) + ): + overhead = model_start_time - request.state.request_start_time + breakdown = {k: f"{v * 1e3:,.1f} ms" for k, v in request.state.timing.items()} + logger.info( + f"{request.state.id} - Total overhead: {overhead * 1e3:,.1f} ms. Breakdown: {breakdown}" + ) return response @@ -219,256 +289,134 @@ async def health() -> ORJSONResponse: ) -# --- Order of handlers does not matter --- # - - -@app.exception_handler(AuthorizationError) -async def authorization_exc_handler(request: Request, exc: ForbiddenError): - return unauthorized_response(request, str(exc), exception=exc) - - -@app.exception_handler(ExternalAuthError) -async def external_auth_exc_handler(request: Request, exc: ExternalAuthError): - return unauthorized_response( - request, str(exc), error="external_authentication_failed", exception=exc - ) - - -@app.exception_handler(PermissionError) -async def permission_error_exc_handler(request: Request, exc: PermissionError): - return forbidden_response(request, str(exc), error="resource_protected", exception=exc) - - -@app.exception_handler(ForbiddenError) -async def forbidden_exc_handler(request: Request, exc: ForbiddenError): - return forbidden_response(request, str(exc), exception=exc) - - -@app.exception_handler(UpgradeTierError) -async def upgrade_tier_exc_handler(request: Request, exc: UpgradeTierError): - return forbidden_response(request, str(exc), error="upgrade_tier", exception=exc) - - -@app.exception_handler(InsufficientCreditsError) -async def insufficient_credits_exc_handler(request: Request, exc: InsufficientCreditsError): - return forbidden_response(request, str(exc), error="insufficient_credits", exception=exc) - - -@app.exception_handler(FileNotFoundError) -async def file_not_found_exc_handler(request: Request, exc: FileNotFoundError): - return resource_not_found_response(request, str(exc), exception=exc) - - -@app.exception_handler(ResourceNotFoundError) -async def resource_not_found_exc_handler(request: Request, exc: ResourceNotFoundError): - return resource_not_found_response(request, str(exc), exception=exc) - - -@app.exception_handler(FileExistsError) -async def file_exists_exc_handler(request: Request, exc: FileExistsError): - return resource_exists_response(request, str(exc), exception=exc) - - -@app.exception_handler(ResourceExistsError) -async def resource_exists_exc_handler(request: Request, exc: ResourceExistsError): - return resource_exists_response(request, str(exc), exception=exc) - - -@app.exception_handler(UnsupportedMediaTypeError) -async def unsupported_media_type_exc_handler(request: Request, exc: UnsupportedMediaTypeError): - logger.warning(f"{make_request_log_str(request, 415)} - {exc.__class__.__name__}: {exc}") - return ORJSONResponse( - status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, - content={ - "object": "error", - "error": "unsupported_media_type", - "message": str(exc), - "detail": str(exc), - "request_id": request.state.id, - "exception": "", - }, - ) +# Process OpenAPI docs +openapi_schema = app.openapi() +# Remove MCP and permission tags +for path_info in openapi_schema["paths"].values(): + for method_info in path_info.values(): + tags = method_info["tags"] + tags = [ + tag + for tag in tags + if not (tag == MCP_TOOL_TAG or tag.startswith(("system", "organization", "project"))) + ] + method_info["tags"] = tags +# Re-order paths to put internal APIs last +if ENV_CONFIG.is_cloud: + openapi_schema["paths"] = { + k: openapi_schema["paths"][k] + for k in sorted( + openapi_schema["paths"].keys(), + key=lambda p: internal_api_tag + in list(openapi_schema["paths"][p].values())[0]["tags"][0], + ) + } +if ENV_CONFIG.is_cloud: + # Add security schemes + openapi_schema["components"]["securitySchemes"] = { + "Authentication": {"type": "http", "scheme": "bearer"}, + } + openapi_schema["security"] = [{"Authentication": []}] + openapi_schema["info"]["x-logo"] = {"url": "https://www.jamaibase.com/favicon.svg"} +app.openapi_schema = openapi_schema -@app.exception_handler(BadInputError) -async def bad_input_exc_handler(request: Request, exc: BadInputError): - return bad_input_response(request, str(exc), exception=exc) +class StandaloneApplication(BaseApplication): + def __init__(self, app, options=None): + self.options = options or {} + self.application = app + super().__init__() + def load_config(self): + config = { + key: value + for key, value in self.options.items() + if key in self.cfg.settings and value is not None + } + for key, value in config.items(): + self.cfg.set(key.lower(), value) -@app.exception_handler(TableSchemaFixedError) -async def table_fixed_exc_handler(request: Request, exc: TableSchemaFixedError): - return bad_input_response(request, str(exc), error="table_schema_fixed", exception=exc) + def load(self): + return self.application -@app.exception_handler(ContextOverflowError) -async def context_overflow_exc_handler(request: Request, exc: ContextOverflowError): - return bad_input_response(request, str(exc), error="context_overflow", exception=exc) +# Gunicorn post_fork hook +def post_fork(server, worker): + from opentelemetry import metrics, trace + from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter + from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + from opentelemetry.sdk.metrics import MeterProvider + from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + from owl.utils.loguru_otlp_handler import OTLPHandler -class Wrapper(BaseModel): - body: Any + # from opentelemetry.instrumentation.auto_instrumentation import sitecustomize + # trace.set_tracer_provider(trace.get_tracer_provider()) + # metrics.set_meter_provider(metrics.get_meter_provider()) + # for manual instrumentation -@app.exception_handler(RequestValidationError) -async def request_validation_exc_handler(request: Request, exc: RequestValidationError): - content = None - try: - logger.info( - f"{make_request_log_str(request, 422)} - RequestValidationError: {exc.errors()}" - ) - errors, messages = [], [] - for i, e in enumerate(exc.errors()): - try: - msg = str(e["ctx"]["error"]).strip() - except Exception: - msg = e["msg"].strip() - if not msg.endswith("."): - msg = f"{msg}." - # Intercept Table and Column ID regex error message - if TABLE_NAME_PATTERN in msg: - msg = ( - "Table name or ID must be unique with at least 1 character and up to 100 characters. " - "Must start and end with an alphabet or number. " - "Characters in the middle can include `_` (underscore), `-` (dash), `.` (dot)." - ) - elif COL_NAME_PATTERN in msg: - msg = ( - "Column name or ID must be unique with at least 1 character and up to 100 characters. " - "Must start and end with an alphabet or number. " - "Characters in the middle can include `_` (underscore), `-` (dash), ` ` (space). " - 'Cannot be called "ID" or "Updated at" (case-insensitive).' - ) - - path = "" - for j, x in enumerate(e.get("loc", [])): - if isinstance(x, str): - if j > 0: - path += "." - path += x - elif isinstance(x, int): - path += f"[{x}]" - else: - raise TypeError("Unexpected type") - if path: - path += " : " - messages.append(f"{i + 1}. {path}{msg}") - error = {k: v for k, v in e.items() if k != "ctx"} - if "ctx" in e: - error["ctx"] = {k: repr(v) if k == "error" else v for k, v in e["ctx"].items()} - if "input" in e: - error["input"] = repr(e["input"]) - errors.append(error) - message = "\n".join(messages) - message = f"Your request contains errors:\n{message}" - content = { - "object": "error", - "error": "validation_error", - "message": message, - "detail": errors, - "request_id": request.state.id, - "exception": "", - **Wrapper(body=exc.body).model_dump(), + resource = Resource.create( + { + "service.name": "owl", + "service.instance.id": uuid7_str(), } - return ORJSONResponse( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - content=content, - ) - except Exception: - if content is None: - content = repr(exc) - logger.exception(f"{request.state.id} - Failed to parse error data: {content}") - message = str(exc) - return ORJSONResponse( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - content={ - "object": "error", - "error": "validation_error", - "message": message, - "detail": message, - "request_id": request.state.id, - "exception": exc.__class__.__name__, - }, + ) + # Meter provider configuration + reader = PeriodicExportingMetricReader( + OTLPMetricExporter( + endpoint=f"http://{ENV_CONFIG.opentelemetry_host}:{ENV_CONFIG.opentelemetry_port}" + ), + export_interval_millis=1000, + ) + provider = MeterProvider(resource=resource, metric_readers=[reader]) + metrics.set_meter_provider(provider) + # Trace provider configuration + trace_provider = TracerProvider(resource=resource) + trace_provider.add_span_processor( + BatchSpanProcessor( + OTLPSpanExporter( + endpoint=f"http://{ENV_CONFIG.opentelemetry_host}:{ENV_CONFIG.opentelemetry_port}" + ) ) - - -@app.exception_handler(Exception) -async def exception_handler(request: Request, exc: Exception): - return internal_server_error_response(request, exception=exc) - - -@app.exception_handler(UnexpectedError) -async def unexpected_error_handler(request: Request, exc: UnexpectedError): - return internal_server_error_response(request, exception=exc) - - -@app.exception_handler(ResponseValidationError) -async def response_validation_error_handler(request: Request, exc: ResponseValidationError): - return internal_server_error_response(request, exception=exc) - - -@app.exception_handler(Timeout) -async def write_lock_timeout_exc_handler(request: Request, exc: Timeout): - return server_busy_response( - request, - "This table is currently busy. Please try again later.", - exception=exc, - headers={"Retry-After": "10"}, ) - - -@app.exception_handler(ServerBusyError) -async def busy_exc_handler(request: Request, exc: ServerBusyError): - return server_busy_response( - request, - "The server is currently busy. Please try again later.", - exception=exc, - headers={"Retry-After": "30"}, + trace.set_tracer_provider(trace_provider) + + # # for auto-instrumentation + # trace.get_tracer_provider() + # metrics.get_meter_provider() + # Configure the OTLP Exporter + otlp_exporter = OTLPLogExporter( + endpoint=f"http://{ENV_CONFIG.opentelemetry_host}:{ENV_CONFIG.opentelemetry_port}" ) - -@app.exception_handler(HTTPException) -async def http_exc_handler(request: Request, exc: HTTPException): - return make_response( - request=request, - message=str(exc), - error="http_error", - status_code=exc.status_code, - detail=None, - exception=exc, - log=exc.status_code != 404, + # Create an instance of OTLPHandler + otlp_handler = OTLPHandler.create( + service_name="owl", + exporter=otlp_exporter, + development_mode=False, # Set to True for development ) - -if not ENV_CONFIG.is_oss: - openapi_schema = app.openapi() - # Add security schemes - openapi_schema["components"]["securitySchemes"] = { - "Authentication": { - "type": "http", - "scheme": "bearer", - }, - } - openapi_schema["security"] = [{"Authentication": []}] - openapi_schema["info"]["x-logo"] = {"url": "https://www.jamaibase.com/favicon.svg"} - app.openapi_schema = openapi_schema + logger.add(otlp_handler.sink, level="INFO") + server.log.info(f"Worker spawned (pid: {worker.pid})") if __name__ == "__main__": - import uvicorn - - if os.name == "nt": - import asyncio - from multiprocessing import freeze_support - - logger.warning("The system is Windows, performing asyncio and multiprocessing patches.") - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - freeze_support() - - uvicorn.run( - "owl.entrypoints.api:app", - reload=False, - host=ENV_CONFIG.owl_host, - port=ENV_CONFIG.owl_port, - workers=ENV_CONFIG.owl_workers, - limit_concurrency=ENV_CONFIG.owl_max_concurrency, - ) + options = { + "bind": f"{ENV_CONFIG.host}:{ENV_CONFIG.port}", + "workers": ENV_CONFIG.workers, + "worker_class": "uvicorn.workers.UvicornWorker", + "limit_concurrency": ENV_CONFIG.max_concurrency, + "timeout": 600, + "graceful_timeout": 60, + "max_requests": 2000, + "max_requests_jitter": 200, + "keepalive": 60, # AWS ALB and Nginx default to 60 seconds + "post_fork": post_fork, + "loglevel": "error", + } + StandaloneApplication(app, options).run() diff --git a/services/api/src/owl/entrypoints/chat_echo.py b/services/api/src/owl/entrypoints/chat_echo.py deleted file mode 100644 index 71cce87..0000000 --- a/services/api/src/owl/entrypoints/chat_echo.py +++ /dev/null @@ -1,121 +0,0 @@ -from time import time - -from fastapi import FastAPI, Response -from fastapi.responses import StreamingResponse -from pydantic import BaseModel, Field - -from jamaibase.protocol import ( - ChatCompletionChoice, - ChatEntry, - ChatRequest, - CompletionUsage, -) -from owl.configs.manager import ENV_CONFIG - - -class ChatCompletionRequest(ChatRequest): - stream: bool = False - - -class ChatCompletionChoiceDelta(BaseModel): - delta: dict[str, str] = Field(description="A chat completion message generated by the model.") - index: int = Field(description="The index of the choice in the list of choices.") - finish_reason: str | None = Field( - default=None, - description=( - "The reason the model stopped generating tokens. " - "This will be stop if the model hit a natural stop point or a provided stop sequence, " - "length if the maximum number of tokens specified in the request was reached." - ), - ) - - -class ChatCompletionResponse(BaseModel): - id: str = Field( - description="A unique identifier for the chat completion. Each chunk has the same ID." - ) - object: str = Field( - default="chat.completion", - description="Type of API response object.", - examples=["chat.completion"], - ) - created: int = Field( - default_factory=lambda: int(time()), - description="The Unix timestamp (in seconds) of when the chat completion was created.", - ) - model: str = Field(description="The model used for the chat completion.") - choices: list[ChatCompletionChoice | ChatCompletionChoiceDelta] = Field( - description="A list of chat completion choices. Can be more than one if `n` is greater than 1." - ) - usage: CompletionUsage | None = Field( - description="Number of tokens consumed for the completion request.", - examples=[CompletionUsage(), None], - ) - - -app = FastAPI() - - -@app.post("/v1/chat/completions") -async def chat_completion(body: ChatCompletionRequest): - output = body.model_dump_json() - - if body.stream: - - async def stream_response(): - for i, char in enumerate(output): - chunk = ChatCompletionResponse( - id=body.id, - object="chat.completion.chunk", - model=body.model, - choices=[ - ChatCompletionChoiceDelta( - index=0, - delta=dict(content=char), - finish_reason=None if i < len(output) - 1 else "stop", - ) - ], - usage=CompletionUsage( - prompt_tokens=len(output), - completion_tokens=i + 1, - total_tokens=len(output) + i + 1, - ), - ) - yield f"data: {chunk.model_dump()}\n\n" - yield "data: [DONE]\n\n" - - return StreamingResponse(stream_response(), media_type="text/event-stream") - - return ChatCompletionResponse( - id=body.id, - model=body.model, - choices=[ - ChatCompletionChoice( - index=0, message=ChatEntry.assistant(output), finish_reason="stop" - ) - ], - usage=CompletionUsage( - prompt_tokens=len(output), - completion_tokens=len(output), - total_tokens=len(output) + len(output), - ), - ) - - -@app.get("/health", tags=["Health"]) -async def health() -> Response: - """Health check.""" - return Response(status_code=200) - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run( - "owl.entrypoints.chat_echo:app", - reload=False, - host=ENV_CONFIG.owl_host, - port=6868, - workers=1, - limit_concurrency=10, - ) diff --git a/services/api/src/owl/entrypoints/chat_python.py b/services/api/src/owl/entrypoints/chat_python.py deleted file mode 100644 index 66788c3..0000000 --- a/services/api/src/owl/entrypoints/chat_python.py +++ /dev/null @@ -1,137 +0,0 @@ -import io -from contextlib import redirect_stdout -from time import time - -from fastapi import FastAPI -from fastapi.responses import StreamingResponse -from pydantic import BaseModel, Field - -from jamaibase.protocol import ( - ChatCompletionChoice, - ChatEntry, - ChatRequest, - CompletionUsage, -) -from owl.configs.manager import ENV_CONFIG - - -class ChatCompletionRequest(ChatRequest): - stream: bool = False - - -class ChatCompletionChoiceDelta(BaseModel): - delta: dict[str, str] = Field(description="A chat completion message generated by the model.") - index: int = Field(description="The index of the choice in the list of choices.") - finish_reason: str | None = Field( - default=None, - description=( - "The reason the model stopped generating tokens. " - "This will be stop if the model hit a natural stop point or a provided stop sequence, " - "length if the maximum number of tokens specified in the request was reached." - ), - ) - - -class ChatCompletionResponse(BaseModel): - id: str = Field( - description="A unique identifier for the chat completion. Each chunk has the same ID." - ) - object: str = Field( - default="chat.completion", - description="Type of API response object.", - examples=["chat.completion"], - ) - created: int = Field( - default_factory=lambda: int(time()), - description="The Unix timestamp (in seconds) of when the chat completion was created.", - ) - model: str = Field(description="The model used for the chat completion.") - choices: list[ChatCompletionChoice | ChatCompletionChoiceDelta] = Field( - description="A list of chat completion choices. Can be more than one if `n` is greater than 1." - ) - usage: CompletionUsage | None = Field( - description="Number of tokens consumed for the completion request.", - examples=[CompletionUsage(), None], - ) - - -app = FastAPI() - - -def assemble_script(body: ChatCompletionRequest): - messages = [ - message.content - if isinstance(message.content, str) - else "\n".join(d["text"] for d in message if d["type"] == "text") - for message in body.messages - if message.role == "user" - ] - script = "\n".join(messages) - return script - - -def execute_script(script): - with redirect_stdout(io.StringIO(newline="\n")) as f: - exec(script) - output = f.getvalue().strip() - return output.split("\n")[-1] - - -@app.post("/v1/chat/completions") -async def chat_completion(body: ChatCompletionRequest): - script = assemble_script(body) - output = execute_script(script) - - if body.stream: - - async def stream_response(): - for i, char in enumerate(output): - chunk = ChatCompletionResponse( - id=body.id, - object="chat.completion.chunk", - model=body.model, - choices=[ - ChatCompletionChoiceDelta( - index=0, - delta=dict(content=char), - finish_reason=None if i < len(output) - 1 else "stop", - ) - ], - usage=CompletionUsage( - prompt_tokens=len(script), - completion_tokens=i + 1, - total_tokens=len(script) + i + 1, - ), - ) - yield f"data: {chunk.model_dump()}\n\n" - yield "data: [DONE]\n\n" - - return StreamingResponse(stream_response(), media_type="text/event-stream") - - return ChatCompletionResponse( - id=body.id, - model=body.model, - choices=[ - ChatCompletionChoice( - index=0, message=ChatEntry.assistant(output), finish_reason="stop" - ) - ], - usage=CompletionUsage( - prompt_tokens=len(script), - completion_tokens=len(output), - total_tokens=len(script) + len(output), - ), - ) - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run( - "owl.entrypoints.chat_python:app", - reload=False, - host=ENV_CONFIG.owl_host, - port=6869, - workers=1, - limit_concurrency=10, - ) diff --git a/services/api/src/owl/entrypoints/llm.py b/services/api/src/owl/entrypoints/llm.py new file mode 100644 index 0000000..2cb3da9 --- /dev/null +++ b/services/api/src/owl/entrypoints/llm.py @@ -0,0 +1,626 @@ +import base64 +import hashlib +import io +import re +from asyncio import sleep +from contextlib import asynccontextmanager +from time import time +from typing import Any + +import httpx +import numpy as np +from fastapi import FastAPI, Request +from fastapi.responses import ORJSONResponse, StreamingResponse +from loguru import logger +from PIL import Image +from pydantic import BaseModel, Field +from pydub import AudioSegment + +from owl.configs import CACHE, ENV_CONFIG +from owl.types import ( + AudioContent, + ChatCompletionChoice, + ChatCompletionChunkResponse, + ChatCompletionDelta, + ChatCompletionMessage, + ChatCompletionResponse, + ChatCompletionUsage, + ChatRequest, + ChatRole, + CompletionUsageDetails, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, + EmbeddingUsage, + ImageContent, + PromptUsageDetails, + SanitisedNonEmptyStr, + TextContent, + UserAgent, +) +from owl.utils import uuid7_str +from owl.utils.exceptions import BadInputError, JamaiException +from owl.utils.handlers import exception_handler, make_request_log_str, path_not_found_handler +from owl.utils.logging import setup_logger_sinks, suppress_logging_handlers + +# Setup logging +setup_logger_sinks(None) +suppress_logging_handlers(["uvicorn", "litellm", "pottery"], True) + + +class ChatCompletionRequest(ChatRequest): + stream: bool = False # Set default to False + + +class ModelSpec(BaseModel, validate_assignment=True): + id: SanitisedNonEmptyStr = Field(description="Model ID") + ttft_ms: int = Field(0, description="Time to first token (TTFT)") + tpot_ms: int = Field(0, description="Time per output token (TPOT)") + max_context_length: int = Field(int(1e12), description="Max context length") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Startup logic + logger.info(f"Using configuration: {ENV_CONFIG}") + yield + logger.info("Shutting down...") + + # Close Redis connection + logger.info("Closing Redis connection.") + try: + await CACHE.aclose() + except Exception as e: + logger.warning(f"Failed to close Redis connection: {repr(e)}") + + +app = FastAPI(title="Mock LLM", lifespan=lifespan) +app.add_exception_handler(JamaiException, exception_handler) # Suppress starlette traceback +app.add_exception_handler(Exception, exception_handler) +app.add_exception_handler(404, path_not_found_handler) + + +def _describe_image(image_content: ImageContent) -> str: + """ + Describe the image based on an `ImageContent` object. + + Args: + image_content (ImageContent): The `ImageContent` object containing the image URL or base64 data. + + Returns: + description (str): A brief description of the image. + """ + # Get image data from URL or base64 + url = image_content.image_url.url + + if url.startswith("data:"): + # Handle base64 encoded image + mime_match = re.match(r"data:([^;]+);base64,", url) + mime_type = mime_match.group(1) if mime_match else "image/unknown" + base64_data = url.split(",", 1)[1] + image_data = base64.b64decode(base64_data) + img = Image.open(io.BytesIO(image_data)) + else: + # Handle URL using httpx instead of requests + with httpx.Client() as client: + response = client.get(url) + response.raise_for_status() + mime_type = response.headers.get("Content-Type", "image/unknown") + img = Image.open(io.BytesIO(response.content)) + + # Convert to numpy array for calculations + img_array = np.asarray(img) + + # Get dimensions (height, width, channels) + if len(img_array.shape) == 2: # Grayscale image + height, width = img_array.shape + channels = 1 + img_array = img_array.reshape((height, width, 1)) + else: + height, width, channels = img_array.shape + + # Calculate mean and standard deviation + mean_value = float(np.mean(img_array)) + std_value = float(np.std(img_array)) + + return ( + f"There is an image with MIME type [{mime_type}], " + f"shape [{(height, width, channels)}], mean [{mean_value:,.1f}] and std [{std_value:,.1f}]." + ) + + +def _describe_audio(audio_content: AudioContent) -> str: + """ + Describe the audio based on an `AudioContent` object. + + Args: + audio_content (AudioContent): The `AudioContent` object containing the base64 encoded audio data. + + Returns: + description (str): A brief description of the audio. + """ + # Format to MIME type mapping + format_to_mime: dict[str, str] = {"mp3": "audio/mpeg", "wav": "audio/wav"} + # Get audio data and format + base64_data = audio_content.input_audio.data + audio_format = audio_content.input_audio.format + # Decode base64 data + audio_data = base64.b64decode(base64_data) + # Get MIME type + mime_type = format_to_mime.get(audio_format, f"audio/{audio_format}") + # Load audio using pydub + audio_file = io.BytesIO(audio_data) + + if audio_format == "mp3": + audio = AudioSegment.from_mp3(audio_file) + elif audio_format == "wav": + audio = AudioSegment.from_wav(audio_file) + else: + # This shouldn't happen due to the Literal type constraint, but just in case + raise BadInputError(f'Unsupported audio format: "{audio_format}".') + + # Calculate duration in seconds + duration_sec = len(audio) / 1000.0 # pydub uses milliseconds + return ( + f"There is an audio with MIME type [{mime_type}], duration [{duration_sec:,.1f}] seconds." + ) + + +def _describe_text(text_content: str | TextContent) -> str: + """ + Describe the text based on a `TextContent` object. + + Args: + text_content (str | TextContent): A string or `TextContent` object containing the text. + + Returns: + description (TextDescription): A `TextDescription` object with text metadata. + """ + if isinstance(text_content, str): + text = text_content + else: + text = text_content.text + text = text.strip() + num_tokens = 0 if text == "" else len(text.split(" ")) + return f"There is a text with [{num_tokens:,d}] tokens." + + +def _execute_python(code: str, context: dict[str, Any] | None = None) -> Any: + """ + Execute a string containing Python code and return its return value. + This version wraps the code in a function to properly capture return values. + + Args: + code (str): The Python code to execute + context (dict[str, Any] | None, optional): + Dictionary of variables to make available in the execution context. + Defaults to None. + + Returns: + value (Any): The return value of the executed code. + """ + if context is None: + context = {} + # Wrap the code in a function to capture return values + wrapped_code = [ + "def __temp_function():", + "\n".join(" " + f"{k} = {repr(v)}" for k, v in context.items()), + "\n".join(" " + line for line in code.strip().split("\n")), + "__return_value__ = __temp_function()", + ] + # Execute the wrapped code + local_namespace = {} + exec("\n".join(wrapped_code), globals(), local_namespace) + return local_namespace.get("__return_value__") + + +def _parse_chat_model_id(model_id: str) -> ModelSpec: + spec = ModelSpec(id=model_id) + # Time to first token (TTFT) + if match := re.search(r"-ttft-(\d+)", model_id): + spec.ttft_ms = int(match.group(1)) + # Time per output token (TPOT) + if match := re.search(r"-tpot-(\d+)", model_id): + spec.tpot_ms = int(match.group(1)) + # Max context length + if match := re.search(r"-context-(\d+)", model_id): + spec.max_context_length = int(match.group(1)) + return spec + + +@app.post("/v1/chat/completions") +async def chat_completion(body: ChatCompletionRequest): + logger.info(f"Chat completion request: {body}") + model_spec = _parse_chat_model_id(body.model) + num_input_tokens = len(" ".join(m.text_content for m in body.messages).split(" ")) + user_messages = [m for m in body.messages if m.role == ChatRole.USER] + + # Test context length error handling + if num_input_tokens > model_spec.max_context_length: + return ORJSONResponse( + status_code=400, + content={ + "error": { + "message": ( + f"This model's maximum context length is {model_spec.max_context_length} tokens. " + f"However, your messages resulted in {num_input_tokens} tokens. " + "Please reduce the length of the messages." + ), + "type": "invalid_request_error", + "param": "messages", + "code": "context_length_exceeded", + } + }, + ) + elif num_input_tokens + body.max_tokens > model_spec.max_context_length: + return ORJSONResponse( + status_code=400, + content={ + "error": { + "message": ( + f"This model's maximum context length is {model_spec.max_context_length} tokens. " + f"However, you requested {num_input_tokens + body.max_tokens} tokens " + f"({num_input_tokens} in the messages, {body.max_tokens} in the completion). " + "Please reduce the length of the messages or completion." + ), + "type": "invalid_request_error", + "param": "messages", + "code": "context_length_exceeded", + } + }, + ) + + if "lorem" in model_spec.id: + completion_tokens = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.".split(" ") + num_completion_tokens = body.max_tokens + elif "describe" in model_spec.id: + descriptions = [] + if body.messages[0].role == ChatRole.SYSTEM: + descriptions.append(f"System prompt: {_describe_text(body.messages[0].content)}") + if len(user_messages) == 0: + descriptions.append(_describe_text("")) + for message in user_messages: + if isinstance(message.content, str): + descriptions.append(_describe_text(message.content)) + else: + for c in message.content: + if isinstance(c, ImageContent): + descriptions.append(_describe_image(c)) + elif isinstance(c, AudioContent): + descriptions.append(_describe_audio(c)) + elif isinstance(c, TextContent): + descriptions.append(_describe_text(c)) + else: + raise BadInputError(f'Unknown content type: "{type(c)}".') + completion_tokens = "\n".join(descriptions).split(" ") + num_completion_tokens = len(completion_tokens) + elif "echo-request" in model_spec.id: + completion_tokens = body.model_dump_json().split(" ") + num_completion_tokens = len(completion_tokens) + elif "echo-prompt" in model_spec.id: + prompt_concat = " ".join(m.text_content for m in user_messages) + if body.messages[0].role == ChatRole.SYSTEM: + prompt_concat = f"{body.messages[0].text_content} {prompt_concat}" + completion_tokens = prompt_concat.strip().split(" ") + num_completion_tokens = len(completion_tokens) + elif "python" in model_spec.id: + if len(user_messages) == 0: + result = None + else: + result = _execute_python(user_messages[-1].text_content) + completion_tokens = [repr(result)] + num_completion_tokens = len(completion_tokens) + else: + raise BadInputError(f'Unknown model: "{model_spec.id}"') + + if body.stream: + + async def stream_response(): + if model_spec.ttft_ms > 0: + await sleep(model_spec.ttft_ms / 1000) + # Role chunk + for i in range(body.n): + chunk = ChatCompletionChunkResponse( + id=body.id, + model=model_spec.id, + choices=[ + ChatCompletionChoice( + index=i, + delta=ChatCompletionDelta(role="assistant", content="", refusal=None), + logprobs=None, + finish_reason=None, + ) + ], + usage=None, + object="chat.completion.chunk", + created=int(time()), + system_fingerprint=None, + service_tier=None, + ) + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + # Content chunks + for t in range(num_completion_tokens): + # If this is the last token + if t == num_completion_tokens - 1: + content = f"{completion_tokens[t % len(completion_tokens)]}" + else: + content = f"{completion_tokens[t % len(completion_tokens)]} " + for i in range(body.n): + if model_spec.tpot_ms > 0: + await sleep(model_spec.tpot_ms / 1000) + chunk = ChatCompletionChunkResponse( + id=body.id, + model=model_spec.id, + choices=[ + ChatCompletionChoice( + index=i, + delta=ChatCompletionDelta(content=content), + logprobs=None, + finish_reason=None, + ) + ], + usage=None, + object="chat.completion.chunk", + created=int(time()), + system_fingerprint=None, + service_tier=None, + ) + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + # Finish reason chunk + for i in range(body.n): + chunk = ChatCompletionChunkResponse( + id=body.id, + model=model_spec.id, + choices=[ + ChatCompletionChoice( + index=i, + logprobs=None, + finish_reason="length" + if num_completion_tokens == body.max_tokens + else "stop", + ) + ], + usage=None, + object="chat.completion.chunk", + created=int(time()), + system_fingerprint=None, + service_tier=None, + ) + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + # Usage chunk + chunk = ChatCompletionChunkResponse( + id=body.id, + model=model_spec.id, + choices=[], + usage=ChatCompletionUsage( + prompt_tokens=num_input_tokens, + completion_tokens=num_completion_tokens, + total_tokens=num_input_tokens + num_completion_tokens, + prompt_tokens_details=PromptUsageDetails(cached_tokens=0, audio_tokens=0), + completion_tokens_details=CompletionUsageDetails( + audio_tokens=0, + reasoning_tokens=0, + accepted_prediction_tokens=0, + rejected_prediction_tokens=0, + ), + ), + object="chat.completion.chunk", + created=int(time()), + system_fingerprint=None, + service_tier=None, + ) + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(stream_response(), media_type="text/event-stream") + + # Non-stream + if (model_spec.ttft_ms + model_spec.tpot_ms) > 0: + await sleep((model_spec.ttft_ms + model_spec.tpot_ms * len(completion_tokens)) / 1000) + response = ChatCompletionResponse( + id=body.id, + model=model_spec.id, + choices=[ + ChatCompletionChoice( + index=i, + message=ChatCompletionMessage( + content=" ".join( + completion_tokens[t % len(completion_tokens)] + for t in range(num_completion_tokens) + ) + ), + logprobs=None, + finish_reason="length", + ) + for i in range(body.n) + ], + usage=ChatCompletionUsage( + prompt_tokens=num_input_tokens, + completion_tokens=num_completion_tokens, + total_tokens=num_input_tokens + num_completion_tokens, + ), + ) + return response + + +def _parse_embedding_model_id(model_id: str, fallback_dim: int = 768) -> int: + """ + Extract the embedding dimension from the model_id if present (e.g. '...-dim-768'). + Otherwise return fallback_dim. + """ + if match := re.search(r"-dim-(\d+)", model_id): + return int(match.group(1)) + return fallback_dim + + +@app.post("/v1/embeddings") +async def embeddings(body: EmbeddingRequest) -> EmbeddingResponse: + """ + Mock embedding endpoint that deterministically generates embeddings by + seeding NumPy with a hash derived from each input string. + """ + # Validate inputs + inputs: list[str] + if isinstance(body.input, str): + text = body.input.strip() + if text == "": + raise BadInputError("Input cannot be an empty string.") + inputs = [text] + else: + inputs = [] + for i, s in enumerate(body.input): + t = s.strip() + if t == "": + raise BadInputError(f"Input at index {i} cannot be an empty string.") + inputs.append(t) + + # Determine embedding dimension + dim: int + if body.dimensions is not None: + if body.dimensions <= 0: + raise BadInputError("`dimensions` must be a positive integer.") + dim = body.dimensions + else: + dim = _parse_embedding_model_id(body.model, fallback_dim=768) + + # Generate deterministic embeddings per input + data: list[EmbeddingResponseData] = [] + prompt_token_count = 0 + + for idx, text in enumerate(inputs): + # Naive token counting by whitespace + prompt_token_count += 0 if text == "" else len(text.split()) + # Deterministic seed from SHA-256 + sha = hashlib.blake2b(text.encode("utf-8")).hexdigest() + seed = int(sha[:16], 16) % (2**32) + rng = np.random.default_rng(seed) + vec = rng.standard_normal(size=dim, dtype=np.float32) + if body.encoding_format == "float": + emb_value: list[float] | str = vec.tolist() + else: + # base64 encoding of float32 bytes + emb_value = base64.b64encode(vec.tobytes()).decode("ascii") + data.append(EmbeddingResponseData(embedding=emb_value, index=idx)) + + return EmbeddingResponse( + data=data, + model=body.model, + usage=EmbeddingUsage( + prompt_tokens=prompt_token_count, + total_tokens=prompt_token_count, + ), + ) + + +@app.get("/health", tags=["Health"]) +async def health() -> ORJSONResponse: + """Health check.""" + return ORJSONResponse(status_code=200, content={}) + + +@app.middleware("http") +async def log_request(request: Request, call_next): + """ + Args: + request (Request): Starlette request object. + call_next (Callable): A function that will receive the request, + pass it to the path operation, and returns the response generated. + + Returns: + response (Response): Response of the path operation. + """ + # Set request state + request_id = request.headers.get("x-request-id", uuid7_str()) + request.state.id = request_id + request.state.user_agent = UserAgent.from_user_agent_string( + request.headers.get("user-agent", "") + ) + # Call request + logger.info(make_request_log_str(request)) + response = await call_next(request) + response.headers["x-request-id"] = request_id + logger.info(make_request_log_str(request, response.status_code)) + return response + + +if __name__ == "__main__": + import uvicorn + + logger.info(f"Starting LLM test server on {ENV_CONFIG.host}:{ENV_CONFIG.port + 1}") + uvicorn.run( + "owl.entrypoints.llm:app", + reload=False, + host=ENV_CONFIG.host, + port=ENV_CONFIG.port + 1, + workers=2, + limit_concurrency=100, + ) + +""" +OpenAI Chat Completion SSE + + +{ +"error": { + "message": "This model's maximum context length is 16385 tokens. However, your messages resulted in 19901 tokens. Please reduce the length of the messages.", + "type": "invalid_request_error", + "param": "messages", + "code": "context_length_exceeded" + } +} + +{ + "error": { + "message": "This model's maximum context length is 16385 tokens. However, you requested 18242 tokens (16242 in the messages, 2000 in the completion). Please reduce the length of the messages or completion.", + "type": "invalid_request_error", + "param": "messages", + "code": "context_length_exceeded" + } +} + +{ + "id": "chatcmpl-AtBWW4Kf8NoM4WDBaNSBLR8fD0fc6", + "object": "chat.completion", + "created": 1737715700, + "model": "gpt-3.5-turbo-1106", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "\n\nS", + "refusal": null + }, + "logprobs": null, + "finish_reason": "length" + } + ], + "usage": { + "prompt_tokens": 17, + "completion_tokens": 2, + "total_tokens": 19, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "service_tier": "default", + "system_fingerprint": "fp_2f141ce944" +} + +data: {"id":"chatcmpl-AtBSi41j2M6DGdAzfHgpTKjKKqtMy","object":"chat.completion.chunk","created":1737715464,"model":"gpt-3.5-turbo-1106","service_tier":"default","system_fingerprint":"fp_7fe28551a8","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-AtBSi41j2M6DGdAzfHgpTKjKKqtMy","object":"chat.completion.chunk","created":1737715464,"model":"gpt-3.5-turbo-1106","service_tier":"default","system_fingerprint":"fp_7fe28551a8","choices":[{"index":0,"delta":{"content":"S"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-AtBSi41j2M6DGdAzfHgpTKjKKqtMy","object":"chat.completion.chunk","created":1737715464,"model":"gpt-3.5-turbo-1106","service_tier":"default","system_fingerprint":"fp_7fe28551a8","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"length"}]} + +data: {"id":"chatcmpl-AtBbXar0rpsdn69L9cIeeu88frXVd","object":"chat.completion.chunk","created":1737716011,"model":"gpt-3.5-turbo-1106","service_tier":"default","system_fingerprint":"fp_2f141ce944","choices":[],"usage":{"prompt_tokens":17,"completion_tokens":2,"total_tokens":19,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} + +data: [DONE] +""" diff --git a/services/api/src/owl/entrypoints/starling.py b/services/api/src/owl/entrypoints/starling.py index f36efef..d39d577 100644 --- a/services/api/src/owl/entrypoints/starling.py +++ b/services/api/src/owl/entrypoints/starling.py @@ -8,85 +8,106 @@ ``` """ -import os +from datetime import timedelta -from celery import Celery from celery.schedules import crontab +from celery.signals import worker_process_init from loguru import logger - -from owl.configs.manager import CONFIG, ENV_CONFIG +from opentelemetry import metrics, trace +from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter +from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry.instrumentation.celery import CeleryInstrumentor +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor + +from owl.configs import ENV_CONFIG, celery_app +from owl.utils import uuid7_str from owl.utils.logging import ( replace_logging_handlers, setup_logger_sinks, suppress_logging_handlers, ) +from owl.utils.loguru_otlp_handler import OTLPHandler -# Maybe purge Redis data -if ENV_CONFIG.owl_cache_purge: - CONFIG.purge() - -SCHEDULER_DB = f"{ENV_CONFIG.owl_db_dir}/_scheduler" logger.enable("") -setup_logger_sinks(f"{ENV_CONFIG.owl_log_dir}/starling.log") +setup_logger_sinks(None) replace_logging_handlers(["uvicorn.access"], False) -suppress_logging_handlers(["litellm", "openmeter", "azure"], True) - - -try: - if not os.path.exists(SCHEDULER_DB): - os.makedirs(SCHEDULER_DB, exist_ok=True) - logger.info(f"Created scheduler directory at {SCHEDULER_DB}") - else: - logger.info(f"Scheduler directory already exists at {SCHEDULER_DB}") -except Exception as e: - logger.error(f"Error creating scheduler directory: {e}") - - -# Set up Celery -app = Celery("tasks", broker=f"redis://{ENV_CONFIG.owl_redis_host}:{ENV_CONFIG.owl_redis_port}/0") - -# Configure Celery -app.conf.update( - result_backend=f"redis://{ENV_CONFIG.owl_redis_host}:{ENV_CONFIG.owl_redis_port}/0", - task_serializer="json", - accept_content=["json"], - result_serializer="json", - result_expires=36000, - timezone="UTC", - enable_utc=True, - beat_schedule_filename=os.path.join(SCHEDULER_DB, "celerybeat-schedule"), -) +suppress_logging_handlers(["uvicorn", "litellm", "azure", "openmeter", "pottery"], True) + + +@worker_process_init.connect(weak=False) +def init_celery_tracing(*args, **kwargs): + CeleryInstrumentor().instrument() + + resource = Resource.create( + { + "service.name": "starling", + "service.instance.id": uuid7_str(), + } + ) + # Meter provider configuration + reader = PeriodicExportingMetricReader( + OTLPMetricExporter( + endpoint=f"http://{ENV_CONFIG.opentelemetry_host}:{ENV_CONFIG.opentelemetry_port}" + ), + export_interval_millis=1, + ) + provider = MeterProvider(resource=resource, metric_readers=[reader]) + metrics.set_meter_provider(provider) + # Trace provider configuration + trace_provider = TracerProvider(resource=resource) + trace_provider.add_span_processor( + BatchSpanProcessor( + OTLPSpanExporter( + endpoint=f"http://{ENV_CONFIG.opentelemetry_host}:{ENV_CONFIG.opentelemetry_port}" + ), + schedule_delay_millis=1, + ) + ) + trace.set_tracer_provider(trace_provider) + + otlp_exporter = OTLPLogExporter( + endpoint=f"http://{ENV_CONFIG.opentelemetry_host}:{ENV_CONFIG.opentelemetry_port}" + ) + + # Create an instance of OTLPHandler + otlp_handler = OTLPHandler.create( + service_name="starling", + exporter=otlp_exporter, + development_mode=False, # Set to True for development + export_interval_ms=1, + ) + + logger.add(otlp_handler.sink, level="INFO") + # Load task modules -app.conf.imports = [ +celery_app.conf.imports = [ + # "owl.tasks.checks", + "owl.tasks.database", + "owl.tasks.gen_table", "owl.tasks.genitor", - "owl.tasks.storage", ] # Configure the scheduler -app.conf.beat_schedule = {} +# celery_app.conf.beat_schedule = { +# "periodic-model-check": { +# "task": "owl.tasks.checks.test_models", +# "schedule": crontab(minute="*/10"), +# } +# } # Add periodic storage update task if service_key_plain is not empty if ENV_CONFIG.service_key_plain != "": - app.conf.beat_schedule["periodic-storage-update"] = { - "task": "owl.tasks.storage.periodic_storage_update", - "schedule": crontab(minute=f"*/{ENV_CONFIG.owl_compute_storage_period_min}"), + celery_app.conf.beat_schedule["periodic-flush-clickhouse-buffer"] = { + "task": "owl.tasks.database.run_periodic_flush_buffer", + "schedule": timedelta(seconds=ENV_CONFIG.flush_clickhouse_buffer_sec), } -# Add Lance-related tasks -app.conf.beat_schedule.update( - { - "lance-periodic-reindex": { - "task": "owl.tasks.storage.lance_periodic_reindex", - "schedule": crontab(minute=f"*/{max(1,ENV_CONFIG.owl_reindex_period_sec//60)}"), - }, - "lance-periodic-optimize": { - "task": "owl.tasks.storage.lance_periodic_optimize", - "schedule": crontab(minute=f"*/{max(1,ENV_CONFIG.owl_optimize_period_sec//60)}"), - }, - } -) - # Check if S3-related environment variables are present and non-empty if all( getattr(ENV_CONFIG, attr, "") # Use getattr to safely access attributes @@ -98,7 +119,7 @@ ] ): logger.info("S3 Backup tasks has been configured.") - app.conf.beat_schedule.update( + celery_app.conf.beat_schedule.update( { "backup-to-s3": { "task": "owl.tasks.genitor.backup_to_s3", diff --git a/services/api/src/owl/llm.py b/services/api/src/owl/llm.py deleted file mode 100644 index f4bf99c..0000000 --- a/services/api/src/owl/llm.py +++ /dev/null @@ -1,698 +0,0 @@ -from copy import deepcopy -from datetime import datetime, timezone -from functools import lru_cache -from os.path import join -from time import time -from typing import AsyncGenerator - -import litellm -import openai -from fastapi import Request -from litellm import Router -from litellm.router import RetryPolicy -from loguru import logger - -from jamaibase.exceptions import ( - BadInputError, - ContextOverflowError, - ExternalAuthError, - JamaiException, - ResourceNotFoundError, - ServerBusyError, - UnexpectedError, -) -from owl.billing import BillingManager -from owl.configs.manager import ENV_CONFIG -from owl.db.gen_table import KnowledgeTable -from owl.models import CloudEmbedder, CloudReranker -from owl.protocol import ( - ChatCompletionChoiceDelta, - ChatCompletionChoiceOutput, - ChatCompletionChunk, - ChatEntry, - ChatRole, - Chunk, - CompletionUsage, - ExternalKeys, - LLMModelConfig, - ModelInfo, - ModelInfoResponse, - ModelListConfig, - RAGParams, - References, -) -from owl.utils import mask_content, mask_string, select_external_api_key - -litellm.drop_params = True -litellm.set_verbose = False -litellm.suppress_debug_info = True - - -@lru_cache(maxsize=64) -def _get_llm_router(model_json: str, external_api_keys: str): - models = ModelListConfig.model_validate_json(model_json).llm_models - ExternalApiKeys = ExternalKeys.model_validate_json(external_api_keys) - # refer to https://docs.litellm.ai/docs/routing for more details - return Router( - model_list=[ - { - "model_name": m.id, - "litellm_params": { - "model": deployment.litellm_id if deployment.litellm_id.strip() else m.id, - "api_key": select_external_api_key(ExternalApiKeys, deployment.provider), - "api_base": deployment.api_base if deployment.api_base.strip() else None, - }, - } - for m in models - for deployment in m.deployments - ], - routing_strategy="latency-based-routing", - num_retries=3, - retry_policy=RetryPolicy( - TimeoutErrorRetries=3, - RateLimitErrorRetries=3, - ContentPolicyViolationErrorRetries=3, - AuthenticationErrorRetries=0, - BadRequestErrorRetries=0, - ContextWindowExceededErrorRetries=0, - ), - retry_after=5.0, - timeout=ENV_CONFIG.owl_llm_timeout_sec, - allowed_fails=3, - cooldown_time=5.5, - debug_level="DEBUG", - redis_host=ENV_CONFIG.owl_redis_host, - redis_port=ENV_CONFIG.owl_redis_port, - ) - - -class LLMEngine: - def __init__( - self, - *, - request: Request, - ) -> None: - self.request = request - self.id: str = request.state.id - self.organization_id: str = request.state.org_id - self.project_id: str = request.state.project_id - self.org_models: ModelListConfig = request.state.org_models - self.external_keys: ExternalKeys = request.state.external_keys - self.is_browser: bool = request.state.user_agent.is_browser - self._billing: BillingManager = request.state.billing - - @property - def router(self): - return _get_llm_router( - model_json=self.request.state.all_models.model_dump_json(), - external_api_keys=self.external_keys.model_dump_json(), - ) - - @staticmethod - def _prepare_hyperparams(model: str, hyperparams: dict, **kwargs) -> dict: - if isinstance(hyperparams.get("stop", None), list) and len(hyperparams["stop"]) == 0: - hyperparams["stop"] = None - hyperparams.update(kwargs) - if model.startswith("anthropic"): - hyperparams["extra_headers"] = {"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"} - return hyperparams - - @staticmethod - def _prepare_messages(messages: list[ChatEntry | dict]) -> list[ChatEntry]: - messages: list[ChatEntry] = [ChatEntry.model_validate(m) for m in messages] - if len(messages) == 0: - raise ValueError("`messages` is an empty list.") - elif len(messages) == 1: - # [user] - if messages[0].role in (ChatRole.USER.value, ChatRole.USER): - pass - # [system] - elif messages[0].role in (ChatRole.SYSTEM.value, ChatRole.SYSTEM): - messages.append(ChatEntry.user(content=".")) - # [assistant] - else: - messages = [ChatEntry.system(content="."), ChatEntry.user(content=".")] + messages - else: - # [user, ...] - if messages[0].role in (ChatRole.USER.value, ChatRole.USER): - pass - # [system, ...] - elif messages[0].role in (ChatRole.SYSTEM.value, ChatRole.SYSTEM): - # [system, assistant, ...] - if messages[1].role in (ChatRole.ASSISTANT.value, ChatRole.ASSISTANT): - messages.insert(1, ChatEntry.user(content=".")) - # [assistant, ...] - else: - messages = [ChatEntry.system(content="."), ChatEntry.user(content=".")] + messages - return messages - - def _log_completion_masked( - self, - model: str, - messages: list[ChatEntry], - **hyperparams, - ): - body = dict( - model=model, - messages=[ - {"role": m["role"], "content": mask_content(m["content"])} for m in messages - ], - **hyperparams, - ) - logger.info(f"{self.id} - Generating chat completions: {body}") - - def _log_exception( - self, - model: str, - messages: list[ChatEntry], - api_key: str = "", - **hyperparams, - ): - body = dict( - model=model, - messages=[{"role": m["role"], "content": m["content"]} for m in messages], - api_key=mask_string(api_key), - **hyperparams, - ) - logger.exception(f"{self.id} - Chat completion got unexpected error !!! {body}") - - def _map_and_log_exception( - self, - e: Exception, - model: str, - messages: list[ChatEntry], - api_key: str = "", - **hyperparams, - ) -> Exception: - request_id = hyperparams.get("id", None) - err_mssg = getattr(e, "message", str(e)) - log_mssg = f"{request_id} - LiteLLM {e.__class__.__name__}: {err_mssg}" - if isinstance(e, JamaiException): - logger.info(log_mssg) - return e - elif isinstance(e, openai.BadRequestError): - logger.info(log_mssg) - return BadInputError(err_mssg) - elif isinstance(e, openai.AuthenticationError): - logger.info(log_mssg) - return ExternalAuthError(err_mssg) - elif isinstance(e, (openai.RateLimitError, openai.APITimeoutError)): - logger.info(log_mssg) - return ServerBusyError(err_mssg) - elif isinstance(e, openai.OpenAIError): - logger.warning(log_mssg) - return UnexpectedError(err_mssg) - else: - self._log_exception(model, messages, api_key, **hyperparams) - return UnexpectedError(err_mssg) - - def model_info( - self, - model: str = "", - capabilities: list[str] | None = None, - ) -> ModelInfoResponse: - model_list: ModelListConfig = self.request.state.all_models - models = model_list.models - # Filter by name - if model != "": - models = [m for m in models if m.id == model] - # Filter by capability - if capabilities is not None: - for capability in capabilities: - models = [m for m in models if capability in m.capabilities] - if len(models) == 0: - raise ResourceNotFoundError(f"No model found with capabilities: {capabilities}") - response = ModelInfoResponse( - data=[ModelInfo.model_validate(m.model_dump()) for m in models] - ) - return response - - def model_names( - self, - prefer: str = "", - capabilities: list[str] | None = None, - ) -> list[str]: - models = self.model_info( - model="", - capabilities=capabilities, - ) - names = [m.id for m in models.data] - if prefer in names: - names.remove(prefer) - names.insert(0, prefer) - return names - - def get_model_name(self, model: str, capabilities: list[str] | None = None) -> str: - capabilities = ["chat"] if capabilities is None else capabilities - models = self.model_info( - model="", - capabilities=capabilities, - ) - return [m.name for m in models.data if m.id == model][0] - - def validate_model_id( - self, - model: str = "", - capabilities: list[str] | None = None, - ) -> str: - capabilities = ["chat"] if capabilities is None else capabilities - if model == "": - models: ModelListConfig = self.request.state.all_models - model = models.get_default_model(capabilities) - logger.info(f'{self.id} - Empty model changed to "{model}"') - else: - models = self.model_info( - model="", - capabilities=capabilities, - ) - model_ids = [m.id for m in models.data] - if model not in model_ids: - err_mssg = ( - f'Model "{model}" is not available among models with capabilities {capabilities}. ' - f"Choose from: {model_ids}" - ) - logger.info(f"{self.id} - {err_mssg}") - # Return different error message depending if request came from browser - if self.is_browser: - model_names = ", ".join(m.name for m in models.data) - err_mssg = ( - f'Model "{model}" is not available among models with capabilities: {', '.join(capabilities)}. ' - f'Choose from: {model_names}' - ) - raise ResourceNotFoundError(err_mssg) - return model - - async def generate_stream( - self, - model: str, - messages: list[ChatEntry | dict], - capabilities: list[str] | None = None, - **hyperparams, - ) -> AsyncGenerator[ChatCompletionChunk, None]: - api_key = "" - usage = None - try: - model = model.strip() - # check audio model type - is_audio_gen_model = False - if model != "": - model_config: LLMModelConfig = self.request.state.all_models.get_llm_model_info( - model - ) - if ( - "audio" in model_config.capabilities - and model_config.deployments[0].provider == "openai" - ): - is_audio_gen_model = True - hyperparams = self._prepare_hyperparams(model, hyperparams, stream=True) - messages = self._prepare_messages(messages) - # omit system prompt for audio input with audio gen - if is_audio_gen_model and messages[0].role in (ChatRole.SYSTEM.value, ChatRole.SYSTEM): - messages = messages[1:] - messages = [m.model_dump(mode="json", exclude_none=True) for m in messages] - model = self.validate_model_id( - model=model, - capabilities=capabilities, - ) - self._log_completion_masked(model, messages, **hyperparams) - if is_audio_gen_model: - response = await self.router.acompletion( - model=model, - modalities=["text", "audio"], - audio={"voice": "alloy", "format": "pcm16"}, - messages=messages, - # Fixes discrepancy between stream and non-stream token usage - stream_options={"include_usage": True}, - **hyperparams, - ) - else: - response = await self.router.acompletion( - model=model, - messages=messages, - # Fixes discrepancy between stream and non-stream token usage - stream_options={"include_usage": True}, - **hyperparams, - ) - output_text = "" - usage = CompletionUsage() - async for chunk in response: - if hasattr(chunk, "usage"): - usage = CompletionUsage( - prompt_tokens=chunk.usage.prompt_tokens, - completion_tokens=chunk.usage.completion_tokens, - total_tokens=chunk.usage.total_tokens, - ) - yield ChatCompletionChunk( - id=self.id, - object="chat.completion.chunk", - created=int(time()), - model=model, - usage=usage, - choices=[ - ChatCompletionChoiceDelta( - message=ChatEntry.assistant(choice.delta.audio.get("transcript", "")) - if is_audio_gen_model and choice.delta.audio is not None - else ChatCompletionChoiceOutput.assistant( - choice.delta.content, - tool_calls=[ - tool_call.model_dump() for tool_call in choice.delta.tool_calls - ] - if isinstance(chunk.choices[0].delta.tool_calls, list) - else None, - ), - index=choice.index, - finish_reason=choice.get( - "finish_reason", chunk.get("finish_reason", None) - ), - ) - for choice in chunk.choices - ], - ) - if is_audio_gen_model and chunk.choices[0].delta.audio is not None: - output_text += chunk.choices[0].delta.audio.get("transcript", "") - else: - content = chunk.choices[0].delta.content - output_text += content if content else "" - logger.info(f"{self.id} - Streamed completion: <{mask_string(output_text)}>") - - self._billing.create_llm_events( - model=model, - input_tokens=usage.prompt_tokens, - output_tokens=usage.completion_tokens, - ) - except Exception as e: - self._map_and_log_exception(e, model, messages, api_key, **hyperparams) - yield ChatCompletionChunk( - id=self.id, - object="chat.completion.chunk", - created=int(time()), - model=model, - usage=usage, - choices=[ - ChatCompletionChoiceDelta( - message=ChatEntry.assistant(f"[ERROR] {e!r}"), - index=0, - finish_reason="error", - ) - ], - ) - - async def generate( - self, - model: str, - messages: list[ChatEntry | dict], - capabilities: list[str] | None = None, - **hyperparams, - ) -> ChatCompletionChunk: - api_key = "" - try: - model = model.strip() - # check audio model type - is_audio_gen_model = False - if model != "": - model_config: LLMModelConfig = self.request.state.all_models.get_llm_model_info( - model - ) - if ( - "audio" in model_config.capabilities - and model_config.deployments[0].provider == "openai" - ): - is_audio_gen_model = True - hyperparams = self._prepare_hyperparams(model, hyperparams, stream=False) - messages = self._prepare_messages(messages) - # omit system prompt for audio input with audio gen - if is_audio_gen_model and messages[0].role in (ChatRole.SYSTEM.value, ChatRole.SYSTEM): - messages = messages[1:] - messages = [m.model_dump(mode="json", exclude_none=True) for m in messages] - model = self.validate_model_id( - model=model, - capabilities=capabilities, - ) - self._log_completion_masked(model, messages, **hyperparams) - if is_audio_gen_model: - completion = await self.router.acompletion( - model=model, - modalities=["text", "audio"], - audio={"voice": "alloy", "format": "pcm16"}, - messages=messages, - **hyperparams, - ) - else: - completion = await self.router.acompletion( - model=model, - messages=messages, - **hyperparams, - ) - self._billing.create_llm_events( - model=model, - input_tokens=completion.usage.prompt_tokens, - output_tokens=completion.usage.completion_tokens, - ) - choices = [] - for choice in completion.choices: - if is_audio_gen_model and choice.message.audio.transcript is not None: - choice.message.content = choice.message.audio.transcript - choices.append(choice.model_dump()) - completion = ChatCompletionChunk( - id=self.id, - object="chat.completion", - created=completion.created, - model=model, - usage=completion.usage.model_dump(), - choices=choices, - ) - logger.info(f"{self.id} - Generated completion: <{mask_string(completion.text)}>") - return completion - except Exception as e: - raise self._map_and_log_exception(e, model, messages, api_key, **hyperparams) from e - - async def retrieve_references( - self, - model: str, - messages: list[ChatEntry | dict], - rag_params: RAGParams | dict | None, - **hyperparams, - ) -> tuple[list[ChatEntry], References | None]: - if rag_params is None: - return messages, None - - hyperparams = self._prepare_hyperparams(model, hyperparams) - messages = self._prepare_messages(messages) - has_file_input = True if isinstance(messages[-1].content, list) else False - rag_params = RAGParams.model_validate(rag_params) - search_query = rag_params.search_query - # Reformulate query if not provided - if search_query == "": - hyperparams.update(temperature=0.01, top_p=0.01, max_tokens=512) - rewriter_messages = deepcopy(messages) - if rewriter_messages[0].role not in (ChatRole.SYSTEM.value, ChatRole.SYSTEM): - logger.warning(f"{self.id} - `messages[0].role` is not `system` !!!") - rewriter_messages.insert(0, ChatEntry.system("You are a concise assistant.")) - if has_file_input: - query_ori = rewriter_messages[-1].content[0]["text"] - else: - query_ori = rewriter_messages[-1].content - - # Search query rewriter - now = datetime.now(timezone.utc) - rewriter_messages[-1] = ChatEntry.user( - ( - f"QUESTION: `{query_ori}`\n\n" - f"Current datetime: {now.isoformat()}\n" - "You need to retrieve documents that are relevant to the user by using a search engine. " - "Use the information provided to generate one good Google search query sentence in English. " - "Do not include any search modifiers or symbols. " - "Make sure all relevant keywords are in the sentence. " - "Convert any ranges into comma-separated list of items. " - "Any date or time in the query should be in numeric format, " - f'for example last year is "{now.year - 1}", last 2 years is "{now.year - 1}, {now.year}". ' - "Reply with only the query. Do not include reasoning, explanations, or notes." - ) - ) - completion = await self.generate( - model=model, - messages=rewriter_messages, - **hyperparams, - ) - search_query = completion.text.strip() - if search_query.startswith('"') and search_query.endswith('"'): - search_query = search_query[1:-1] - logger.info( - ( - f'{self.id} - Rewritten query using "{model}": ' - f"<{mask_string(query_ori)}> -> <{mask_string(search_query)}>" - ) - ) - - # Query - rag_params.search_query = search_query - if rag_params.reranking_model is not None: - reranker = CloudReranker(request=self.request) - else: - reranker = None - embedder = CloudEmbedder(request=self.request) - logger.info(f"{self.id} - Querying table: {rag_params}") - lance_path = join( - ENV_CONFIG.owl_db_dir, self.organization_id, self.project_id, "knowledge" - ) - sqlite_path = f"sqlite:///{lance_path}.db" - table = KnowledgeTable(sqlite_path, lance_path) - with table.create_session() as session: - rows = await table.hybrid_search( - session=session, - table_id=rag_params.table_id, - embedder=embedder, - reranker=reranker, - reranking_model=rag_params.reranking_model, - query=search_query, - limit=rag_params.k, - remove_state_cols=True, - float_decimals=0, - vec_decimals=0, - ) - if len(rows) > 1: - logger.info( - ( - f"{self.id} - Retrieved {len(rows):,d} rows from hybrid search: " - f"[{self._mask_retrieved_row(rows[0])}, ..., {self._mask_retrieved_row(rows[-1])}]" - ) - ) - elif len(rows) == 1: - logger.info( - ( - f"{self.id} - Retrieved 1 row from hybrid search: " - f"[{self._mask_retrieved_row(rows[0])}]" - ) - ) - else: - logger.warning(f"{self.id} - Failed to retrieve any rows from hybrid search !") - chunks = [ - Chunk( - text="" if row["Text"] is None else row["Text"], - title="" if row["Title"] is None else row["Title"], - page=row["Page"], - document_id="" if row["File ID"] is None else row["File ID"], - chunk_id=row["ID"], - ) - for row in rows - ] - references = References(chunks=chunks, search_query=search_query) - - # Generate - # https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chains/retrieval_qa/prompt.py - new_prompt = """UP-TO-DATE CONTEXT:\n\n""" - for chunk in chunks: - new_prompt += f""" -# Document: {chunk.title} -## Document ID: {chunk.chunk_id} - -{chunk.text} - -""" - # new_prompt += f""" - # QUESTION:\n{body.messages[-1].content} - - # Answer the question with citation of relevant documents in the form of `\cite{{Document ID}}`. - # """ # noqa: W605 - new_prompt += f""" -QUESTION:\n{messages[-1].content[0]["text"].strip() if has_file_input else messages[-1].content.strip()} - -Answer the question. -""" # noqa: W605 - logger.debug( - "{id} - Constructed new user prompt: {prompt}", - id=self.id, - prompt=new_prompt, - ) - if has_file_input: - new_content = [{"type": "text", "text": new_prompt}, messages[-1].content[1]] - else: - new_content = new_prompt - messages[-1] = ChatEntry.user(content=new_content) - return messages, references - - @staticmethod - def _mask_retrieved_row(row: dict[str, str | None]): - return { - "ID": row["ID"], - "File ID": row["File ID"], - "Title": mask_string(row["Title"]), - "Text": mask_string(row["Text"]), - "Page": str(row["Page"]), - } - - async def rag_stream( - self, - model: str, - messages: list[ChatEntry | dict], - rag_params: RAGParams | None = None, - **hyperparams, - ) -> AsyncGenerator[References | ChatCompletionChunk, None]: - try: - hyperparams = self._prepare_hyperparams(model, hyperparams) - messages, references = await self.retrieve_references( - model=model, - messages=messages, - rag_params=rag_params, - **hyperparams, - ) - if references is not None: - yield references - async for chunk in self.generate_stream( - model=model, - messages=messages, - **hyperparams, - ): - yield chunk - except Exception as e: - self._log_exception(model, messages, **hyperparams) - yield ChatCompletionChunk( - id=self.id, - object="chat.completion.chunk", - created=int(time()), - model=model, - usage=None, - choices=[ - ChatCompletionChoiceDelta( - message=ChatEntry.assistant(f"[ERROR] {e!r}"), - index=0, - finish_reason="error", - ) - ], - ) - - async def rag( - self, - model: str, - messages: list[ChatEntry | dict], - capabilities: list[str] | None = None, - rag_params: RAGParams | dict | None = None, - **hyperparams, - ) -> ChatCompletionChunk: - hyperparams = self._prepare_hyperparams(model, hyperparams) - messages, references = await self.retrieve_references( - model=model, - messages=messages, - rag_params=rag_params, - **hyperparams, - ) - try: - response = await self.generate( - model=model, - messages=messages, - capabilities=capabilities, - **hyperparams, - ) - response.references = references - except ContextOverflowError: - logger.warning(f"{self.id} - Chat is too long, returning references only.") - response = ChatCompletionChunk( - id=self.id, - object="chat.completion", - created=int(time()), - model=model, - usage=None, - choices=[], - references=references, - ) - return response diff --git a/services/api/src/owl/loaders.py b/services/api/src/owl/loaders.py deleted file mode 100644 index f53a4c0..0000000 --- a/services/api/src/owl/loaders.py +++ /dev/null @@ -1,283 +0,0 @@ -import re -import sys -from os.path import join, splitext -from tempfile import TemporaryDirectory - -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_core.documents.base import Document -from loguru import logger - -from jamaibase.exceptions import BadInputError -from owl.configs.manager import ENV_CONFIG -from owl.docio import DocIOAPIFileLoader -from owl.protocol import Chunk, SplitChunksParams, SplitChunksRequest -from owl.unstructuredio import UnstructuredAPIFileLoader - -# build a table mapping all non-printable characters to None -NOPRINT_TRANS_TABLE = { - i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable() and chr(i) != "\n" -} - - -def make_printable(s: str) -> str: - """ - Replace non-printable characters in a string using - `translate()` that removes characters that map to None. - - # https://stackoverflow.com/a/54451873 - """ - return s.translate(NOPRINT_TRANS_TABLE) - - -def format_chunks(documents: list[Document], file_name: str, page: int = None) -> list[Chunk]: - if page is not None: - for d in documents: - d.metadata["page"] = page - chunks = [ - # TODO: Probably can use regex for this - # Replace vertical tabs, form feed, Unicode replacement character - # page_content=d.page_content.replace("\x0c", " ") - # .replace("\x0b", " ") - # .replace("\uFFFD", ""), - # For now we use a more aggressive strategy - Chunk( - text=make_printable(d.page_content), - title=d.metadata.get("title", ""), - page=d.metadata.get("page", 0), - file_name=file_name, - file_path=file_name, - metadata=d.metadata, - ) - for d in documents - ] - return chunks - - -async def load_file( - file_name: str, - content: bytes, - chunk_size: int, - chunk_overlap: int, -) -> list[Chunk]: - """ - Asynchronously loads and processes a file, converting its content into a list of Chunk objects. - - Args: - file_name (str): The name of the file to be loaded. - content (bytes): The binary content of the file. - chunk_size (int): The desired size of each chunk. - chunk_overlap (int): The amount of overlap between chunks. - - Returns: - list[Chunk]: A list of Chunk objects representing the processed file content. - - Raises: - ValueError: If the file type is not supported. - """ - - ext = splitext(file_name)[1].lower() - with TemporaryDirectory() as tmp_dir_path: - tmp_path = join(tmp_dir_path, f"tmpfile{ext}") - with open(tmp_path, "wb") as tmp: - tmp.write(content) - tmp.flush() - logger.debug(f"Loading from temporary file: {tmp_path}") - - if ext in (".csv", ".tsv", ".json", ".jsonl"): - loader = DocIOAPIFileLoader(tmp_path, ENV_CONFIG.docio_url) - documents = loader.load() - logger.debug('File "{file_name}" loaded: {docs}', file_name=file_name, docs=documents) - chunks = format_chunks(documents, file_name, page=1) - if ext == ".json": - chunks = split_chunks( - SplitChunksRequest( - chunks=chunks, - params=SplitChunksParams( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - ), - ) - ) - - elif ext in (".html", ".xml", ".pptx", ".ppt", ".xlsx", ".xls", ".docx", ".doc"): - loader = UnstructuredAPIFileLoader( - tmp_path, - url=ENV_CONFIG.unstructuredio_url, - api_key=ENV_CONFIG.unstructuredio_api_key_plain, - mode="paged", - xml_keep_tags=True, - ) - documents = await loader.aload() - logger.debug('File "{file_name}" loaded: {docs}', file_name=file_name, docs=documents) - chunks = format_chunks(documents, file_name) - chunks = split_chunks( - SplitChunksRequest( - chunks=chunks, - params=SplitChunksParams( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - ), - ) - ) - - elif ext in (".md", ".txt"): - loader = UnstructuredAPIFileLoader( - tmp_path, - url=ENV_CONFIG.unstructuredio_url, - api_key=ENV_CONFIG.unstructuredio_api_key_plain, - mode="elements", - chunking_strategy="by_title", - max_characters=chunk_size, - overlap=chunk_overlap, - ) - documents = await loader.aload() - logger.debug('File "{file_name}" loaded: {docs}', file_name=file_name, docs=documents) - chunks = format_chunks(documents, file_name) - - elif ext == ".pdf": - - def unstructured_api_file_loader( - strategy: str, split_pdf_page: bool - ) -> UnstructuredAPIFileLoader: - return UnstructuredAPIFileLoader( - tmp_path, - url=ENV_CONFIG.unstructuredio_url, - api_key=ENV_CONFIG.unstructuredio_api_key_plain, - mode="elements", - strategy=strategy, - chunking_strategy="by_title", - max_characters=chunk_size, - overlap=chunk_overlap, - multipage_sections=False, # respect page boundaries - include_page_breaks=True, - split_pdf_page=split_pdf_page, - ) - - if ENV_CONFIG.owl_fast_pdf_parsing: - strategy, split_pdf_page = "fast", False - documents = await unstructured_api_file_loader( - strategy=strategy, split_pdf_page=split_pdf_page - ).aload() - if len(documents) == 0: - strategy = "ocr_only" - logger.info( - "[Scan PDF Detected]: No text or content is found, running `ocr` mode." - ) - else: - strategy, split_pdf_page = "hi_res", True - - documents = await unstructured_api_file_loader( - strategy=strategy, split_pdf_page=split_pdf_page - ).aload() - logger.info( - f"File '{file_name}' parsed in `{strategy}` mode {'with' if split_pdf_page else 'without'} partitioning." - ) - logger.debug(f"File '{file_name}' content: {documents}") - chunks = format_chunks(documents, file_name) - if strategy == "hi_res": - chunks = combine_table_chunks(chunks=chunks) - - else: - raise BadInputError(f'File type "{ext}" is not supported at the moment.') - - logger.info(f'File "{file_name}" loaded and split into {len(chunks):,d} chunks.') - return chunks - - -def combine_table_chunks(chunks: list[Chunk]) -> list[Chunk]: - """Combines chunks identified as parts of a table into a single chunk. - - This function iterates through the chunks and identifies consecutive chunks that - belong to the same table based on the presence of "text_as_html" and "is_continuation" - metadata flags. It then merges these chunks into a single chunk, preserving the - table's HTML structure. - - Args: - chunks (List[Chunk]): A list of Chunk objects. - - Returns: - List[Chunk]: A list of Chunk objects with table chunks combined. - """ - table_chunk_idx_groups = {} - current_table_start_idx = 0 - for i, chunk in enumerate(chunks): - if "text_as_html" in chunk.metadata and chunk.metadata.get("is_continuation", False): - table_chunk_idx_groups[current_table_start_idx].append(i) - elif "text_as_html" in chunk.metadata: - current_table_start_idx = i - table_chunk_idx_groups[current_table_start_idx] = [current_table_start_idx] - chunk.metadata.pop("orig_elements", None) - - table_indexes = table_chunk_idx_groups.keys() - processed_chunks = [] - current_table_start_idx = 0 - current_table_end_idx = 0 - table_chunk = Chunk(text="") - for i, chunk in enumerate(chunks): - if i in table_indexes: - current_table_start_idx = i - current_table_end_idx = table_chunk_idx_groups[i][-1] - table_chunk = Chunk( - text=chunk.metadata.get("text_as_html", chunk.text), - title=chunk.title, - page=chunk.page, - file_name=chunk.file_name, - file_path=chunk.file_path, - metadata=chunk.metadata.copy(), - ) - table_chunk.metadata.pop("text_as_html", None) - if current_table_end_idx == current_table_start_idx: - processed_chunks.append(table_chunk) - elif i > current_table_start_idx and i <= current_table_end_idx: - table_chunk.text += chunk.metadata.get("text_as_html", chunk.text) - if i == current_table_end_idx: - processed_chunks.append(table_chunk) - else: - processed_chunks.append(chunk) - - return processed_chunks - - -def split_chunks(request: SplitChunksRequest) -> list[Chunk]: - _id = request.id - logger.info(f"{_id} - Split documents request: {request.str_trunc()}") - if request.params.method == "RecursiveCharacterTextSplitter": - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=request.params.chunk_size, - chunk_overlap=request.params.chunk_overlap, - ) - else: - raise ValueError(f"Split method not supported: {request.params.method}") - - try: - chunks = [] - for chunk in request.chunks: - # Module-level functions store compiled object in a cache - text_tables_parts = re.split(r"(.*?
)", chunk.text, flags=re.DOTALL) - table_split_texts = [part for part in text_tables_parts if part] - for table_split_text in table_split_texts: - if table_split_text.startswith("") and table_split_text.endswith( - "
" - ): - chunks.append(chunk) - else: - chunks += [ - Chunk( - text=d.page_content, - title=chunk.title, - page=chunk.page, - file_name=chunk.file_name, - file_path=chunk.file_name, - metadata=chunk.metadata, - ) - for d in text_splitter.split_documents( - [Document(page_content=chunk.text, metadata={})] - ) - ] - logger.info( - f"{_id} - {len(request.chunks):,d} chunks split into {len(chunks):,d} chunks.", - ) - return chunks - except Exception: - logger.exception("Failed to split chunks.") - raise diff --git a/services/api/src/owl/models.py b/services/api/src/owl/models.py deleted file mode 100644 index f609dc4..0000000 --- a/services/api/src/owl/models.py +++ /dev/null @@ -1,416 +0,0 @@ -import asyncio -import base64 -import imghdr -import io -import itertools -from functools import lru_cache - -import httpx -import litellm -import orjson -from fastapi import Request -from langchain.schema.embeddings import Embeddings -from litellm import Router -from litellm.router import RetryPolicy -from loguru import logger - -from jamaibase.utils.io import json_loads -from owl.configs.manager import ENV_CONFIG -from owl.protocol import ( - Chunk, - ClipInputData, - CompletionUsage, - EmbeddingModelConfig, - EmbeddingResponse, - EmbeddingResponseData, - ExternalKeys, - ModelListConfig, - RerankingModelConfig, -) -from owl.utils import select_external_api_key - -litellm.drop_params = True -litellm.set_verbose = False -litellm.suppress_debug_info = True - -HTTP_CLIENT = httpx.AsyncClient(timeout=60.0, transport=httpx.AsyncHTTPTransport(retries=3)) - - -@lru_cache(maxsize=32) -def _get_embedding_router(model_json: str, external_api_keys: str): - models = ModelListConfig.model_validate_json(model_json).embed_models - ExternalApiKeys = ExternalKeys.model_validate_json(external_api_keys) - # refer to https://docs.litellm.ai/docs/routing for more details - return Router( - model_list=[ - { - "model_name": m.id, - "litellm_params": { - "model": deployment.litellm_id if deployment.litellm_id.strip() else m.id, - "api_key": select_external_api_key(ExternalApiKeys, deployment.provider), - "api_base": deployment.api_base if deployment.api_base.strip() else None, - }, - } - for m in models - for deployment in m.deployments - ], - routing_strategy="latency-based-routing", - num_retries=3, - retry_policy=RetryPolicy( - TimeoutErrorRetries=3, - RateLimitErrorRetries=3, - ContentPolicyViolationErrorRetries=3, - AuthenticationErrorRetries=0, - BadRequestErrorRetries=0, - ContextWindowExceededErrorRetries=0, - ), - retry_after=5.0, - timeout=ENV_CONFIG.owl_embed_timeout_sec, - allowed_fails=3, - cooldown_time=5.5, - ) - - -# Cached function -def get_embedding_router(all_models: ModelListConfig, external_keys: ExternalKeys) -> Router: - return _get_embedding_router( - model_json=all_models.model_dump_json(), - external_api_keys=external_keys.model_dump_json(), - ) - - -class CloudBase: - @staticmethod - def batch(seq, n): - if n < 1: - raise ValueError("`n` must be > 0") - for i in range(0, len(seq), n): - yield seq[i : i + n] - - @staticmethod - def _resolve_provider_model_name(id: str) -> str: - split_names = id.split("/") - if len(split_names) < 2: - raise ValueError("`id` needs to be in the form of provider/model_name") - # this assume using huggingface model (usually org/model_name) - return split_names[0], "/".join(split_names[1:]) - - -class CloudReranker(CloudBase): - API_MAP = { - "cohere": ENV_CONFIG.cohere_api_base, - "voyage": ENV_CONFIG.voyage_api_base, - "jina": ENV_CONFIG.jina_api_base, - } - - def __init__(self, request: Request): - """Reranker router. - - Args: - request (Request): Starlette request object. - - Raises: - ValueError: If provider is not supported. - """ - from owl.billing import BillingManager - - self.request = request - self.external_keys: ExternalKeys = request.state.external_keys - self._billing: BillingManager = request.state.billing - - def set_rerank_model(self, reranker_name): - # Get embedder_config - reranker_config: RerankingModelConfig = ( - self.request.state.all_models.get_rerank_model_info(reranker_name) - ) - reranker_config = reranker_config.model_dump(exclude_none=True) - _, model_name = self._resolve_provider_model_name(reranker_config["id"]) - self.reranker_config = reranker_config - # 2024-10-03: reranker only support single deployment now. - deployment = reranker_config["deployments"][0] - self.provider_name = deployment["provider"] - if deployment["provider"] not in ["ellm", "cohere", "voyage", "jina"]: - raise ValueError( - f"reranker `provider`: {deployment['provider']} not supported please use only following provider: ellm/cohere/voyage/jina" - ) - api_url = ( - deployment["api_base"] + "/rerank" - if self.provider_name == "ellm" - else self.API_MAP[self.provider_name] + "/rerank" - ) - api_key = select_external_api_key(self.external_keys, self.provider_name) - self.reranking_args = { - "model": model_name, - "api_key": api_key, - "api_url": api_url, - } - - async def rerank_chunks( - self, - reranker_name: str, - chunks: list[Chunk], - query: str, - batch_size: int = 256, - title_weight: float = 0.6, - content_weight: float = 0.4, - use_concat: bool = False, - ) -> list[tuple[Chunk, float, int]]: - self.set_rerank_model(reranker_name) # configure the reranker to be used - if self.provider_name == "voyage": - batch_size = 32 # voyage has a limit on token lengths 100,000 - all_contents = [d.text for d in chunks] - all_titles = [d.title for d in chunks] - self._billing.check_reranker_quota(model_id=self.reranker_config["id"]) - if use_concat: - all_concats = [ - "Title: " + _title + "\nContent: " + _content - for _title, _content in zip(all_titles, all_contents, strict=True) - ] - concat_scores = await self._rerank_by_batch(query, all_concats, batch_size) - scores = [x["relevance_score"] for x in concat_scores] - else: - content_scores = await self._rerank_by_batch(query, all_contents, batch_size) - title_scores = await self._rerank_by_batch(query, all_titles, batch_size) - scores = [ - ( - c["relevance_score"] * content_weight + t["relevance_score"] * title_weight - if chunks[idx].title != "" - else 0.0 - ) - for idx, (c, t) in enumerate(zip(content_scores, title_scores, strict=True)) - ] - self._billing.create_reranker_events( - self.reranker_config["id"], - len(all_titles) // 100, - ) - reranked_chunks = sorted( - ((d, s, i) for i, (d, s) in enumerate(zip(chunks, scores, strict=True))), - key=lambda x: x[1], - reverse=True, - ) - logger.info(f"Reranked order: {[r[2] for r in reranked_chunks]}") - return reranked_chunks - - async def _rerank(self, query, documents: list[str]) -> list[dict]: - headers = { - "Content-Type": "application/json", - "Authorization": ( - f"Bearer {self.reranking_args['api_key']}" - if self.provider_name in self.API_MAP.keys() - else "" - ), - } - data = { - "model": self.reranking_args["model"], - "query": query, - "documents": documents, - "return_documents": False, - } - - response = await HTTP_CLIENT.post( - self.reranking_args["api_url"], headers=headers, json=data - ) - if response.status_code != 200: - raise RuntimeError(response.text) - response = json_loads(response.text) - if self.provider_name == "voyage": - return response["data"] - else: - return response["results"] - - async def _rerank_by_batch(self, query, documents: list[str], batch_size: int) -> list[dict]: - all_data = [] - for document in self.batch(documents, batch_size): - _tmp = await self._rerank( - query, document - ) # this scores might not be sorted by input index. some provider will sort result by relevance score - _tmp = sorted(_tmp, key=lambda x: x["index"], reverse=False) # sort by index - all_data.extend(_tmp) - return all_data - - -class CloudEmbedder(CloudBase): - def __init__(self, request: Request): - """Embedder router. - - Args: - request (Request): Starlette request object. - """ - from owl.billing import BillingManager - - self.request = request - self.external_keys: ExternalKeys = request.state.external_keys - self._billing: BillingManager = request.state.billing - - def set_embed_model(self, embedder_name): - # Get embedder_config - embedder_config: EmbeddingModelConfig = self.request.state.all_models.get_embed_model_info( - embedder_name - ) - embedder_config = embedder_config.model_dump(exclude_none=True) - self.embedder_config = embedder_config - self.embedder_router = get_embedding_router( - self.request.state.all_models, self.external_keys - ) - for deployment in embedder_config["deployments"]: - if deployment["provider"] not in ["ellm", "openai", "cohere", "voyage", "jina"]: - raise ValueError( - ( - f"Embedder provider {deployment['provider']} not supported, " - "please use only following provider: ellm/openai/cohere/voyage/jina" - ) - ) - self.embedding_args = { - "model": embedder_config["id"], - "dimensions": self.embedder_config.get("dimensions"), - } - - async def embed_texts(self, texts: list[str]) -> EmbeddingResponse: - if self.embedder_config["owned_by"] == "jina": - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.external_keys.jina}", - } - data = {"input": texts, "model": self.embedding_args["model"]} - response = await HTTP_CLIENT.post( - ENV_CONFIG.jina_api_base + "/embeddings", - headers=headers, - json=data, - ) - if response.status_code != 200: - raise RuntimeError(response.text) - response = EmbeddingResponse.model_validate_json(response.text) - else: - response = await self.embedder_router.aembedding(**self.embedding_args, input=texts) - response = EmbeddingResponse.model_validate(response.model_dump()) - - return response - - async def embed_documents( - self, - embedder_name: str, - texts: list[str], - batch_size: int = 2048, - ) -> EmbeddingResponse: - self.set_embed_model(embedder_name) - """Embed search docs.""" - if not isinstance(texts, list): - raise TypeError("`texts` must be a list.") - if self.embedder_config["owned_by"] == "cohere": - self.embedding_args["input_type"] = "search_document" - batch_size = 96 # limit on cohere server - if self.embedder_config["owned_by"] == "jina": - batch_size = 128 # don't know limit, but too large will timeout - if self.embedder_config["owned_by"] == "voyage": - batch_size = 128 # limit on voyage server - if self.embedder_config["owned_by"] == "openai": - batch_size = 256 # limited by token per min (10,000,000) - self._billing.check_embedding_quota(model_id=self.embedder_config["id"]) - responses = await asyncio.gather( - *[self.embed_texts(txt) for txt in self.batch(texts, batch_size)] - ) - embeddings = [e.embedding for e in itertools.chain(*[r.data for r in responses])] - usages = CompletionUsage( - prompt_tokens=sum(r.usage.prompt_tokens for r in responses), - total_tokens=sum(r.usage.total_tokens for r in responses), - ) - embeddings = EmbeddingResponse( - data=[EmbeddingResponseData(embedding=e, index=i) for i, e in enumerate(embeddings)], - model=responses[0].model, - usage=usages, - ) - self._billing.create_embedding_events( - model=self.embedder_config["id"], - token_usage=usages.total_tokens, - ) - return embeddings - - async def embed_queries(self, embedder_name: str, texts: list[str]) -> EmbeddingResponse: - self.set_embed_model(embedder_name) - """Embed query text.""" - if not isinstance(texts, list): - raise TypeError("`texts` must be a list.") - if self.embedding_args.get("transform_query"): - texts = [self.embedding_args.get("transform_query") + text for text in texts] - if self.embedder_config["owned_by"] == "cohere": - self.embedding_args["input_type"] = "search_query" - self._billing.check_embedding_quota(model_id=self.embedder_config["id"]) - response = await self.embed_texts(texts) - self._billing.create_embedding_events( - model=self.embedder_config["id"], - token_usage=response.usage.total_tokens, - ) - return response - - -class CloudImageEmbedder(CloudBase, Embeddings): - def __init__(self): - """ - Args: - client: an httpx client - Info: - Read the clip_api_base from the .env directly - Only use for image embedding - Query can be text/image - can be used for text-to-image search or image-to-image search - DO NOT DO image-to-text-and-image search - same modality would most certainly always result in a higher scores than different modality obj - """ - api_url = ENV_CONFIG.clip_api_base + "/post" - self.embedding_args = { - "api_url": api_url, - } - - async def _embed(self, objects: list[ClipInputData]) -> list[list[float]]: - parsed_data = self._parse_data(objects) - headers = {"Content-Type": "application/json"} - data = {"data": parsed_data, "execEndpoint": "/"} - response = await HTTP_CLIENT.post( - self.embedding_args["api_url"], - headers=headers, - data=orjson.dumps(data), - ) - if response.status_code != 200: - raise RuntimeError(response.text) - return [x["embedding"] for x in json_loads(response)["data"]] - - def _parse_data(self, objects: list[ClipInputData]): - """ - The objects are list of [ClipInputData] - """ - return [ - {"uri": self._get_blob_from_data(obj)} if obj.image_filename else {"text": obj.content} - for obj in objects - ] - - def _get_blob_from_data(self, data: ClipInputData): - """get blob from ClipInputData""" - with io.BytesIO(data.content) as f: - # Get the image format - try: - img_format = imghdr.what(f).lower() - except Exception as e: - raise ValueError( - f"object {data.image_filename} is not a valid image format." - ) from e - # Read the image file - img_data = f.read() - img_base64 = base64.b64encode(img_data) - data_uri = f"data:image/{img_format};base64," + img_base64.decode("utf-8") - return data_uri - - async def embed_documents( - self, objects: list[ClipInputData], batch_size: int = 64 - ) -> list[list[float]]: - """Embed search objects (image).""" - if not isinstance(objects, list): - raise TypeError("`objects` must be a list.") - embeddings = await asyncio.gather( - *[self._embed(obj) for obj in self.batch(objects, batch_size)] - ) - return list(itertools.chain(*embeddings)) - - async def embed_query(self, data: ClipInputData) -> list[float]: - """Embed query text/image.""" - embeddings = await self._embed([data]) - return embeddings[0] # should just have 1 elements diff --git a/services/api/src/owl/protocol.py b/services/api/src/owl/protocol.py deleted file mode 100644 index 622d23d..0000000 --- a/services/api/src/owl/protocol.py +++ /dev/null @@ -1,2598 +0,0 @@ -""" -NOTES: - -- Pydantic supports setting mutable values as default. - This is in contrast to native `dataclasses` where it is not supported. - -- Pydantic supports setting default fields in any order. - This is in contrast to native `dataclasses` where fields with default values must be defined after non-default fields. -""" - -from __future__ import annotations - -import re -from copy import deepcopy -from datetime import datetime, timezone -from enum import Enum, EnumMeta -from functools import cached_property, reduce -from os.path import splitext -from typing import Annotated, Any, Generic, Literal, Sequence, Type, TypeVar, Union - -import numpy as np -import pyarrow as pa -from loguru import logger -from natsort import natsorted -from pydantic import ( - AfterValidator, - BaseModel, - BeforeValidator, - ConfigDict, - Discriminator, - Field, - Tag, - ValidationError, - computed_field, - create_model, - field_validator, - model_validator, -) -from sqlmodel import JSON, Column, MetaData, SQLModel -from sqlmodel import Field as sql_Field -from typing_extensions import Self - -from jamaibase import protocol as p -from jamaibase.exceptions import ResourceNotFoundError -from jamaibase.utils.io import json_dumps -from owl.utils import datetime_now_iso, uuid7_draft2_str -from owl.version import __version__ as owl_version - -PositiveInt = Annotated[int, Field(ge=0, description="Positive integer.")] -PositiveNonZeroInt = Annotated[int, Field(gt=0, description="Positive non-zero integer.")] - - -def sanitise_document_id(v: str) -> str: - if v.startswith('"') and v.endswith('"'): - v = v[1:-1] - return v - - -def sanitise_document_id_list(v: list[str]) -> list[str]: - return [sanitise_document_id(vv) for vv in v] - - -DocumentID = Annotated[str, AfterValidator(sanitise_document_id)] -DocumentIDList = Annotated[list[str], AfterValidator(sanitise_document_id_list)] - -EXAMPLE_CHAT_MODEL_IDS = ["openai/gpt-4o-mini"] -# for openai embedding models doc: https://platform.openai.com/docs/guides/embeddings -# for cohere embedding models doc: https://docs.cohere.com/reference/embed -# for jina embedding models doc: https://jina.ai/embeddings/ -# for voyage embedding models doc: https://docs.voyageai.com/docs/embeddings -# for hf embedding models doc: check the respective hf model page, name should be ellm/{org}/{model} -EXAMPLE_EMBEDDING_MODEL_IDS = [ - "openai/text-embedding-3-small-512", - "ellm/sentence-transformers/all-MiniLM-L6-v2", -] -# for cohere reranking models doc: https://docs.cohere.com/reference/rerank-1 -# for jina reranking models doc: https://jina.ai/reranker -# for colbert reranking models doc: https://docs.voyageai.com/docs/reranker -# for hf embedding models doc: check the respective hf model page, name should be ellm/{org}/{model} -EXAMPLE_RERANKING_MODEL_IDS = [ - "cohere/rerank-multilingual-v3.0", - "ellm/cross-encoder/ms-marco-TinyBERT-L-2", -] - -IMAGE_FILE_EXTENSIONS = [".jpeg", ".jpg", ".png", ".gif", ".webp"] -AUDIO_FILE_EXTENSIONS = [".mp3", ".wav"] -DOCUMENT_FILE_EXTENSIONS = [ - ".pdf", - ".txt", - ".md", - ".docx", - ".xml", - ".html", - ".json", - ".csv", - ".tsv", - ".jsonl", - ".xlsx", - ".xls", -] - -Name = Annotated[ - str, - BeforeValidator(lambda v, _: v.strip() if isinstance(v, str) else v), - Field( - pattern=r"\w+", - max_length=100, - description=( - "Name or ID. Must be unique with at least 1 non-symbol character and up to 100 characters." - ), - ), -] - - -class UserAgent(BaseModel): - is_browser: bool = Field( - default=True, - description="Whether the request originates from a browser or an app.", - examples=[True, False], - ) - agent: str = Field( - description="The agent, such as 'SDK', 'Chrome', 'Firefox', 'Edge', or an empty string if it cannot be determined.", - examples=["", "SDK", "Chrome", "Firefox", "Edge"], - ) - agent_version: str = Field( - default="", - description="The agent version, or an empty string if it cannot be determined.", - examples=["", "5.0", "0.3.0"], - ) - os: str = Field( - default="", - description="The system/OS name and release, such as 'Windows NT 10.0', 'Linux 5.15.0-113-generic', or an empty string if it cannot be determined.", - examples=["", "Windows NT 10.0", "Linux 5.15.0-113-generic"], - ) - architecture: str = Field( - default="", - description="The machine type, such as 'AMD64', 'x86_64', or an empty string if it cannot be determined.", - examples=["", "AMD64", "x86_64"], - ) - language: str = Field( - default="", - description="The SDK language, such as 'TypeScript', 'Python', or an empty string if it is not applicable.", - examples=["", "TypeScript", "Python"], - ) - language_version: str = Field( - default="", - description="The SDK language version, such as '4.9', '3.10.14', or an empty string if it is not applicable.", - examples=["", "4.9", "3.10.14"], - ) - - @computed_field( - description="The system/OS name, such as 'Linux', 'Darwin', 'Java', 'Windows', or an empty string if it cannot be determined.", - examples=["", "Windows NT", "Linux"], - ) - @property - def system(self) -> str: - return self._split_os_string()[0] - - @computed_field( - description="The system's release, such as '2.2.0', 'NT', or an empty string if it cannot be determined.", - examples=["", "10", "5.15.0-113-generic"], - ) - @property - def system_version(self) -> str: - return self._split_os_string()[1] - - def _split_os_string(self) -> tuple[str, str]: - match = re.match(r"([^\d]+) ([\d.]+).*$", self.os) - if match: - os_name = match.group(1).strip() - os_version = match.group(2).strip() - return os_name, os_version - else: - return "", "" - - @classmethod - def from_user_agent_string(cls, ua_string: str) -> Self: - if not ua_string: - return cls(is_browser=False, agent="") - - # SDK pattern - sdk_match = re.match(r"SDK/(\S+) \((\w+)/(\S+); ([^;]+); (\w+)\)", ua_string) - if sdk_match: - return cls( - is_browser=False, - agent="SDK", - agent_version=sdk_match.group(1), - os=sdk_match.group(4), - architecture=sdk_match.group(5), - language=sdk_match.group(2), - language_version=sdk_match.group(3), - ) - - # Browser pattern - browser_match = re.match(r"Mozilla/5.0 \(([^)]+)\).*", ua_string) - if browser_match: - os_info = browser_match.group(1).split(";") - # Microsoft Edge - match = re.match(r".+(Edg/.+)$", ua_string) - if match: - return cls( - agent="Edge", - agent_version=match.group(1).split("/")[-1].strip(), - os=os_info[0].strip(), - architecture=os_info[-1].strip() if len(os_info) == 3 else "", - language="", - language_version="", - ) - # Firefox - match = re.match(r".+(Firefox/.+)$", ua_string) - if match: - return cls( - agent="Firefox", - agent_version=match.group(1).split("/")[-1].strip(), - os=os_info[0].strip(), - architecture=os_info[-1].strip() if len(os_info) == 3 else "", - language="", - language_version="", - ) - # Chrome - match = re.match(r".+(Chrome/.+)$", ua_string) - if match: - return cls( - agent="Chrome", - agent_version=match.group(1).split("/")[-1].strip(), - os=os_info[0].strip(), - architecture=os_info[-1].strip() if len(os_info) == 3 else "", - language="", - language_version="", - ) - return cls(is_browser="mozilla" in ua_string.lower(), agent="") - - -class ExternalKeys(BaseModel): - model_config = ConfigDict(extra="forbid") - custom: str = "" - openai: str = "" - anthropic: str = "" - gemini: str = "" - cohere: str = "" - groq: str = "" - together_ai: str = "" - jina: str = "" - voyage: str = "" - hyperbolic: str = "" - cerebras: str = "" - sambanova: str = "" - deepseek: str = "" - - -class OkResponse(BaseModel): - ok: bool = True - - -class StringResponse(BaseModel): - object: Literal["string"] = Field( - default="string", - description='The object type, which is always "string".', - examples=["string"], - ) - data: str = Field( - description="The string data.", - examples=["text"], - ) - - -class AdminOrderBy(str, Enum): - ID = "id" - """Sort by `id` column.""" - NAME = "name" - """Sort by `name` column.""" - CREATED_AT = "created_at" - """Sort by `created_at` column.""" - UPDATED_AT = "updated_at" - """Sort by `updated_at` column.""" - - def __str__(self) -> str: - return self.value - - -class GenTableOrderBy(str, Enum): - ID = "id" - """Sort by `id` column.""" - UPDATED_AT = "updated_at" - """Sort by `updated_at` column.""" - - def __str__(self) -> str: - return self.value - - -class TemplateMeta(BaseModel): - """Template metadata.""" - - name: Name - description: str - tags: list[str] - created_at: str = Field( - default_factory=datetime_now_iso, - description="Creation datetime (ISO 8601 UTC).", - ) - - -class ModelCapability(str, Enum): - COMPLETION = "completion" - CHAT = "chat" - IMAGE = "image" - AUDIO = "audio" - TOOL = "tool" - EMBED = "embed" - RERANK = "rerank" - - def __str__(self) -> str: - return self.value - - -class ModelInfo(BaseModel): - id: str = Field( - description=( - 'Unique identifier in the form of "{provider}/{model_id}". ' - "Users will specify this to select a model." - ), - examples=EXAMPLE_CHAT_MODEL_IDS, - ) - object: str = Field( - default="model", - description="Type of API response object.", - examples=["model"], - ) - name: str = Field( - description="Name of the model.", - examples=["OpenAI GPT-4o Mini"], - ) - context_length: int = Field( - description="Context length of model.", - examples=[16384], - ) - languages: list[str] = Field( - description="List of languages which the model is well-versed in.", - examples=[["en"]], - ) - owned_by: str = Field( - default="", - description="The organization that owns the model. Defaults to the provider in model ID.", - examples=["openai"], - ) - capabilities: list[ModelCapability] = Field( - description="List of capabilities of model.", - examples=[[ModelCapability.CHAT]], - ) - - @model_validator(mode="after") - def check_owned_by(self) -> Self: - if self.owned_by.strip() == "": - self.owned_by = self.id.split("/")[0] - return self - - -class ModelInfoResponse(BaseModel): - object: str = Field( - default="chat.model_info", - description="Type of API response object.", - examples=["chat.model_info"], - ) - data: list[ModelInfo] = Field( - description="List of model information.", - ) - - -class ModelDeploymentConfig(BaseModel): - litellm_id: str = Field( - default="", - description=( - "LiteLLM routing / mapping ID. " - 'For example, you can map "openai/gpt-4o" calls to "openai/gpt-4o-2024-08-06". ' - 'For vLLM with OpenAI compatible server, use "openai/".' - ), - examples=EXAMPLE_CHAT_MODEL_IDS, - ) - api_base: str = Field( - default="", - description="Hosting url for the model.", - ) - provider: str = Field( - default="", - description="Provider of the model.", - ) - - -class ModelConfig(ModelInfo): - priority: int = Field( - default=0, - ge=0, - description="Priority when assigning default model. Larger number means higher priority.", - ) - deployments: list[ModelDeploymentConfig] = Field( - [], - description="List of model deployment configs.", - ) - litellm_id: str = Field( - default="", - deprecated=True, - description=( - "Deprecated. Retained for compatibility. " - "LiteLLM routing / mapping ID. " - 'For example, you can map "openai/gpt-4o" calls to "openai/gpt-4o-2024-08-06". ' - 'For vLLM with OpenAI compatible server, use "openai/".' - ), - examples=EXAMPLE_CHAT_MODEL_IDS, - ) - api_base: str = Field( - default="", - deprecated=True, - description="Deprecated. Retained for compatibility. Hosting url for the model.", - ) - - @model_validator(mode="after") - def compat_deployments(self) -> Self: - if len(self.deployments) > 0: - return self - self.deployments = [ - ModelDeploymentConfig( - litellm_id=self.litellm_id, - api_base=self.api_base, - provider=self.id.split("/")[0], - ) - ] - return self - - -class LLMModelConfig(ModelConfig): - input_cost_per_mtoken: float = Field( - default=-1.0, - description="Cost in USD per million (mega) input / prompt token.", - ) - output_cost_per_mtoken: float = Field( - default=-1.0, - description="Cost in USD per million (mega) output / completion token.", - ) - capabilities: list[ModelCapability] = Field( - default=[ModelCapability.CHAT], - description="List of capabilities of model.", - examples=[[ModelCapability.CHAT]], - ) - - @model_validator(mode="after") - def check_cost_per_mtoken(self) -> Self: - # GPT-4o-mini pricing (2024-08-10) - if self.input_cost_per_mtoken <= 0: - self.input_cost_per_mtoken = 0.150 - if self.output_cost_per_mtoken <= 0: - self.output_cost_per_mtoken = 0.600 - return self - - -class EmbeddingModelConfig(ModelConfig): - id: str = Field( - description=( - 'Unique identifier in the form of "{provider}/{model_id}". ' - 'For self-hosted models with Infinity, use "ellm/{org}/{model}". ' - "Users will specify this to select a model." - ), - examples=EXAMPLE_EMBEDDING_MODEL_IDS, - ) - embedding_size: int = Field( - description="Embedding size of the model", - ) - # Currently only useful for openai - dimensions: int | None = Field( - default=None, - description="Dimensions, a reduced embedding size (openai specs).", - ) - # Most likely only useful for hf models - transform_query: str | None = Field( - default=None, - description="Transform query that might be needed, esp. for hf models", - ) - capabilities: list[ModelCapability] = Field( - default=[ModelCapability.EMBED], - description="List of capabilities of model.", - examples=[[ModelCapability.EMBED]], - ) - cost_per_mtoken: float = Field( - default=-1, - description="Cost in USD per million embedding tokens.", - ) - - @model_validator(mode="after") - def check_cost_per_mtoken(self) -> Self: - # OpenAI text-embedding-3-small pricing (2024-09-09) - if self.cost_per_mtoken < 0: - self.cost_per_mtoken = 0.022 - return self - - -class RerankingModelConfig(ModelConfig): - id: str = Field( - description=( - 'Unique identifier in the form of "{provider}/{model_id}". ' - 'For self-hosted models with Infinity, use "ellm/{org}/{model}". ' - "Users will specify this to select a model." - ), - examples=EXAMPLE_RERANKING_MODEL_IDS, - ) - capabilities: list[ModelCapability] = Field( - default=[ModelCapability.RERANK], - description="List of capabilities of model.", - examples=[[ModelCapability.RERANK]], - ) - cost_per_ksearch: float = Field( - default=-1, - description="Cost in USD for a thousand searches.", - ) - - @model_validator(mode="after") - def check_cost_per_ksearch(self) -> Self: - # Cohere rerank-multilingual-v3.0 pricing (2024-09-09) - if self.cost_per_ksearch < 0: - self.cost_per_ksearch = 2.0 - return self - - -class ModelListConfig(BaseModel): - object: str = Field( - default="configs.models", - description="Type of API response object.", - examples=["configs.models"], - ) - llm_models: list[LLMModelConfig] = [] - embed_models: list[EmbeddingModelConfig] = [] - rerank_models: list[RerankingModelConfig] = [] - - @cached_property - def models(self) -> list[LLMModelConfig | EmbeddingModelConfig | RerankingModelConfig]: - """A list of all the models.""" - return self.llm_models + self.embed_models + self.rerank_models - - @cached_property - def model_map(self) -> dict[str, LLMModelConfig | EmbeddingModelConfig | RerankingModelConfig]: - """A map of all the models.""" - return {m.id: m for m in self.models} - - def get_model_info( - self, model_id: str - ) -> LLMModelConfig | EmbeddingModelConfig | RerankingModelConfig: - try: - return self.model_map[model_id] - except KeyError: - raise ValueError( - f"Invalid model ID: {model_id}. Available models: {[m.id for m in self.models]}" - ) from None - - def get_llm_model_info(self, model_id: str) -> LLMModelConfig: - return self.get_model_info(model_id) - - def get_embed_model_info(self, model_id: str) -> EmbeddingModelConfig: - return self.get_model_info(model_id) - - def get_rerank_model_info(self, model_id: str) -> RerankingModelConfig: - return self.get_model_info(model_id) - - def get_default_model(self, capabilities: list[str] | None = None) -> str: - models = self.models - if capabilities is not None: - for capability in capabilities: - models = [m for m in models if capability in m.capabilities] - # if `capabilities`` is chat only, filter out audio model - if capabilities == ["chat"]: - models = [m for m in models if "audio" not in m.capabilities] - if len(models) == 0: - raise ResourceNotFoundError(f"No model found with capabilities: {capabilities}") - model = natsorted(models, key=self._sort_key_with_priority)[0] - return model.id - - @staticmethod - def _sort_key_with_priority( - x: LLMModelConfig | EmbeddingModelConfig | RerankingModelConfig, - ) -> str: - return (int(not x.id.startswith("ellm")), -x.priority, x.name) - - @model_validator(mode="after") - def sort_models(self) -> Self: - self.llm_models = list(natsorted(self.llm_models, key=self._sort_key)) - self.embed_models = list(natsorted(self.embed_models, key=self._sort_key)) - self.rerank_models = list(natsorted(self.rerank_models, key=self._sort_key)) - return self - - @model_validator(mode="after") - def unique_model_ids(self) -> Self: - if len(set(m.id for m in self.models)) != len(self.models): - raise ValueError("There are repeated model IDs in the config.") - return self - - @staticmethod - def _sort_key( - x: LLMModelConfig | EmbeddingModelConfig | RerankingModelConfig, - ) -> str: - return (int(not x.id.startswith("ellm")), x.name) - - def __add__(self, other: ModelListConfig) -> ModelListConfig: - if isinstance(other, ModelListConfig): - self_ids = set(m.id for m in self.models) - other_ids = set(m.id for m in other.models) - repeated_ids = self_ids.intersection(other_ids) - if len(repeated_ids) != 0: - raise ValueError( - f"There are repeated model IDs among the two configs: {list(repeated_ids)}" - ) - return ModelListConfig( - llm_models=self.llm_models + other.llm_models, - embed_models=self.embed_models + other.embed_models, - rerank_models=self.rerank_models + other.rerank_models, - ) - else: - raise TypeError( - f"Unsupported operand type(s) for +: 'ModelListConfig' and '{type(other)}'" - ) - - -class Chunk(p.Chunk): - pass - - -class SplitChunksParams(p.SplitChunksParams): - pass - - -class SplitChunksRequest(BaseModel): - id: str = Field( - default="", - description="Request ID for logging purposes.", - examples=["018ed5f1-6399-71f7-86af-fc18d4a3e3f5"], - ) - chunks: list[Chunk] = Field( - description="List of `Chunk` where each will be further split into chunks.", - examples=[ - [ - Chunk( - text="The Name of the Title is Hope\n\n...", - title="The Name of the Title is Hope", - page=0, - file_name="sample_tables.pdf", - file_path="amagpt/sample_tables.pdf", - metadata={ - "total_pages": 3, - "Author": "Ben Trovato", - "CreationDate": "D:20231031072817Z", - "Creator": "LaTeX with acmart 2023/10/14 v1.92 Typesetting articles for the Association for Computing Machinery and hyperref 2023-07-08 v7.01b Hypertext links for LaTeX", - "Keywords": "Image Captioning, Deep Learning", - "ModDate": "D:20231031073146Z", - "PTEX.Fullbanner": "This is pdfTeX, Version 3.141592653-2.6-1.40.25 (TeX Live 2023) kpathsea version 6.3.5", - "Producer": "3-Heights(TM) PDF Security Shell 4.8.25.2 (http://www.pdf-tools.com) / pdcat (www.pdf-tools.com)", - "Trapped": "False", - }, - ) - ] - ], - ) - params: SplitChunksParams = Field( - default=SplitChunksParams(), - description="How to split each document. Defaults to `RecursiveCharacterTextSplitter` with chunk_size = 1000 and chunk_overlap = 200.", - examples=[SplitChunksParams()], - ) - - def str_trunc(self) -> str: - return f"id={self.id} len(chunks)={len(self.chunks)} params={self.params}" - - -class RAGParams(BaseModel): - table_id: str = Field(description="Knowledge Table ID", examples=["my-dataset"], min_length=2) - reranking_model: str | None = Field( - default=None, - description="Reranking model to use for hybrid search.", - examples=[EXAMPLE_RERANKING_MODEL_IDS[0], None], - ) - search_query: str = Field( - default="", - description="Query used to retrieve items from the KB database. If not provided (default), it will be generated using LLM.", - ) - k: Annotated[int, Field(gt=0, le=1024)] = Field( - default=3, - gt=0, - le=1024, - description="Top-k closest text in terms of embedding distance. Must be in [1, 1024]. Defaults to 3.", - examples=[3], - ) - rerank: bool = Field( - default=True, - description="Flag to perform rerank on the retrieved results. Defaults to True.", - examples=[True, False], - ) - concat_reranker_input: bool = Field( - default=False, - description="Flag to concat title and content as reranker input. Defaults to False.", - examples=[True, False], - ) - - -class VectorSearchRequest(RAGParams): - id: str = Field( - default="", - description="Request ID for logging purposes.", - examples=["018ed5f1-6399-71f7-86af-fc18d4a3e3f5"], - ) - search_query: str = Field(description="Query used to retrieve items from the KB database.") - - -class VectorSearchResponse(BaseModel): - object: str = Field( - default="kb.search_response", - description="Type of API response object.", - examples=["kb.search_response"], - ) - chunks: list[Chunk] = Field( - default=[], - description="A list of `Chunk`.", - examples=[ - [ - Chunk( - text="The Name of the Title is Hope\n\n...", - title="The Name of the Title is Hope", - page=0, - file_name="sample_tables.pdf", - file_path="amagpt/sample_tables.pdf", - metadata={ - "total_pages": 3, - "Author": "Ben Trovato", - "CreationDate": "D:20231031072817Z", - "Creator": "LaTeX with acmart 2023/10/14 v1.92 Typesetting articles for the Association for Computing Machinery and hyperref 2023-07-08 v7.01b Hypertext links for LaTeX", - "Keywords": "Image Captioning, Deep Learning", - "ModDate": "D:20231031073146Z", - "PTEX.Fullbanner": "This is pdfTeX, Version 3.141592653-2.6-1.40.25 (TeX Live 2023) kpathsea version 6.3.5", - "Producer": "3-Heights(TM) PDF Security Shell 4.8.25.2 (http://www.pdf-tools.com) / pdcat (www.pdf-tools.com)", - "Trapped": "False", - }, - ) - ] - ], - ) - - -class ChatRole(str, Enum): - """Represents who said a chat message.""" - - SYSTEM = "system" - """The message is from the system (usually a steering prompt).""" - USER = "user" - """The message is from the user.""" - ASSISTANT = "assistant" - """The message is from the language model.""" - # FUNCTION = "function" - # """The message is the result of a function call.""" - - def __str__(self) -> str: - return self.value - - -def sanitise_name(v: str) -> str: - """Replace any non-alphanumeric and dash characters with space. - - Args: - v (str): Raw name string. - - Returns: - out (str): Sanitised name string that is safe for OpenAI. - """ - return re.sub(r"[^a-zA-Z0-9_-]", "_", v).strip() - - -MessageName = Annotated[str, AfterValidator(sanitise_name)] - - -class MessageToolCallFunction(BaseModel): - arguments: str - name: str | None - - -class MessageToolCall(BaseModel): - id: str | None - function: MessageToolCallFunction - type: str - - -class ChatEntry(BaseModel): - """Represents a message in the chat context.""" - - model_config = ConfigDict(use_enum_values=True) - - role: ChatRole - """Who said the message?""" - content: str | list[dict[str, str | dict[str, str]]] - """The content of the message.""" - name: MessageName | None = None - """The name of the user who sent the message, if set (user messages only).""" - - @classmethod - def system(cls, content: str, **kwargs): - """Create a new system message.""" - return cls(role=ChatRole.SYSTEM, content=content, **kwargs) - - @classmethod - def user(cls, content: str, **kwargs): - """Create a new user message.""" - return cls(role=ChatRole.USER, content=content, **kwargs) - - @classmethod - def assistant(cls, content: str | list[dict[str, str]] | None, **kwargs): - """Create a new assistant message.""" - return cls(role=ChatRole.ASSISTANT, content=content, **kwargs) - - @field_validator("content", mode="before") - @classmethod - def coerce_input(cls, value: Any) -> str | list[dict[str, str | dict[str, str]]]: - if isinstance(value, list): - return [cls.coerce_input(v) for v in value] - if isinstance(value, dict): - return {k: cls.coerce_input(v) for k, v in value.items()} - if isinstance(value, str): - return value - if value is None: - return "" - return str(value) - - -class ChatCompletionChoiceOutput(ChatEntry): - tool_calls: list[MessageToolCall] | None = None - """List of tool calls if the message includes tool call responses.""" - - -class ChatThread(BaseModel): - object: str = Field( - default="chat.thread", - description="Type of API response object.", - examples=["chat.thread"], - ) - thread: list[ChatEntry] = Field( - default=[], - description="List of chat messages.", - examples=[ - [ - ChatEntry.system(content="You are an assistant."), - ChatEntry.user(content="Hello."), - ] - ], - ) - - -class CompletionUsage(BaseModel): - prompt_tokens: int = Field(default=0, description="Number of tokens in the prompt.") - completion_tokens: int = Field( - default=0, description="Number of tokens in the generated completion." - ) - total_tokens: int = Field( - default=0, description="Total number of tokens used in the request (prompt + completion)." - ) - - -class ChatCompletionChoice(BaseModel): - message: ChatEntry | ChatCompletionChoiceOutput = Field( - description="A chat completion message generated by the model." - ) - index: int = Field(description="The index of the choice in the list of choices.") - finish_reason: str | None = Field( - default=None, - description=( - "The reason the model stopped generating tokens. " - "This will be stop if the model hit a natural stop point or a provided stop sequence, " - "length if the maximum number of tokens specified in the request was reached." - ), - ) - - @property - def text(self) -> str: - """The text of the most recent chat completion.""" - return self.message.content - - -class ChatCompletionChoiceDelta(ChatCompletionChoice): - @computed_field - @property - def delta(self) -> ChatEntry | ChatCompletionChoiceOutput: - return self.message - - -class References(BaseModel): - object: str = Field( - default="chat.references", - description="Type of API response object.", - examples=["chat.references"], - ) - chunks: list[Chunk] = Field( - default=[], - description="A list of `Chunk`.", - examples=[ - [ - Chunk( - text="The Name of the Title is Hope\n\n...", - title="The Name of the Title is Hope", - page=0, - file_name="sample_tables.pdf", - file_path="amagpt/sample_tables.pdf", - metadata={ - "total_pages": 3, - "Author": "Ben Trovato", - "CreationDate": "D:20231031072817Z", - "Creator": "LaTeX with acmart 2023/10/14 v1.92 Typesetting articles for the Association for Computing Machinery and hyperref 2023-07-08 v7.01b Hypertext links for LaTeX", - "Keywords": "Image Captioning, Deep Learning", - "ModDate": "D:20231031073146Z", - "PTEX.Fullbanner": "This is pdfTeX, Version 3.141592653-2.6-1.40.25 (TeX Live 2023) kpathsea version 6.3.5", - "Producer": "3-Heights(TM) PDF Security Shell 4.8.25.2 (http://www.pdf-tools.com) / pdcat (www.pdf-tools.com)", - "Trapped": "False", - }, - ) - ] - ], - ) - search_query: str = Field(description="Query used to retrieve items from the KB database.") - finish_reason: Literal["stop", "context_overflow"] | None = Field( - default=None, - description=""" -In streaming mode, reference chunk will be streamed first. -However, if the model's context length is exceeded, then there will be no further completion chunks. -In this case, "finish_reason" will be set to "context_overflow". -Otherwise, it will be None or null. -""", - ) - - def remove_contents(self): - copy = self.model_copy(deep=True) - for d in copy.documents: - d.page_content = "" - return copy - - -class ChatCompletionChunk(BaseModel): - id: str = Field( - description="A unique identifier for the chat completion. Each chunk has the same ID." - ) - object: str = Field( - default="chat.completion.chunk", - description="Type of API response object.", - examples=["chat.completion.chunk"], - ) - created: int = Field( - description="The Unix timestamp (in seconds) of when the chat completion was created." - ) - model: str = Field(description="The model used for the chat completion.") - usage: CompletionUsage | None = Field( - description="Number of tokens consumed for the completion request.", - examples=[CompletionUsage(), None], - ) - choices: list[ChatCompletionChoice | ChatCompletionChoiceDelta] = Field( - description="A list of chat completion choices. Can be more than one if `n` is greater than 1." - ) - references: References | None = Field( - default=None, - description="Contains the references retrieved from database when performing chat completion with RAG.", - ) - - @property - def message(self) -> ChatEntry | ChatCompletionChoiceOutput | None: - return self.choices[0].message if len(self.choices) > 0 else None - - @property - def prompt_tokens(self) -> int: - return self.usage.prompt_tokens - - @property - def completion_tokens(self) -> int: - return self.usage.completion_tokens - - @property - def text(self) -> str: - """The text of the most recent chat completion.""" - return self.message.content if len(self.choices) > 0 else "" - - @property - def finish_reason(self) -> str | None: - return self.choices[0].finish_reason if len(self.choices) > 0 else None - - -class GenTableStreamReferences(References): - object: str = Field( - default="gen_table.references", - description="Type of API response object.", - examples=["gen_table.references"], - ) - output_column_name: str - - -class GenTableChatCompletionChunks(BaseModel): - object: str = Field( - default="gen_table.completion.chunks", - description="Type of API response object.", - examples=["gen_table.completion.chunks"], - ) - columns: dict[str, ChatCompletionChunk] - row_id: str - - -class GenTableRowsChatCompletionChunks(BaseModel): - object: str = Field( - default="gen_table.completion.rows", - description="Type of API response object.", - examples=["gen_table.completion.rows"], - ) - rows: list[GenTableChatCompletionChunks] - - -class GenTableStreamChatCompletionChunk(ChatCompletionChunk): - object: str = Field( - default="gen_table.completion.chunk", - description="Type of API response object.", - examples=["gen_table.completion.chunk"], - ) - output_column_name: str - row_id: str - - -class FunctionParameter(BaseModel): - type: str = Field( - default="", description="The type of the parameter, e.g., 'string', 'number'." - ) - description: str = Field(default="", description="A description of the parameter.") - enum: list[str] = Field( - default=[], description="An optional list of allowed values for the parameter." - ) - - -class FunctionParameters(BaseModel): - type: str = Field( - default="object", description="The type of the parameters object, usually 'object'." - ) - properties: dict[str, FunctionParameter] = Field( - description="The properties of the parameters object." - ) - required: list[str] = Field(description="A list of required parameter names.") - additionalProperties: bool = Field( - default=False, description="Whether additional properties are allowed." - ) - - -class Function(BaseModel): - name: str = Field(default="", description="The name of the function.") - description: str = Field(default="", description="A description of what the function does.") - parameters: FunctionParameters = Field(description="The parameters for the function.") - - -class Tool(BaseModel): - type: str = Field(default="function", description="The type of the tool, e.g., 'function'.") - function: Function = Field(description="The function details of the tool.") - - -class ToolChoiceFunction(BaseModel): - name: str = Field(default="", description="The name of the function.") - - -class ToolChoice(BaseModel): - type: str = Field(default="function", description="The type of the tool, e.g., 'function'.") - function: ToolChoiceFunction = Field(description="Select a tool for the chat model to use.") - - -class ChatRequest(BaseModel): - id: str = Field( - default="", - description="Chat ID. Must be unique against document ID for it to be embeddable. Defaults to ''.", - ) - model: str = Field( - default="", - description="ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - ) - messages: list[ChatEntry] = Field( - description="A list of messages comprising the conversation so far.", - min_length=1, - ) - rag_params: RAGParams | None = Field( - default=None, - description="Retrieval Augmented Generation search params. Defaults to None (disabled).", - examples=[None], - ) - temperature: Annotated[float, Field(ge=0.001, le=2.0)] = Field( - default=0.2, - description=""" -What sampling temperature to use, in [0.001, 2.0]. -Higher values like 0.8 will make the output more random, -while lower values like 0.2 will make it more focused and deterministic. -""", - examples=[0.2], - ) - top_p: Annotated[float, Field(ge=0.001, le=1.0)] = Field( - default=0.6, - description=""" -An alternative to sampling with temperature, called nucleus sampling, -where the model considers the results of the tokens with top_p probability mass. -So 0.1 means only the tokens comprising the top 10% probability mass are considered. -Must be in [0.001, 1.0]. -""", - examples=[0.6], - ) - n: int = Field( - default=1, - description="How many chat completion choices to generate for each input message.", - examples=[1], - ) - stream: bool = Field( - default=True, - description=""" -If set, partial message deltas will be sent, like in ChatGPT. -Tokens will be sent as server-sent events as they become available, -with the stream terminated by a 'data: [DONE]' message. -""", - examples=[True], - ) - stop: list[str] | None = Field( - default=None, - description="Up to 4 sequences where the API will stop generating further tokens.", - examples=[None], - ) - max_tokens: PositiveNonZeroInt = Field( - default=2048, - description=""" -The maximum number of tokens to generate in the chat completion. -Must be in [1, context_length - 1). Default is 2048. -The total length of input tokens and generated tokens is limited by the model's context length. -""", - examples=[2048], - ) - presence_penalty: float = Field( - default=0.0, - description=""" -Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, -increasing the model's likelihood to talk about new topics. -""", - examples=[0.0], - ) - frequency_penalty: float = Field( - default=0.0, - description=""" -Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, -decreasing the model's likelihood to repeat the same line verbatim. -""", - examples=[0.0], - ) - logit_bias: dict = Field( - default={}, - description=""" -Modify the likelihood of specified tokens appearing in the completion. -Accepts a json object that maps tokens (specified by their token ID in the tokenizer) -to an associated bias value from -100 to 100. -Mathematically, the bias is added to the logits generated by the model prior to sampling. -The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; -values like -100 or 100 should result in a ban or exclusive selection of the relevant token. -""", - examples=[{}], - ) - user: str = Field( - default="", - description="A unique identifier representing your end-user. For monitoring and debugging purposes.", - examples=[""], - ) - - @field_validator("stop", mode="after") - @classmethod - def convert_stop(cls, v: list[str] | None) -> list[str] | None: - if isinstance(v, list) and len(v) == 0: - v = None - return v - - -class ChatRequestWithTools(ChatRequest): - tools: list[Tool] = Field( - description="A list of tools available for the chat model to use.", - min_length=1, - examples=[ - # --- [Tool Function] --- - # def get_delivery_date(order_id: str) -> datetime: - # # Connect to the database - # conn = sqlite3.connect('ecommerce.db') - # cursor = conn.cursor() - # # ... - [ - Tool( - type="function", - function=Function( - name="get_delivery_date", - description="Get the delivery date for a customer's order.", - parameters=FunctionParameters( - type="object", - properties={ - "order_id": FunctionParameter( - type="string", description="The customer's order ID." - ) - }, - required=["order_id"], - additionalProperties=False, - ), - ), - ) - ], - ], - ) - tool_choice: str | ToolChoice = Field( - default="auto", - description="Set `auto` to let chat model pick a tool or select a tool for the chat model to use.", - examples=[ - "auto", - ToolChoice(type="function", function=ToolChoiceFunction(name="get_delivery_date")), - ], - ) - - -class EmbeddingRequest(BaseModel): - input: str | list[str] = Field( - description=( - "Input text to embed, encoded as a string or array of strings " - "(to embed multiple inputs in a single request). " - "The input must not exceed the max input tokens for the model, and cannot contain empty string." - ), - examples=["What is a llama?", ["What is a llama?", "What is an alpaca?"]], - ) - model: str = Field( - description=( - "The ID of the model to use. " - "You can use the List models API to see all of your available models." - ), - examples=EXAMPLE_EMBEDDING_MODEL_IDS, - ) - type: Literal["query", "document"] = Field( - default="document", - description=( - 'Whether the input text is a "query" (used to retrieve) or a "document" (to be retrieved).' - ), - examples=["query", "document"], - ) - encoding_format: Literal["float", "base64"] = Field( - default="float", - description=( - '_Optional_. The format to return the embeddings in. Can be either "float" or "base64". ' - "`base64` string should be decoded as a `float32` array. " - "Example: `np.frombuffer(base64.b64decode(response), dtype=np.float32)`" - ), - examples=["float", "base64"], - ) - - -class EmbeddingResponseData(BaseModel): - object: str = Field( - default="embedding", - description="Type of API response object.", - examples=["embedding"], - ) - embedding: list[float] | str = Field( - description=( - "The embedding vector, which is a list of floats or a base64-encoded string. " - "The length of vector depends on the model." - ), - examples=[[0.0, 1.0, 2.0], []], - ) - index: int = Field( - default=0, - description="The index of the embedding in the list of embeddings.", - examples=[0, 1], - ) - - -class EmbeddingResponse(BaseModel): - object: str = Field( - default="list", - description="Type of API response object.", - examples=["list"], - ) - data: list[EmbeddingResponseData] = Field( - description="List of `EmbeddingResponseData`.", - examples=[[EmbeddingResponseData(embedding=[0.0, 1.0, 2.0])]], - ) - model: str = Field( - description="The ID of the model used.", - examples=["openai/text-embedding-3-small-512"], - ) - usage: CompletionUsage = Field( - default=CompletionUsage(), - description="The number of tokens consumed.", - examples=[CompletionUsage()], - ) - - -class ClipInputData(BaseModel): - """Data model for Clip input data, assume if image_filename is None then it have to be text, otherwise, the input is an image with bytes content""" - - content: str | bytes - """content of this input data, either be str of text or an """ - image_filename: str | None - """image filename of the content, None if the content is text""" - - -T = TypeVar("T") - - -class Page(BaseModel, Generic[T]): - items: Annotated[ - Sequence[T], Field(description="List of items paginated items.", examples=[[]]) - ] = [] - offset: Annotated[int, Field(description="Number of skipped items.", examples=[0])] = 0 - limit: Annotated[int, Field(description="Number of items per page.", examples=[0])] = 0 - total: Annotated[int, Field(description="Total number of items.", examples=[0])] = 0 - starting_after: Annotated[ - str | int | None, Field(description="Pagination cursor.", examples=["31a0552", 0, None]) - ] = None - - -def nd_array_before_validator(x): - return np.array(x) if isinstance(x, list) else x - - -def datetime_str_before_validator(x): - return x.isoformat() if isinstance(x, datetime) else str(x) - - -COL_NAME_PATTERN = r"^[A-Za-z0-9]([A-Za-z0-9 _-]{0,98}[A-Za-z0-9])?$" -TABLE_NAME_PATTERN = r"^[A-Za-z0-9]([A-Za-z0-9._-]{0,98}[A-Za-z0-9])?$" -ODD_SINGLE_QUOTE = r"(? 0: - return list[float] if json_safe else NdArray - return _str_to_py_type[py_type] - - -class MetaEnum(EnumMeta): - def __contains__(cls, x): - try: - cls[x] - except KeyError: - return False - return True - - -class CSVDelimiter(Enum, metaclass=MetaEnum): - COMMA = "," - """Comma-separated""" - TAB = "\t" - """Tab-separated""" - - def __str__(self) -> str: - return self.value - - -class ColumnDtype(str, Enum, metaclass=MetaEnum): - INT = "int" - INT8 = "int8" - FLOAT = "float" - FLOAT32 = "float32" - FLOAT16 = "float16" - BOOL = "bool" - STR = "str" - DATE_TIME = "date-time" - IMAGE = "image" - AUDIO = "audio" - - def __str__(self) -> str: - return self.value - - -class ColumnDtypeCreate(str, Enum, metaclass=MetaEnum): - INT = "int" - FLOAT = "float" - BOOL = "bool" - STR = "str" - IMAGE = "image" - AUDIO = "audio" - - def __str__(self) -> str: - return self.value - - -class TableType(str, Enum, metaclass=MetaEnum): - ACTION = "action" - """Action table.""" - KNOWLEDGE = "knowledge" - """Knowledge table.""" - CHAT = "chat" - """Chat table.""" - - def __str__(self) -> str: - return self.value - - -class LLMGenConfig(BaseModel): - object: Literal["gen_config.llm"] = Field( - default="gen_config.llm", - description='The object type, which is always "gen_config.llm".', - examples=["gen_config.llm"], - ) - model: str = Field( - default="", - description="ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - ) - system_prompt: str = Field( - default="", - description="System prompt for the LLM.", - ) - prompt: str = Field( - default="", - description="Prompt for the LLM.", - ) - multi_turn: bool = Field( - default=False, - description="Whether this column is a multi-turn chat with history along the entire column.", - ) - rag_params: RAGParams | None = Field( - default=None, - description="Retrieval Augmented Generation search params. Defaults to None (disabled).", - examples=[None], - ) - temperature: Annotated[float, Field(ge=0.001, le=2.0)] = Field( - default=0.2, - description=""" -What sampling temperature to use, in [0.001, 2.0]. -Higher values like 0.8 will make the output more random, -while lower values like 0.2 will make it more focused and deterministic. -""", - examples=[0.2], - ) - top_p: Annotated[float, Field(ge=0.001, le=1.0)] = Field( - default=0.6, - description=""" -An alternative to sampling with temperature, called nucleus sampling, -where the model considers the results of the tokens with top_p probability mass. -So 0.1 means only the tokens comprising the top 10% probability mass are considered. -Must be in [0.001, 1.0]. -""", - examples=[0.6], - ) - stop: list[str] | None = Field( - default=None, - description="Up to 4 sequences where the API will stop generating further tokens.", - examples=[None], - ) - max_tokens: PositiveNonZeroInt = Field( - default=2048, - description=""" -The maximum number of tokens to generate in the chat completion. -Must be in [1, context_length - 1). Default is 2048. -The total length of input tokens and generated tokens is limited by the model's context length. -""", - examples=[2048], - ) - presence_penalty: float = Field( - default=0.0, - description=""" -Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, -increasing the model's likelihood to talk about new topics. -""", - examples=[0.0], - ) - frequency_penalty: float = Field( - default=0.0, - description=""" -Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, -decreasing the model's likelihood to repeat the same line verbatim. -""", - examples=[0.0], - ) - logit_bias: dict = Field( - default={}, - description=""" -Modify the likelihood of specified tokens appearing in the completion. -Accepts a json object that maps tokens (specified by their token ID in the tokenizer) -to an associated bias value from -100 to 100. -Mathematically, the bias is added to the logits generated by the model prior to sampling. -The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; -values like -100 or 100 should result in a ban or exclusive selection of the relevant token. -""", - examples=[{}], - ) - - @model_validator(mode="before") - @classmethod - def compat(cls, data: Any) -> Any: - if isinstance(data, BaseModel): - data = data.model_dump() - if not isinstance(data, dict): - raise TypeError( - f"Input to `LLMGenConfig` must be a dict or BaseModel, received: {type(data)}" - ) - if data.get("system_prompt", None) or data.get("prompt", None): - return data - messages: list[dict[str, Any]] = data.get("messages", []) - num_prompts = len(messages) - if num_prompts >= 2: - data["system_prompt"] = messages[0]["content"] - data["prompt"] = messages[1]["content"] - elif num_prompts == 1: - if messages[0]["role"] == "system": - data["system_prompt"] = messages[0]["content"] - data["prompt"] = "" - elif messages[0]["role"] == "user": - data["system_prompt"] = "" - data["prompt"] = messages[0]["content"] - else: - raise ValueError( - f'Attribute "messages" cannot contain only assistant messages: {messages}' - ) - data["object"] = "gen_config.llm" - return data - - @field_validator("stop", mode="after") - @classmethod - def convert_stop(cls, v: list[str] | None) -> list[str] | None: - if isinstance(v, list) and len(v) == 0: - v = None - return v - - -class EmbedGenConfig(BaseModel): - object: Literal["gen_config.embed"] = Field( - default="gen_config.embed", - description='The object type, which is always "gen_config.embed".', - examples=["gen_config.embed"], - ) - embedding_model: str = Field( - description="The embedding model to use.", - examples=EXAMPLE_EMBEDDING_MODEL_IDS, - ) - source_column: str = Field( - description="The source column for embedding.", - examples=["text_column"], - ) - - -class CodeGenConfig(p.CodeGenConfig): - pass - - -def _gen_config_discriminator(x: Any) -> str | None: - object_attr = getattr(x, "object", None) - if object_attr: - return object_attr - if isinstance(x, BaseModel): - x = x.model_dump() - if isinstance(x, dict): - if "object" in x: - return x["object"] - if "embedding_model" in x: - return "gen_config.embed" - else: - return "gen_config.llm" - return None - - -GenConfig = LLMGenConfig | EmbedGenConfig | CodeGenConfig -DiscriminatedGenConfig = Annotated[ - Union[ - Annotated[CodeGenConfig, Tag("gen_config.code")], - Annotated[LLMGenConfig, Tag("gen_config.llm")], - Annotated[LLMGenConfig, Tag("gen_config.chat")], - Annotated[EmbedGenConfig, Tag("gen_config.embed")], - ], - Discriminator(_gen_config_discriminator), -] - - -class ColumnSchema(BaseModel): - id: str = Field(description="Column name.") - dtype: ColumnDtype = Field( - default=ColumnDtype.STR, - description='Column data type, one of ["int", "int8", "float", "float32", "float16", "bool", "str", "date-time", "image"]', - ) - vlen: PositiveInt = Field( # type: ignore - default=0, - description=( - "_Optional_. Vector length. " - "If this is larger than zero, then `dtype` must be one of the floating data types. Defaults to zero." - ), - ) - index: bool = Field( - default=True, - description=( - "_Optional_. Whether to build full-text-search (FTS) or vector index for this column. " - "Only applies to string and vector columns. Defaults to True." - ), - ) - gen_config: DiscriminatedGenConfig | None = Field( - default=None, - description=( - '_Optional_. Generation config. If provided, then this column will be an "Output Column". ' - "Table columns on its left can be referenced by `${column-name}`." - ), - ) - - @model_validator(mode="after") - def check_vector_column_dtype(self) -> Self: - if self.vlen > 0 and self.dtype not in (ColumnDtype.FLOAT32, ColumnDtype.FLOAT16): - raise ValueError("Vector columns must contain float32 or float16 only.") - return self - - -class ColumnSchemaCreate(ColumnSchema): - id: ColName = Field(description="Column name.") - dtype: ColumnDtypeCreate = Field( - default=ColumnDtypeCreate.STR, - description='Column data type, one of ["int", "float", "bool", "str", "image", "audio"]', - ) - - @model_validator(mode="before") - def match_column_dtype_file_to_image(self) -> Self: - if self.get("dtype", "") == "file": - self["dtype"] = ColumnDtype.IMAGE - return self - - @model_validator(mode="after") - def check_output_column_dtype(self) -> Self: - if self.gen_config is not None and self.vlen == 0: - if isinstance(self.gen_config, CodeGenConfig): - if self.dtype not in (ColumnDtype.STR, ColumnDtype.IMAGE): - raise ValueError( - "Output column must be either string or image column when gen_config is CodeGenConfig." - ) - elif self.dtype != ColumnDtype.STR: - raise ValueError("Output column must be string column.") - return self - - -class TableSQLModel(SQLModel): - metadata = MetaData() - - -class TableBase(TableSQLModel): - id: str = sql_Field(primary_key=True, description="Table name.") - version: str = sql_Field( - default=owl_version, description="Table version, following owl version." - ) - meta: dict[str, Any] = sql_Field( - sa_column=Column(JSON), - default={}, - description="Additional metadata about the table.", - ) - - -class TableSchema(TableBase): - cols: list[ColumnSchema] = sql_Field(description="List of column schema.") - - def get_col(self, id: str): - return [c for c in self.cols if c.id.lower() == id.lower()][0] - - @staticmethod - def _get_col_dtype(py_type: str, vlen: int = 0): - if vlen > 0: - return pa.list_(_str_to_arrow[py_type], vlen) - return _str_to_arrow[py_type] - - @property - def pyarrow(self) -> pa.Schema: - return pa.schema( - [pa.field(c.id, self._get_col_dtype(c.dtype.value, c.vlen)) for c in self.cols] - ) - - @property - def pyarrow_vec(self) -> pa.Schema: - return pa.schema( - [ - pa.field(c.id, self._get_col_dtype(c.dtype.value, c.vlen)) - for c in self.cols - if c.vlen > 0 - ] - ) - - def add_state_cols(self) -> Self: - """ - Adds state columns. - - Returns: - self (TableSchemaCreate): TableSchemaCreate - """ - cols = [] - for c in self.cols: - cols.append(c) - if c.id.lower() not in ("id", "updated at"): - cols.append(ColumnSchema(id=f"{c.id}_", dtype=ColumnDtype.STR)) - self.cols = cols - return self - - def add_info_cols(self) -> Self: - """ - Adds "ID", "Updated at" columns. - - Returns: - self (TableSchemaCreate): TableSchemaCreate - """ - self.cols = [ - ColumnSchema(id="ID", dtype=ColumnDtype.STR), - ColumnSchema(id="Updated at", dtype=ColumnDtype.DATE_TIME), - ] + self.cols - return self - - @staticmethod - def get_default_prompts( - table_id: str, - curr_column: ColumnSchema, - column_ids: list[str], - ) -> tuple[str, str]: - input_cols = "\n\n".join(c + ": ${" + c + "}" for c in column_ids) - if getattr(curr_column.gen_config, "multi_turn", False): - system_prompt = ( - f'You are an agent named "{table_id}". Be helpful. Provide answers based on the information given. ' - "Ensuring that your reply is easy to understand and is accessible to all users. " - "Be factual and do not hallucinate." - ) - user_prompt = "${User}" - else: - system_prompt = ( - "You are a versatile data generator. " - "Your task is to process information from input data and generate appropriate responses " - "based on the specified column name and input data. " - "Adapt your output format and content according to the column name provided." - ) - user_prompt = ( - f'Table name: "{table_id}"\n\n' - f"{input_cols}\n\n" - f'Based on the available information, provide an appropriate response for the column "{curr_column.id}".\n' - "Remember to act as a cell in a spreadsheet and provide concise, " - "relevant information without explanations unless specifically requested." - ) - return system_prompt, user_prompt - - @model_validator(mode="after") - def check_gen_configs(self) -> Self: - for i, col in enumerate(self.cols): - gen_config = col.gen_config - if gen_config is None: - continue - available_cols = [ - col - for col in self.cols[:i] - if (not col.id.endswith("_")) - and col.id.lower() not in ("id", "updated at") - and col.vlen == 0 - ] - col_ids = [col.id for col in available_cols] - col_ids_set = set(col_ids) - if isinstance(gen_config, EmbedGenConfig): - if gen_config.source_column not in col_ids_set: - raise ValueError( - ( - f"Table '{self.id}': " - f"Embedding config of column '{col.id}' referenced " - f"an invalid source column '{gen_config.source_column}'. " - "Make sure you only reference columns on its left. " - f"Available columns: {col_ids}." - ) - ) - elif isinstance(gen_config, CodeGenConfig): - source_col = next( - (c for c in available_cols if c.id == gen_config.source_column), None - ) - if source_col is None: - raise ValueError( - ( - f"Table '{self.id}': " - f"Code Execution config of column '{col.id}' referenced " - f"an invalid source column '{gen_config.source_column}'. " - "Make sure you only reference columns on its left. " - f"Available columns: {col_ids}." - ) - ) - if source_col.dtype != ColumnDtype.STR: - raise ValueError( - ( - f"Table '{self.id}': " - f"Code Execution config of column '{col.id}' referenced " - f"a source column '{gen_config.source_column}' with an invalid datatype of '{source_col.dtype}'. " - "Make sure the source column is Str typed." - ) - ) - elif isinstance(gen_config, LLMGenConfig): - # Insert default prompts if needed - system_prompt, user_prompt = self.get_default_prompts( - table_id=self.id, - curr_column=col, - column_ids=[col.id for col in available_cols if col.gen_config is None], - ) - if not gen_config.system_prompt.strip(): - gen_config.system_prompt = system_prompt - if not gen_config.prompt.strip(): - gen_config.prompt = user_prompt - # Check references - for message in (gen_config.system_prompt, gen_config.prompt): - for key in re.findall(GEN_CONFIG_VAR_PATTERN, message): - if key not in col_ids_set: - raise ValueError( - ( - f"Table '{self.id}': " - f"Generation prompt of column '{col.id}' referenced " - f"an invalid source column '{key}'. " - "Make sure you only reference columns on its left. " - f"Available columns: {col_ids}." - ) - ) - return self - - -class TableSchemaCreate(TableSchema): - id: TableName = Field(description="Table name.") - cols: list[ColumnSchemaCreate] = Field(description="List of column schema.") - - @model_validator(mode="after") - def check_cols(self) -> Self: - if len(set(c.id.lower() for c in self.cols)) != len(self.cols): - raise ValueError("There are repeated column names (case-insensitive) in the schema.") - if sum(c.id.lower() in ("id", "updated at") for c in self.cols) > 0: - raise ValueError("Schema cannot contain column names: 'ID' or 'Updated at'.") - if sum(c.vlen > 0 for c in self.cols) > 0: - raise ValueError("Schema cannot contain columns with `vlen` > 0.") - return self - - -class ActionTableSchemaCreate(TableSchemaCreate): - pass - - -class AddActionColumnSchema(ActionTableSchemaCreate): - @model_validator(mode="after") - def check_gen_configs(self) -> Self: - # Check gen config using TableSchema - return self - - -class KnowledgeTableSchemaCreate(TableSchemaCreate): - embedding_model: str - - @model_validator(mode="after") - def check_cols(self) -> Self: - super().check_cols() - num_text_cols = sum( - c.id.lower() in ("text", "title", "file id", "page") for c in self.cols - ) - if num_text_cols != 0: - raise ValueError( - "Schema cannot contain column names: 'Text', 'Title', 'File ID', 'Page'." - ) - return self - - @staticmethod - def get_default_prompts(*args, **kwargs) -> tuple[str, str]: - # This should act as if its AddKnowledgeColumnSchema - return "", "" - - -class AddKnowledgeColumnSchema(TableSchemaCreate): - @model_validator(mode="after") - def check_cols(self) -> Self: - super().check_cols() - num_text_cols = sum( - c.id.lower() in ("text", "title", "file id", "page") for c in self.cols - ) - if num_text_cols != 0: - raise ValueError( - "Schema cannot contain column names: 'Text', 'Title', 'File ID', 'Page'." - ) - return self - - @model_validator(mode="after") - def check_gen_configs(self) -> Self: - # Check gen config using TableSchema - return self - - -class ChatTableSchemaCreate(TableSchemaCreate): - pass - - -class AddChatColumnSchema(TableSchemaCreate): - @model_validator(mode="after") - def check_cols(self) -> Self: - super().check_cols() - return self - - @model_validator(mode="after") - def check_gen_configs(self) -> Self: - # Check gen config using TableSchema - return self - - -class TableMeta(TableBase, table=True): - cols: list[dict[str, Any]] = sql_Field( - sa_column=Column(JSON), description="List of column schema." - ) - parent_id: str | None = sql_Field( - default=None, - description="The parent table ID. If None (default), it means this is a template table.", - ) - title: str = sql_Field( - default="", - description="Chat title. Defaults to ''.", - ) - updated_at: str = sql_Field( - default_factory=datetime_now_iso, - description="Table last update timestamp (ISO 8601 UTC).", - ) # SQLite does not support TZ - indexed_at_fts: str | None = sql_Field( - default=None, description="Table last FTS index timestamp (ISO 8601 UTC)." - ) - indexed_at_vec: str | None = sql_Field( - default=None, description="Table last vector index timestamp (ISO 8601 UTC)." - ) - indexed_at_sca: str | None = sql_Field( - default=None, description="Table last scalar index timestamp (ISO 8601 UTC)." - ) - - @property - def cols_schema(self) -> list[ColumnSchema]: - return [ColumnSchema.model_validate(c) for c in deepcopy(self.cols)] - - @property - def regular_cols(self) -> list[ColumnSchema]: - return [c for c in self.cols_schema if not c.id.endswith("_")] - - -class TableMetaResponse(TableSchema): - parent_id: TableName | None = Field( - description="The parent table ID. If None (default), it means this is a template table.", - ) - title: str = Field(description="Chat title. Defaults to ''.") - updated_at: str = Field( - description="Table last update timestamp (ISO 8601 UTC).", - ) # SQLite does not support TZ - indexed_at_fts: str | None = Field( - description="Table last FTS index timestamp (ISO 8601 UTC)." - ) - indexed_at_vec: str | None = Field( - description="Table last vector index timestamp (ISO 8601 UTC)." - ) - indexed_at_sca: str | None = Field( - description="Table last scalar index timestamp (ISO 8601 UTC)." - ) - num_rows: int = Field( - default=-1, - description="Number of rows in the table. Defaults to -1 (not counted).", - ) - - @model_validator(mode="after") - def check_gen_configs(self) -> Self: - return self - - @model_validator(mode="after") - def remove_state_cols(self) -> Self: - self.cols = [c for c in self.cols if not c.id.endswith("_")] - return self - - -class RowAddData(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - table_meta: TableMeta = Field(description="Table metadata.") - data: list[dict[ColName, Any]] = Field( - description="List of row data to add or update. Each list item is a mapping of column ID to its value." - ) - errors: list[list[str]] | None = Field( - default=None, - description=( - "List of row columns that encountered errors (perhaps LLM generation failed mid-stream). " - "Each list item is a list of column IDs." - ), - ) - - @model_validator(mode="after") - def check_errors(self) -> Self: - if self.errors is None: - return self - if len(self.errors) != len(self.data): - raise ValueError( - ( - "`errors` must contain same number of items as `data`, " - f"received: len(errors)={len(self.errors)} len(data)={len(self.data)}" - ) - ) - return self - - @model_validator(mode="after") - def check_data(self) -> Self: - if "updated at" in self.data: - raise ValueError("`data` cannot contain keys: 'Updated at'.") - return self - - @model_validator(mode="after") - def handle_nulls_and_validate(self) -> Self: - return self._handle_nulls_and_validate() - - def _handle_nulls_and_validate(self, check_missing_cols: bool = True) -> Self: - cols = { - c.id: c - for c in self.table_meta.cols_schema - if not (c.id.lower() in ("id", "updated at") or c.id.endswith("_")) - } - # Create the row schema for validation - PydanticSchema: Type[BaseModel] = create_model( - f"{self.__class__.__name__}Schema", - __config__=ConfigDict(arbitrary_types_allowed=True), - **{c.id: (str_to_py_type(c.dtype.value, c.vlen) | None, None) for c in cols.values()}, - ) - self.errors = [[] for _ in self.data] - - # Validate - for d, err in zip(self.data, self.errors, strict=True): - # Fill in missing cols - if check_missing_cols: - for k in cols: - if k not in d: - d[k] = None - try: - PydanticSchema.model_validate(d) - except ValidationError as e: - failed_cols = set(reduce(lambda a, b: a + b, (err["loc"] for err in e.errors()))) - logger.info( - f"Table {self.table_meta.id}: These columns failed validation: {failed_cols}" - ) - else: - failed_cols = {} - for k in list(d.keys()): - if k not in cols: - continue - col = cols[k] - state = {} - if k in failed_cols: - d[k], state["original"] = None, d[k] - if k in err: - d[k] = None - # state["error"] = True - if d[k] is None: - if col.dtype == ColumnDtype.INT: - d[k] = 0 - elif col.dtype == ColumnDtype.FLOAT: - d[k] = 0.0 - elif col.dtype == ColumnDtype.BOOL: - d[k] = False - elif col.dtype in (ColumnDtype.STR, ColumnDtype.IMAGE): - # Store null string as "" - # https://github.com/lancedb/lancedb/issues/1160 - d[k] = "" - elif col.vlen > 0: - # TODO: Investigate setting null vectors to np.nan - # Pros: nan vectors won't show up in vector search - # Cons: May cause error during vector indexing - d[k] = np.zeros([col.vlen], dtype=_str_to_py_type[col.dtype.value]) - state["is_null"] = True - else: - if col.vlen > 0: - d[k] = np.asarray(d[k], dtype=_str_to_py_type[col.dtype.value]) - state["is_null"] = False - d[f"{k}_"] = json_dumps(state) - d["Updated at"] = datetime.now(timezone.utc) - return self - - def set_id(self) -> Self: - """ - Sets ID, - - Returns: - self (RowAddData): RowAddData - """ - for d in self.data: - if "ID" not in d: - d["ID"] = uuid7_draft2_str() - return self - - def sql_escape(self) -> Self: - cols = {c.id: c for c in self.table_meta.cols_schema} - for d in self.data: - for k in list(d.keys()): - if cols[k].dtype == ColumnDtype.STR: - d[k] = re.sub(ODD_SINGLE_QUOTE, "''", d[k]) - return self - - -class RowUpdateData(RowAddData): - @model_validator(mode="after") - def check_data(self) -> Self: - if sum(n.lower() in ("id", "updated at") for d in self.data for n in d) > 0: - raise ValueError("`data` cannot contain keys: 'ID' or 'Updated at'.") - return self - - @model_validator(mode="after") - def handle_nulls_and_validate(self) -> Self: - return self._handle_nulls_and_validate(check_missing_cols=False) - - -class GenConfigUpdateRequest(BaseModel): - table_id: TableName = Field(description="Table name or ID.") - column_map: dict[ColName, DiscriminatedGenConfig | None] = Field( - description=( - "Mapping of column ID to generation config JSON in the form of `GenConfig`. " - "Table columns on its left can be referenced by `${column-name}`." - ) - ) - - @model_validator(mode="after") - def check_column_map(self) -> Self: - if sum(n.lower() in ("id", "updated at") for n in self.column_map) > 0: - raise ValueError("column_map cannot contain keys: 'ID' or 'Updated at'.") - return self - - -class ColumnRenameRequest(BaseModel): - table_id: TableName = Field(description="Table name or ID.") - column_map: dict[ColName, ColName] = Field( - description="Mapping of old column names to new column names." - ) - - @model_validator(mode="after") - def check_column_map(self) -> Self: - if sum(n.lower() in ("id", "updated at") for n in self.column_map) > 0: - raise ValueError("`column_map` cannot contain keys: 'ID' or 'Updated at'.") - return self - - -class ColumnReorderRequest(BaseModel): - table_id: TableName = Field(description="Table name or ID.") - column_names: list[ColName] = Field(description="List of column ID in the desired order.") - - @field_validator("column_names", mode="after") - @classmethod - def check_unique_column_names(cls, value: list[ColName]) -> list[ColName]: - if len(set(n.lower() for n in value)) != len(value): - raise ValueError("Column names must be unique (case-insensitive).") - return value - - -class ColumnDropRequest(BaseModel): - table_id: TableName = Field(description="Table name or ID.") - column_names: list[ColName] = Field(description="List of column ID to drop.") - - @model_validator(mode="after") - def check_column_names(self) -> Self: - if sum(n.lower() in ("id", "updated at") for n in self.column_names) > 0: - raise ValueError("`column_names` cannot contain keys: 'ID' or 'Updated at'.") - return self - - -class Task(BaseModel): - output_column_name: str - body: LLMGenConfig - - -class RowAdd(BaseModel): - table_id: TableName = Field( - description="Table name or ID.", - ) - data: dict[ColName, Any] = Field( - description="Mapping of column names to its value.", - ) - stream: bool = Field( - default=True, - description="Whether or not to stream the LLM generation.", - ) - reindex: bool | None = Field( - default=None, - description=( - "_Optional_. If True, reindex immediately. If False, wait until next periodic reindex. " - "If None (default), reindex immediately for smaller tables." - ), - ) - concurrent: bool = Field( - default=True, - description="_Optional_. Whether or not to concurrently generate the output columns.", - ) - - -class RowAddRequest(BaseModel): - table_id: TableName = Field( - description="Table name or ID.", - ) - data: list[dict[ColName, Any]] = Field( - min_length=1, - description=( - "List of mapping of column names to its value. " - "In other words, each item in the list is a row, and each item is a mapping. " - "Minimum 1 row, maximum 100 rows." - ), - ) - stream: bool = Field( - default=True, - description="Whether or not to stream the LLM generation.", - ) - reindex: bool | None = Field( - default=None, - description=( - "_Optional_. If True, reindex immediately. If False, wait until next periodic reindex. " - "If None (default), reindex immediately for smaller tables." - ), - ) - concurrent: bool = Field( - default=True, - description="_Optional_. Whether or not to concurrently generate the output rows and columns.", - ) - - def __repr__(self): - _data = [ - { - k: ( - {"type": type(v), "shape": v.shape, "dtype": v.dtype} - if isinstance(v, np.ndarray) - else v - ) - } - for row in self.data - for k, v in row.items() - ] - return ( - f"{self.__class__.__name__}(" - f"table_id={self.table_id} stream={self.stream} reindex={self.reindex}" - f"concurrent={self.concurrent} data={_data}" - ")" - ) - - @model_validator(mode="after") - def check_data(self) -> Self: - for row in self.data: - for value in row.values(): - if isinstance(value, str) and ( - value.startswith("s3://") or value.startswith("file://") - ): - extension = splitext(value)[1].lower() - if extension not in IMAGE_FILE_EXTENSIONS + AUDIO_FILE_EXTENSIONS: - raise ValueError( - "Unsupported file type. Make sure the file belongs to " - "one of the following formats: \n" - f"[Image File Types]: \n{IMAGE_FILE_EXTENSIONS} \n" - f"[Audio File Types]: \n{AUDIO_FILE_EXTENSIONS}" - ) - return self - - -class RowAddRequestWithLimit(RowAddRequest): - data: list[dict[ColName, Any]] = Field( - min_length=1, - max_length=100, - description=( - "List of mapping of column names to its value. " - "In other words, each item in the list is a row, and each item is a mapping. " - "Minimum 1 row, maximum 100 rows." - ), - ) - - -class RowUpdateRequest(BaseModel): - table_id: TableName = Field( - description="Table name or ID.", - ) - row_id: str = Field( - description="ID of the row to update.", - ) - data: dict[ColName, Any] = Field( - description="Mapping of column names to its value.", - ) - reindex: bool | None = Field( - default=None, - description=( - "_Optional_. If True, reindex immediately. If False, wait until next periodic reindex. " - "If None (default), reindex immediately for smaller tables." - ), - ) - - @model_validator(mode="after") - def check_data(self) -> Self: - for value in self.data.values(): - if isinstance(value, str) and ( - value.startswith("s3://") or value.startswith("file://") - ): - extension = splitext(value)[1].lower() - if extension not in IMAGE_FILE_EXTENSIONS + AUDIO_FILE_EXTENSIONS: - raise ValueError( - "Unsupported file type. Make sure the file belongs to " - "one of the following formats: \n" - f"[Image File Types]: \n{IMAGE_FILE_EXTENSIONS} \n" - f"[Audio File Types]: \n{AUDIO_FILE_EXTENSIONS}" - ) - return self - - -class RegenStrategy(str, Enum): - """Strategies for selecting columns during row regeneration.""" - - RUN_ALL = "run_all" - RUN_BEFORE = "run_before" - RUN_SELECTED = "run_selected" - RUN_AFTER = "run_after" - - def __str__(self) -> str: - return self.value - - -class RowRegen(BaseModel): - table_id: TableName = Field( - description="Table name or ID.", - ) - row_id: str = Field( - description="ID of the row to regenerate.", - ) - regen_strategy: RegenStrategy = Field( - default=RegenStrategy.RUN_ALL, - description=( - "_Optional_. Strategy for selecting columns to regenerate." - "Choose `run_all` to regenerate all columns in the specified row; " - "Choose `run_before` to regenerate columns up to the specified column_id; " - "Choose `run_selected` to regenerate only the specified column_id; " - "Choose `run_after` to regenerate columns starting from the specified column_id; " - ), - ) - output_column_id: str | None = Field( - default=None, - description=( - "_Optional_. Output column name to indicate the starting or ending point of regen for `run_before`, " - "`run_selected` and `run_after` strategies. Required if `regen_strategy` is not 'run_all'. " - "Given columns are 'C1', 'C2', 'C3' and 'C4', if column_id is 'C3': " - "`run_before` regenerate columns 'C1', 'C2' and 'C3'; " - "`run_selected` regenerate only column 'C3'; " - "`run_after` regenerate columns 'C3' and 'C4'; " - ), - ) - stream: bool = Field( - description="Whether or not to stream the LLM generation.", - ) - reindex: bool | None = Field( - default=None, - description=( - "_Optional_. If True, reindex immediately. If False, wait until next periodic reindex. " - "If None (default), reindex immediately for smaller tables." - ), - ) - concurrent: bool = Field( - default=True, - description="_Optional_. Whether or not to concurrently generate the output columns.", - ) - - -class RowRegenRequest(BaseModel): - table_id: TableName = Field( - description="Table name or ID.", - ) - row_ids: list[str] = Field( - min_length=1, - max_length=100, - description="List of ID of the row to regenerate. Minimum 1 row, maximum 100 rows.", - ) - regen_strategy: RegenStrategy = Field( - default=RegenStrategy.RUN_ALL, - description=( - "_Optional_. Strategy for selecting columns to regenerate." - "Choose `run_all` to regenerate all columns in the specified row; " - "Choose `run_before` to regenerate columns up to the specified column_id; " - "Choose `run_selected` to regenerate only the specified column_id; " - "Choose `run_after` to regenerate columns starting from the specified column_id; " - ), - ) - output_column_id: str | None = Field( - default=None, - description=( - "_Optional_. Output column name to indicate the starting or ending point of regen for `run_before`, " - "`run_selected` and `run_after` strategies. Required if `regen_strategy` is not 'run_all'. " - "Given columns are 'C1', 'C2', 'C3' and 'C4', if column_id is 'C3': " - "`run_before` regenerate columns 'C1', 'C2' and 'C3'; " - "`run_selected` regenerate only column 'C3'; " - "`run_after` regenerate columns 'C3' and 'C4'; " - ), - ) - stream: bool = Field( - description="Whether or not to stream the LLM generation.", - ) - reindex: bool | None = Field( - default=None, - description=( - "_Optional_. If True, reindex immediately. If False, wait until next periodic reindex. " - "If None (default), reindex immediately for smaller tables." - ), - ) - concurrent: bool = Field( - default=True, - description="_Optional_. Whether or not to concurrently generate the output rows and columns.", - ) - - @model_validator(mode="after") - def check_output_column_id_provided(self) -> Self: - if self.regen_strategy != RegenStrategy.RUN_ALL and self.output_column_id is None: - raise ValueError( - "`output_column_id` is required for regen_strategy other than 'run_all'." - ) - return self - - @model_validator(mode="after") - def sort_row_ids(self) -> Self: - self.row_ids = sorted(self.row_ids) - return self - - -class RowDeleteRequest(BaseModel): - table_id: TableName = Field(description="Table name or ID.") - row_ids: list[str] | None = Field( - min_length=1, - max_length=100, - default=None, - description="List of ID of the row to delete. Minimum 1 row, maximum 100 rows.", - ) - where: str | None = Field( - default=None, - description="_Optional_. SQL where clause. If not provided, will match all rows and thus deleting all table content.", - ) - reindex: bool | None = Field( - default=None, - description=( - "_Optional_. If True, reindex immediately. If False, wait until next periodic reindex. " - "If None (default), reindex immediately for smaller tables." - ), - ) - - -class EmbedFileRequest(BaseModel): - table_id: TableName = Field(description="Table name or ID.") - file_id: str = Field(description="ID of the file.") - chunk_size: Annotated[ - int, Field(description="Maximum chunk size (number of characters). Must be > 0.", gt=0) - ] = 1000 - chunk_overlap: Annotated[ - int, Field(description="Overlap in characters between chunks. Must be >= 0.", ge=0) - ] = 200 - # stream: Annotated[bool, Field(description="Whether or not to stream the LLM generation.")] = ( - # True - # ) - - -class SearchRequest(BaseModel): - table_id: TableName = Field(description="Table name or ID.") - query: str = Field( - min_length=1, - description="Query for full-text-search (FTS) and vector search. Must not be empty.", - ) - where: str | None = Field( - default=None, - description="_Optional_. SQL where clause. If not provided, will match all rows.", - ) - limit: Annotated[int, Field(gt=0, le=1_000)] = Field( - default=100, description="_Optional_. Min 1, max 1000. Number of rows to return." - ) - metric: str = Field( - default="cosine", - description='_Optional_. Vector search similarity metric. Defaults to "cosine".', - ) - nprobes: Annotated[int, Field(gt=0, le=1000)] = Field( - default=50, - description=( - "_Optional_. Set the number of partitions to search (probe)." - "This argument is only used when the vector column has an IVF PQ index. If there is no index then this value is ignored. " - "The IVF stage of IVF PQ divides the input into partitions (clusters) of related values. " - "The partition whose centroids are closest to the query vector will be exhaustively searched to find matches. " - "This parameter controls how many partitions should be searched. " - "Increasing this value will increase the recall of your query but will also increase the latency of your query. Defaults to 50." - ), - ) - refine_factor: Annotated[int, Field(gt=0, le=1000)] = Field( - default=20, - description=( - "_Optional_. A multiplier to control how many additional rows are taken during the refine step. " - "This argument is only used when the vector column has an IVF PQ index. " - "If there is no index then this value is ignored. " - "An IVF PQ index stores compressed (quantized) values. " - "They query vector is compared against these values and, since they are compressed, the comparison is inaccurate. " - "This parameter can be used to refine the results. " - "It can improve both improve recall and correct the ordering of the nearest results. " - "To refine results LanceDb will first perform an ANN search to find the nearest limit * refine_factor results. " - "In other words, if refine_factor is 3 and limit is the default (10) then the first 30 results will be selected. " - "LanceDb then fetches the full, uncompressed, values for these 30 results. " - "The results are then reordered by the true distance and only the nearest 10 are kept. Defaults to 50." - ), - ) - float_decimals: int = Field( - default=0, - description="_Optional_. Number of decimals for float values. Defaults to 0 (no rounding).", - ) - vec_decimals: int = Field( - default=0, - description="_Optional_. Number of decimals for vectors. If its negative, exclude vector columns. Defaults to 0 (no rounding).", - ) - reranking_model: Annotated[ - str | None, Field(description="Reranking model to use for hybrid search.") - ] = None - - -class FileUploadRequest(BaseModel): - file_path: Annotated[str, Field(description="File path of the document to be uploaded.")] - table_id: Annotated[str, Field(description="Knowledge Table name / ID.")] - chunk_size: Annotated[ - int, Field(description="Maximum chunk size (number of characters). Must be > 0.", gt=0) - ] = 1000 - chunk_overlap: Annotated[ - int, Field(description="Overlap in characters between chunks. Must be >= 0.", ge=0) - ] = 200 - # overwrite: Annotated[ - # bool, - # Field( - # description="Whether to overwrite the file.", - # examples=[True, False], - # ), - # ] = False - - -class TableDataImportRequest(BaseModel): - file_path: Annotated[str, Field(description="CSV or TSV file path.")] - table_id: Annotated[str, Field(description="Table name / ID.")] - stream: Annotated[bool, Field(description="Whether or not to stream the LLM generation.")] = ( - True - ) - # column_names: Annotated[ - # list[str] | None, - # Field( - # description="A list of columns names if the CSV does not have header row. Defaults to None (read from CSV)." - # ), - # ] = None - # columns: Annotated[ - # list[str] | None, - # Field( - # description="A list of columns to be imported. Defaults to None (import all columns except 'ID' and 'Updated at')." - # ), - # ] = None - delimiter: Annotated[ - str, - Field(description='The delimiter of the file: can be "," or "\\t". Defaults to ",".'), - ] = "," - - -class FileUploadResponse(p.FileUploadResponse): - pass - - -class GetURLRequest(p.GetURLRequest): - pass - - -class GetURLResponse(p.GetURLResponse): - pass diff --git a/services/api/src/owl/routers/auth.py b/services/api/src/owl/routers/auth.py new file mode 100644 index 0000000..a3e99bf --- /dev/null +++ b/services/api/src/owl/routers/auth.py @@ -0,0 +1,90 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Request +from pwdlib import PasswordHash +from sqlmodel import select + +from owl.db import AsyncSession, yield_async_session +from owl.db.models import User +from owl.types import ( + PasswordChangeRequest, + PasswordLoginRequest, + UserAuth, + UserCreate, + UserReadObscured, +) +from owl.utils.auth import auth_user_service_key +from owl.utils.exceptions import ( + AuthorizationError, + ForbiddenError, + ResourceNotFoundError, + handle_exception, +) + +router = APIRouter() + + +@router.post("/v2/auth/register/password", summary="Register with email and password.") +@handle_exception +async def register_password( + request: Request, + session: Annotated[AsyncSession, Depends(yield_async_session)], + body: UserCreate, +) -> UserReadObscured: + from owl.routers.users.oss import create_user + + return await create_user(request=request, token="", session=session, body=body) + + +@router.post("/v2/auth/login/password", summary="Login with email and password.") +@handle_exception +async def login_password( + session: Annotated[AsyncSession, Depends(yield_async_session)], + body: PasswordLoginRequest, +) -> UserReadObscured: + user = (await session.exec(select(User).where(User.email == body.email))).one_or_none() + if user: + password_hash, updated_hash = user.password_hash, None + if password_hash is None: + raise AuthorizationError("Invalid password.") + hasher = PasswordHash.recommended() + password_match, updated_hash = hasher.verify_and_update(body.password, user.password_hash) + if password_match: + if updated_hash is not None: + user.password_hash = updated_hash + session.add(user) + await session.commit() + await session.refresh(user) + else: + raise AuthorizationError("Invalid password.") + else: + raise AuthorizationError("User not found.") + user = await User.get(session, user.id, populate_existing=True) + return user + + +@router.patch("/v2/auth/login/password", summary="Change password.") +@handle_exception +async def change_password( + session: Annotated[AsyncSession, Depends(yield_async_session)], + _user: Annotated[UserAuth, Depends(auth_user_service_key)], + body: PasswordChangeRequest, +) -> UserReadObscured: + if _user.email != body.email: + raise ForbiddenError("You can only update your own account.") + # Re-fetch user to set `password_hash` + user = await User.get(session, _user.id) + if user is None: + raise ResourceNotFoundError(f'User "{_user.id}" is not found.') + password_hash, updated_hash = user.password_hash, None + hasher = PasswordHash.recommended() + password_match = hasher.verify(body.password, password_hash) + if password_match: + updated_hash = hasher.hash(body.new_password) + user.password_hash = updated_hash + session.add(user) + await session.commit() + await session.refresh(user) + else: + raise AuthorizationError("Invalid existing password.") + return user diff --git a/services/api/src/owl/routers/conversation.py b/services/api/src/owl/routers/conversation.py new file mode 100644 index 0000000..6ad7aaa --- /dev/null +++ b/services/api/src/owl/routers/conversation.py @@ -0,0 +1,574 @@ +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, Query, Request +from fastapi.responses import StreamingResponse +from loguru import logger + +from owl.db.gen_executor import MultiRowGenExecutor +from owl.db.gen_table import ChatTable +from owl.types import ( + AgentMetaResponse, + ConversationCreateRequest, + ConversationMetaResponse, + ConversationThreadsResponse, + GetConversationThreadsQuery, + ListMessageQuery, + ListQuery, + LLMGenConfig, + MessageAddRequest, + MessagesRegenRequest, + MessageUpdateRequest, + MultiRowAddRequest, + MultiRowRegenRequest, + OkResponse, + OrganizationRead, + Page, + ProjectRead, + SanitisedStr, + TableMetaResponse, + UserAuth, +) +from owl.utils.auth import auth_user_project, has_permissions +from owl.utils.billing import BillingManager +from owl.utils.exceptions import ResourceNotFoundError, handle_exception +from owl.utils.lm import LMEngine +from owl.utils.mcp import MCP_TOOL_TAG + +router = APIRouter() + + +def _table_meta_to_conv(metas: Page[TableMetaResponse]) -> Page[ConversationMetaResponse]: + """Converts Page[TableMetaResponse] to Page[ConversationMetaResponse].""" + return Page[ConversationMetaResponse]( + items=[ConversationMetaResponse.from_table_meta(m) for m in metas.items], + limit=metas.limit, + offset=metas.offset, + total=metas.total, + ) + + +async def _generate_and_save_title( + request: Request, + project: ProjectRead, + organization: OrganizationRead, + conversation_id: str, + table: ChatTable, +): + first_multiturn_column_meta = next( + ( + c + for c in table.column_metadata + if isinstance(c.gen_config, LLMGenConfig) and c.gen_config.multi_turn + ), + None, + ) + if first_multiturn_column_meta is None: + raise ResourceNotFoundError( + f'Conversation "{conversation_id}" has no multi-turn LLM column configured.' + ) + + first_multiturn_column_id = first_multiturn_column_meta.column_id + title_model_id = first_multiturn_column_meta.gen_config.model + + # Generate title after the first user message is saved and streamed + rows_page = await table.list_rows(limit=1, order_ascending=True) + first_user_content = rows_page.items[0].get("User", "") + first_assistant_content = rows_page.items[0].get(first_multiturn_column_id, "") + + llm = LMEngine(organization=organization, project=project, request=request) + generated_title = await llm.generate_chat_title( + user_content=first_user_content, + assistant_content=first_assistant_content, + model=title_model_id, + ) + await table.update_table_title(generated_title) + + +### --- Conversations CRUD --- ### + + +@router.post( + "/v2/conversations", + summary="Creates a new conversation and sends the first message. " + "Title will be generated automatically if not provided.", + description="Permissions: `project`.", + tags=[MCP_TOOL_TAG, "project"], +) +@handle_exception +async def create_conversation( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + body: ConversationCreateRequest, +) -> StreamingResponse: + user, project, org = auth_info + has_permissions(user, ["project"], project_id=project.id) + table_id = body.agent_id + # Validate data early + row_data = MultiRowAddRequest(table_id=table_id, data=[body.data], stream=True) + table = await ChatTable.open_table(project_id=project.id, table_id=table_id) + if table.table_metadata.parent_id is not None: + raise ResourceNotFoundError(f'Agent "{table_id}" is not found.') + + billing: BillingManager = request.state.billing + billing.has_gen_table_quota(table) + billing.has_db_storage_quota() + billing.has_egress_quota() + + table = await table.duplicate_table( + project_id=project.id, + table_id_src=table_id, + table_id_dst=None, + include_data=False, + create_as_child=True, + created_by=user.id, + ) + if body.title is not None: + table = await table.update_table_title(body.title) + conversation_id = table.table_metadata.table_id + row_data.table_id = conversation_id + executor = MultiRowGenExecutor( + request=request, + table=table, + organization=org, + project=project, + body=row_data, + ) + + async def stream_generator(): + meta = ConversationMetaResponse.from_table_meta(table.v1_meta_response) + yield f"event: metadata\ndata: {meta.model_dump_json()}\n\n" + + generator = await executor.generate() + async for chunk in generator: + if body.title is None and chunk == "data: [DONE]\n\n": + try: + await _generate_and_save_title( + request=request, + project=project, + organization=org, + conversation_id=conversation_id, + table=table, + ) + except Exception as e: + logger.error(f"Error generating title: {repr(e)}") + finally: + meta = ConversationMetaResponse.from_table_meta(table.v1_meta_response) + yield f"event: metadata\ndata: {meta.model_dump_json()}\n\n" + yield chunk + + return StreamingResponse( + content=stream_generator(), + status_code=200, + media_type="text/event-stream", + headers={"X-Accel-Buffering": "no"}, + ) + + +@router.get( + "/v2/conversations/list", + summary="Lists all conversations.", + description="Permissions: `project`.", + tags=[MCP_TOOL_TAG, "project"], +) +@handle_exception +async def list_conversations( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + params: Annotated[ListQuery, Query()], +) -> Page[ConversationMetaResponse]: + user, project, _ = auth_info + has_permissions(user, ["project"], project_id=project.id) + + metas = await ChatTable.list_tables( + project_id=project.id, + limit=params.limit, + offset=params.offset, + order_by=params.order_by, + order_ascending=params.order_ascending, + created_by=user.id, + parent_id="_chat_", + search_query=params.search_query, + search_columns=["title"], + ) + return _table_meta_to_conv(metas) + + +@router.get( + "/v2/conversations/agents/list", + summary="Lists all available agents.", + description="Permissions: `project`.", + tags=[MCP_TOOL_TAG, "project"], +) +@handle_exception +async def list_agents( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + params: Annotated[ListQuery, Query()], +) -> Page[ConversationMetaResponse]: + user, project, _ = auth_info + has_permissions(user, ["project"], project_id=project.id) + + metas = await ChatTable.list_tables( + project_id=project.id, + limit=params.limit, + offset=params.offset, + order_by=params.order_by, + order_ascending=params.order_ascending, + parent_id="_agent_", + search_query=params.search_query, + search_columns=["table_id"], + ) + return _table_meta_to_conv(metas) + + +@router.get( + "/v2/conversations", + summary="Fetches a single conversation (table) metadata.", + description="Permissions: `project`.", + tags=[MCP_TOOL_TAG, "project"], +) +@handle_exception +async def get_conversation( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + conversation_id: Annotated[str, Query(description="The ID of the conversation to fetch.")], +) -> ConversationMetaResponse: + """Fetches a single conversation metadata.""" + user, project, _ = auth_info + has_permissions(user, ["project"], project_id=project.id) + try: + table = await ChatTable.open_table( + project_id=project.id, table_id=conversation_id, created_by=user.id + ) + except ResourceNotFoundError as e: + raise ResourceNotFoundError(f'Conversation "{conversation_id}" not found.') from e + return ConversationMetaResponse.from_table_meta(table.v1_meta_response) + + +@router.get( + "/v2/conversations/agents", + summary="Fetches a single agent (table) metadata.", + description="Permissions: `project`.", + tags=[MCP_TOOL_TAG, "project"], +) +@handle_exception +async def get_agent( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + agent_id: Annotated[str, Query(description="The ID of the agent to fetch.")], +) -> AgentMetaResponse: + """Fetches a single agent metadata.""" + user, project, _ = auth_info + has_permissions(user, ["project"], project_id=project.id) + try: + table = await ChatTable.open_table(project_id=project.id, table_id=agent_id) + except ResourceNotFoundError as e: + raise ResourceNotFoundError(f'Agent "{agent_id}" not found.') from e + return AgentMetaResponse.from_table_meta(table.v1_meta_response) + + +@router.post( + "/v2/conversations/title", + summary="Generates a title for a conversation based on the first user message and assistant response. " + "If the conversation already has a title, it will be overwritten.", + description="Permissions: `project`.", +) +@handle_exception +async def generate_conversation_title( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + conversation_id: Annotated[ + str, Query(description="The ID of the conversation to generate a title for.") + ], +) -> ConversationMetaResponse: + user, project, org = auth_info + has_permissions(user, ["project"], project_id=project.id) + + try: + table = await ChatTable.open_table( + project_id=project.id, table_id=conversation_id, created_by=user.id + ) + except ResourceNotFoundError as e: + raise ResourceNotFoundError(f'Conversation "{conversation_id}" not found.') from e + + await _generate_and_save_title( + request=request, + project=project, + organization=org, + conversation_id=conversation_id, + table=table, + ) + return ConversationMetaResponse.from_table_meta(table.v1_meta_response) + + +@router.patch( + "/v2/conversations/title", + summary="Renames conversation title.", + description="Permissions: `project`.", +) +@handle_exception +async def rename_conversation_title( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + conversation_id: Annotated[str, Query(description="The ID of the conversation to rename.")], + title: Annotated[SanitisedStr, Query(description="The new title for the conversation.")], +) -> ConversationMetaResponse: + user, project, _ = auth_info + has_permissions(user, ["project"], project_id=project.id) + + try: + table = await ChatTable.open_table( + project_id=project.id, table_id=conversation_id, created_by=user.id + ) + except ResourceNotFoundError as e: + raise ResourceNotFoundError(f'Conversation "{conversation_id}" not found.') from e + + table = await table.update_table_title(title) + return ConversationMetaResponse.from_table_meta(table.v1_meta_response) + + +@router.delete( + "/v2/conversations", + summary="Deletes a conversation permanently.", + description="Permissions: `project`.", +) +@handle_exception +async def delete_conversation( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + conversation_id: Annotated[str, Query(description="The ID of the conversation to delete.")], +) -> OkResponse: + user, project, _ = auth_info + has_permissions(user, ["project"], project_id=project.id) + + try: + table = await ChatTable.open_table( + project_id=project.id, table_id=conversation_id, created_by=user.id + ) + except ResourceNotFoundError as e: + raise ResourceNotFoundError(f'Conversation "{conversation_id}" not found.') from e + + await table.drop_table() + return OkResponse() + + +### --- Messages CRUD --- ### + + +@router.post( + "/v2/conversations/messages", + summary="Sends a message to a conversation and streams the response.", + description="Permissions: `project`.", +) +@handle_exception +async def send_message( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + body: MessageAddRequest, +) -> StreamingResponse: + user, project, org = auth_info + has_permissions(user, ["project"], project_id=project.id) + conversation_id = body.conversation_id + # Validate data early + row_data = MultiRowAddRequest(table_id=conversation_id, data=[body.data], stream=True) + try: + table = await ChatTable.open_table( + project_id=project.id, table_id=conversation_id, created_by=user.id + ) + except ResourceNotFoundError as e: + raise ResourceNotFoundError(f'Conversation "{conversation_id}" not found.') from e + + # Check quota + billing: BillingManager = request.state.billing + billing.has_gen_table_quota(table) + billing.has_db_storage_quota() + billing.has_egress_quota() + + executor = MultiRowGenExecutor( + request=request, + table=table, + organization=org, + project=project, + body=row_data, + ) + + return StreamingResponse( + content=await executor.generate(), + status_code=200, + media_type="text/event-stream", + headers={"X-Accel-Buffering": "no"}, + ) + + +@router.get( + "/v2/conversations/messages/list", + summary="Lists messages in a conversation.", + description="Permissions: `project`.", + tags=[MCP_TOOL_TAG, "project"], +) +@handle_exception +async def list_messages( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + params: Annotated[ListMessageQuery, Query()], +) -> Page[dict[str, Any]]: + user, project, _ = auth_info + has_permissions(user, ["project"], project_id=project.id) + + try: + table = await ChatTable.open_table( + project_id=project.id, table_id=params.conversation_id, created_by=user.id + ) + except ResourceNotFoundError as e: + raise ResourceNotFoundError(f'Conversation "{params.conversation_id}" not found.') from e + + return await table.list_rows( + limit=params.limit, + offset=params.offset, + order_by=[params.order_by], + order_ascending=params.order_ascending, + columns=params.columns, + where=params.where, + search_query=params.search_query, + search_columns=params.search_columns, + remove_state_cols=False, + ) + + +@router.post( + "/v2/conversations/messages/regen", + summary="Regenerates a specific message in a conversation and streams back the response.", + description="Permissions: `project`.", + tags=[MCP_TOOL_TAG, "project"], +) +@handle_exception +async def regen_conversation_message( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + body: MessagesRegenRequest, +) -> StreamingResponse: + user, project, org = auth_info + has_permissions(user, ["project"], project_id=project.id) + + conversation_id = body.conversation_id + try: + table = await ChatTable.open_table( + project_id=project.id, table_id=conversation_id, created_by=user.id + ) + except ResourceNotFoundError as e: + raise ResourceNotFoundError(f'Conversation "{conversation_id}" not found.') from e + + # Check quota + billing: BillingManager = request.state.billing + billing.has_gen_table_quota(table) + billing.has_egress_quota() + + # Construct the full request for the executor + regen_rows = await table.list_rows( + where=f"\"ID\" >= '{body.row_id}'", columns=["ID"], order_by=["ID"], order_ascending=True + ) + regen_row_ids = [str(r["ID"]) for r in regen_rows.items] + + executor = MultiRowGenExecutor( + request=request, + table=table, + organization=org, + project=project, + body=MultiRowRegenRequest( + table_id=table.table_metadata.table_id, row_ids=regen_row_ids, stream=True + ), + ) + + return StreamingResponse( + content=await executor.generate(), + status_code=200, + media_type="text/event-stream", + headers={"X-Accel-Buffering": "no"}, + ) + + +@router.patch( + "/v2/conversations/messages", + summary="Updates a specific message in a conversation.", + description="Permissions: `project`.", + tags=[MCP_TOOL_TAG, "project"], +) +@handle_exception +async def update_conversation_message( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + body: MessageUpdateRequest, +) -> OkResponse: + user, project, _ = auth_info + has_permissions(user, ["project"], project_id=project.id) + + try: + table = await ChatTable.open_table( + project_id=project.id, table_id=body.conversation_id, created_by=user.id + ) + except ResourceNotFoundError as e: + raise ResourceNotFoundError(f'Conversation "{body.conversation_id}" not found.') from e + + # Check quota for DB write + billing: BillingManager = request.state.billing + billing.has_db_storage_quota() + + await table.update_rows({body.row_id: body.data}) + + return OkResponse() + + +### --- Threads CRUD --- ### + + +@router.get( + "/v2/conversations/threads", + summary="Get all threads from a conversation or an agent.", + description="Permissions: `project`.", + tags=[MCP_TOOL_TAG, "project"], +) +@handle_exception +async def get_threads( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + params: Annotated[GetConversationThreadsQuery, Query()], +) -> ConversationThreadsResponse: + user, project, _ = auth_info + has_permissions(user, ["project"], project_id=project.id) + table_id = params.conversation_id + try: + table = await ChatTable.open_table(project_id=project.id, table_id=table_id) + except ResourceNotFoundError as e: + raise ResourceNotFoundError(f'Conversation "{table_id}" not found.') from e + if table.table_metadata.parent_id is None: + pass + elif table.table_metadata.created_by != user.id: + raise ResourceNotFoundError(f'Conversation "{table_id}" not found.') + if params.column_ids: + for column_id in params.column_ids: + table.check_multiturn_column(column_id) + cols = params.column_ids + else: + cols = [c.column_id for c in table.column_metadata if c.is_chat_column] + return ConversationThreadsResponse( + threads={c: await table.get_conversation_thread(column_id=c) for c in cols}, + conversation_id=table_id, + ) diff --git a/services/api/src/owl/routers/file.py b/services/api/src/owl/routers/file.py index fed6675..1b88072 100644 --- a/services/api/src/owl/routers/file.py +++ b/services/api/src/owl/routers/file.py @@ -1,4 +1,3 @@ -import mimetypes import os from os.path import splitext from typing import Annotated @@ -6,24 +5,35 @@ import httpx from fastapi import APIRouter, Depends, Request, Response, UploadFile -from fastapi.responses import FileResponse, JSONResponse +from fastapi.responses import ORJSONResponse from loguru import logger -from jamaibase.exceptions import ResourceNotFoundError -from owl.configs.manager import ENV_CONFIG -from owl.protocol import FileUploadResponse, GetURLRequest, GetURLResponse -from owl.utils.auth import ProjectRead, auth_user_project +from owl.configs import ENV_CONFIG +from owl.types import ( + FileUploadResponse, + GetURLRequest, + GetURLResponse, + OrganizationRead, + ProjectRead, + UserAuth, +) +from owl.utils.auth import auth_user_project, has_permissions +from owl.utils.billing import BillingManager from owl.utils.exceptions import handle_exception from owl.utils.io import ( AUDIO_WHITE_LIST_EXT, - LOCAL_FILE_DIR, - S3_CLIENT, + NON_PDF_DOC_WHITE_LIST_EXT, UPLOAD_WHITE_LIST_MIME, + get_global_thumbnail_path, get_s3_aclient, - upload_file_to_s3, + guess_mime, + s3_upload, ) -HTTP_ACLIENT = httpx.AsyncClient() if S3_CLIENT else None +HTTP_ACLIENT = httpx.AsyncClient( + timeout=10.0, + transport=httpx.AsyncHTTPTransport(retries=3), +) router = APIRouter() @@ -40,8 +50,8 @@ async def _generate_presigned_url(s3_client, bucket_name: str, key: str) -> str: return urlunparse( ( parsed_url.scheme, - ENV_CONFIG.owl_file_proxy_url, - "/api/v1/files" + parsed_url.path, + ENV_CONFIG.file_proxy_url, + "/api/v2/files" + parsed_url.path, parsed_url.params, parsed_url.query, parsed_url.fragment, @@ -49,52 +59,30 @@ async def _generate_presigned_url(s3_client, bucket_name: str, key: str) -> str: ) -@router.get("/v1/files/{path:path}") +@router.get("/v2/files/{path:path}") +@router.get("/v1/files/{path:path}", deprecated=True) @handle_exception async def proxy_file(request: Request, path: str) -> Response: - if HTTP_ACLIENT: - # S3 file handling - encoded_path = quote(path) - original_url = f"{ENV_CONFIG.s3_endpoint}/{encoded_path}?{request.query_params}" - response = await HTTP_ACLIENT.get(original_url) - # Determine the MIME type - mime_type, _ = mimetypes.guess_type(original_url) - if mime_type is None: - mime_type = "application/octet-stream" - # Set the Content-Disposition header - headers = dict(response.headers) - headers["Content-Disposition"] = "inline" - headers["Content-Type"] = mime_type - return Response( - content=response.content, - status_code=response.status_code, - headers=headers, - ) - - elif os.path.exists(LOCAL_FILE_DIR): - # Local file handling - file_path = os.path.join(LOCAL_FILE_DIR, path) - if not os.path.exists(file_path) or not os.path.isfile(file_path): - raise ResourceNotFoundError( - "Requested resource in not found in configured local file store." - ) - # Determine the MIME type - mime_type, _ = mimetypes.guess_type(file_path) - if mime_type is None: - mime_type = "application/octet-stream" - return FileResponse( - path=file_path, - media_type=mime_type, - filename=os.path.basename(file_path), - content_disposition_type="inline", - ) - - else: - raise ResourceNotFoundError("Neither S3 nor local file store is configured") + encoded_path = quote(path) + original_url = f"{ENV_CONFIG.s3_endpoint}/{encoded_path}?{request.query_params}" + response = await HTTP_ACLIENT.get(original_url) + # Set the Content-Disposition header + response.headers["Content-Disposition"] = "inline" + # Usually we can get the MIME type from S3 metadata + if "Content-Type" not in response.headers: + response.headers["Content-Type"] = guess_mime(path) + return Response( + content=response.content, + status_code=response.status_code, + headers=response.headers, + ) -@router.options("/v1/files/upload") -@router.options("/v1/files/upload/", deprecated=True) +@router.options( + "/v2/files/upload", + summary="Get CORS preflight options for file upload endpoint.", +) +@router.options("/v1/files/upload", deprecated=True) @handle_exception async def upload_file_options(): headers = { @@ -103,53 +91,63 @@ async def upload_file_options(): "Access-Control-Allow-Headers": "Content-Type", "Access-Control-Allow-Methods": "POST, OPTIONS", } - return JSONResponse(content={"accepted_types": list(UPLOAD_WHITE_LIST_MIME)}, headers=headers) + return ORJSONResponse( + content={"accepted_types": list(UPLOAD_WHITE_LIST_MIME)}, + headers=headers, + ) -@router.post("/v1/files/upload") -@router.post("/v1/files/upload/", deprecated=True) +@router.post( + "/v2/files/upload", + summary="Upload a file to the server.", + description="Permissions: `organization` OR `project`.", +) +@router.post("/v1/files/upload", deprecated=True) @handle_exception async def upload_file( - project: Annotated[ProjectRead, Depends(auth_user_project)], + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], file: UploadFile, ) -> FileUploadResponse: + user, project, org = auth_info + has_permissions( + user, + ["organization", "project"], + organization_id=org.id, + project_id=project.id, + ) + # Check quota + billing: BillingManager = request.state.billing + billing.has_file_storage_quota() content = await file.read() - uri = await upload_file_to_s3( - project.organization.id, project.id, content, file.content_type, file.filename + uri = await s3_upload( + project.organization.id, + project.id, + content, + content_type=file.content_type, + filename=file.filename, ) return FileUploadResponse(uri=uri) -@router.post("/v1/files/url/raw", response_model=GetURLResponse) +@router.post("/v2/files/url/raw") +@router.post("/v1/files/url/raw", deprecated=True) @handle_exception -async def get_raw_file_urls(body: GetURLRequest, request: Request) -> GetURLResponse: +async def get_raw_file_urls(body: GetURLRequest) -> GetURLResponse: results = [] - if S3_CLIENT: - # S3 file store - async with get_s3_aclient() as aclient: - for uri in body.uris: - file_url = "" - if uri.startswith("s3://"): - try: - bucket_name, key = uri[5:].split("/", 1) - file_url = await _generate_presigned_url(aclient, bucket_name, key) - except Exception as e: - logger.exception( - f'Error generating URL for "{uri}" due to {e.__class__.__name__}: {e}' - ) - results.append(file_url) - else: - # Local file store + async with get_s3_aclient() as aclient: for uri in body.uris: file_url = "" - if uri.startswith("file://"): - try: - local_path = os.path.abspath(uri[7:]) - if os.path.exists(local_path): - # Generate a URL for the local file - relative_path = os.path.relpath(local_path, LOCAL_FILE_DIR) - file_url = str(request.url_for("proxy_file", path=relative_path)) - except Exception as e: + try: + bucket_name, key = uri[5:].split("/", 1) + file_url = await _generate_presigned_url(aclient, bucket_name, key) + except Exception as e: + err_mssg = str(e) + if "NoSuchBucket" in err_mssg: + pass + else: logger.exception( f'Error generating URL for "{uri}" due to {e.__class__.__name__}: {e}' ) @@ -157,43 +155,32 @@ async def get_raw_file_urls(body: GetURLRequest, request: Request) -> GetURLResp return GetURLResponse(urls=results) -@router.post("/v1/files/url/thumb", response_model=GetURLResponse) +@router.post("/v2/files/url/thumb") +@router.post("/v1/files/url/thumb", deprecated=True) @handle_exception -async def get_thumbnail_urls(body: GetURLRequest, request: Request) -> GetURLResponse: +async def get_thumbnail_urls(body: GetURLRequest) -> GetURLResponse: results = [] - if S3_CLIENT: - # S3 file store - async with get_s3_aclient() as aclient: - for uri in body.uris: - file_url = "" - if uri.startswith("s3://"): - try: - ext = splitext(uri)[1].lower() - bucket_name, key = uri[5:].split("/", 1) - thumb_ext = "mp3" if ext in AUDIO_WHITE_LIST_EXT else "webp" - thumb_key = key.replace("raw", "thumb") - thumb_key = f"{os.path.splitext(thumb_key)[0]}.{thumb_ext}" - file_url = await _generate_presigned_url(aclient, bucket_name, thumb_key) - except Exception as e: - logger.exception( - f'Error generating URL for "{uri}" due to {e.__class__.__name__}: {e}' - ) - results.append(file_url) - else: - # Local file store + async with get_s3_aclient() as aclient: for uri in body.uris: file_url = "" - if uri.startswith("file://"): - try: - ext = splitext(uri)[1].lower() - local_path = os.path.abspath(uri[7:]) - thumb_ext = "mp3" if ext in AUDIO_WHITE_LIST_EXT else "webp" - thumb_path = local_path.replace("raw", "thumb") - thumb_path = f"{os.path.splitext(thumb_path)[0]}.{thumb_ext}" - if os.path.exists(thumb_path): - relative_path = os.path.relpath(thumb_path, LOCAL_FILE_DIR) - file_url = str(request.url_for("proxy_file", path=relative_path)) - except Exception as e: + try: + ext = splitext(uri)[1].lower() + bucket_name, key = uri[5:].split("/", 1) + thumb_ext = "mp3" if ext in AUDIO_WHITE_LIST_EXT else "webp" + if ext in NON_PDF_DOC_WHITE_LIST_EXT: + thumb_key = os.path.join( + key[: key.index("raw/")], + get_global_thumbnail_path(ext), + ) + else: + thumb_key = key.replace("raw", "thumb") + thumb_key = f"{os.path.splitext(thumb_key)[0]}.{thumb_ext}" + file_url = await _generate_presigned_url(aclient, bucket_name, thumb_key) + except Exception as e: + err_mssg = str(e) + if "NoSuchBucket" in err_mssg: + pass + else: logger.exception( f'Error generating URL for "{uri}" due to {e.__class__.__name__}: {e}' ) diff --git a/services/api/src/owl/routers/gen_table.py b/services/api/src/owl/routers/gen_table.py index 27b717b..1b3cb5c 100644 --- a/services/api/src/owl/routers/gen_table.py +++ b/services/api/src/owl/routers/gen_table.py @@ -1,1399 +1,1091 @@ import re +from asyncio import sleep from io import BytesIO -from os import listdir, makedirs -from os.path import isdir, join, splitext -from shutil import copy2, copytree +from os.path import join, splitext from tempfile import TemporaryDirectory +from time import perf_counter from typing import Annotated, Any -import numpy as np -import pandas as pd -import tiktoken +from celery.result import AsyncResult from fastapi import ( APIRouter, BackgroundTasks, Depends, - File, Form, Path, Query, Request, Response, - UploadFile, ) from fastapi.responses import FileResponse, StreamingResponse from loguru import logger - -from jamaibase.exceptions import ( - ResourceNotFoundError, - TableSchemaFixedError, - UnsupportedMediaTypeError, - make_validation_error, +from pydantic import Field + +from owl.configs import CACHE +from owl.db.gen_executor import MultiRowGenExecutor +from owl.db.gen_table import ( + ActionTable, + ChatTable, + ColumnMetadata, + KnowledgeTable, + TableMetadata, ) -from jamaibase.utils.io import csv_to_df, json_loads -from owl.configs.manager import ENV_CONFIG -from owl.db.gen_executor import MultiRowsGenExecutor -from owl.db.gen_table import GenerativeTable -from owl.llm import LLMEngine -from owl.loaders import load_file -from owl.models import CloudEmbedder, CloudReranker -from owl.protocol import ( - GEN_CONFIG_VAR_PATTERN, - TABLE_NAME_PATTERN, +from owl.docparse import GeneralDocLoader +from owl.tasks.gen_table import import_gen_table +from owl.types import ( ActionTableSchemaCreate, - AddActionColumnSchema, - AddChatColumnSchema, - AddKnowledgeColumnSchema, - ChatEntry, ChatTableSchemaCreate, - ChatThread, - CodeGenConfig, - ColName, + ChatThreadsResponse, ColumnDropRequest, - ColumnDtype, ColumnRenameRequest, ColumnReorderRequest, CSVDelimiter, - EmbedGenConfig, - GenConfig, + DuplicateTableQuery, + ExportTableDataQuery, + FileEmbedFormData, GenConfigUpdateRequest, - GenTableOrderBy, + GetTableRowQuery, + GetTableThreadsQuery, KnowledgeTableSchemaCreate, - LLMGenConfig, + ListTableQuery, + ListTableRowQuery, + MultiRowAddRequest, + MultiRowAddRequestWithLimit, + MultiRowDeleteRequest, + MultiRowRegenRequest, + MultiRowUpdateRequestWithLimit, OkResponse, + OrganizationRead, Page, - RowAddRequest, - RowAddRequestWithLimit, - RowDeleteRequest, - RowRegenRequest, - RowUpdateRequest, + ProjectRead, + RenameTableQuery, SearchRequest, + TableDataImportFormData, + TableImportFormData, + TableImportProgress, TableMetaResponse, - TableSchema, TableSchemaCreate, TableType, + UserAuth, +) +from owl.utils.auth import auth_user_project, has_permissions +from owl.utils.billing import BillingManager +from owl.utils.exceptions import ( + ServerBusyError, + UnexpectedError, + UnsupportedMediaTypeError, + handle_exception, ) -from owl.utils import uuid7_str -from owl.utils.auth import ProjectRead, auth_user_project -from owl.utils.exceptions import handle_exception -from owl.utils.io import EMBED_WHITE_LIST_MIME, upload_file_to_s3 +from owl.utils.io import EMBED_WHITE_LIST_MIME, guess_mime, s3_temporary_file, s3_upload +from owl.utils.lm import LMEngine +from owl.utils.mcp import MCP_TOOL_TAG router = APIRouter() -def _validate_gen_config( - llm: LLMEngine, - gen_config: GenConfig | None, - table_type: TableType, - column_id: str, - image_column_ids: list[str], - audio_column_ids: list[str], -) -> GenConfig | None: - if gen_config is None: - return gen_config - if isinstance(gen_config, LLMGenConfig): - # Set multi-turn for Chat Table - if table_type == TableType.CHAT and column_id.lower() == "ai": - gen_config.multi_turn = True - # Assign a LLM model if not specified - try: - capabilities = ["chat"] - for message in (gen_config.system_prompt, gen_config.prompt): - for col_id in re.findall(GEN_CONFIG_VAR_PATTERN, message): - if col_id in image_column_ids: - capabilities = ["image"] - if col_id in audio_column_ids: - capabilities = ["audio"] - break - gen_config.model = llm.validate_model_id( - model=gen_config.model, - capabilities=capabilities, - ) - except ValueError as e: - raise ResourceNotFoundError("There is no chat model available.") from e - except ResourceNotFoundError as e: - raise ResourceNotFoundError( - f'Column {column_id} used a chat model "{gen_config.model}" that is not available.' - ) from e - # Check Knowledge Table existence - if gen_config.rag_params is None: - return gen_config - ref_table_id = gen_config.rag_params.table_id - kt_table_dir = join( - ENV_CONFIG.owl_db_dir, - llm.organization_id, - llm.project_id, - TableType.KNOWLEDGE, - f"{ref_table_id}.lance", - ) - if not (isdir(kt_table_dir) and len(listdir(kt_table_dir)) > 0): - raise ResourceNotFoundError( - f"Column {column_id} referred to a Knowledge Table '{ref_table_id}' that does not exist." - ) - # Validate Reranking Model - reranking_model = gen_config.rag_params.reranking_model - if reranking_model is None: - return gen_config - try: - gen_config.rag_params.reranking_model = llm.validate_model_id( - model=reranking_model, - capabilities=["rerank"], - ) - except ValueError as e: - raise ResourceNotFoundError("There is no reranking model available.") from e - except ResourceNotFoundError as e: - raise ResourceNotFoundError( - f'Column {column_id} used a reranking model "{reranking_model}" that is not available.' - ) from e - elif isinstance(gen_config, CodeGenConfig): - pass - elif isinstance(gen_config, EmbedGenConfig): - pass - return gen_config - - -def _create_table( +TABLE_CLS: dict[TableType, ActionTable | KnowledgeTable | ChatTable] = { + TableType.ACTION: ActionTable, + TableType.KNOWLEDGE: KnowledgeTable, + TableType.CHAT: ChatTable, +} + + +async def _create_table( + *, request: Request, - organization_id: str, - project_id: str, + user: UserAuth, + project: ProjectRead, + org: OrganizationRead, table_type: TableType, schema: TableSchemaCreate, ) -> TableMetaResponse: - # Validate - llm = LLMEngine(request=request) - image_column_ids = [ - col.id - for col in schema.cols - if col.dtype == ColumnDtype.IMAGE and not col.id.endswith("_") - ] - audio_column_ids = [ - col.id - for col in schema.cols - if col.dtype == ColumnDtype.AUDIO and not col.id.endswith("_") - ] - for col in schema.cols: - col.gen_config = _validate_gen_config( - llm=llm, - gen_config=col.gen_config, - table_type=table_type, - column_id=col.id, - image_column_ids=image_column_ids, - audio_column_ids=audio_column_ids, - ) - if table_type == TableType.KNOWLEDGE: - try: - embedding_model = schema.embedding_model - schema.embedding_model = llm.validate_model_id( - model=embedding_model, - capabilities=["embed"], + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + # Check quota + billing: BillingManager = request.state.billing + billing.has_db_storage_quota() + billing.has_egress_quota() + kwargs = dict( + project_id=project.id, + table_metadata=TableMetadata( + table_id=schema.id, + created_by=user.id, + ), + column_metadata_list=[ + ColumnMetadata( + table_id=schema.id, + column_id=col.id, + dtype=col.dtype.to_column_type(), + vlen=col.vlen, + gen_config=col.gen_config, ) - except ValueError as e: - raise ResourceNotFoundError("There is no embedding model available.") from e - except ResourceNotFoundError as e: - raise ResourceNotFoundError( - f'Column used a embedding model "{embedding_model}" that is not available.' - ) from e - table = GenerativeTable.from_ids(organization_id, project_id, table_type) - # Create - with table.create_session() as session: - _, meta = ( - table.create_table(session, schema, request.state.all_models) - if table_type == TableType.KNOWLEDGE - else table.create_table(session, schema) - ) - meta = TableMetaResponse(**meta.model_dump(), num_rows=0) - return meta + for col in schema.cols + ], + ) + if table_type == TableType.KNOWLEDGE: + table = await KnowledgeTable.create_table(embedding_model=schema.embedding_model, **kwargs) + else: + table = await TABLE_CLS[table_type].create_table(**kwargs) + return table.v1_meta_response -@router.post("/v1/gen_tables/action") +@router.post( + "/v2/gen_tables/action", + summary="Create an action table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def create_action_table( +async def create_action_table( request: Request, - project: Annotated[ProjectRead, Depends(auth_user_project)], + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], body: ActionTableSchemaCreate, ) -> TableMetaResponse: - return _create_table(request, project.organization.id, project.id, TableType.ACTION, body) + user, project, org = auth_info + return await _create_table( + request=request, + user=user, + project=project, + org=org, + table_type=TableType.ACTION, + schema=body, + ) -@router.post("/v1/gen_tables/knowledge") +@router.post( + "/v2/gen_tables/knowledge", + summary="Create a knowledge table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def create_knowledge_table( +async def create_knowledge_table( request: Request, - project: Annotated[ProjectRead, Depends(auth_user_project)], + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], body: KnowledgeTableSchemaCreate, ) -> TableMetaResponse: - return _create_table(request, project.organization.id, project.id, TableType.KNOWLEDGE, body) + user, project, org = auth_info + return await _create_table( + request=request, + user=user, + project=project, + org=org, + table_type=TableType.KNOWLEDGE, + schema=body, + ) -@router.post("/v1/gen_tables/chat") +@router.post( + "/v2/gen_tables/chat", + summary="Create a chat table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def create_chat_table( +async def create_chat_table( request: Request, - project: Annotated[ProjectRead, Depends(auth_user_project)], + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], body: ChatTableSchemaCreate, ) -> TableMetaResponse: - return _create_table(request, project.organization.id, project.id, TableType.CHAT, body) - - -def _duplicate_table( - organization_id: str, - project_id: str, - table_type: TableType, - table_id_src: str, - table_id_dst: str, - include_data: bool, - create_as_child: bool, -) -> TableMetaResponse: - # Duplicate - table = GenerativeTable.from_ids(organization_id, project_id, table_type) - with table.create_session() as session: - meta = table.duplicate_table( - session, - table_id_src, - table_id_dst, - include_data, - create_as_child=create_as_child, - ) - meta = TableMetaResponse(**meta.model_dump(), num_rows=table.count_rows(meta.id)) - return meta + user, project, org = auth_info + return await _create_table( + request=request, + user=user, + project=project, + org=org, + table_type=TableType.CHAT, + schema=body, + ) -@router.post("/v1/gen_tables/{table_type}/duplicate/{table_id_src}") +@router.post( + "/v2/gen_tables/{table_type}/duplicate", + summary="Duplicate a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def duplicate_table( +async def duplicate_table( *, - project: Annotated[ProjectRead, Depends(auth_user_project)], - table_type: TableType, - table_id_src: str = Path(pattern=TABLE_NAME_PATTERN, description="Source table name or ID."), - table_id_dst: str | None = Query( - default=None, pattern=TABLE_NAME_PATTERN, description="Destination table name or ID." - ), - include_data: bool = Query( - default=True, - description="_Optional_. Whether to include the data from the source table in the duplicated table. Defaults to `True`.", - ), - create_as_child: bool = Query( - default=False, - description=( - "_Optional_. Whether the new table is a child table. Defaults to `False`. " - "If this is True, then `include_data` will be set to True." - ), - ), + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + params: Annotated[DuplicateTableQuery, Query()], ) -> TableMetaResponse: - if create_as_child: - include_data = True - if not table_id_dst: - table_id_dst = f"{table_id_src}_{uuid7_str()}" - return _duplicate_table( - organization_id=project.organization.id, + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, project_id=project.id, - table_type=table_type, - table_id_src=table_id_src, - table_id_dst=table_id_dst, - include_data=include_data, - create_as_child=create_as_child, ) - - -@router.post("/v1/gen_tables/{table_type}/duplicate/{table_id_src}/{table_id_dst}") -@handle_exception -def duplicate_table_deprecated( - *, - response: Response, - project: Annotated[ProjectRead, Depends(auth_user_project)], - table_type: TableType, - table_id_src: str = Path(pattern=TABLE_NAME_PATTERN, description="Source table name or ID."), - table_id_dst: str = Path( - pattern=TABLE_NAME_PATTERN, description="Destination table name or ID." - ), - include_data: bool = Query( - default=True, - description="_Optional_. Whether to include the data from the source table in the duplicated table. Defaults to `True`.", - ), - deploy: bool = Query( - default=False, - description="_Optional_. Whether to deploy the duplicated table. Defaults to `False`.", - ), -) -> TableMetaResponse: - response.headers["Warning"] = ( - '299 - "This endpoint is deprecated and will be removed in v0.4. ' - "Use '/v1/gen_tables/{table_type}/duplicate/{table_id_src}' instead." - '"' + # Check quota + billing: BillingManager = request.state.billing + billing.has_db_storage_quota() + billing.has_egress_quota() + table = await TABLE_CLS[table_type].open_table( + project_id=project.id, table_id=params.table_id_src ) - return _duplicate_table( - organization_id=project.organization.id, + table = await table.duplicate_table( project_id=project.id, - table_type=table_type, - table_id_src=table_id_src, - table_id_dst=table_id_dst, - include_data=include_data, - create_as_child=deploy, + table_id_src=params.table_id_src, + table_id_dst=params.table_id_dst, + include_data=params.include_data, + create_as_child=params.create_as_child, + created_by=user.id, ) + return table.v1_meta_response -@router.post("/v1/gen_tables/{table_type}/rename/{table_id_src}/{table_id_dst}") +@router.get( + "/v2/gen_tables/{table_type}", + summary="Get a specific table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def rename_table( - project: Annotated[ProjectRead, Depends(auth_user_project)], - table_type: Annotated[TableType, Path(description="Table type.")], - table_id_src: Annotated[str, Path(description="Source table name or ID.")], # Don't validate - table_id_dst: Annotated[ - str, - Path( - pattern=TABLE_NAME_PATTERN, - description="Destination table name or ID.", - ), +async def get_table( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) ], + table_type: Annotated[TableType, Path(description="Table type.")], + table_id: Annotated[str, Query(description="Name of the table to fetch.")], ) -> TableMetaResponse: - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - with table.create_session() as session: - meta = table.rename_table(session, table_id_src, table_id_dst) - meta = TableMetaResponse(**meta.model_dump(), num_rows=table.count_rows(table_id_dst)) - return meta + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=table_id) + return table.v1_meta_response -@router.delete("/v1/gen_tables/{table_type}/{table_id}") -@handle_exception -def delete_table( - project: Annotated[ProjectRead, Depends(auth_user_project)], - table_type: Annotated[TableType, Path(description="Table type.")], - table_id: Annotated[str, Path(description="The ID of the table to delete.")], # Don't validate -) -> OkResponse: - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - with table.create_session() as session: - table.delete_table(session, table_id) - return OkResponse() +class _ListTableQuery(ListTableQuery): + created_by: Annotated[ + str | None, + Field( + min_length=1, + description="Return tables created by this user. Defaults to None (return all tables).", + ), + ] = None -@router.get("/v1/gen_tables/{table_type}") +@router.get( + "/v2/gen_tables/{table_type}/list", + summary="List tables of a specific type.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def list_tables( - project: Annotated[ProjectRead, Depends(auth_user_project)], +async def list_tables( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], table_type: Annotated[TableType, Path(description="Table type.")], - offset: Annotated[ - int, - Query( - ge=0, - description="_Optional_. Item offset for pagination. Defaults to 0.", - ), - ] = 0, - limit: Annotated[ - int, - Query( - gt=0, - le=100, - description="_Optional_. Number of tables to return (min 1, max 100). Defaults to 100.", - ), - ] = 100, - parent_id: Annotated[ - str | None, - Query( - description=( - "_Optional_. Parent ID of tables to return. Defaults to None (return all tables). " - "Additionally for Chat Table, you can list: " - '(1) all chat agents by passing in "_agent_"; or ' - '(2) all chats by passing in "_chat_".' - ), - ), - ] = None, - search_query: Annotated[ - str, - Query( - max_length=100, - description='_Optional_. A string to search for within table IDs as a filter. Defaults to "" (no filter).', - ), - ] = "", - order_by: Annotated[ - GenTableOrderBy, - Query( - min_length=1, - description='_Optional_. Sort tables by this attribute. Defaults to "updated_at".', - ), - ] = GenTableOrderBy.UPDATED_AT, - order_descending: Annotated[ - bool, - Query(description="_Optional_. Whether to sort by descending order. Defaults to True."), - ] = True, - count_rows: Annotated[ - bool, - Query( - description="_Optional_. Whether to count the rows of the tables. Defaults to False." - ), - ] = False, + params: Annotated[_ListTableQuery, Query()], ) -> Page[TableMetaResponse]: - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - with table.create_session() as session: - metas, total = table.list_meta( - session, - offset=offset, - limit=limit, - remove_state_cols=True, - parent_id=parent_id, - search_query=search_query, - order_by=order_by, - order_descending=order_descending, - count_rows=count_rows, - ) - return Page[TableMetaResponse]( - items=metas, - offset=offset, - limit=limit, - total=total, - ) + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + metas = await TABLE_CLS[table_type].list_tables( + project_id=project.id, + limit=params.limit, + offset=params.offset, + order_by=params.order_by, + order_ascending=params.order_ascending, + created_by=getattr(params, "created_by", None), + parent_id=params.parent_id, + search_query=params.search_query, + count_rows=params.count_rows, + ) + return metas -@router.get("/v1/gen_tables/{table_type}/{table_id}") +@router.post( + "/v2/gen_tables/{table_type}/rename", + summary="Rename a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def get_table( - request: Request, - project: Annotated[ProjectRead, Depends(auth_user_project)], +async def rename_table( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], table_type: Annotated[TableType, Path(description="Table type.")], - table_id: str = Path(pattern=TABLE_NAME_PATTERN, description="The ID of the table to fetch."), + params: Annotated[RenameTableQuery, Query()], ) -> TableMetaResponse: - organization_id = project.organization.id - project_id = project.id - try: - table = GenerativeTable.from_ids(organization_id, project_id, table_type) - with table.create_session() as session: - meta = table.open_meta(session, table_id, remove_state_cols=True) - meta = TableMetaResponse(**meta.model_dump(), num_rows=table.count_rows(meta.id)) - return meta - except ResourceNotFoundError: - lance_path = join( - ENV_CONFIG.owl_db_dir, - organization_id, - project_id, - table_type, - f"{table_id}.lance", - ) - if isdir(lance_path): - logger.exception( - f"{request.state.id} - Table cannot be opened but the directory exists !!!" - ) - dst_dir = join( - ENV_CONFIG.owl_db_dir, - "problematic", - organization_id, - project_id, - table_type, - ) - makedirs(dst_dir, exist_ok=True) - _uuid = uuid7_str() - copytree(lance_path, join(dst_dir, f"{table_id}_{_uuid}.lance")) - copy2( - join( - ENV_CONFIG.owl_db_dir, - organization_id, - project_id, - f"{table_type}.db", - ), - join( - ENV_CONFIG.owl_db_dir, - "problematic", - organization_id, - project_id, - f"{table_type}_{_uuid}.db", - ), - ) - raise + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + table = await TABLE_CLS[table_type].open_table( + project_id=project.id, table_id=params.table_id_src + ) + table = await table.rename_table(params.table_id_dst) + return table.v1_meta_response -@router.post("/v1/gen_tables/{table_type}/gen_config/update") +@router.delete( + "/v2/gen_tables/{table_type}", + summary="Delete a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) @handle_exception -def update_gen_config( - request: Request, - project: Annotated[ProjectRead, Depends(auth_user_project)], +async def delete_table( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], table_type: Annotated[TableType, Path(description="Table type.")], - updates: GenConfigUpdateRequest, -) -> TableMetaResponse: - # Validate - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - with table.create_session() as session: - meta = table.open_meta(session, updates.table_id) - llm = LLMEngine(request=request) - image_column_ids = [ - col["id"] - for col in meta.cols - if col["dtype"] == ColumnDtype.IMAGE and not col["id"].endswith("_") - ] - audio_column_ids = [ - col["id"] - for col in meta.cols - if col["dtype"] == ColumnDtype.AUDIO and not col["id"].endswith("_") - ] - - if table_type == TableType.KNOWLEDGE: - # Knowledge Table "Title Embed" and "Text Embed" columns must always have gen config - for c in ["Title Embed", "Text Embed"]: - if c in updates.column_map and updates.column_map[c] is None: - updates.column_map.pop(c) - elif table_type == TableType.CHAT: - # Chat Table AI column must always have gen config - if "AI" in updates.column_map and updates.column_map["AI"] is None: - updates.column_map.pop("AI") - - updates.column_map = { - col_id: _validate_gen_config( - llm=llm, - gen_config=gen_config, - table_type=table_type, - column_id=col_id, - image_column_ids=image_column_ids, - audio_column_ids=audio_column_ids, - ) - for col_id, gen_config in updates.column_map.items() - } - # Update - meta = table.update_gen_config(session, updates) - meta = TableMetaResponse(**meta.model_dump(), num_rows=table.count_rows(meta.id)) - return meta + table_id: Annotated[str, Query(description="Name of the table to be deleted.")], +) -> OkResponse: + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=table_id) + await table.drop_table() + return OkResponse() -def _add_columns( +@router.post( + "/v2/gen_tables/{table_type}/columns/add", + summary="Add columns to a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) +@handle_exception +async def add_columns( request: Request, - organization_id: str, - project_id: str, - table_type: TableType, - schema: TableSchemaCreate, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + body: TableSchemaCreate, ) -> TableMetaResponse: - # Validate - table = GenerativeTable.from_ids(organization_id, project_id, table_type) - with table.create_session() as session: - meta = table.open_meta(session, schema.id) - llm = LLMEngine(request=request) - cols = TableSchema( - id=meta.id, cols=[c.model_dump() for c in meta.cols_schema + schema.cols] - ).cols - image_column_ids = [ - col.id for col in cols if col.dtype == ColumnDtype.IMAGE and not col.id.endswith("_") - ] - audio_column_ids = [ - col.id for col in cols if col.dtype == ColumnDtype.AUDIO and not col.id.endswith("_") - ] - schema.cols = [col for col in cols if col.id in set(c.id for c in schema.cols)] - for col in schema.cols: - col.gen_config = _validate_gen_config( - llm=llm, - gen_config=col.gen_config, - table_type=table_type, + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=body.id) + # Check quota + billing: BillingManager = request.state.billing + billing.has_gen_table_quota(table) + billing.has_db_storage_quota() + billing.has_egress_quota() + for col in body.cols: + table = await table.add_column( + ColumnMetadata( + table_id=body.id, column_id=col.id, - image_column_ids=image_column_ids, - audio_column_ids=audio_column_ids, + dtype=col.dtype.to_column_type(), + vlen=col.vlen, + gen_config=col.gen_config, ) - # Create - _, meta = table.add_columns(session, schema) - meta = TableMetaResponse(**meta.model_dump(), num_rows=table.count_rows(meta.id)) - return meta + ) + return table.v1_meta_response -@router.post("/v1/gen_tables/action/columns/add") +@router.post( + "/v2/gen_tables/{table_type}/columns/rename", + summary="Rename columns in a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def add_action_columns( - request: Request, - project: Annotated[ProjectRead, Depends(auth_user_project)], - body: AddActionColumnSchema, +async def rename_columns( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + body: ColumnRenameRequest, ) -> TableMetaResponse: - return _add_columns(request, project.organization.id, project.id, TableType.ACTION, body) + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=body.table_id) + table = await table.rename_columns(body.column_map) + return table.v1_meta_response -@router.post("/v1/gen_tables/knowledge/columns/add") +@router.patch( + "/v2/gen_tables/{table_type}/gen_config", + summary="Update generation configuration for table columns.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def add_knowledge_columns( - request: Request, - project: Annotated[ProjectRead, Depends(auth_user_project)], - body: AddKnowledgeColumnSchema, +async def update_gen_config( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + updates: GenConfigUpdateRequest, ) -> TableMetaResponse: - return _add_columns(request, project.organization.id, project.id, TableType.KNOWLEDGE, body) + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + table = await TABLE_CLS[table_type].open_table( + project_id=project.id, table_id=updates.table_id + ) + table = await table.update_gen_config(update_mapping=updates.column_map) + return table.v1_meta_response -@router.post("/v1/gen_tables/chat/columns/add") +@router.post( + "/v2/gen_tables/{table_type}/columns/reorder", + summary="Reorder columns in a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def add_chat_columns( - request: Request, - project: Annotated[ProjectRead, Depends(auth_user_project)], - body: AddChatColumnSchema, -) -> TableMetaResponse: - return _add_columns(request, project.organization.id, project.id, TableType.CHAT, body) - - -def _create_indexes( - project: ProjectRead, - table_type: TableType, - table_id: str, +async def reorder_columns( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + body: ColumnReorderRequest, ) -> TableMetaResponse: - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - with table.create_session() as session: - table.create_indexes(session, table_id) + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=body.table_id) + table = await table.reorder_columns(body.column_names) + return table.v1_meta_response -@router.post("/v1/gen_tables/{table_type}/columns/drop") +@router.post( + "/v2/gen_tables/{table_type}/columns/drop", + summary="Drop columns from a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def drop_columns( - bg_tasks: BackgroundTasks, - project: Annotated[ProjectRead, Depends(auth_user_project)], +async def drop_columns( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], table_type: Annotated[TableType, Path(description="Table type.")], body: ColumnDropRequest, ) -> TableMetaResponse: - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - with table.create_session() as session: - _, meta = table.drop_columns(session, body.table_id, body.column_names) - meta = TableMetaResponse(**meta.model_dump(), num_rows=table.count_rows(meta.id)) - bg_tasks.add_task(_create_indexes, project, table_type, body.table_id) - return meta - - -@router.post("/v1/gen_tables/{table_type}/columns/rename") + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=body.table_id) + # Check quota + billing: BillingManager = request.state.billing + billing.has_db_storage_quota() + billing.has_egress_quota() + table = await table.drop_columns(body.column_names) + return table.v1_meta_response + + +@router.post( + "/v2/gen_tables/{table_type}/rows/add", + summary="Add rows to a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def rename_columns( - project: Annotated[ProjectRead, Depends(auth_user_project)], +async def add_rows( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], table_type: Annotated[TableType, Path(description="Table type.")], - body: ColumnRenameRequest, -) -> TableMetaResponse: - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - with table.create_session() as session: - meta = table.rename_columns(session, body.table_id, body.column_map) - meta = TableMetaResponse(**meta.model_dump(), num_rows=table.count_rows(meta.id)) - return meta + body: MultiRowAddRequestWithLimit, +): + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=body.table_id) + # Check quota + billing: BillingManager = request.state.billing + billing.has_gen_table_quota(table) + billing.has_db_storage_quota() + billing.has_egress_quota() + executor = MultiRowGenExecutor( + request=request, + table=table, + organization=org, + project=project, + body=body, + ) + if body.stream: + return StreamingResponse( + content=await executor.generate(), + status_code=200, + media_type="text/event-stream", + headers={"X-Accel-Buffering": "no"}, + ) + else: + return await executor.generate() -@router.post("/v1/gen_tables/{table_type}/columns/reorder") +@router.get( + "/v2/gen_tables/{table_type}/rows/list", + summary="List rows in a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def reorder_columns( - project: Annotated[ProjectRead, Depends(auth_user_project)], +async def list_rows( + *, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], table_type: Annotated[TableType, Path(description="Table type.")], - body: ColumnReorderRequest, -) -> TableMetaResponse: - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - with table.create_session() as session: - meta = table.reorder_columns(session, body.table_id, body.column_names) - meta = TableMetaResponse(**meta.model_dump(), num_rows=table.count_rows(meta.id)) - return meta + params: Annotated[ListTableRowQuery, Query()], +) -> Page[dict[str, Any]]: + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=params.table_id) + rows = await table.list_rows( + limit=params.limit, + offset=params.offset, + order_by=[params.order_by], + order_ascending=params.order_ascending, + columns=params.columns, + where=params.where, + search_query=params.search_query, + search_columns=params.search_columns, + remove_state_cols=False, + ) + return Page[dict[str, Any]]( + items=table.postprocess_rows( + rows.items, + float_decimals=params.float_decimals, + vec_decimals=params.vec_decimals, + ), + offset=params.offset, + limit=params.limit, + total=rows.total, + ) -@router.get("/v1/gen_tables/{table_type}/{table_id}/rows") +@router.get( + "/v2/gen_tables/{table_type}/rows", + summary="Get a specific row from a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def list_rows( +async def get_row( *, - project: Annotated[ProjectRead, Depends(auth_user_project)], + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], table_type: Annotated[TableType, Path(description="Table type.")], - table_id: str = Path(pattern=TABLE_NAME_PATTERN, description="Table ID or name."), - offset: int = Query( - default=0, - ge=0, - description="_Optional_. Item offset for pagination. Defaults to 0.", - ), - limit: int = Query( - default=100, - gt=0, - le=100, - description="_Optional_. Number of rows to return (min 1, max 100). Defaults to 100.", - ), - search_query: str = Query( - default="", - max_length=10_000, - description='_Optional_. A string to search for within the rows as a filter. Defaults to "" (no filter).', - ), - columns: list[ColName] | None = Query( - default=None, - description="_Optional_. A list of column names to include in the response. Default is to return all columns.", - ), - float_decimals: int = Query( - default=0, - ge=0, - description="_Optional_. Number of decimals for float values. Defaults to 0 (no rounding).", - ), - vec_decimals: int = Query( - default=0, - description="_Optional_. Number of decimals for vectors. If its negative, exclude vector columns. Defaults to 0 (no rounding).", - ), - order_descending: Annotated[ - bool, - Query(description="_Optional_. Whether to sort by descending order. Defaults to True."), - ] = True, -) -> Page[dict[ColName, Any]]: - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - if search_query == "": - rows, total = table.list_rows( - table_id=table_id, - offset=offset, - limit=limit, - columns=columns, - convert_null=True, - remove_state_cols=True, - json_safe=True, - include_original=True, - float_decimals=float_decimals, - vec_decimals=vec_decimals, - order_descending=order_descending, - ) - else: - with table.create_session() as session: - rows = table.regex_search( - session=session, - table_id=table_id, - query=search_query, - columns=columns, - convert_null=True, - remove_state_cols=True, - json_safe=True, - include_original=True, - float_decimals=float_decimals, - vec_decimals=vec_decimals, - order_descending=order_descending, - ) - total = len(rows) - rows = rows[offset : offset + limit] - return Page[dict[ColName, Any]](items=rows, offset=offset, limit=limit, total=total) + params: Annotated[GetTableRowQuery, Query()], +) -> dict[str, Any]: + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=params.table_id) + row = await table.get_row( + row_id=params.row_id, + columns=params.columns, + remove_state_cols=False, + ) + row = table.postprocess_rows( + [row], + float_decimals=params.float_decimals, + vec_decimals=params.vec_decimals, + )[0] + return row -@router.get("/v1/gen_tables/{table_type}/{table_id}/rows/{row_id}") +@router.get( + "/v2/gen_tables/{table_type}/threads", + summary="Get all multi-turn / conversation threads from a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def get_row( - *, - project: Annotated[ProjectRead, Depends(auth_user_project)], +async def get_conversation_threads( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], table_type: Annotated[TableType, Path(description="Table type.")], - table_id: str = Path(pattern=TABLE_NAME_PATTERN, description="Table ID or name."), - row_id: Annotated[str, Path(description="The ID of the specific row to fetch.")], - columns: list[ColName] | None = Query( - default=None, - description="_Optional_. A list of column names to include in the response. Default is to return all columns.", - ), - float_decimals: int = Query( - default=0, - ge=0, - description="_Optional_. Number of decimals for float values. Defaults to 0 (no rounding).", - ), - vec_decimals: int = Query( - default=0, - description="_Optional_. Number of decimals for vectors. If its negative, exclude vector columns. Defaults to 0 (no rounding).", - ), -) -> dict[ColName, Any]: - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - row = table.get_row( - table_id, - row_id, - columns=columns, - convert_null=True, - remove_state_cols=True, - json_safe=True, - include_original=True, - float_decimals=float_decimals, - vec_decimals=vec_decimals, + params: Annotated[GetTableThreadsQuery, Query()], +) -> ChatThreadsResponse: + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + table_id = params.table_id + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=table_id) + if params.column_ids: + for column_id in params.column_ids: + table.check_multiturn_column(column_id) + cols = params.column_ids + else: + cols = [c.column_id for c in table.column_metadata if c.is_chat_column] + return ChatThreadsResponse( + threads={ + c: await table.get_conversation_thread( + column_id=c, + row_id=params.row_id, + include_row=params.include_row, + ) + for c in cols + }, + table_id=table_id, ) - return row -@router.post("/v1/gen_tables/{table_type}/rows/add") +@router.post( + "/v2/gen_tables/{table_type}/hybrid_search", + summary="Perform hybrid search on a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -async def add_rows( +async def hybrid_search( request: Request, - bg_tasks: BackgroundTasks, - project: Annotated[ProjectRead, Depends(auth_user_project)], + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], table_type: Annotated[TableType, Path(description="Table type.")], - body: RowAddRequestWithLimit, -): - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - # Check quota - request.state.billing.check_gen_table_llm_quota(table, body.table_id) - # Checks - with table.create_session() as session: - meta = table.open_meta(session, body.table_id) - has_chat_cols = ( - sum( - col["gen_config"] is not None and col["gen_config"].get("multi_turn", False) - for col in meta.cols - ) - > 0 + body: SearchRequest, +) -> list[dict[str, Any]]: + # TODO: Maybe this should return `Page` instead of `list` + def split_query_to_or_terms(query): + # Regular expression to match either quoted phrases or words + pattern = r'("[^"]*"|\S+)' + parts = re.findall(pattern, query) + return " OR ".join(parts) + + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, ) - # Maybe re-index - if body.reindex or ( - body.reindex is None - and table.count_rows(body.table_id) <= ENV_CONFIG.owl_immediate_reindex_max_rows - ): - bg_tasks.add_task(_create_indexes, project, table_type, body.table_id) - executor = MultiRowsGenExecutor( - table=table, - meta=meta, + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=body.table_id) + # Check quota + billing: BillingManager = request.state.billing + billing.has_gen_table_quota(table) + billing.has_db_storage_quota() + billing.has_egress_quota() + lm = LMEngine( + organization=org, + project=project, request=request, - body=body, - rows_batch_size=(1 if has_chat_cols else ENV_CONFIG.owl_concurrent_rows_batch_size), - cols_batch_size=ENV_CONFIG.owl_concurrent_cols_batch_size, - max_write_batch_size=(1 if has_chat_cols else ENV_CONFIG.owl_max_write_batch_size), ) - if body.stream: - return StreamingResponse( - content=await executor.gen_rows(), - status_code=200, - media_type="text/event-stream", - headers={"X-Accel-Buffering": "no"}, - ) - else: - return await executor.gen_rows() + # Do a split and OR join for fts query + fts_query = split_query_to_or_terms(body.query) + + # As of 2025-04-17, this endpoint does not perform query rewrite + rows = await table.hybrid_search( + fts_query=fts_query, + vs_query=body.query, + embedding_fn=lm.embed_query_as_vector, + vector_column_names=None, + limit=body.limit, + offset=0, + remove_state_cols=False, + ) + # Rerank + if len(rows) > 0 and body.reranking_model is not None: + order = ( + await lm.rerank_documents( + model=body.reranking_model, + query=body.query, + documents=table.rows_to_documents(rows), + ) + ).results + rows = [rows[i.index] for i in order] + rows = rows[: body.limit] + rows = table.postprocess_rows( + rows, + float_decimals=body.float_decimals, + vec_decimals=body.vec_decimals, + ) + return rows -@router.post("/v1/gen_tables/{table_type}/rows/regen") +@router.post( + "/v2/gen_tables/{table_type}/rows/regen", + summary="Regenerate rows in a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception async def regen_rows( request: Request, - bg_tasks: BackgroundTasks, - project: Annotated[ProjectRead, Depends(auth_user_project)], + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], table_type: Annotated[TableType, Path(description="Table type.")], - body: RowRegenRequest, + body: MultiRowRegenRequest, ): - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - # Check quota - request.state.billing.check_gen_table_llm_quota(table, body.table_id) - # Checks - with table.create_session() as session: - meta = table.open_meta(session, body.table_id) - if body.output_column_id is not None: - output_column_ids = [col["id"] for col in meta.cols if col["gen_config"] is not None] - if len(output_column_ids) > 0 and body.output_column_id not in output_column_ids: - raise ResourceNotFoundError( - ( - f'`output_column_id` "{body.output_column_id}" is not found. ' - f"Available output columns: {output_column_ids}" - ) - ) - has_chat_cols = ( - sum( - col["gen_config"] is not None and col["gen_config"].get("multi_turn", False) - for col in meta.cols - ) - > 0 + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, ) - # Maybe re-index - if body.reindex or ( - body.reindex is None - and table.count_rows(body.table_id) <= ENV_CONFIG.owl_immediate_reindex_max_rows - ): - bg_tasks.add_task(_create_indexes, project, table_type, body.table_id) - executor = MultiRowsGenExecutor( - table=table, - meta=meta, + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=body.table_id) + # Check quota + billing: BillingManager = request.state.billing + billing.has_gen_table_quota(table) + billing.has_db_storage_quota() + billing.has_egress_quota() + executor = MultiRowGenExecutor( request=request, + table=table, + organization=org, + project=project, body=body, - rows_batch_size=(1 if has_chat_cols else ENV_CONFIG.owl_concurrent_rows_batch_size), - cols_batch_size=ENV_CONFIG.owl_concurrent_cols_batch_size, - max_write_batch_size=(1 if has_chat_cols else ENV_CONFIG.owl_max_write_batch_size), ) if body.stream: return StreamingResponse( - content=await executor.gen_rows(), + content=await executor.generate(), status_code=200, media_type="text/event-stream", headers={"X-Accel-Buffering": "no"}, ) else: - return await executor.gen_rows() - - -@router.post("/v1/gen_tables/{table_type}/rows/update") -@handle_exception -def update_row( - bg_tasks: BackgroundTasks, - project: Annotated[ProjectRead, Depends(auth_user_project)], - table_type: Annotated[TableType, Path(description="Table type.")], - body: RowUpdateRequest, -) -> OkResponse: - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - # Check column type - if table_type == TableType.KNOWLEDGE: - col_names = set(n.lower() for n in body.data.keys()) - if "text embed" in col_names or "title embed" in col_names: - raise TableSchemaFixedError("Cannot update 'Text Embed' or 'Title Embed'.") - # Update - with table.create_session() as session: - table.update_rows( - session, - body.table_id, - where=f"`ID` = '{body.row_id}'", - values=body.data, - ) - if body.reindex or ( - body.reindex is None - and table.count_rows(body.table_id) <= ENV_CONFIG.owl_immediate_reindex_max_rows - ): - bg_tasks.add_task(_create_indexes, project, table_type, body.table_id) - return OkResponse() + return await executor.generate() -@router.post("/v1/gen_tables/{table_type}/rows/delete") +@router.patch( + "/v2/gen_tables/{table_type}/rows", + summary="Update rows in a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def delete_rows( - bg_tasks: BackgroundTasks, - project: Annotated[ProjectRead, Depends(auth_user_project)], +async def update_rows( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], table_type: Annotated[TableType, Path(description="Table type.")], - body: RowDeleteRequest, + body: MultiRowUpdateRequestWithLimit, ) -> OkResponse: - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - with table.create_session() as session: - table.delete_rows(session, body.table_id, body.row_ids, body.where) - if body.reindex or ( - body.reindex is None - and table.count_rows(body.table_id) <= ENV_CONFIG.owl_immediate_reindex_max_rows - ): - bg_tasks.add_task(_create_indexes, project, table_type, body.table_id) + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=body.table_id) + # Check quota + billing: BillingManager = request.state.billing + billing.has_gen_table_quota(table) + billing.has_db_storage_quota() + billing.has_egress_quota() + await table.update_rows(body.data) return OkResponse() -@router.delete("/v1/gen_tables/{table_type}/{table_id}/rows/{row_id}") +@router.post( + "/v2/gen_tables/{table_type}/rows/delete", + summary="Delete rows from a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", + tags=[MCP_TOOL_TAG, "organization.MEMBER", "project.MEMBER"], +) @handle_exception -def delete_row( - bg_tasks: BackgroundTasks, - project: Annotated[ProjectRead, Depends(auth_user_project)], +async def delete_rows( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], table_type: Annotated[TableType, Path(description="Table type.")], - table_id: str = Path(pattern=TABLE_NAME_PATTERN, description="Table ID or name."), - row_id: str = Path(description="The ID of the specific row to delete."), - reindex: Annotated[bool, Query(description="Whether to reindex immediately.")] = True, + body: MultiRowDeleteRequest, ) -> OkResponse: - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - with table.create_session() as session: - table.delete_row(session, table_id, row_id) - if reindex: - bg_tasks.add_task(_create_indexes, project, table_type, table_id) - return OkResponse() - - -@router.get("/v1/gen_tables/{table_type}/{table_id}/thread") -@handle_exception -def get_conversation_thread( - project: Annotated[ProjectRead, Depends(auth_user_project)], - table_type: Annotated[TableType, Path(description="Table type.")], - table_id: Annotated[str, Path(pattern=TABLE_NAME_PATTERN, description="Table ID or name.")], - column_id: Annotated[str, Query(description="ID / name of the column to fetch.")], - row_id: Annotated[ - str, - Query( - description='_Optional_. ID / name of the last row in the thread. Defaults to "" (export all rows).' - ), - ] = "", - include: Annotated[ - bool, - Query( - description="_Optional_. Whether to include the row specified by `row_id`. Defaults to True." - ), - ] = True, -) -> ChatThread: - # Fetch - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - return table.get_conversation_thread( - table_id=table_id, - column_id=column_id, - row_id=row_id, - include=include, - ) - - -@router.post("/v1/gen_tables/{table_type}/hybrid_search") -@handle_exception -async def hybrid_search( - request: Request, - project: Annotated[ProjectRead, Depends(auth_user_project)], - table_type: Annotated[TableType, Path(description="Table type.")], - body: SearchRequest, -) -> list[dict[ColName, Any]]: - # Search - embedder = CloudEmbedder(request=request) - if body.reranking_model is not None: - reranker = CloudReranker(request=request) - else: - reranker = None - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - with table.create_session() as session: - rows = await table.hybrid_search( - session, - body.table_id, - query=body.query, - where=body.where, - limit=body.limit, - metric=body.metric, - nprobes=body.nprobes, - refine_factor=body.refine_factor, - embedder=embedder, - reranker=reranker, - reranking_model=body.reranking_model, - vec_decimals=body.vec_decimals, - convert_null=True, - remove_state_cols=True, - json_safe=True, - include_original=True, - ) - return rows - - -def list_files(): - pass - - -def _truncate_text(text: str, max_context_length: int, encoding_name: str = "cl100k_base") -> str: - """Truncates the text to fit within the max_context_length.""" - - encoding = tiktoken.get_encoding(encoding_name) - encoded_text = encoding.encode(text) - - if len(encoded_text) <= max_context_length: - return text - - truncated_encoded = encoded_text[:max_context_length] - truncated_text = encoding.decode(truncated_encoded) - return truncated_text - - -async def _embed( - embedder_name: str, embedder: CloudEmbedder, texts: list[str], embed_dtype: str -) -> np.ndarray: - if len(texts) == 0: - raise make_validation_error( - ValueError("There is no text or content to embed."), loc=("body", "file") - ) - embeddings = await embedder.embed_documents(embedder_name, texts=texts) - embeddings = np.asarray([d.embedding for d in embeddings.data], dtype=embed_dtype) - embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) - return embeddings - - -async def _embed_file( - request: Request, - bg_tasks: BackgroundTasks, - project: ProjectRead, - table_id: str, - file_name: str, - file_content: bytes, - file_uri: str, - chunk_size: int, - chunk_overlap: int, -) -> OkResponse: - request_id = request.state.id - logger.info(f'{request_id} - Parsing file "{file_name}".') - chunks = await load_file(file_name, file_content, chunk_size, chunk_overlap) - logger.info(f'{request_id} - Embedding file "{file_name}" with {len(chunks):,d} chunks.') - - # --- Extract title --- # - excerpt = "".join(d.text for d in chunks[:8])[:50000] - llm = LLMEngine(request=request) - model = llm.validate_model_id( - model="", - capabilities=["chat"], + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, ) - logger.debug(f"{request_id} - Performing title extraction using: {model}") - try: - response = await llm.generate( - id=request_id, - model=model, - messages=[ - ChatEntry.system("You are an concise assistant."), - ChatEntry.user( - ( - f"CONTEXT:\n{excerpt}\n\n" - "From the excerpt, extract the document title or guess a possible title. " - "Provide the title without explanation." - ) - ), - ], - max_tokens=200, - temperature=0.01, - top_p=0.01, - stream=False, - ) - title = response.text.strip() - if title.startswith('"') and title.endswith('"'): - title = title[1:-1] - except Exception: - logger.exception(f"{request_id} - Title extraction errored for excerpt: \n{excerpt}\n") - title = "" - - # --- Add into Knowledge Table --- # - organization_id = project.organization.id - project_id = project.id - table = GenerativeTable.from_ids(organization_id, project_id, TableType.KNOWLEDGE) - # Check quota - request.state.billing.check_gen_table_llm_quota(table, table_id) - with table.create_session() as session: - meta = table.open_meta(session, table_id) - title_embed = None - text_embeds = [] - for col in meta.cols: - if col["vlen"] == 0: - continue - gen_config = EmbedGenConfig.model_validate(col["gen_config"]) - request.state.billing.check_embedding_quota(model_id=gen_config.embedding_model) - embedder = CloudEmbedder(request=request) - if col["id"] == "Title Embed": - title_embed = await _embed( - gen_config.embedding_model, embedder, [title], col["dtype"] - ) - title_embed = title_embed[0] - elif col["id"] == "Text Embed": - # Truncate based on embedder context length - embedder_context_length = ( - (llm.model_info(gen_config.embedding_model)).data[0].context_length - ) - texts = [_truncate_text(chunk.text, embedder_context_length) for chunk in chunks] - - text_embeds = await _embed( - gen_config.embedding_model, - embedder, - texts, - col["dtype"], - ) - else: - continue - if title_embed is None or len(text_embeds) == 0: - raise RuntimeError( - "Sorry we encountered an issue during embedding. Please try again later." - ) - row_add_data = [ - { - "Text": chunk.text, - "Text Embed": text_embed, - "Title": title, - "Title Embed": title_embed, - "File ID": file_uri, - "Page": chunk.page, - } - for chunk, text_embed in zip(chunks, text_embeds, strict=True) - ] - logger.info( - f'{request_id} - Writing file "{file_name}" with {len(chunks):,d} chunks to DB.' - ) - await add_rows( - request=request, - bg_tasks=bg_tasks, - project=project, - table_type=TableType.KNOWLEDGE, - body=RowAddRequest.model_construct(table_id=table_id, data=row_add_data, stream=False), - ) - bg_tasks.add_task(_create_indexes, project, "knowledge", table_id) + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=body.table_id) + await table.delete_rows(row_ids=body.row_ids, where=body.where) return OkResponse() -@router.options("/v1/gen_tables/knowledge/embed_file") -@router.options("/v1/gen_tables/knowledge/upload_file", deprecated=True) +@router.options( + "/v2/gen_tables/knowledge/embed_file", + summary="Get CORS preflight options for file embedding endpoint", + description="Permissions: None, publicly accessible.", +) @handle_exception -async def embed_file_options(request: Request, response: Response): +async def embed_file_options(): headers = { "Allow": "POST, OPTIONS", "Accept": ", ".join(EMBED_WHITE_LIST_MIME), "Access-Control-Allow-Methods": "POST, OPTIONS", "Access-Control-Allow-Headers": "Content-Type", } - if "upload_file" in request.url.path: - response.headers["Warning"] = ( - '299 - "This endpoint is deprecated and will be removed in v0.4. ' - "Use '/v1/gen_tables/{table_type}/embed_file' instead." - '"' - ) return Response(content=None, headers=headers) -@router.post("/v1/gen_tables/knowledge/embed_file") -@router.post("/v1/gen_tables/knowledge/upload_file", deprecated=True) +@router.post( + "/v2/gen_tables/knowledge/embed_file", + summary="Embed a file into a knowledge table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) @handle_exception async def embed_file( *, request: Request, - response: Response, - bg_tasks: BackgroundTasks, - project: Annotated[ProjectRead, Depends(auth_user_project)], - file: Annotated[UploadFile, File(description="The file.")], - file_name: Annotated[str, Form(description="File name.", deprecated=True)] = "", - table_id: Annotated[str, Form(pattern=TABLE_NAME_PATTERN, description="Knowledge Table ID.")], - # overwrite: Annotated[ - # bool, Form(description="Whether to overwrite old file with the same name.") - # ] = False, - chunk_size: Annotated[ - int, Form(description="Maximum chunk size (number of characters). Must be > 0.", gt=0) - ] = 2000, - chunk_overlap: Annotated[ - int, Form(description="Overlap in characters between chunks. Must be >= 0.", ge=0) - ] = 200, - # stream: Annotated[ - # bool, Form(description="Whether or not to stream the LLM generation.") - # ] = True, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + data: Annotated[FileEmbedFormData, Form()], ) -> OkResponse: - if "upload_file" in request.url.path: - response.headers["Warning"] = ( - '299 - "This endpoint is deprecated and will be removed in v0.4. ' - "Use '/v1/gen_tables/{table_type}/embed_file' instead." - '"' - ) + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) # Validate the Content-Type of the uploaded file - file_name = file.filename or file_name - if splitext(file_name)[1].lower() == ".jsonl": - file_content_type = "application/jsonl" - elif splitext(file_name)[1].lower() == ".md": - file_content_type = "text/markdown" - elif splitext(file_name)[1].lower() == ".tsv": - file_content_type = "text/tab-separated-values" - else: - file_content_type = file.content_type - if file_content_type not in EMBED_WHITE_LIST_MIME: + file_name = data.file.filename or data.file_name + mime = guess_mime(file_name) + if mime == "application/octet-stream": + mime = data.file.content_type + if mime not in EMBED_WHITE_LIST_MIME: raise UnsupportedMediaTypeError( - f"File type '{file_content_type}' is unsupported. Accepted types are: {', '.join(EMBED_WHITE_LIST_MIME)}" + f'File type "{mime}" is unsupported. Accepted types are: {", ".join(EMBED_WHITE_LIST_MIME)}' ) - # --- Add into File Table --- # - content = await file.read() - uri = await upload_file_to_s3( + table = await KnowledgeTable.open_table( + project_id=project.id, + table_id=data.table_id, + ) + # Check quota + request_id: str = request.state.id + billing: BillingManager = request.state.billing + billing.has_gen_table_quota(table) + billing.has_db_storage_quota() + billing.has_egress_quota() + # --- Store original file into S3 --- # + file_content = await data.file.read() + file_uri = await s3_upload( project.organization.id, project.id, - content, - file_content_type, - file_name, + file_content, + content_type=mime, + filename=file_name, ) # if overwrite: # file_table.delete_file(file_name=file_name) # --- Add into Knowledge Table --- # - return await _embed_file( - request=request, - bg_tasks=bg_tasks, + logger.info(f'{request_id} - Parsing file "{file_name}".') + doc_parser = GeneralDocLoader(request_id=request_id) + chunks = await doc_parser.load_document_chunks( + file_name, file_content, data.chunk_size, data.chunk_overlap + ) + logger.info(f'{request_id} - Embedding file "{file_name}" with {len(chunks):,d} chunks.') + + # --- Extract title --- # + lm = LMEngine( + organization=org, project=project, - table_id=table_id, - file_name=file_name, - file_content=content, - file_uri=uri, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, + request=request, ) + ext = splitext(file_name)[1].lower() + if ext in [".pdf", ".pptx", ".xlsx"]: + first_page_chunks = [d.text for d in chunks if d.page == 1] + # If the first page content is too short, use the first 8 chunks instead + if len(first_page_chunks) < 3: + first_page_chunks = [d.text for d in chunks[:8]] + excerpt = "".join(first_page_chunks)[:50000] + else: + excerpt = "".join(d.text for d in chunks[:8])[:50000] + logger.debug(f"{request_id} - Performing title extraction.") + title = await lm.generate_title(excerpt=excerpt, model="") + + # --- Embed --- # + title_embed = text_embeds = None + for col in table.column_metadata: + if col.column_id.lower() == "title embed": + title_embed = await lm.embed_documents( + model=col.gen_config.embedding_model, + texts=[title], + encoding_format="float", + ) + title_embed = title_embed.data[0].embedding + elif col.column_id.lower() == "text embed": + text_embeds = await lm.embed_documents( + model=col.gen_config.embedding_model, + texts=[chunk.text for chunk in chunks], + encoding_format="float", + ) + text_embeds = [data.embedding for data in text_embeds.data] + + if title_embed is None or text_embeds is None or len(text_embeds) == 0: + raise UnexpectedError( + "Sorry we encountered an issue during embedding. If this issue persists, please contact support." + ) + # --- Store into Knowledge Table --- # + row_add_data = [ + { + "Title": title, + "Title Embed": title_embed, + "Text": chunk.text, + "Text Embed": text_embed, + "File ID": file_uri, + "Page": chunk.page, + } + for chunk, text_embed in zip(chunks, text_embeds, strict=True) + ] + await table.add_rows(row_add_data) + return OkResponse() -@router.post("/v1/gen_tables/{table_type}/import_data") +@router.post( + "/v2/gen_tables/{table_type}/import_data", + summary="Import data into a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) @handle_exception async def import_table_data( request: Request, - bg_tasks: BackgroundTasks, - project: Annotated[ProjectRead, Depends(auth_user_project)], - table_type: Annotated[TableType, Path(description="Table type.")], - file: Annotated[UploadFile, File(description="The CSV or TSV file.")], - table_id: Annotated[ - str, - Form( - pattern=TABLE_NAME_PATTERN, - description="ID or name of the table that the data should be imported into.", - ), + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) ], - stream: Annotated[ - bool, Form(description="Whether or not to stream the LLM generation.") - ] = True, - # List of inputs is bugged as of 2024-07-14: https://github.com/tiangolo/fastapi/pull/9928/files - # column_names: Annotated[ - # list[ColName] | None, - # Form( - # description="_Optional_. A list of columns names if the CSV does not have header row. Defaults to None (read from CSV).", - # ), - # ] = None, - # columns: Annotated[ - # list[ColName] | None, - # Form( - # description="_Optional_. A list of columns to be imported. Defaults to None (import all columns except 'ID' and 'Updated at').", - # ), - # ] = None, - delimiter: Annotated[ - CSVDelimiter, - Form(description='The delimiter, can be "," or "\\t". Defaults to ",".'), - ] = CSVDelimiter.COMMA, + table_type: Annotated[TableType, Path(description="Table type.")], + data: Annotated[TableDataImportFormData, Form()], ): - # Get column info - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - with table.create_session() as session: - meta = table.open_meta(session, table_id, remove_state_cols=True) - cols = { - c.id.lower(): c for c in meta.cols_schema if c.id.lower() not in ("id", "updated at") - } - cols_dtype = { - c.id: c.dtype - for c in meta.cols_schema - if c.id.lower() not in ("id", "updated at") and c.vlen == 0 - } - # --- Read file as DataFrame --- # - content = await file.read() - try: - df = csv_to_df(content.decode("utf-8"), sep=delimiter.value) - # Do not import "ID" and "Updated at" - keep_cols = [c for c in df.columns.tolist() if c.lower() in cols] - df = df.filter(items=keep_cols, axis="columns") - except ValueError as e: - raise make_validation_error(e, loc=("body", "file")) from e - # if isinstance(columns, list) and len(columns) > 0: - # df = df[columns] - if len(df) == 0: - raise make_validation_error( - ValueError("The file provided is empty."), loc=("body", "file") - ) - # Convert vector data - for col_id in df.columns.tolist(): - if cols[col_id.lower()].vlen > 0: - df[col_id] = df[col_id].apply(json_loads) - # Cast data to follow column dtype - for col_id, dtype in cols_dtype.items(): - if col_id not in df.columns: - continue - try: - if dtype == "str": - df[col_id] = df[col_id].apply(lambda x: str(x) if not pd.isna(x) else x) - else: - if dtype in [ColumnDtype.IMAGE, ColumnDtype.AUDIO]: - dtype = "str" - df[col_id] = df[col_id].astype(dtype, errors="raise") - except ValueError as e: - raise make_validation_error(e, loc=("body", "file")) from e - # Convert DF to list of dicts - row_add_data = df.to_dict(orient="records") + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=data.table_id) + # Check quota + billing: BillingManager = request.state.billing + billing.has_gen_table_quota(table) + billing.has_db_storage_quota() + billing.has_egress_quota() + # Import data + rows = await table.read_csv( + input_path=BytesIO(await data.file.read()), + column_id_mapping=None, + delimiter=data.delimiter, + ignore_info_columns=True, # Ignore "ID" and "Updated at" columns + ) return await add_rows( request=request, - bg_tasks=bg_tasks, - project=project, + auth_info=auth_info, table_type=table_type, - body=RowAddRequest(table_id=table_id, data=row_add_data, stream=stream), + body=MultiRowAddRequest(table_id=data.table_id, data=rows, stream=data.stream), ) -@router.get("/v1/gen_tables/{table_type}/{table_id}/export_data") +@router.get( + "/v2/gen_tables/{table_type}/export_data", + summary="Export data from a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) @handle_exception -def export_table_data( - *, +async def export_table_data( + request: Request, bg_tasks: BackgroundTasks, - project: Annotated[ProjectRead, Depends(auth_user_project)], - table_type: Annotated[TableType, Path(description="Table type.")], - table_id: Annotated[ - str, - Path(pattern=TABLE_NAME_PATTERN, description="ID or name of the table to be exported."), + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) ], - delimiter: Annotated[ - CSVDelimiter, - Query(description='The delimiter, can be "," or "\\t". Defaults to ",".'), - ] = CSVDelimiter.COMMA, - columns: Annotated[ - list[ColName] | None, - Query( - min_length=1, - description="_Optional_. A list of columns to be exported. Defaults to None (export all columns).", - ), - ] = None, + table_type: Annotated[TableType, Path(description="Table type.")], + params: Annotated[ExportTableDataQuery, Query()], ) -> FileResponse: - # Export data - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - ext = ".csv" if delimiter == CSVDelimiter.COMMA else ".tsv" + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=params.table_id) + # Check quota + billing: BillingManager = request.state.billing + billing.has_gen_table_quota(table) + billing.has_db_storage_quota() + billing.has_egress_quota() + # Temporary file + ext = ".csv" if params.delimiter == CSVDelimiter.COMMA else ".tsv" tmp_dir = TemporaryDirectory() - filename = f"{table_id}{ext}" + filename = f"{params.table_id}{ext}" filepath = join(tmp_dir.name, filename) # Keep a reference to the directory and only delete upon completion bg_tasks.add_task(tmp_dir.cleanup) - # Get column ordering - with table.create_session() as session: - meta = table.open_meta(session, table_id, remove_state_cols=True) - columns_order = [c.id for c in meta.cols_schema] - if columns is None: - columns_to_export = columns_order - else: - columns_to_export = [ - col for col in columns_order if col in columns or col.lower() in ("id", "updated at") - ] - table.export_csv( - table_id=table_id, - columns=columns_to_export, - file_path=filepath, - delimiter=delimiter, + # Export + await table.export_data( + output_path=filepath, + columns=params.columns, + where="", + limit=None, + offset=0, + delimiter=params.delimiter, ) return FileResponse( path=filepath, @@ -1402,53 +1094,97 @@ def export_table_data( ) -@router.post("/v1/gen_tables/{table_type}/import") +@router.post( + "/v2/gen_tables/{table_type}/import", + summary="Import a table including its metadata.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) @handle_exception async def import_table( - project: Annotated[ProjectRead, Depends(auth_user_project)], + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], table_type: Annotated[TableType, Path(description="Table type.")], - file: Annotated[UploadFile, File(description="The parquet file.")], - table_id_dst: Annotated[ - str | None, - Form(pattern=TABLE_NAME_PATTERN, description="The ID or name of the new table."), - ] = None, -) -> TableMetaResponse: - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) - with BytesIO(await file.read()) as source: - with table.create_session() as session: - _, meta = await table.import_parquet( - session=session, - source=source, - table_id_dst=table_id_dst, + data: Annotated[TableImportFormData, Form()], +) -> TableMetaResponse | OkResponse: + user, project, org = auth_info + if not data.migrate: + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + # Check quota + billing: BillingManager = request.state.billing + billing.has_db_storage_quota() + billing.has_egress_quota() + # Import + async with s3_temporary_file(await data.file.read(), "application/vnd.apache.parquet") as uri: + result: AsyncResult = import_gen_table.delay( + source=uri, + project_id=project.id, + table_type=table_type, + table_id_dst=data.table_id_dst, + reupload_files=not data.migrate, + progress_key=data.progress_key, + verbose=data.migrate, + ) + # Poll progress + initial_wait: float = 0.5 + max_wait: float = 30 * 60 # 30 minutes + t0 = perf_counter() + i = 1 + while (not result.ready()) and ((perf_counter() - t0) < max_wait): + await sleep(min(initial_wait * i, 5.0)) + if not data.blocking: + prog = await CACHE.get_progress(data.progress_key, TableImportProgress) + if prog.load_data.progress == 100: + return OkResponse(progress_key=data.progress_key) + i += 1 + if (perf_counter() - t0) >= max_wait: + raise ServerBusyError( + "Table import took too long to complete. Please try again later." ) - meta = TableMetaResponse(**meta.model_dump(), num_rows=table.count_rows(meta.id)) - return meta + return TableMetaResponse.model_validate_json(result.get(propagate=True)) -@router.get("/v1/gen_tables/{table_type}/{table_id}/export") +@router.get( + "/v2/gen_tables/{table_type}/export", + summary="Export a table including its metadata.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) @handle_exception -def export_table( - *, +async def export_table( + request: Request, bg_tasks: BackgroundTasks, - project: Annotated[ProjectRead, Depends(auth_user_project)], - table_type: Annotated[TableType, Path(description="Table type.")], - table_id: Annotated[ - str, - Path(pattern=TABLE_NAME_PATTERN, description="ID or name of the table to be exported."), + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) ], + table_type: Annotated[TableType, Path(description="Table type.")], + table_id: Annotated[str, Query(description="Table name.")], ) -> FileResponse: - table = GenerativeTable.from_ids(project.organization.id, project.id, table_type) + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + # Check quota + billing: BillingManager = request.state.billing + billing.has_db_storage_quota() + billing.has_egress_quota() + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=table_id) + # Temporary file tmp_dir = TemporaryDirectory() filename = f"{table_id}.parquet" filepath = join(tmp_dir.name, filename) # Keep a reference to the directory and only delete upon completion bg_tasks.add_task(tmp_dir.cleanup) - with table.create_session() as session: - table.dump_parquet( - session=session, - table_id=table_id, - dest=filepath, - ) + # Export + await table.export_table(filepath) return FileResponse( path=filepath, filename=filename, diff --git a/services/api/src/owl/routers/gen_table_v1.py b/services/api/src/owl/routers/gen_table_v1.py new file mode 100644 index 0000000..3d2d953 --- /dev/null +++ b/services/api/src/owl/routers/gen_table_v1.py @@ -0,0 +1,916 @@ +from typing import Annotated, Any, Literal + +from fastapi import ( + APIRouter, + BackgroundTasks, + Depends, + File, + Form, + Path, + Query, + Request, + Response, + UploadFile, +) +from fastapi.responses import FileResponse +from pydantic import BaseModel, BeforeValidator, Field + +from owl.routers import gen_table as v2 +from owl.types import ( + TABLE_NAME_PATTERN, + ActionTableSchemaCreate, + ChatTableSchemaCreate, + ChatThreadResponse, + ColumnDropRequest, + ColumnRenameRequest, + ColumnReorderRequest, + CSVDelimiter, + GenConfigUpdateRequest, + KnowledgeTableSchemaCreate, + MultiRowAddRequestWithLimit, + MultiRowDeleteRequest, + MultiRowRegenRequest, + MultiRowUpdateRequest, + OkResponse, + OrganizationRead, + Page, + ProjectRead, + RowUpdateRequest, + SanitisedNonEmptyStr, + SearchRequest, + TableMetaResponse, + TableSchemaCreate, + TableType, + UserAuth, + empty_string_to_none, +) +from owl.utils import uuid7_str +from owl.utils.auth import auth_user_project, has_permissions +from owl.utils.exceptions import handle_exception + +router = APIRouter() + + +@router.post( + "/v1/gen_tables/action", + summary="Create an action table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def create_action_table( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + body: ActionTableSchemaCreate, +) -> TableMetaResponse: + return await v2.create_action_table(request=request, auth_info=auth_info, body=body) + + +@router.post( + "/v1/gen_tables/knowledge", + summary="Create a knowledge table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def create_knowledge_table( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + body: KnowledgeTableSchemaCreate, +) -> TableMetaResponse: + user, project, org = auth_info + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org.id, + project_id=project.id, + ) + return await v2.create_knowledge_table(request=request, auth_info=auth_info, body=body) + + +@router.post( + "/v1/gen_tables/chat", + summary="Create a chat table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def create_chat_table( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + body: ChatTableSchemaCreate, +) -> TableMetaResponse: + return await v2.create_chat_table(request=request, auth_info=auth_info, body=body) + + +@router.post( + "/v1/gen_tables/{table_type}/duplicate/{table_id_src}", + summary="Duplicate a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def duplicate_table( + *, + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: TableType, + table_id_src: str = Path(description="Name of the table to be duplicated."), + table_id_dst: str | None = Query( + None, + pattern=TABLE_NAME_PATTERN, + max_length=100, + description=( + "_Optional_. Name for the new table." + "Defaults to None (automatically find the next available table name)." + ), + ), + include_data: bool = Query( + True, + description="_Optional_. Whether to include data from the source table. Defaults to `True`.", + ), + create_as_child: bool = Query( + False, + description=( + "_Optional_. Whether the new table is a child table. Defaults to `False`. " + "If this is `True`, then `include_data` will be set to `True`." + ), + ), +) -> TableMetaResponse: + return await v2.duplicate_table( + request=request, + auth_info=auth_info, + table_type=table_type, + params=v2.DuplicateTableQuery( + table_id_src=table_id_src, + table_id_dst=table_id_dst, + include_data=include_data, + create_as_child=create_as_child, + ), + ) + + +@router.post( + "/v1/gen_tables/{table_type}/duplicate/{table_id_src}/{table_id_dst}", + deprecated=True, + summary="Duplicate a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def duplicate_table_deprecated( + *, + request: Request, + response: Response, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: TableType, + table_id_src: str = Path(description="Source table name or ID."), + table_id_dst: str = Path( + pattern=TABLE_NAME_PATTERN, + max_length=100, + description="Destination table name or ID.", + ), + include_data: bool = Query( + True, + description="_Optional_. Whether to include the data from the source table in the duplicated table. Defaults to `True`.", + ), + deploy: bool = Query( + False, + description="_Optional_. Whether to deploy the duplicated table. Defaults to `False`.", + ), +) -> TableMetaResponse: + response.headers["Warning"] = ( + '299 - "This endpoint is deprecated and will be removed in v0.5. ' + "Use '/v1/gen_tables/{table_type}/duplicate/{table_id_src}' instead." + '"' + ) + return await v2.duplicate_table( + request=request, + auth_info=auth_info, + table_type=table_type, + params=v2.DuplicateTableQuery( + table_id_src=table_id_src, + table_id_dst=table_id_dst, + include_data=include_data, + create_as_child=deploy, + ), + ) + + +@router.get( + "/v1/gen_tables/{table_type}/{table_id}", + summary="Get a specific table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def get_table( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + table_id: str = Path(description="The ID of the table to fetch."), +) -> TableMetaResponse: + return await v2.get_table(auth_info=auth_info, table_type=table_type, table_id=table_id) + + +class _ListTableQueryLegacy(BaseModel): + offset: Annotated[ + int, + Field(ge=0, description="Item offset for pagination. Defaults to 0."), + ] = 0 + limit: Annotated[ + int, + Field( + gt=0, + le=100, + description="Number of tables to return (min 1, max 100). Defaults to 100.", + ), + ] = 100 + parent_id: Annotated[ + str | None, + Field( + min_length=1, + description=( + "Parent ID of tables to return. Defaults to None (return all tables). " + "Additionally for Chat Table, you can list: " + '(1) all chat agents by passing in "_agent_"; or ' + '(2) all chats by passing in "_chat_".' + ), + ), + ] = None + search_query: Annotated[ + str, + Field( + max_length=255, + description='A string to search for within table IDs as a filter. Defaults to "" (no filter).', + ), + ] = "" + order_by: Annotated[ + Literal["id", "updated_at"], + Field(description='Sort tables by this attribute. Defaults to "updated_at".'), + ] = "updated_at" + order_descending: Annotated[ + bool, + Field( + description="Whether to sort by descending order. Defaults to True.", + ), + ] = True + count_rows: Annotated[ + bool, + Field(description="Whether to count the rows of the tables. Defaults to False."), + ] = False + + +@router.get( + "/v1/gen_tables/{table_type}", + summary="List tables of a specific type.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def list_tables( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + params: Annotated[_ListTableQueryLegacy, Query()], +) -> Page[TableMetaResponse]: + kwargs = params.model_dump() + order_ascending = not kwargs.pop("order_descending", True) + return await v2.list_tables( + auth_info=auth_info, + table_type=table_type, + params=v2.ListTableQuery(order_ascending=order_ascending, **kwargs), + ) + + +@router.post( + "/v1/gen_tables/{table_type}/rename/{table_id_src}/{table_id_dst}", + summary="Rename a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def rename_table( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + table_id_src: Annotated[str, Path(description="Source table name.")], # Don't validate + table_id_dst: Annotated[ + str, + Path(pattern=TABLE_NAME_PATTERN, max_length=100, description="New name for the table."), + ], +) -> TableMetaResponse: + return await v2.rename_table( + auth_info=auth_info, + table_type=table_type, + params=v2.RenameTableQuery(table_id_src=table_id_src, table_id_dst=table_id_dst), + ) + + +@router.delete( + "/v1/gen_tables/{table_type}/{table_id}", + summary="Delete a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def delete_table( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + table_id: Annotated[str, Path(description="Name of the table to be deleted.")], +) -> OkResponse: + return await v2.delete_table(auth_info=auth_info, table_type=table_type, table_id=table_id) + + +@router.post( + "/v1/gen_tables/{table_type}/columns/add", + summary="Add columns to a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def add_columns( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + body: TableSchemaCreate, +) -> TableMetaResponse: + return await v2.add_columns( + request=request, auth_info=auth_info, table_type=table_type, body=body + ) + + +@router.post( + "/v1/gen_tables/{table_type}/columns/rename", + summary="Rename columns in a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def rename_columns( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + body: ColumnRenameRequest, +) -> TableMetaResponse: + return await v2.rename_columns(auth_info=auth_info, table_type=table_type, body=body) + + +@router.post( + "/v1/gen_tables/{table_type}/gen_config/update", + summary="Update generation configuration for table columns.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def update_gen_config( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + updates: GenConfigUpdateRequest, +) -> TableMetaResponse: + return await v2.update_gen_config(auth_info=auth_info, table_type=table_type, updates=updates) + + +@router.post( + "/v1/gen_tables/{table_type}/columns/reorder", + summary="Reorder columns in a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def reorder_columns( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + body: ColumnReorderRequest, +) -> TableMetaResponse: + return await v2.reorder_columns(auth_info=auth_info, table_type=table_type, body=body) + + +@router.post( + "/v1/gen_tables/{table_type}/columns/drop", + summary="Drop columns from a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def drop_columns( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + body: ColumnDropRequest, +) -> TableMetaResponse: + return await v2.drop_columns( + request=request, auth_info=auth_info, table_type=table_type, body=body + ) + + +@router.post( + "/v1/gen_tables/{table_type}/rows/add", + summary="Add rows to a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def add_rows( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + body: MultiRowAddRequestWithLimit, +): + return await v2.add_rows( + request=request, auth_info=auth_info, table_type=table_type, body=body + ) + + +class _ListTableRowQueryLegacy(BaseModel): + offset: Annotated[ + int, + Field(ge=0, description="Item offset for pagination. Defaults to 0."), + ] = 0 + limit: Annotated[ + int, + Field( + gt=0, + le=100, + description="Number of rows to return (min 1, max 100). Defaults to 100.", + ), + ] = 100 + order_descending: Annotated[ + bool, + Field( + description="Whether to sort by descending order. Defaults to True.", + ), + ] = True + columns: Annotated[ + list[str] | None, + Field( + description="A list of column names to include in the response. Default is to return all columns.", + ), + ] = None + search_query: Annotated[ + str, + Field( + max_length=10_000, + description=( + "A string to search for within row data as a filter. " + 'The string is interpreted as both POSIX regular expression and literal string. Defaults to "" (no filter). ' + "It will be combined other filters using `AND`." + ), + ), + ] = "" + float_decimals: Annotated[ + int, + Field( + ge=0, + description="Number of decimals for float values. Defaults to 0 (no rounding).", + ), + ] = 0 + vec_decimals: Annotated[ + int, + Field( + description="Number of decimals for vectors. If its negative, exclude vector columns. Defaults to 0 (no rounding).", + ), + ] = 0 + + +@router.get( + "/v1/gen_tables/{table_type}/{table_id}/rows", + summary="List rows in a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def list_rows( + *, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + table_id: str = Path(description="Table ID or name."), + params: Annotated[_ListTableRowQueryLegacy, Query()], +) -> Page[dict[str, Any]]: + kwargs = params.model_dump() + order_ascending = not kwargs.pop("order_descending", True) + response = await v2.list_rows( + auth_info=auth_info, + table_type=table_type, + params=v2.ListTableRowQuery(table_id=table_id, order_ascending=order_ascending, **kwargs), + ) + # Reproduce V1 "value" bug for backwards compatibility + if params.columns: + for col in params.columns: + for row in response.items: + if col in row and isinstance(row[col], dict): + row[col] = row[col].get("value", row[col]) + return response + + +@router.get( + "/v1/gen_tables/{table_type}/{table_id}/rows/{row_id}", + summary="Get a specific row from a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def get_row( + *, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + table_id: str = Path(description="Table ID or name."), + row_id: Annotated[str, Path(description="The ID of the specific row to fetch.")], + columns: list[str] | None = Query( + default=None, + description="_Optional_. A list of column names to include in the response. Default is to return all columns.", + ), + float_decimals: int = Query( + default=0, + ge=0, + description="_Optional_. Number of decimals for float values. Defaults to 0 (no rounding).", + ), + vec_decimals: int = Query( + default=0, + description="_Optional_. Number of decimals for vectors. If its negative, exclude vector columns. Defaults to 0 (no rounding).", + ), +) -> dict[str, Any]: + return await v2.get_row( + auth_info=auth_info, + table_type=table_type, + params=v2.GetTableRowQuery( + table_id=table_id, + row_id=row_id, + columns=columns, + float_decimals=float_decimals, + vec_decimals=vec_decimals, + ), + ) + + +@router.get( + "/v1/gen_tables/{table_type}/{table_id}/thread", + summary="Get a conversation thread from a multi-turn LLM column.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def get_conversation_thread( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + table_id: Annotated[str, Path(description="Table ID or name.")], + column_id: Annotated[str, Query(description="ID / name of the column to fetch.")], + row_id: Annotated[ + str, + Query( + description='_Optional_. ID / name of the last row in the thread. Defaults to "" (export all rows).' + ), + ] = "", + include: Annotated[ + bool, + Query( + description="_Optional_. Whether to include the row specified by `row_id`. Defaults to True." + ), + ] = True, +) -> ChatThreadResponse: + response = await v2.get_conversation_threads( + auth_info=auth_info, + table_type=table_type, + params=v2.GetTableThreadsQuery( + table_id=table_id, column_ids=[column_id], row_id=row_id, include_row=include + ), + ) + return response.threads[column_id] + + +@router.post( + "/v1/gen_tables/{table_type}/hybrid_search", + summary="Perform hybrid search on a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def hybrid_search( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + body: SearchRequest, +) -> list[dict[str, Any]]: + return await v2.hybrid_search( + request=request, auth_info=auth_info, table_type=table_type, body=body + ) + + +@router.post( + "/v1/gen_tables/{table_type}/rows/regen", + summary="Regenerate rows in a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def regen_rows( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + body: MultiRowRegenRequest, +): + return await v2.regen_rows( + request=request, auth_info=auth_info, table_type=table_type, body=body + ) + + +@router.post( + "/v1/gen_tables/{table_type}/rows/update", + summary="Update a row in a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def update_row( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + body: RowUpdateRequest, +) -> OkResponse: + return await v2.update_rows( + request=request, + auth_info=auth_info, + table_type=table_type, + body=MultiRowUpdateRequest(table_id=body.table_id, data={body.row_id: body.data}), + ) + + +@router.post( + "/v1/gen_tables/{table_type}/rows/delete", + summary="Delete rows from a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def delete_rows( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + body: MultiRowDeleteRequest, +) -> OkResponse: + return await v2.delete_rows(auth_info=auth_info, table_type=table_type, body=body) + + +@router.delete( + "/v1/gen_tables/{table_type}/{table_id}/rows/{row_id}", + summary="Delete a row from a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def delete_row( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + table_id: str = Path(description="Table ID or name."), + row_id: str = Path(description="The ID of the specific row to delete."), +) -> OkResponse: + return await v2.delete_rows( + auth_info=auth_info, + table_type=table_type, + body=MultiRowDeleteRequest(table_id=table_id, row_ids=[row_id]), + ) + + +@router.options( + "/v1/gen_tables/knowledge/embed_file", + summary="Get CORS preflight options for file embedding endpoint", + description="Permissions: None, publicly accessible.", +) +@router.options( + "/v1/gen_tables/knowledge/upload_file", + deprecated=True, + summary="Get CORS preflight options for file embedding endpoint", + description="Permissions: None, publicly accessible.", +) +@handle_exception +async def embed_file_options(request: Request, response: Response): + if "upload_file" in request.url.path: + response.headers["Warning"] = ( + '299 - "This endpoint is deprecated and will be removed in v0.5. ' + "Use '/v1/gen_tables/{table_type}/embed_file' instead." + '"' + ) + return await v2.embed_file_options() + + +class FileEmbedFormData(BaseModel): + file: Annotated[UploadFile, File(description="The file.")] + file_name: Annotated[str, Field(description="File name.", deprecated=True)] = "" + table_id: Annotated[SanitisedNonEmptyStr, Field(description="Knowledge Table ID.")] + # overwrite: Annotated[ + # bool, Field(description="Whether to overwrite old file with the same name.") + # ] = False, + chunk_size: Annotated[ + int, Field(gt=0, description="Maximum chunk size (number of characters). Must be > 0.") + ] = 2000 + chunk_overlap: Annotated[ + int, Field(ge=0, description="Overlap in characters between chunks. Must be >= 0.") + ] = 200 + + +@router.post( + "/v1/gen_tables/knowledge/embed_file", + summary="Embed a file into a knowledge table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@router.post( + "/v1/gen_tables/knowledge/upload_file", + deprecated=True, + summary="Embed a file into a knowledge table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def embed_file( + *, + request: Request, + response: Response, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + data: Annotated[FileEmbedFormData, Form()], +) -> OkResponse: + if "upload_file" in request.url.path: + response.headers["Warning"] = ( + '299 - "This endpoint is deprecated and will be removed in v0.5. ' + "Use '/v1/gen_tables/{table_type}/embed_file' instead." + '"' + ) + return await v2.embed_file( + request=request, + auth_info=auth_info, + data=data, + ) + + +class TableDataImportFormData(BaseModel): + file: Annotated[UploadFile, File(description="The CSV or TSV file.")] + file_name: Annotated[str, Field(description="File name.", deprecated=True)] = "" + table_id: Annotated[ + SanitisedNonEmptyStr, + Field(description="ID or name of the table that the data should be imported into."), + ] + stream: Annotated[bool, Field(description="Whether or not to stream the LLM generation.")] = ( + True + ) + # List of inputs is bugged as of 2024-07-14: https://github.com/tiangolo/fastapi/pull/9928/files + # TODO: Maybe we can re-enable these since the bug is for direct `Form` declaration and not Form Model + # column_names: Annotated[ + # list[ColName] | None, + # Field( + # description="_Optional_. A list of columns names if the CSV does not have header row. Defaults to None (read from CSV).", + # ), + # ] = None + # columns: Annotated[ + # list[ColName] | None, + # Field( + # description="_Optional_. A list of columns to be imported. Defaults to None (import all columns except 'ID' and 'Updated at').", + # ), + # ] = None + delimiter: Annotated[ + CSVDelimiter, + Field(description='The delimiter, can be "," or "\\t". Defaults to ",".'), + ] = CSVDelimiter.COMMA + + +@router.post( + "/v1/gen_tables/{table_type}/import_data", + summary="Import data into a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def import_table_data( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + data: Annotated[TableDataImportFormData, Form()], +): + return await v2.import_table_data( + request=request, auth_info=auth_info, table_type=table_type, data=data + ) + + +@router.get( + "/v1/gen_tables/{table_type}/{table_id}/export_data", + summary="Export data from a table.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def export_table_data( + request: Request, + bg_tasks: BackgroundTasks, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + table_id: Annotated[str, Path(description="ID or name of the table to be exported.")], + delimiter: Annotated[ + CSVDelimiter, + Query(description='The delimiter, can be "," or "\\t". Defaults to ",".'), + ] = CSVDelimiter.COMMA, + columns: Annotated[ + list[str] | None, + Query( + min_length=1, + description="_Optional_. A list of columns to be exported. Defaults to None (export all columns).", + ), + ] = None, +) -> FileResponse: + return await v2.export_table_data( + request=request, + bg_tasks=bg_tasks, + auth_info=auth_info, + table_type=table_type, + params=v2.ExportTableDataQuery(table_id=table_id, delimiter=delimiter, columns=columns), + ) + + +class TableImportFormData(BaseModel): + file: Annotated[UploadFile, File(description="The Parquet file.")] + table_id_dst: Annotated[ + SanitisedNonEmptyStr | None, + BeforeValidator(empty_string_to_none), + Field(description="The ID or name of the new table."), + ] = None + blocking: Annotated[ + bool, + Field( + description=( + "If True, waits until import finishes. " + "If False, the task is submitted to a task queue and returns immediately." + ), + ), + ] = True + progress_key: Annotated[ + str, + Field( + default_factory=uuid7_str, + description="The key to use to query progress. Defaults to a random string.", + ), + ] + migrate: Annotated[bool, Field(description="Whether to import in migration mode.")] = False + + +@router.post( + "/v1/gen_tables/{table_type}/import", + summary="Import a table including its metadata.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def import_table( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + data: Annotated[TableImportFormData, Form()], +) -> TableMetaResponse: + return await v2.import_table( + request=request, + auth_info=auth_info, + table_type=table_type, + data=data, + ) + + +@router.get( + "/v1/gen_tables/{table_type}/{table_id}/export", + summary="Export a table including its metadata.", + description="Permissions: `organization.MEMBER` OR `project.MEMBER`.", +) +@handle_exception +async def export_table( + request: Request, + bg_tasks: BackgroundTasks, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + table_type: Annotated[TableType, Path(description="Table type.")], + table_id: Annotated[str, Path(description="ID or name of the table to be exported.")], +) -> FileResponse: + return await v2.export_table( + request=request, + bg_tasks=bg_tasks, + auth_info=auth_info, + table_type=table_type, + table_id=table_id, + ) diff --git a/services/api/src/owl/routers/llm.py b/services/api/src/owl/routers/llm.py deleted file mode 100644 index 35d754a..0000000 --- a/services/api/src/owl/routers/llm.py +++ /dev/null @@ -1,174 +0,0 @@ -""" -LLM operations. -""" - -import base64 -from typing import Annotated - -import numpy as np -from fastapi import APIRouter, Depends, Query, Request -from fastapi.responses import StreamingResponse - -from jamaibase.exceptions import ResourceNotFoundError -from owl.llm import LLMEngine -from owl.models import CloudEmbedder -from owl.protocol import ( - EXAMPLE_CHAT_MODEL_IDS, - ChatRequest, - ChatRequestWithTools, - EmbeddingRequest, - EmbeddingResponse, - EmbeddingResponseData, - ModelCapability, - ModelInfoResponse, -) -from owl.utils.auth import auth_user_project -from owl.utils.exceptions import handle_exception - -router = APIRouter(dependencies=[Depends(auth_user_project)]) - - -@router.get( - "/v1/models", - summary="List the info of models available.", - description="List the info of models available with the specified name and capabilities.", -) -@handle_exception -async def get_model_info( - request: Request, - model: Annotated[ - str, - Query( - description="ID of the requested model.", - examples=EXAMPLE_CHAT_MODEL_IDS, - ), - ] = "", - capabilities: Annotated[ - list[ModelCapability] | None, - Query( - description=( - "Filter the model info by model's capabilities. " - "Leave it blank to disable filter." - ), - examples=[[ModelCapability.CHAT]], - ), - ] = None, -) -> ModelInfoResponse: - try: - if capabilities is not None: - capabilities = [c.value for c in capabilities] - return LLMEngine(request=request).model_info( - model=model, - capabilities=capabilities, - ) - except ResourceNotFoundError: - return ModelInfoResponse(data=[]) - - -@router.get( - "/v1/model_names", - summary="List the ID of models available.", - description=( - "List the ID of models available with the specified capabilities with an optional preferred model. " - "If the preferred model is not available, then return the first available model." - ), -) -@handle_exception -async def get_model_names( - request: Request, - prefer: Annotated[ - str, - Query( - description="ID of the preferred model.", - examples=EXAMPLE_CHAT_MODEL_IDS, - ), - ] = "", - capabilities: Annotated[ - list[ModelCapability] | None, - Query( - description=( - "Filter the model info by model's capabilities. " - "Leave it blank to disable filter." - ), - examples=[[ModelCapability.CHAT]], - ), - ] = None, -) -> list[str]: - try: - if capabilities is not None: - capabilities = [c.value for c in capabilities] - return LLMEngine(request=request).model_names( - prefer=prefer, - capabilities=capabilities, - ) - except ResourceNotFoundError: - return [] - - -@router.post( - "/v1/chat/completions", - description="Given a list of messages comprising a conversation, the model will return a response.", -) -@handle_exception -async def generate_completions(request: Request, body: ChatRequest | ChatRequestWithTools): - # Check quota - request.state.billing.check_llm_quota(body.model) - request.state.billing.check_egress_quota() - # Run LLM - llm = LLMEngine(request=request) - # object key could cause issue to some LLM provider, ex: Anthropic - body.id = request.state.id - hyperparams = body.model_dump(exclude_none=True, exclude={"object"}) - if body.stream: - - async def _generate(): - content_length = 0 - async for chunk in llm.rag_stream(**hyperparams): - sse = f"data: {chunk.model_dump_json()}\n\n" - content_length += len(sse.encode("utf-8")) - yield sse - sse = "data: [DONE]\n\n" - content_length += len(sse.encode("utf-8")) - yield sse - request.state.billing.create_egress_events(content_length / (1024**3)) - - response = StreamingResponse( - content=_generate(), - status_code=200, - media_type="text/event-stream", - headers={"X-Accel-Buffering": "no"}, - ) - - else: - response = await llm.rag(**hyperparams) - request.state.billing.create_egress_events( - len(response.model_dump_json().encode("utf-8")) / (1024**3) - ) - return response - - -@router.post( - "/v1/embeddings", - description=( - "Get a vector representation of a given input that can be " - "easily consumed by machine learning models and algorithms. " - "Note that the vectors are NOT normalized." - ), -) -@handle_exception -async def generate_embeddings(request: Request, body: EmbeddingRequest) -> EmbeddingResponse: - embedder = CloudEmbedder(request=request) - if isinstance(body.input, str): - body.input = [body.input] - if body.type == "document": - embeddings = await embedder.embed_documents(embedder_name=body.model, texts=body.input) - else: - embeddings = await embedder.embed_queries(embedder_name=body.model, texts=body.input) - if body.encoding_format == "base64": - embeddings.data = [ - EmbeddingResponseData( - embedding=base64.b64encode(np.asarray(e.embedding, dtype=np.float32)), index=i - ) - for i, e in enumerate(embeddings.data) - ] - return embeddings diff --git a/services/api/src/owl/routers/meters.py b/services/api/src/owl/routers/meters.py new file mode 100644 index 0000000..0ef6e2e --- /dev/null +++ b/services/api/src/owl/routers/meters.py @@ -0,0 +1,631 @@ +from datetime import datetime +from typing import Annotated, Literal + +from fastapi import APIRouter, Depends, Query + +from owl.db import SCHEMA, async_session, cached_text +from owl.types import UsageResponse, UserAuth +from owl.utils.auth import ( + auth_user_service_key, + has_permissions, +) +from owl.utils.billing import CLICKHOUSE_CLIENT +from owl.utils.billing_metrics import BillingMetrics +from owl.utils.exceptions import ( + BadInputError, + handle_exception, +) +from owl.utils.metrics import Telemetry + +router = APIRouter() +telemetry = Telemetry() + +billing_metrics = BillingMetrics(clickhouse_client=CLICKHOUSE_CLIENT) + + +async def _check_permissions( + user: UserAuth, + org_ids: list[str] | None, + proj_ids: list[str] | None, +) -> None: + if org_ids is None and proj_ids is None: + # This will return usages across ALL orgs and ALL projects + has_permissions(user, ["system.MEMBER"]) + else: + if org_ids: + for org_id in org_ids: + has_permissions(user, ["organization.MEMBER"], organization_id=org_id) + if proj_ids: + for proj_id in proj_ids: + async with async_session() as session: + stmt = f"""SELECT organization_id FROM {SCHEMA}."Project" WHERE id = '{proj_id}';""" + org_id = (await session.exec(cached_text(stmt))).one() + has_permissions( + user, + ["organization.MEMBER", "project.MEMBER"], + organization_id=org_id, + project_id=proj_id, + ) + + +@router.get( + "/v2/meters/usages", + summary="Get the usage metrics of the specified type (llm, embedding, reranking).", + description=( + "Permissions: `system.MEMBER` to retrieve metrics for all organizations or all projects; " + "`organization.MEMBER` to retrieve metrics for a specific organization; " + "`project.MEMBER` to retrieve metrics for a specific project." + ), + response_model=UsageResponse, +) +@handle_exception +async def get_usage_metrics( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + type: Annotated[ + Literal["llm", "embedding", "reranking"], + Query( + min_length=1, + description="Type of usage data to query. Must be one of: 'llm', 'embedding', or 'reranking'.", + ), + ], + from_: Annotated[ + datetime, Query(alias="from", description="Start datetime for the usage data query.") + ], + window_size: Annotated[ + str, + Query( + min_length=1, + description="The aggregation window size (e.g., '1d' for daily, '1w' for weekly).", + alias="windowSize", + ), + ], + org_ids: Annotated[ + list[str] | None, + Query( + description="List of organization IDs to filter the query. If not provided, data for all organizations is returned.", + alias="orgIds", + ), + ] = None, + proj_ids: Annotated[ + list[str] | None, + Query( + description="List of project IDs to filter the query. If not provided, data for all projects is returned.", + alias="projIds", + ), + ] = None, + to: Annotated[ + datetime | None, + Query( + description="End datetime for the usage data query. If not provided, data up to the current datetime is returned." + ), + ] = None, + group_by: Annotated[ + list[str] | None, + Query( + min_length=1, + description="List of fields to group the usage data by. If not provided, no grouping is applied.", + alias="groupBy", + ), + ] = None, + data_source: Annotated[ + Literal["clickhouse", "victoriametrics"], + Query(description="Data source to query. Defaults to 'clickhouse'.", alias="dataSource"), + ] = "clickhouse", +) -> UsageResponse: + """ + Retrieves usages metrics based on the provided filters. + This endpoint requires `system.MEMBER` permission. + + This endpoint allows querying usage data for specific organizations within a given time range. + The results can be grouped by specified fields and aggregated using a window size. + + Args: + user (UserAuth): The authenticated user making the request. + type (str): Type of usage data to query. One of: llm, embedding, reranking. + from_ (datetime): The start of the time range for the usage data. + window_size (str): The size of the time window for aggregating usage data + (e.g., "1d" for daily, "1w" for weekly). + org_ids (list[str] | None): A list of organization IDs to filter the usage data. + If not provided, data for all organizations will be returned. + proj_ids (list[str] | None): A list of project IDs to filter the usage data. + If not provided, data for all projects will be returned. + to (datetime | None): The end of the time range for the usage data. + If not provided, data up to the current date will be returned. + group_by (list[str] | None): A list of fields to group the usage data by. + If not provided, the data will not be grouped. + data_source (str): The data source to query. Defaults to "clickhouse". + + Returns: + UsageResponse: A response containing window_size and a list of the usage metrics. + + Raises: + BadInputError: If the 'type' parameter is invalid (not one of 'llm', + 'embedding', or 'reranking'). + """ + # RBAC + await _check_permissions(user, org_ids, proj_ids) + # Fetch + if group_by is None: + group_by = [] + if data_source == "clickhouse": + metrics_client = billing_metrics + elif data_source == "victoriametrics": + metrics_client = telemetry + if type == "llm": + return await metrics_client.query_llm_usage( + org_ids, proj_ids, from_, to, group_by, window_size + ) + elif type == "embedding": + return await metrics_client.query_embedding_usage( + org_ids, proj_ids, from_, to, group_by, window_size + ) + elif type == "reranking": + return await metrics_client.query_reranking_usage( + org_ids, proj_ids, from_, to, group_by, window_size + ) + raise BadInputError(f"type: {type} invalid. Must be one of: llm, embedding, reranking") + + +@router.get( + "/v2/meters/billings", + summary="Get billing metrics.", + description="Permissions: `system.MEMBER`.", + response_model=UsageResponse, +) +@handle_exception +async def get_billing_metrics( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + from_: Annotated[ + datetime, + Query( + alias="from", + description="Start datetime for the billing data query.", + ), + ], + window_size: Annotated[ + str, + Query( + min_length=1, + description="The aggregation window size (e.g., '1d' for daily, '1w' for weekly).", + alias="windowSize", + ), + ], + org_ids: Annotated[ + list[str] | None, + Query( + description="List of organization IDs to filter the query. If not provided, data for all organizations is returned.", + alias="orgIds", + ), + ] = None, + proj_ids: Annotated[ + list[str] | None, + Query( + description="List of project IDs to filter the query. If not provided, data for all projects is returned.", + alias="projIds", + ), + ] = None, + to: Annotated[ + datetime | None, + Query( + description="End datetime for the billing data query. If not provided, data up to the current datetime is returned.", + ), + ] = None, + group_by: Annotated[ + list[str] | None, + Query( + min_length=1, + description="List of fields to group the billing data by. If not provided, no grouping is applied.", + alias="groupBy", + ), + ] = None, + data_source: Annotated[ + Literal["clickhouse", "victoriametrics"], + Query(description="Data source to query. Defaults to 'clickhouse'.", alias="dataSource"), + ] = "clickhouse", +) -> UsageResponse: + """ + Retrieves billing metrics based on the provided filters. + This endpoint requires `system.MEMBER` permission. + + This endpoint allows querying billing data for specific organizations within a given time range. + The results can be grouped by specified fields and aggregated using a window size. + + Args: + user (str): The authenticated user making the request. + from_ (datetime): The start of the time range for the billing data. + window_size (str): The size of the time window for aggregating billing data + (e.g., "1d" for daily, "1w" for weekly). + org_ids (list[str] | None): A list of organization IDs to filter the billing data. + If not provided, data for all organizations will be returned. + proj_ids (list[str] | None): A list of project IDs to filter the billing data. + If not provided, data for all projects will be returned. + to (datetime | None): The end of the time range for the billing data. + If not provided, data up to the current date will be returned. + group_by (list[str] | None): A list of fields to group the billing data by. + If not provided, the data will not be grouped. + data_source (str): The data source to query. Defaults to "clickhouse". + + Returns: + UsageResponse: A response containing window_size and a list of the billing metrics. + """ + # RBAC + await _check_permissions(user, org_ids, proj_ids) + # Fetch + if group_by is None: + group_by = [] + if data_source == "clickhouse": + metrics_client = billing_metrics + elif data_source == "victoriametrics": + metrics_client = telemetry + return await metrics_client.query_billing(org_ids, proj_ids, from_, to, group_by, window_size) + + +@router.get( + "/v2/meters/bandwidths", + summary="Get bandwidth usage metrics.", + description="Permissions: `system.MEMBER`.", + response_model=UsageResponse, +) +@handle_exception +async def get_bandwidth_metrics( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + from_: Annotated[ + datetime, + Query( + alias="from", + description="Start datetime for the bandwidth data query.", + ), + ], + window_size: Annotated[ + str, + Query( + min_length=1, + description="The aggregation window size (e.g., '1d' for daily, '1w' for weekly).", + alias="windowSize", + ), + ], + org_ids: Annotated[ + list[str] | None, + Query( + description="List of organization IDs to filter the query. If not provided, data for all organizations is returned.", + alias="orgIds", + ), + ] = None, + proj_ids: Annotated[ + list[str] | None, + Query( + description="List of project IDs to filter the query. If not provided, data for all projects is returned.", + alias="projIds", + ), + ] = None, + to: Annotated[ + datetime | None, + Query( + description="End datetime for the bandwidth data query. If not provided, data up to the current datetime is returned.", + ), + ] = None, + group_by: Annotated[ + list[str] | None, + Query( + min_length=1, + description="List of fields to group the bandwidth data by. If not provided, no grouping is applied.", + alias="groupBy", + ), + ] = None, + data_source: Annotated[ + Literal["clickhouse", "victoriametrics"], + Query(description="Data source to query. Defaults to 'clickhouse'.", alias="dataSource"), + ] = "clickhouse", +) -> UsageResponse: + """ + Retrieves bandwidth metrics based on the provided filters. + This endpoint requires `system.MEMBER` permission. + + This endpoint allows querying bandwidth data for specific organizations within a given time range. + The results can be grouped by specified fields and aggregated using a window size. + + Args: + user (str): The authenticated user making the request. + from_ (datetime): The start of the time range for the bandwidth data. + window_size (str): The size of the time window for aggregating bandwidth data + (e.g., "1d" for daily, "1w" for weekly). + org_ids (list[str] | None): A list of organization IDs to filter the bandwidth data. + If not provided, data for all organizations will be returned. + proj_ids (list[str] | None): A list of project IDs to filter the bandwidth data. + If not provided, data for all projects will be returned. + to (datetime | None): The end of the time range for the bandwidth data. + If not provided, data up to the current date will be returned. + group_by (list[str] | None): A list of fields to group the bandwidth data by. + If not provided, the data will not be grouped. + data_source (str): The data source to query. Defaults to "clickhouse". + + Returns: + UsageResponse: A response containing window_size and a list of the bandwidth metrics. + """ + # RBAC + await _check_permissions(user, org_ids, proj_ids) + # Fetch + if group_by is None: + group_by = [] + if data_source == "clickhouse": + metrics_client = billing_metrics + elif data_source == "victoriametrics": + metrics_client = telemetry + return await metrics_client.query_bandwidth( + org_ids, proj_ids, from_, to, group_by, window_size + ) + + +@router.get( + "/v2/meters/storages", + summary="Get storage usage metrics.", + description="Permissions: `system.MEMBER`.", + response_model=UsageResponse, +) +@handle_exception +async def get_storage_metrics( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + from_: Annotated[ + datetime, + Query( + alias="from", + description="Start datetime for the storage data query.", + ), + ], + window_size: Annotated[ + str, + Query( + min_length=1, + description="The aggregation window size (e.g., '1d' for daily, '1w' for weekly).", + alias="windowSize", + ), + ], + org_ids: Annotated[ + list[str] | None, + Query( + description="List of organization IDs to filter the query. If not provided, data for all organizations is returned.", + alias="orgIds", + ), + ] = None, + proj_ids: Annotated[ + list[str] | None, + Query( + description="List of project IDs to filter the query. If not provided, data for all projects is returned.", + alias="projIds", + ), + ] = None, + to: Annotated[ + datetime | None, + Query( + description="End datetime for the storage data query. If not provided, data up to the current datetime is returned.", + ), + ] = None, + group_by: Annotated[ + list[str] | None, + Query( + min_length=1, + description="List of fields to group the storage data by. If not provided, no grouping is applied.", + alias="groupBy", + ), + ] = None, + data_source: Annotated[ + Literal["clickhouse", "victoriametrics"], + Query(description="Data source to query. Defaults to 'clickhouse'.", alias="dataSource"), + ] = "clickhouse", +) -> UsageResponse: + """ + Retrieves storage metrics based on the provided filters. + This endpoint requires `system.MEMBER` permission. + + This endpoint allows querying storage data for specific organizations within a given time range. + The results can be grouped by specified fields and aggregated using a window size. + + Args: + user (str): The authenticated user making the request. + from_ (datetime): The start of the time range for the storage data. + window_size (str): The size of the time window for aggregating storage data + (e.g., "1d" for daily, "1w" for weekly). + org_ids (list[str] | None): A list of organization IDs to filter the storage data. + If not provided, data for all organizations will be returned. + proj_ids (list[str] | None): A list of project IDs to filter the storage data. + If not provided, data for all projects will be returned. + to (datetime | None): The end of the time range for the storage data. + If not provided, data up to the current date will be returned. + group_by (list[str] | None): A list of fields to group the storage data by. + If not provided, the data will not be grouped. + + Returns: + UsageResponse: A response containing window_size and a list of the storage metrics. + """ + # RBAC + await _check_permissions(user, org_ids, proj_ids) + # Fetch + if group_by is None: + group_by = [] + if data_source == "clickhouse": + metrics_client = billing_metrics + elif data_source == "victoriametrics": + metrics_client = telemetry + return await metrics_client.query_storage(org_ids, proj_ids, from_, to, group_by, window_size) + + +# @router.get( +# "/v2/meters/models/throughput", +# summary="Get the model throughput statistics of the specified model type (llm, embedding, reranking), and metric type.", +# description="Permissions: `system.models` OR `system.metrics`.", +# response_model=UsageResponse, +# ) +# @handle_exception +# async def get_model_throughput_metrics( +# user: Annotated[UserAuth, Depends(auth_user_service_key)], +# type: Annotated[ +# Literal["llm", "embedding", "reranking"], +# Query( +# min_length=1, +# description="Type of usage data to query. Must be one of: 'llm', 'embedding', or 'reranking'.", +# ), +# ], +# metric_type: Annotated[ +# Literal["tpm", "rpm", "spm"], +# Query( +# min_length=1, +# description=( +# "Type of metric to query, " +# "Here is the list of possible metric type: " +# "llm: tpm, rpm" +# "embedding: tpm, rpm" +# "reranking: spm, rpm" +# "tpm (tokens per minute), rpm (requests per minute), spm (searches per minute) " +# ), +# ), +# ], +# from_: Annotated[ +# datetime, Query(alias="from", description="Start datetime for the usage data query.") +# ], +# to: Annotated[ +# datetime | None, +# Query( +# description="End datetime for the usage data query. If not provided, data up to the current datetime is returned." +# ), +# ] = None, +# ) -> UsageResponse: +# """ +# Retrieves model throughput statistics based on the provided filters. +# This endpoint requires `system.metrics` permission. + +# This endpoint allows querying model throughput statistics data for specific metric type within a given time range. + +# Args: +# user (UserAuth): The authenticated user making the request. +# type (str): Type of usage data to query. One of: llm, embedding, reranking. +# metric_type (str): Type of metric to query. One of tpm, rpm, spm. +# Valid metric_type depends on model type: +# llm: tpm, rpm +# embedding: tpm, rpm +# reranking: spm, rpm +# from_ (datetime): The start of the time range for the usage data. +# to (datetime | None): The end of the time range for the usage data. +# If not provided, data up to the current date will be returned. + +# Returns: +# UsageResponse: A response containing window_size and a list of the usage metrics. + +# Raises: +# BadInputError: If the 'type' parameter is invalid (not one of 'llm', +# 'embedding', or 'reranking'). Or if the 'model_type' parameter is invalid. +# """ +# has_permissions(user, ["system.models", "system.metrics"]) +# if type == "llm": +# if metric_type == "tpm": +# return await telemetry.query_llm_tpm(from_, to) +# elif metric_type == "rpm": +# return await telemetry.query_llm_rpm(from_, to) +# elif type == "embedding": +# if metric_type == "tpm": +# return await telemetry.query_embed_tpm(from_, to) +# elif metric_type == "rpm": +# return await telemetry.query_embed_rpm(from_, to) +# elif type == "reranking": +# if metric_type == "spm": +# return await telemetry.query_rerank_spm(from_, to) +# elif metric_type == "rpm": +# return await telemetry.query_rerank_rpm(from_, to) +# raise BadInputError( +# f"type: {type} with metric type: {metric_type} invalid. Must be one of: llm (tpm, rpm), embedding (tpm, rpm), reranking (tpm, rpm)" +# ) + + +# @router.get( +# "/v2/meters/models/latency", +# summary="Get the model latency past hour statistics of the specified model type (llm, embedding, reranking), and metric type.", +# description="Permissions: `system.models` OR `system.metrics`.", +# response_model=UsageResponse, +# ) +# @handle_exception +# async def get_model_latency_metrics( +# user: Annotated[UserAuth, Depends(auth_user_service_key)], +# type: Annotated[ +# Literal["llm", "embedding", "reranking"], +# Query( +# min_length=1, +# description="Type of usage data to query. Must be one of: 'llm', 'embedding', or 'reranking'.", +# ), +# ], +# metric_type: Annotated[ +# Literal["itl", "ttft", "tpot", "rt"], +# Query( +# min_length=1, +# description=( +# "Type of metric to query, " +# "Here is the list of possible metric type: " +# "llm: itl, ttft, tpot " +# "embedding: rt " +# "reranking: rt " +# "itl (inter-token latency), ttft (time to first token), tpot (time per output token), rt (response time)" +# ), +# ), +# ], +# quantile: Annotated[ +# float, +# Query( +# ge=0, +# le=1, +# description=("Quantile of latency to query, ex: 0.95 means 95th percentile latency."), +# ), +# ], +# from_: Annotated[ +# datetime, Query(alias="from", description="Start datetime for the usage data query.") +# ], +# to: Annotated[ +# datetime | None, +# Query( +# description="End datetime for the usage data query. If not provided, data up to the current datetime is returned." +# ), +# ] = None, +# ) -> UsageResponse: +# """ +# Retrieves model latency statistics based on the provided filters. +# This endpoint requires `system.metrics` permission. + +# This endpoint allows querying model latency statistics data for specific metric type within a given time range. + +# Args: +# user (UserAuth): The authenticated user making the request. +# type (str): Type of usage data to query. One of: llm, embedding, reranking. +# metric_type (str): Type of metric to query. One of ttft, tpot, rt. +# Valid metric_type depends on model type: +# llm: itl, ttft, tpot +# embedding: rt +# reranking: rt +# quantile (float): Quantile of latency to query, ex: 0.95 means 95th percentile latency. +# from_ (datetime): The start of the time range for the usage data. +# to (datetime | None): The end of the time range for the usage data. +# If not provided, data up to the current date will be returned. + +# Returns: +# Each data point is the quantile latency based on past 1 hour data, with 1 minute resolution. +# UsageResponse: A response containing window_size and a list of the usage metrics. + +# Raises: +# BadInputError: If the 'type' parameter is invalid (not one of 'llm', +# 'embedding', or 'reranking'). Or if the 'model_type' parameter is invalid. +# """ +# has_permissions(user, ["system.models", "system.metrics"]) +# if type == "llm": +# if metric_type == "ttft": +# return await telemetry.query_hourly_llm_ttft_quantile(from_, to, quantile) +# elif metric_type == "tpot": +# return await telemetry.query_hourly_llm_tpot_quantile(from_, to, quantile) +# elif metric_type == "itl": +# return await telemetry.query_hourly_llm_itl_quantile(from_, to, quantile) +# elif type == "embedding": +# if metric_type == "rt": +# return await telemetry.query_hourly_embed_completion_time_quantile(from_, to, quantile) +# elif type == "reranking": +# if metric_type == "rt": +# return await telemetry.query_hourly_rerank_completion_time_quantile( +# from_, to, quantile +# ) +# raise BadInputError( +# f"type: {type} with metric type: {metric_type} invalid. Must be one of: llm (itl, ttft, tpot), embedding (rt), reranking (rt)" +# ) diff --git a/services/api/src/owl/routers/models.py b/services/api/src/owl/routers/models.py new file mode 100644 index 0000000..4418270 --- /dev/null +++ b/services/api/src/owl/routers/models.py @@ -0,0 +1,318 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Query, Request +from loguru import logger +from sqlmodel import func, select + +from owl.db import AsyncSession, yield_async_session +from owl.db.models import Deployment, ModelConfig +from owl.types import ( + CloudProvider, + DeploymentCreate, + DeploymentRead, + DeploymentUpdate, + ListQuery, + ModelConfigCreate, + ModelConfigRead, + ModelConfigUpdate, + OkResponse, + Page, + UserAuth, +) +from owl.utils.auth import auth_user_service_key, has_permissions +from owl.utils.dates import now +from owl.utils.exceptions import ( + BadInputError, + ResourceExistsError, + ResourceNotFoundError, + handle_exception, +) + +router = APIRouter() + + +@router.post( + "/v2/models/configs", + summary="Create a model config.", + description="Permissions: `system.MEMBER`. Prerequisite for creating a deployment.", +) +@handle_exception +async def create_model_config( + request: Request, + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + body: ModelConfigCreate, +) -> ModelConfigRead: + has_permissions(user, ["system.MEMBER"]) + if (await session.get(ModelConfig, body.id)) is not None: + raise ResourceExistsError(f'ModelConfig "{body.id}" already exists.') + model = ModelConfig.model_validate(body) + session.add(model) + await session.commit() + await session.refresh(model) + logger.bind(user_id=user.id).success( + f'{user.name} ({user.email}) created a model config for "{model.name}" ({model.id}).' + ) + logger.bind(user_id=user.id).info(f"{request.state.id} - Created model config: {model}") + return model + + +@router.get( + "/v2/models/configs/list", + summary="List system-wide model configs.", + description="Permissions: `system`.", +) +@handle_exception +async def list_model_configs( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + params: Annotated[ListQuery, Query()], +) -> Page[ModelConfigRead]: + has_permissions(user, ["system"]) + return await ModelConfig.list_( + session=session, + return_type=ModelConfigRead, + organization_id=None, + offset=params.offset, + limit=params.limit, + order_by=params.order_by, + order_ascending=params.order_ascending, + search_query=params.search_query, + search_columns=params.search_columns, + after=params.after, + ) + + +@router.get( + "/v2/models/configs", + summary="Get a model config.", + description="Permissions: `system`.", +) +@handle_exception +async def get_model_config( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + model_id: Annotated[str, Query(min_length=1, description="Deployment ID.")], +) -> ModelConfigRead: + has_permissions(user, ["system"]) + return await ModelConfig.get(session, model_id, name="Model config") + + +@router.patch( + "/v2/models/configs", + summary="Update a model config.", + description="Permissions: `system.MEMBER`.", +) +@handle_exception +async def update_model_config( + request: Request, + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + model_id: Annotated[str, Query(min_length=1, description="Deployment ID.")], + body: ModelConfigUpdate, +) -> ModelConfigRead: + has_permissions(user, ["system.MEMBER"]) + model = await ModelConfig.get(session, model_id, name="Model config") + updates = body.model_dump(exclude_unset=True) + ModelConfigCreate.validate_updates(base=model, updates=updates) + for key, value in updates.items(): + setattr(model, key, value) + model.updated_at = now() + session.add(model) + await session.commit() + await session.refresh(model) + logger.bind(user_id=user.id).success( + ( + f"{user.name} ({user.email}) updated the attributes " + f'{list(updates.keys())} of the model config for "{model.name}" ({model.id}).' + ) + ) + logger.bind(user_id=user.id).info(f"{request.state.id} - Updated model config: {model}") + return model + + +@router.delete( + "/v2/models/configs", + summary="Delete a model config.", + description="Permissions: `system.MEMBER`.", +) +@handle_exception +async def delete_model_config( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + model_id: Annotated[str, Query(min_length=1, description="Deployment ID.")], +) -> OkResponse: + has_permissions(user, ["system.MEMBER"]) + model = await session.get(ModelConfig, model_id) + if model is None: + raise ResourceNotFoundError(f'ModelConfig "{model_id}" is not found.') + # Check deployments + num_deployments = ( + await session.exec( + select(func.count(Deployment.id)).where(Deployment.model_id == model_id) + ) + ).one() + if num_deployments > 0: + raise BadInputError( + ( + f'Cannot delete model "{model_id}" because it still has {num_deployments:,d} deployments. ' + "Please delete the deployments first." + ) + ) + await session.delete(model) + await session.commit() + return OkResponse() + + +@router.get( + "/v2/models/deployments/providers/cloud", + summary="List available cloud providers.", + description="Permissions: `system`.", +) +@handle_exception +async def list_available_providers( + user: Annotated[UserAuth, Depends(auth_user_service_key)], +) -> list[str]: + has_permissions(user, ["system"]) + return list(CloudProvider) + + +@router.post( + "/v2/models/deployments/cloud", + summary="Create an external cloud deployment.", + description=( + "Permissions: `system.MEMBER`. " + "Note that a model config must be created before creating a deployment. " + "Request body format: " + "`provider` must be a valid Provider enum. " + "`routing_id` must be a string. " + "`api_base` is an OPTIONAL string. " + ), +) +@handle_exception +async def create_deployment( + request: Request, + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + body: DeploymentCreate, +) -> DeploymentRead: + logger.info(f"{request.state.id} - Creating deployment: {body}") + has_permissions(user, ["system.MEMBER"]) + # Check if the associated model exists + model = await session.get(ModelConfig, body.model_id) + if model is None: + raise ResourceNotFoundError(f'Model "{body.model_id}" does not exist.') + deployment = Deployment.model_validate(body) + session.add(deployment) + await session.commit() + await session.refresh(deployment) + logger.bind(user_id=user.id).success( + ( + f"{user.name} ({user.email}) created a cloud deployment " + f'"{deployment.name}" ({deployment.id}) for model "{model.name}" with ' + f'provider "{deployment.provider}".' + ) + ) + logger.info(f"{request.state.id} - Created cloud deployment: {deployment}") + return deployment + + +@router.get( + "/v2/models/deployments/list", + summary="List deployments.", + description="Permissions: `system`.", +) +@handle_exception +async def list_deployments( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + params: Annotated[ListQuery, Query()], +) -> Page[DeploymentRead]: + has_permissions(user, ["system"]) + return await Deployment.list_( + session=session, + return_type=DeploymentRead, + offset=params.offset, + limit=params.limit, + order_by=params.order_by, + order_ascending=params.order_ascending, + search_query=params.search_query, + search_columns=params.search_columns, + after=params.after, + ) + + +@router.get( + "/v2/models/deployments", + summary="Get a deployment.", + description="Permissions: `system`.", +) +@handle_exception +async def get_deployment( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + deployment_id: Annotated[str, Query(min_length=1, description="Deployment ID.")], +) -> DeploymentRead: + has_permissions(user, ["system"]) + deployment = await session.get(Deployment, deployment_id) + if deployment is None: + raise ResourceNotFoundError(f'Deployment "{deployment_id}" is not found.') + return deployment + + +@router.patch( + "/v2/models/deployments", + summary="Update a deployment.", + description="Permissions: `system.MEMBER`.", +) +@handle_exception +async def update_deployment( + request: Request, + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + deployment_id: Annotated[str, Query(min_length=1, description="Deployment ID.")], + body: DeploymentUpdate, +) -> DeploymentRead: + has_permissions(user, ["system.MEMBER"]) + deployment = await session.get(Deployment, deployment_id) + if deployment is None: + raise ResourceNotFoundError(f'Deployment "{deployment_id}" is not found.') + logger.info(f"Current deployment: {deployment}") + # Perform update + updates = body.model_dump(exclude=["id"], exclude_unset=True) + for key, value in updates.items(): + setattr(deployment, key, value) + deployment.updated_at = now() + session.add(deployment) + await session.commit() + await session.refresh(deployment) + logger.bind(user_id=user.id).success( + ( + f"{user.name} ({user.email}) updated the attributes " + f'{list(updates.keys())} of a deployment "{deployment.name}" ({deployment.id}).' + ) + ) + logger.info(f"{request.state.id} - Updated deployment: {deployment}") + return deployment + + +@router.delete( + "/v2/models/deployments", + summary="Delete a deployment.", + description="Permissions: `system.MEMBER`.", +) +@handle_exception +async def delete_deployment( + request: Request, + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + deployment_id: Annotated[str, Query(min_length=1, description="Deployment ID.")], +) -> OkResponse: + logger.info(f"{request.state.id} - Deleting deployment: {deployment_id}") + has_permissions(user, ["system.MEMBER"]) + deployment = await session.get(Deployment, deployment_id) + if deployment is None: + raise ResourceNotFoundError(f'Deployment "{deployment_id}" is not found.') + await session.delete(deployment) + await session.commit() + return OkResponse() diff --git a/services/api/src/owl/routers/org_admin.py b/services/api/src/owl/routers/org_admin.py deleted file mode 100644 index e1f66e4..0000000 --- a/services/api/src/owl/routers/org_admin.py +++ /dev/null @@ -1,703 +0,0 @@ -import pathlib -from datetime import datetime -from io import BytesIO -from os.path import join -from tempfile import TemporaryDirectory -from time import perf_counter -from typing import Annotated, Literal, Mapping - -import pyarrow as pa -from fastapi import ( - APIRouter, - BackgroundTasks, - Depends, - File, - Form, - Path, - Query, - Request, - UploadFile, -) -from fastapi.responses import FileResponse -from loguru import logger -from pyarrow.parquet import read_table as read_parquet_table -from pyarrow.parquet import write_table as write_parquet_table -from sqlalchemy import func -from sqlmodel import Session, select - -from jamaibase.exceptions import ( - BadInputError, - ForbiddenError, - ResourceExistsError, - ResourceNotFoundError, - UpgradeTierError, - make_validation_error, -) -from jamaibase.utils.io import json_dumps, json_loads, read_json -from owl.configs.manager import CONFIG, ENV_CONFIG -from owl.db import MAIN_ENGINE, UserSQLModel, cached_text, create_sql_tables -from owl.db.gen_table import GenerativeTable -from owl.protocol import ( - AdminOrderBy, - ModelListConfig, - Name, - OkResponse, - Page, - TableMeta, - TableMetaResponse, - TableType, - TemplateMeta, -) -from owl.utils import datetime_now_iso -from owl.utils.auth import WRITE_METHODS, AuthReturn, auth_user -from owl.utils.crypt import generate_key -from owl.utils.exceptions import handle_exception - -if ENV_CONFIG.is_oss: - from owl.db.oss_admin import ( - Organization, - OrganizationRead, - Project, - ProjectCreate, - ProjectRead, - ProjectUpdate, - ) -else: - from owl.db.cloud_admin import ( - Organization, - OrganizationRead, - Project, - ProjectCreate, - ProjectRead, - ProjectUpdate, - ) - - -CURR_DIR = pathlib.Path(__file__).resolve().parent -TEMPLATE_DIR = CURR_DIR.parent / "templates" -router = APIRouter() - - -@router.on_event("startup") -async def startup(): - create_sql_tables(UserSQLModel, MAIN_ENGINE) - - -def _get_session(): - with Session(MAIN_ENGINE) as session: - yield session - - -def _check_access( - *, - session: Annotated[Session, Depends(_get_session)], - request: Request, - auth_info: AuthReturn, - org_or_id: str | Organization, -) -> Organization: - if isinstance(org_or_id, str): - if ENV_CONFIG.is_oss: - # OSS only has one default organization - org_or_id = ENV_CONFIG.default_org_id - organization = session.get(Organization, org_or_id) - if organization is None: - raise ResourceNotFoundError(f'Organization "{org_or_id}" is not found.') - else: - organization = org_or_id - if ENV_CONFIG.is_oss: - return organization - - user, org = auth_info - if user is not None: - user_roles = {m.organization_id: m.role for m in user.member_of} - user_role = user_roles.get(organization.id, None) - if user_role is None: - raise ForbiddenError(f'You do not have access to organization "{organization.id}".') - if user_role == "guest" and request.method in WRITE_METHODS: - raise ForbiddenError( - f'You do not have write access to organization "{organization.id}".' - ) - if org is not None and org.id != organization.id: - raise ForbiddenError(f'You do not have access to organization "{organization.id}".') - # Non-activated orgs can only perform GET requests - if (not organization.active) and (request.method != "GET"): - raise UpgradeTierError(f'Your organization "{organization.id}" is not activated.') - return organization - - -def _get_organization_from_path( - *, - session: Annotated[Session, Depends(_get_session)], - request: Request, - auth_info: Annotated[AuthReturn, Depends(auth_user)], - organization_id: Annotated[str, Path(min_length=1, description='Organization ID "org_xxx".')], -) -> Organization: - return _check_access( - session=session, request=request, auth_info=auth_info, org_or_id=organization_id - ) - - -def _get_organization_from_query( - *, - session: Annotated[Session, Depends(_get_session)], - request: Request, - auth_info: Annotated[AuthReturn, Depends(auth_user)], - organization_id: Annotated[str, Query(min_length=1, description='Organization ID "org_xxx".')], -) -> Organization: - return _check_access( - session=session, request=request, auth_info=auth_info, org_or_id=organization_id - ) - - -def _get_project_from_path( - *, - session: Annotated[Session, Depends(_get_session)], - request: Request, - auth_info: Annotated[AuthReturn, Depends(auth_user)], - project_id: Annotated[str, Path(min_length=1, description='Project ID "proj_xxx".')], -) -> Project: - proj = session.get(Project, project_id) - if proj is None: - raise ResourceNotFoundError(f'Project "{project_id}" is not found.') - _check_access( - session=session, request=request, auth_info=auth_info, org_or_id=proj.organization - ) - return proj - - -@router.get("/v1/models/{organization_id}") -@handle_exception -def get_org_model_config( - organization: Annotated[Organization, Depends(_get_organization_from_path)], -) -> ModelListConfig: - # Get only org models - return ModelListConfig.model_validate(organization.models) - - -@router.patch("/v1/models/{organization_id}") -@handle_exception -def set_org_model_config( - *, - session: Annotated[Session, Depends(_get_session)], - organization: Annotated[Organization, Depends(_get_organization_from_path)], - body: ModelListConfig, -) -> OkResponse: - # Validate - _ = body + CONFIG.get_model_config() - for m in body.models: - m.owned_by = "custom" - organization.models = body.model_dump(mode="json") - organization.updated_at = datetime_now_iso() - session.add(organization) - session.commit() - return OkResponse() - - -@router.post("/v1/projects") -@handle_exception -def create_project( - *, - session: Annotated[Session, Depends(_get_session)], - request: Request, - auth_info: Annotated[AuthReturn, Depends(auth_user)], - body: ProjectCreate, -) -> ProjectRead: - if ENV_CONFIG.is_oss: - body.organization_id = ENV_CONFIG.default_org_id - _check_access( - session=session, request=request, auth_info=auth_info, org_or_id=body.organization_id - ) - same_name_count = session.exec( - select( - func.count(Project.id).filter( - Project.organization_id == body.organization_id, Project.name == body.name - ) - ) - ).one() - if same_name_count > 0: - raise ResourceExistsError("Project with the same name exists.") - project_id = generate_key(24, "proj_") - while session.get(Project, project_id) is not None: - project_id = generate_key(24, "proj_") - proj = Project( - id=project_id, - name=body.name, - organization_id=body.organization_id, - ) - session.add(proj) - session.commit() - session.refresh(proj) - logger.info(f"{request.state.id} - Project created: {proj}") - return ProjectRead( - **proj.model_dump(), - organization=OrganizationRead( - **proj.organization.model_dump(), - members=proj.organization.members, - ).decrypt(ENV_CONFIG.owl_encryption_key_plain), - ) - - -@router.patch("/v1/projects") -@handle_exception -def update_project( - *, - session: Annotated[Session, Depends(_get_session)], - request: Request, - auth_info: Annotated[AuthReturn, Depends(auth_user)], - body: ProjectUpdate, -) -> ProjectRead: - proj = session.get(Project, body.id) - if proj is None: - raise ResourceNotFoundError(f'Project "{body.id}" is not found.') - _check_access( - session=session, request=request, auth_info=auth_info, org_or_id=proj.organization - ) - for key, value in body.model_dump(exclude=["id"], exclude_none=True).items(): - if key == "name": - same_name_count = session.exec( - select( - func.count(Project.id).filter( - Project.organization_id == proj.organization_id, - Project.name == body.name, - ) - ) - ).one() - if same_name_count > 0: - raise ResourceExistsError("Project with the same name exists.") - setattr(proj, key, value) - proj.updated_at = datetime_now_iso() - session.add(proj) - session.commit() - session.refresh(proj) - logger.info(f"{request.state.id} - Project updated: {proj}") - return ProjectRead( - **proj.model_dump(), - organization=OrganizationRead( - **proj.organization.model_dump(), - members=proj.organization.members, - ).decrypt(ENV_CONFIG.owl_encryption_key_plain), - ) - - -@router.patch("/v1/projects/{project_id}") -@handle_exception -def set_project_updated_at( - *, - session: Annotated[Session, Depends(_get_session)], - request: Request, - project: Annotated[Project, Depends(_get_project_from_path)], - updated_at: Annotated[ - str | None, Query(min_length=1, description="Project update datetime (ISO 8601 UTC).") - ] = None, -) -> OkResponse: - if updated_at is None: - updated_at = datetime_now_iso() - else: - try: - tz = str(datetime.fromisoformat(updated_at).tzinfo) - except Exception as e: - raise BadInputError("`updated_at` must be a ISO 8601 UTC datetime string.") from e - if tz != "UTC": - raise BadInputError(f'`updated_at` must be UTC, but received "{tz}".') - project.updated_at = updated_at - session.add(project) - session.commit() - logger.info(f"{request.state.id} - Project updated_at set to: {updated_at}") - return OkResponse() - - -@router.get("/v1/projects") -@handle_exception -def list_projects( - *, - session: Annotated[Session, Depends(_get_session)], - organization: Annotated[Organization, Depends(_get_organization_from_query)], - search_query: Annotated[ - str, - Query( - max_length=10_000, - description='_Optional_. A string to search for within project names as a filter. Defaults to "" (no filter).', - ), - ] = "", - offset: Annotated[int, Query(ge=0)] = 0, - limit: Annotated[int, Query(gt=0, le=100)] = 100, - order_by: Annotated[ - AdminOrderBy, - Query( - min_length=1, - description='_Optional_. Sort projects by this attribute. Defaults to "updated_at".', - ), - ] = AdminOrderBy.UPDATED_AT, - order_descending: Annotated[ - bool, - Query(description="_Optional_. Whether to sort by descending order. Defaults to True."), - ] = True, -) -> Page[ProjectRead]: - organization_id = organization.id - org = session.get(Organization, organization_id) - if org is None: - raise ResourceNotFoundError(f'Organization "{organization_id}" is not found.') - search_query = search_query.strip() - selection = select(Project).where(Project.organization_id == organization_id) - count = func.count(Project.id).filter(Project.organization_id == organization_id) - if search_query: - selection = selection.where(Project.name.ilike(f"%{search_query}%")) - count = count.filter(Project.name.ilike(f"%{search_query}%")) - order_by = f"LOWER({order_by})" - selection = selection.order_by( - cached_text(f"{order_by} DESC" if order_descending else f"{order_by} ASC") - ) - projects = session.exec(selection.offset(offset).limit(limit)).all() - total = session.exec(select(count)).one() - return Page[ProjectRead]( - items=projects, - offset=offset, - limit=limit, - total=total, - ) - - -@router.get("/v1/projects/{project_id}") -@handle_exception -def get_project( - project: Annotated[Project, Depends(_get_project_from_path)], -) -> ProjectRead: - proj = ProjectRead( - **project.model_dump(), - organization=OrganizationRead( - **project.organization.model_dump(), - members=project.organization.members, - ).decrypt(ENV_CONFIG.owl_encryption_key_plain), - ) - return proj - - -@router.delete("/v1/projects/{project_id}") -@handle_exception -def delete_project( - *, - session: Annotated[Session, Depends(_get_session)], - request: Request, - project: Annotated[Project, Depends(_get_project_from_path)], -) -> OkResponse: - project_id = project.id - session.delete(project) - session.commit() - logger.info(f"{request.state.id} - Project deleted: {project_id}") - return OkResponse() - - -def _package_project_tables(project: Project) -> list[tuple[str, TableMetaResponse, bytes]]: - data = [] - table_types = [TableType.ACTION, TableType.KNOWLEDGE, TableType.CHAT] - for table_type in table_types: - table = GenerativeTable.from_ids(project.organization_id, project.id, table_type) - with table.create_session() as session: - # Lance tables could be on S3 so we use list_meta instead of listdir - batch_size, offset, total = 200, 0, 1 - while offset < total: - metas, total = table.list_meta( - session, - offset=offset, - limit=batch_size, - remove_state_cols=True, - parent_id=None, - ) - offset += batch_size - for meta in metas: - with BytesIO() as f: - table.dump_parquet(session=session, table_id=meta.id, dest=f) - data.append((table_type.value, meta, f.getvalue())) - return data - - -def _export_project( - *, - request: Request, - bg_tasks: BackgroundTasks, - project: Project, - output_file_ext: str, - compression: Literal["NONE", "ZSTD", "LZ4", "SNAPPY"] = "ZSTD", - extra_metas: Mapping[str, str] | None = None, -) -> FileResponse: - t0 = perf_counter() - # Check quota - request.state.billing.check_egress_quota() - # Check extra metadata - extra_metas = extra_metas or {} - for k, v in extra_metas.items(): - if not isinstance(v, str): - raise BadInputError(f'Invalid extra metadata: value of key "{k}" is not a string.') - # Dump all tables as parquet files - data = _package_project_tables(project) - if len(data) == 0: - metas = [] - pa_table = pa.Table.from_pydict({"table_type": pa.array([]), "data": pa.array([])}) - else: - metas = [] - for table_type, meta, _ in data: - metas.append({"table_type": table_type, "table_meta": meta.model_dump(mode="json")}) - data = list(zip(*data, strict=True)) - pa_table = pa.Table.from_pydict( - {"table_type": pa.array(data[0]), "data": pa.array(data[2])} - ) - pa_meta = pa_table.schema.metadata or {} - pa_table = pa_table.replace_schema_metadata( - { - "project_meta": project.model_dump_json(), - "table_metas": json_dumps(metas), - **extra_metas, - **pa_meta, - } - ) - tmp_dir = TemporaryDirectory() - filename = f"{project.id}{output_file_ext}" - filepath = join(tmp_dir.name, filename) - # Keep a reference to the directory and only delete upon completion - bg_tasks.add_task(tmp_dir.cleanup) - write_parquet_table(pa_table, where=filepath, compression=compression) - logger.info( - f'{request.state.id} - Project "{project.id}" exported in {perf_counter() - t0:,.2f} s.' - ) - return FileResponse( - path=filepath, - filename=filename, - media_type="application/octet-stream", - ) - - -@router.get("/v1/projects/{project_id}/export") -@handle_exception -def export_project( - *, - request: Request, - bg_tasks: BackgroundTasks, - project: Annotated[Project, Depends(_get_project_from_path)], - compression: Annotated[ - Literal["NONE", "ZSTD", "LZ4", "SNAPPY"], - Query(description="Parquet compression codec."), - ] = "ZSTD", -) -> FileResponse: - return _export_project( - request=request, - bg_tasks=bg_tasks, - project=project, - output_file_ext=".parquet", - compression=compression, - ) - - -@router.get("/v1/projects/{project_id}/export/template") -@handle_exception -def export_project_as_template( - *, - request: Request, - bg_tasks: BackgroundTasks, - project: Annotated[Project, Depends(_get_project_from_path)], - name: Annotated[Name, Query(description="Template name.")], - tags: Annotated[list[str], Query(description="Template tags.")], - description: Annotated[str, Query(description="Template description.")], - compression: Annotated[ - Literal["NONE", "ZSTD", "LZ4", "SNAPPY"], - Query(description="Parquet compression codec."), - ] = "ZSTD", -) -> FileResponse: - template_meta = TemplateMeta(name=name, description=description, tags=tags) - return _export_project( - request=request, - bg_tasks=bg_tasks, - project=project, - output_file_ext=".template.parquet", - compression=compression, - extra_metas={"template_meta": template_meta.model_dump_json()}, - ) - - -@router.post("/v1/projects/import/{organization_id}") -@handle_exception -async def import_project( - *, - session: Annotated[Session, Depends(_get_session)], - request: Request, - organization: Annotated[Organization, Depends(_get_organization_from_path)], - file: Annotated[UploadFile, File(description="Project or Template Parquet file.")], - project_id_dst: Annotated[ - str, - Form( - description=( - "_Optional_. ID of the project to import tables into. " - "Defaults to creating new project." - ), - ), - ] = "", -) -> ProjectRead: - t0 = perf_counter() - organization_id = organization.id - if project_id_dst == "": - proj = None - else: - proj = session.get(Project, project_id_dst) - if proj is None: - raise ResourceNotFoundError(f'Project "{project_id_dst}" is not found.') - if proj.organization_id != organization_id: - raise ForbiddenError( - f'You do not have access to organization "{proj.organization_id}".' - ) - try: - with BytesIO(await file.read()) as source: - # Read metadata - pa_table = read_parquet_table(source, columns=[], use_threads=False, memory_map=True) - metadata = pa_table.schema.metadata - if proj is None: - # Create the project - project_meta = metadata.get(b"template_meta", None) - if project_meta is None: - project_meta = metadata.get(b"project_meta", None) - if project_meta is None: - raise BadInputError("Missing template or table metadata in the Parquet file.") - try: - project_meta = json_loads(project_meta) - except Exception as e: - raise BadInputError( - "Invalid template or table metadata in the Parquet file." - ) from e - proj = Project(name=project_meta["name"], organization_id=organization_id) - session.add(proj) - session.commit() - session.refresh(proj) - project_id_dst = proj.id - else: - # Check if all the table IDs have no conflict - try: - type_metas = json_loads(metadata[b"table_metas"]) - except KeyError as e: - raise BadInputError("Missing table metadata in the Parquet file.") from e - except Exception as e: - raise BadInputError("Invalid table metadata in the Parquet file.") from e - for type_meta in type_metas: - table = GenerativeTable.from_ids( - organization_id, project_id_dst, type_meta["table_type"] - ) - with table.create_session() as gt_sess: - table_id = type_meta["table_meta"]["id"] - meta = gt_sess.get(TableMeta, table_id) - if meta is not None: - raise ResourceExistsError(f'Table "{table_id}" already exists.') - logger.info( - f'{request.state.id} - Project "{proj.id}" metadata imported in {perf_counter() - t0:,.2f} s.' - ) - # Create the tables - pa_table = read_parquet_table(source, columns=None, use_threads=False, memory_map=True) - for row in pa_table.to_pylist(): - table_type = row["table_type"] - with BytesIO(row["data"]) as pq_source: - table = GenerativeTable.from_ids(organization_id, project_id_dst, table_type) - with table.create_session() as gt_sess: - await table.import_parquet( - session=gt_sess, - source=pq_source, - table_id_dst=None, - ) - logger.info( - f'{request.state.id} - Project "{proj.id}" imported in {perf_counter() - t0:,.2f} s.' - ) - except pa.ArrowInvalid as e: - raise make_validation_error( - e, - loc=("body", "file"), - ) from e - return ProjectRead( - **proj.model_dump(), - organization=OrganizationRead( - **proj.organization.model_dump(), - members=proj.organization.members, - ).decrypt(ENV_CONFIG.owl_encryption_key_plain), - ) - - -@router.post("/v1/projects/import/{organization_id}/templates/{template_id}") -@handle_exception -async def import_project_from_template( - *, - session: Annotated[Session, Depends(_get_session)], - organization: Annotated[Organization, Depends(_get_organization_from_path)], - template_id: Annotated[str, Path(description="ID of the template to import from.")], - project_id_dst: Annotated[ - str, - Query( - description=( - "_Optional_. ID of the project to import tables into. " - "Defaults to creating new project." - ), - ), - ] = "", -) -> ProjectRead: - template_dir = TEMPLATE_DIR / template_id - if not template_dir.is_dir(): - raise ResourceNotFoundError(f'Template "{template_id}" is not found.') - organization_id = organization.id - if project_id_dst == "": - proj = None - else: - proj = session.get(Project, project_id_dst) - if proj is None: - raise ResourceNotFoundError(f'Project "{project_id_dst}" is not found.') - if proj.organization_id != organization_id: - raise ForbiddenError( - f'You do not have access to organization "{proj.organization_id}".' - ) - if proj is None: - # Create the project - template_meta = read_json(template_dir / "template_meta.json") - proj = Project(name=template_meta["name"], organization_id=organization_id) - session.add(proj) - session.commit() - session.refresh(proj) - project_id_dst = proj.id - else: - # Check if all the table IDs have no conflict - for table_type in [TableType.ACTION, TableType.KNOWLEDGE, TableType.CHAT]: - table_dir = template_dir / table_type - if not table_dir.is_dir(): - continue - table = GenerativeTable.from_ids(organization_id, project_id_dst, table_type) - for pq_source in table_dir.iterdir(): - if not pq_source.is_file(): - continue - pa_table = read_parquet_table( - pq_source, columns=[], use_threads=False, memory_map=True - ) - try: - pq_meta = TableMeta.model_validate_json( - pa_table.schema.metadata[b"gen_table_meta"] - ) - except KeyError as e: - raise BadInputError("Missing table metadata in the Parquet file.") from e - except Exception as e: - raise BadInputError("Invalid table metadata in the Parquet file.") from e - with table.create_session() as gt_sess: - meta = gt_sess.get(TableMeta, pq_meta.id) - if meta is not None: - raise ResourceExistsError(f'Table "{pq_meta.id}" already exists.') - # Create the tables - for table_type in [TableType.ACTION, TableType.KNOWLEDGE, TableType.CHAT]: - table_dir = template_dir / table_type - if not table_dir.is_dir(): - continue - for pq_source in table_dir.iterdir(): - if not pq_source.is_file(): - continue - table = GenerativeTable.from_ids(organization_id, project_id_dst, table_type) - with table.create_session() as gt_sess: - await table.import_parquet( - session=gt_sess, - source=pq_source, - table_id_dst=None, - ) - return ProjectRead( - **proj.model_dump(), - organization=OrganizationRead( - **proj.organization.model_dump(), - members=proj.organization.members, - ).decrypt(ENV_CONFIG.owl_encryption_key_plain), - ) diff --git a/services/api/src/owl/routers/organizations/__init__.py b/services/api/src/owl/routers/organizations/__init__.py new file mode 100644 index 0000000..8c5dbec --- /dev/null +++ b/services/api/src/owl/routers/organizations/__init__.py @@ -0,0 +1,12 @@ +from fastapi import APIRouter + +from owl.configs import ENV_CONFIG +from owl.routers.organizations.oss import router as oss_router + +router = APIRouter() +router.include_router(oss_router) + +if ENV_CONFIG.is_cloud: + from owl.routers.organizations.cloud import router as cloud_router + + router.include_router(cloud_router) diff --git a/services/api/src/owl/routers/organizations/oss.py b/services/api/src/owl/routers/organizations/oss.py new file mode 100644 index 0000000..55d223a --- /dev/null +++ b/services/api/src/owl/routers/organizations/oss.py @@ -0,0 +1,732 @@ +from datetime import datetime, timezone +from typing import Annotated, Literal + +from fastapi import APIRouter, Depends, Query, Request +from loguru import logger +from sqlmodel import delete, func, select + +from owl.configs import CACHE, ENV_CONFIG +from owl.db import TEMPLATE_ORG_ID, AsyncSession, async_session, yield_async_session +from owl.db.gen_table import GenerativeTableCore +from owl.db.models import ( + BASE_PLAN_ID, + Deployment, + ModelConfig, + Organization, + OrgMember, + PricePlan, + Project, + ProjectMember, + User, +) +from owl.types import ( + ListQuery, + ListQueryByOrg, + ModelConfigRead, + OkResponse, + OrganizationCreate, + OrganizationRead, + OrganizationReadDecrypt, + OrganizationUpdate, + OrgMemberRead, + OrgModelCatalogueQuery, + Page, + PricePlanCreate, + Role, + UsageResponse, + UserAuth, +) +from owl.utils import mask_dict +from owl.utils.auth import auth_user_service_key, has_permissions +from owl.utils.billing import CLICKHOUSE_CLIENT, STRIPE_CLIENT, BillingManager +from owl.utils.billing_metrics import BillingMetrics +from owl.utils.crypt import decrypt, encrypt_random, generate_key +from owl.utils.dates import now +from owl.utils.exceptions import ( + BadInputError, + BaseTierCountError, + ForbiddenError, + NoTierError, + ResourceExistsError, + ResourceNotFoundError, + UnexpectedError, + UpgradeTierError, + handle_exception, +) +from owl.utils.mcp import MCP_TOOL_TAG +from owl.utils.metrics import Telemetry + +router = APIRouter() +telemetry = Telemetry() + +billing_metrics = BillingMetrics(clickhouse_client=CLICKHOUSE_CLIENT) + + +def _encrypt_dict(value: dict[str, str]) -> dict[str, str]: + return {k: encrypt_random(v, ENV_CONFIG.encryption_key_plain) for k, v in value.items()} + + +def _decrypt_dict(value: dict[str, str]) -> dict[str, str]: + return {k: decrypt(v, ENV_CONFIG.encryption_key_plain) for k, v in value.items()} + + +@router.post( + "/v2/organizations", + summary="Create an organization.", + description="Permissions: None.", +) +@handle_exception +async def create_organization( + request: Request, + user: Annotated[UserAuth, Depends(auth_user_service_key)], + body: OrganizationCreate, + organization_id: str = "", +) -> OrganizationReadDecrypt: + # There must always be a free plan + async with async_session() as session: + base_plan = await session.get(PricePlan, BASE_PLAN_ID) + if base_plan is None: + session.add( + PricePlan( + id=BASE_PLAN_ID, + **PricePlanCreate.free().model_dump(mode="json", exclude={"id"}), + ) + ) + await session.commit() + # This is mainly for migration and not exposed as REST API + if organization_id: + if await session.get(Organization, organization_id) is not None: + raise ResourceExistsError(f'Organization "{organization_id}" already exists.') + else: + # There must always be a system admin org + if await session.get(Organization, "0") is None: + organization_id = "0" + else: + organization_id = generate_key(24, "org_") + num_base_tier_orgs = len(await Organization.list_base_tier_orgs(session, user.id)) + if organization_id != "0" and ENV_CONFIG.is_cloud: + # A user can only have one free organization + if num_base_tier_orgs > 1: + raise BaseTierCountError + # Create Stripe customer + if STRIPE_CLIENT is None: + stripe_id = None + else: + customer = await STRIPE_CLIENT.customers.create_async( + dict( + name=f"{user.name} | {body.name}", + email=user.email, + metadata=dict(organization_id=organization_id), + ) + ) + logger.bind(user_id=user.id, org_id=organization_id).info( + f"Stripe customer created: {customer}" + ) + stripe_id = customer.id + async with async_session() as session: + org = Organization( + **body.model_dump(exclude={"external_keys"}), + id=organization_id, + created_by=user.id, + owner=user.id, + stripe_id=stripe_id, + external_keys=_encrypt_dict(body.external_keys), + ) + session.add(org) + await session.commit() + await session.refresh(org) + logger.bind(user_id=user.id, org_id=org.id).success( + f'{user.name} ({user.email}) created an organization "{org.name}".' + ) + logger.bind(user_id=user.id, org_id=org.id).info( + f"{request.state.id} - Created organization: {org}" + ) + # Add user as admin + org_member = OrgMember(user_id=user.id, organization_id=org.id, role=Role.ADMIN) + session.add(org_member) + await session.commit() + await session.refresh(org_member) + logger.bind(user_id=user.id, org_id=org.id).success( + f'{user.name} ({user.email}) joined organization "{org.name}" as as admin.' + ) + logger.info(f"{request.state.id} - Created organization member: {org_member}") + # Create template org + if await session.get(Organization, TEMPLATE_ORG_ID) is None: + session.add( + Organization( + id=TEMPLATE_ORG_ID, + name="Template", + created_by=user.id, + owner=user.id, + ) + ) + await session.commit() + logger.bind(user_id=user.id, org_id=org.id).success( + f"{user.name} ({user.email}) created template organization." + ) + session.add( + OrgMember(user_id=user.id, organization_id=TEMPLATE_ORG_ID, role=Role.ADMIN) + ) + await session.commit() + logger.bind(user_id=user.id, org_id=org.id).success( + f"{user.name} ({user.email}) joined template organization as as admin." + ) + # Subscribe to base plan if the user has no base tier org + if ENV_CONFIG.is_cloud and num_base_tier_orgs == 0: + from owl.routers.organizations.cloud import subscribe_plan + + async with async_session() as session: + user = UserAuth.model_validate( + await session.get(User, user.id, populate_existing=True) + ) + await subscribe_plan(user, org.id, BASE_PLAN_ID) + async with async_session() as session: + org = await session.get(Organization, org.id, populate_existing=True) + return org + + +@router.get( + "/v2/organizations/list", + summary="List organizations.", + description="Permissions: `system`.", + tags=[MCP_TOOL_TAG, "system"], +) +@handle_exception +async def list_organizations( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + params: Annotated[ListQuery, Query()], +) -> Page[OrganizationRead]: + has_permissions(user, ["system"]) + return await Organization.list_( + session=session, + return_type=OrganizationRead, + offset=params.offset, + limit=params.limit, + order_by=params.order_by, + order_ascending=params.order_ascending, + search_query=params.search_query, + search_columns=params.search_columns, + after=params.after, + ) + + +@router.get( + "/v2/organizations", + summary="Get an organization.", + description="Permissions: `system` OR `organization`. Only `organization.ADMIN` can view API keys.", + tags=[MCP_TOOL_TAG, "system", "organization"], +) +@handle_exception +async def get_organization( + request: Request, + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + organization_id: Annotated[str, Query(min_length=1, description="Organization ID.")], +) -> OrganizationRead: + has_permissions(user, ["system", "organization"], organization_id=organization_id) + org = await session.get(Organization, organization_id) + if org is None: + raise ResourceNotFoundError(f'Organization "{organization_id}" is not found.') + org = OrganizationReadDecrypt.model_validate(org) + # Whether we need to mask external API keys + if not has_permissions( + user, ["organization.ADMIN"], organization_id=org.id, raise_error=False + ): + org.external_keys = mask_dict(org.external_keys) + # Update billing data if needed + request.state.billing = BillingManager( + organization=org, + project_id="", + user_id=user.id, + request=request, + models=None, + ) + return org + + +@router.patch( + "/v2/organizations", + summary="Update an organization.", + description="Permissions: `organization.ADMIN`.", +) +@handle_exception +async def update_organization( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + organization_id: Annotated[str, Query(min_length=1, description="Organization ID.")], + body: OrganizationUpdate, +) -> OrganizationRead: + has_permissions(user, ["organization.ADMIN"], organization_id=organization_id) + org = await session.get(Organization, organization_id) + if org is None: + raise ResourceNotFoundError(f'Organization "{organization_id}" is not found.') + # Perform update + updates = body.model_dump(exclude=["id"], exclude_unset=True) + for key, value in updates.items(): + if key == "external_keys": + value = _encrypt_dict(value) + setattr(org, key, value) + org.updated_at = now() + session.add(org) + await session.commit() + await session.refresh(org) + logger.bind(user_id=user.id, org_id=org.id).success( + ( + f"{user.name} ({user.email}) updated the attributes " + f'{list(updates.keys())} of organization "{org.name}".' + ) + ) + org = OrganizationReadDecrypt.model_validate(org) + if not has_permissions( + user, ["organization.ADMIN"], organization_id=org.id, raise_error=False + ): + org.external_keys = mask_dict(org.external_keys) + # Clear cache + await CACHE.clear_organization_async(organization_id) + return org + + +@router.delete( + "/v2/organizations", + summary="Delete an organization.", + description=( + "Permissions: Only the owner can delete an organization. " + 'WARNING: Deleting system organization "0" will also delete ALL data.' + ), +) +@handle_exception +async def delete_organization( + request: Request, + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + organization_id: Annotated[str, Query(min_length=1, description="Organization ID.")], +) -> OkResponse: + organization = await session.get(Organization, organization_id) + if organization is None: + raise ResourceNotFoundError(f'Organization "{organization_id}" is not found.') + # TODO: Create an endpoint to transfer ownership + if organization.owner != user.id: + raise ForbiddenError("Only the owner can delete an organization.") + logger.info(f'{request.state.id} - Deleting organization: "{organization_id}"') + # Delete Generative Tables + await session.refresh(organization, ["projects"]) + for project in organization.projects: + await GenerativeTableCore.drop_schemas(project_id=project.id) + # Delete related resources + await session.exec(delete(Organization).where(Organization.id == organization_id)) + await session.exec(delete(Project).where(Project.organization_id == organization_id)) + if ENV_CONFIG.is_cloud: + from owl.db.models.cloud import VerificationCode + + await session.exec( + delete(VerificationCode).where(VerificationCode.organization_id == organization_id) + ) + if organization_id == "0": + await session.exec(delete(Deployment)) + await session.exec(delete(ModelConfig)) + await session.exec(delete(Organization).where(Organization.id == TEMPLATE_ORG_ID)) + # Delete Stripe customer + if STRIPE_CLIENT is not None and organization.stripe_id is not None: + customer = await STRIPE_CLIENT.customers.delete_async(organization.stripe_id) + logger.info( + f'Stripe customer "{customer.id}" deleted for organization "{organization_id}".' + ) + await session.commit() + if organization_id == "0": + logger.bind(user_id=user.id, org_id=organization_id).success( + f"{user.name} ({user.email}) deleted all templates, models and deployments." + ) + logger.bind(user_id=user.id, org_id=organization_id).success( + f'{user.name} ({user.email}) deleted organization "{organization.name}".' + ) + logger.info(f"{request.state.id} - Deleted organization: {organization_id}") + # Clear cache + await CACHE.clear_organization_async(organization_id) + return OkResponse() + + +@router.post( + "/v2/organizations/members", + summary="Join an organization.", + description=( + "Permissions: `organization.ADMIN`. " + "Permissions are only needed if adding another user or invite code is not provided." + ), +) +@handle_exception +async def join_organization( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + user_id: Annotated[str, Query(min_length=1, description="ID of the user joining the org.")], + invite_code: Annotated[ + str | None, + Query(min_length=1, description="(Optional) Invite code for validation."), + ] = None, + organization_id: Annotated[ + str | None, + Query( + min_length=1, + description="(Optional) Organization ID. Ignored if invite code is provided.", + ), + ] = None, + role: Annotated[ + Role | None, + Query(min_length=1, description="(Optional) Role. Ignored if invite code is provided."), + ] = None, +) -> OrgMemberRead: + joining_user = await session.get(User, user_id) + if joining_user is None: + raise ResourceNotFoundError(f'User "{user_id}" is not found.') + if invite_code is None: + if organization_id is None or role is None: + raise BadInputError("Missing organization ID or role.") + invite = None + else: + if ENV_CONFIG.is_oss: + raise BadInputError("Invite code is not supported in OSS.") + else: + from owl.db.models.cloud import VerificationCode + + # Fetch code + invite = await session.get(VerificationCode, invite_code) + if ( + invite is None + or invite.organization_id is None + or invite.purpose not in ("organization_invite", None) + or now() > invite.expiry + or invite.revoked_at is not None + or invite.used_at is not None + or invite.user_email != joining_user.preferred_email + ): + raise ResourceNotFoundError(f'Invite code "{invite_code}" is invalid.') + organization_id = invite.organization_id + role = invite.role + # RBAC + if user.id != user_id or invite_code is None: + has_permissions(user, ["organization.ADMIN"], organization_id=organization_id) + # Check for existing membership + organization = await session.get(Organization, organization_id) + if organization is None: + raise ResourceNotFoundError(f'Organization "{organization_id}" is not found.') + if await session.get(OrgMember, (user_id, organization_id)) is not None: + raise ResourceExistsError("You are already in the organization.") + # Enforce member count limit (cloud only) + if ENV_CONFIG.is_cloud and organization.id not in ["0", TEMPLATE_ORG_ID]: + if (plan := organization.price_plan) is None: + raise NoTierError + else: + if plan.max_users is not None: + member_count = ( + await session.exec( + select(func.count(OrgMember.user_id)).where( + OrgMember.organization_id == organization_id + ) + ) + ).one() + if member_count >= plan.max_users: + raise UpgradeTierError( + ( + f"Your subscribed plan only supports {plan.max_users:,d} members. " + "Consider upgrading your plan or remove existing member before adding more." + ) + ) + # Add member + org_member = OrgMember(user_id=user_id, organization_id=organization_id, role=role) + session.add(org_member) + await session.commit() + await session.refresh(org_member) + # Consume invite code + if invite is not None: + invite.used_at = now() + session.add(invite) + await session.commit() + logger.bind(user_id=joining_user.id, org_id=organization.id).success( + ( + f"{joining_user.preferred_name} ({joining_user.preferred_email}) joined " + f'organization "{organization.name}" as "{role.name}".' + ) + ) + return org_member + + +@router.get( + "/v2/organizations/members/list", + summary="List organization members.", + description="Permissions: `system` OR `organization`.", +) +@handle_exception +async def list_organization_members( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + params: Annotated[ListQueryByOrg[Literal["id", "created_at", "updated_at"]], Query()], +) -> Page[OrgMemberRead]: + has_permissions(user, ["system", "organization"], organization_id=params.organization_id) + return await OrgMember.list_( + session=session, + return_type=OrgMemberRead, + offset=params.offset, + limit=params.limit, + order_by=params.order_by, + order_ascending=params.order_ascending, + search_query=params.search_query, + search_columns=params.search_columns, + filters=dict(organization_id=params.organization_id), + after=params.after, + ) + + +@router.get( + "/v2/organizations/members", + summary="Get an organization member.", + description="Permissions: `system` OR `organization`.", +) +@handle_exception +async def get_organization_member( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + user_id: Annotated[str, Query(min_length=1, description="User ID.")], + organization_id: Annotated[str, Query(min_length=1, description="Organization ID.")], +) -> OrgMemberRead: + has_permissions(user, ["system", "organization"], organization_id=organization_id) + member_id = (user_id, organization_id) + member = await session.get(OrgMember, member_id) + if member is None: + raise ResourceNotFoundError(f'Organization member "{member_id}" is not found.') + return member + + +@router.patch( + "/v2/organizations/members/role", + summary="Update a organization member's role.", + description="Permissions: `organization.ADMIN`.", +) +@handle_exception +async def update_member_role( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + user_id: Annotated[str, Query(min_length=1, description="User ID.")], + organization_id: Annotated[str, Query(min_length=1, description="Organization ID.")], + role: Annotated[Role, Query(description="New role.")], +) -> OrgMemberRead: + # Check permissions + has_permissions(user, ["organization.ADMIN"], organization_id=organization_id) + # Fetch the member + member = await session.get(OrgMember, (user_id, organization_id)) + if member is None: + raise ResourceNotFoundError( + f'User "{user_id}" is not a member of organization "{organization_id}".' + ) + # Update + member.role = role + await session.commit() + return member + + +@router.delete( + "/v2/organizations/members", + summary="Leave an organization.", + description=( + "Permissions: `organization.ADMIN`. " + "Permissions are only needed if deleting another user's membership." + ), +) +@handle_exception +async def leave_organization( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + user_id: Annotated[str, Query(min_length=1, description="User ID.")], + organization_id: Annotated[str, Query(min_length=1, description="Organization ID.")], +) -> OkResponse: + if user.id != user_id: + has_permissions(user, ["organization.ADMIN"], organization_id=organization_id) + leaving_user = await session.get(User, user_id) + if leaving_user is None: + raise ResourceNotFoundError(f'User "{user_id}" is not found.') + organization = await session.get(Organization, organization_id) + if organization is None: + raise ResourceNotFoundError(f'Organization "{organization_id}" is not found.') + if user_id == organization.created_by: + raise ForbiddenError("Owner cannot leave the organization.") + org_member = await session.get(OrgMember, (user_id, organization_id)) + if org_member is None: + raise ResourceNotFoundError( + f"Organization membership {(user_id, organization_id)} is not found." + ) + await session.delete(org_member) + await session.commit() + # If the user has no remaining membership with the org, remove them from all projects + num_memberships = ( + await session.exec( + select(func.count(OrgMember.user_id)).where( + OrgMember.user_id == user_id, + OrgMember.organization_id == organization_id, + ) + ) + ).one() + if num_memberships == 0: + projects = ( + await session.exec(select(Project).where(Project.organization_id == organization_id)) + ).all() + for p in projects: + try: + await session.exec( + delete(ProjectMember).where( + ProjectMember.user_id == user_id, + ProjectMember.project_id == p.id, + ) + ) + await session.commit() + except Exception as e: + logger.warning( + f'Failed to remove "{user_id}" from project "{p.id}" due to {repr(e)}' + ) + logger.bind(user_id=leaving_user.id, org_id=organization.id).success( + ( + f"{leaving_user.preferred_name} ({leaving_user.preferred_email}) left " + f'organization "{organization.name}".' + ) + ) + return OkResponse() + + +@router.get( + "/v2/organizations/models/catalogue", + summary="List models AVAILABLE to an organization.", + description="Permissions: `system` OR `organization`.", +) +@handle_exception +async def organization_model_catalogue( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + params: Annotated[OrgModelCatalogueQuery, Query()], +) -> Page[ModelConfigRead]: + has_permissions(user, ["system", "organization"], organization_id=params.organization_id) + return await ModelConfig.list_( + session=session, + return_type=ModelConfigRead, + organization_id=params.organization_id, + offset=params.offset, + limit=params.limit, + order_by=params.order_by, + order_ascending=params.order_ascending, + search_query=params.search_query, + search_columns=params.search_columns, + after=params.after, + capabilities=params.capabilities, + exclude_inactive=True, + ) + + +@router.get("/v2/organizations/meters/query") +async def get_organization_metrics( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + metric_id: Annotated[ + Literal["llm", "embedding", "reranking", "spent", "bandwidth", "storage"], + Query(alias="metricId", description="Type of usage data to query."), + ], + from_: Annotated[ + datetime, + Query(alias="from", description="Start datetime for the usage data query."), + ], + window_size: Annotated[ + str | None, + Query( + min_length=1, + alias="windowSize", + description="The aggregation window size (e.g., '1d' for daily, '1w' for weekly).", + ), + ], + org_id: Annotated[ + str, + Query( + min_length=1, + alias="orgId", + description="Organization ID to filter the usage data.", + ), + ], + proj_ids: Annotated[ + list[str] | None, + Query( + min_length=1, + alias="projIds", + description="List of project IDs to filter the usage data. If not provided, data for all projects is returned.", + ), + ] = None, + to: Annotated[ + datetime | None, + Query( + description="End datetime for the usage data query. If not provided, data up to the current datetime is returned." + ), + ] = None, + group_by: Annotated[ + list[str] | None, + Query( + min_length=1, + alias="groupBy", + description="List of fields to group the usage data by. If not provided, no grouping is applied.", + ), + ] = None, + data_source: Annotated[ + Literal["clickhouse", "victoriametrics"], + Query(description="Data source to query. Defaults to 'clickhouse'.", alias="dataSource"), + ] = "clickhouse", +) -> UsageResponse: + has_permissions(user, ["organization.MEMBER"], organization_id=org_id) + try: + # always add org_id to group_by + if to is None: + to = datetime.now(tz=timezone.utc).replace(minute=0, second=0, microsecond=0) + # set to default [] + if group_by is None: + group_by = [] + + if data_source == "clickhouse": + metrics_client = billing_metrics + elif data_source == "victoriametrics": + metrics_client = telemetry + + if metric_id == "llm": + results = await metrics_client.query_llm_usage( + [org_id], + proj_ids, + from_, + to, + group_by, + window_size, + ) + elif metric_id == "embedding": + results = await metrics_client.query_embedding_usage( + [org_id], + proj_ids, + from_, + to, + group_by, + window_size, + ) + elif metric_id == "reranking": + results = await metrics_client.query_reranking_usage( + [org_id], + proj_ids, + from_, + to, + group_by, + window_size, + ) + elif metric_id == "spent": + results = await metrics_client.query_billing( + [org_id], proj_ids, from_, to, group_by, window_size + ) + elif metric_id == "bandwidth": + results = await metrics_client.query_bandwidth( + [org_id], proj_ids, from_, to, group_by, window_size + ) + elif metric_id == "storage": + results = await metrics_client.query_storage( + [org_id], proj_ids, from_, to, group_by, window_size + ) + return results + except Exception as e: + err = f"Failed to fetch Metrics Data events: {e}" + logger.error(err) + raise UnexpectedError(err) from e diff --git a/services/api/src/owl/routers/oss_admin.py b/services/api/src/owl/routers/oss_admin.py deleted file mode 100644 index 29a0cca..0000000 --- a/services/api/src/owl/routers/oss_admin.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import Annotated - -from fastapi import APIRouter, Depends, Path, Request -from loguru import logger -from sqlmodel import Session - -from jamaibase.exceptions import ResourceNotFoundError -from owl.configs.manager import CONFIG, ENV_CONFIG -from owl.db import MAIN_ENGINE, UserSQLModel, create_sql_tables -from owl.db.oss_admin import ( - Organization, - OrganizationRead, - OrganizationUpdate, -) -from owl.protocol import ModelListConfig, OkResponse -from owl.utils import datetime_now_iso -from owl.utils.crypt import encrypt_random -from owl.utils.exceptions import handle_exception - -router = APIRouter() -public_router = APIRouter() # Dummy router to be compatible with cloud admin - - -@router.on_event("startup") -async def startup(): - create_sql_tables(UserSQLModel, MAIN_ENGINE) - - -def _get_session(): - with Session(MAIN_ENGINE) as session: - yield session - - -@router.patch("/admin/backend/v1/organizations") -@handle_exception -def update_organization( - *, - session: Annotated[Session, Depends(_get_session)], - request: Request, - body: OrganizationUpdate, -) -> OrganizationRead: - body.id = ENV_CONFIG.default_org_id - org = session.get(Organization, body.id) - if org is None: - raise ResourceNotFoundError(f'Organization "{body.id}" is not found.') - - # --- Perform update --- # - for key, value in body.model_dump(exclude=["id"], exclude_none=True).items(): - if key == "external_keys": - value = { - k: encrypt_random(v, ENV_CONFIG.owl_encryption_key_plain) for k, v in value.items() - } - setattr(org, key, value) - org.updated_at = datetime_now_iso() - session.add(org) - session.commit() - session.refresh(org) - logger.info(f"{request.state.id} - Organization updated: {org}") - org = OrganizationRead( - **org.model_dump(), - projects=org.projects, - ).decrypt(ENV_CONFIG.owl_encryption_key_plain) - return org - - -@router.get("/admin/backend/v1/organizations/{org_id}") -@handle_exception -def get_organization( - *, - session: Annotated[Session, Depends(_get_session)], - org_id: Annotated[str, Path(min_length=1)], -) -> OrganizationRead: - org = session.get(Organization, org_id) - if org is None: - raise ResourceNotFoundError(f'Organization "{org_id}" is not found.') - org = OrganizationRead( - **org.model_dump(), - projects=org.projects, - ).decrypt(ENV_CONFIG.owl_encryption_key_plain) - return org - - -@router.get("/admin/backend/v1/models") -@handle_exception -def get_model_config() -> ModelListConfig: - # Get model config (exclude org models) - return CONFIG.get_model_config() - - -@router.patch("/admin/backend/v1/models") -@handle_exception -def set_model_config(body: ModelListConfig) -> OkResponse: - CONFIG.set_model_config(body) - return OkResponse() diff --git a/services/api/src/owl/routers/projects/__init__.py b/services/api/src/owl/routers/projects/__init__.py new file mode 100644 index 0000000..40533ab --- /dev/null +++ b/services/api/src/owl/routers/projects/__init__.py @@ -0,0 +1,12 @@ +from fastapi import APIRouter + +from owl.configs import ENV_CONFIG +from owl.routers.projects.oss import router as oss_router + +router = APIRouter() +router.include_router(oss_router) + +if ENV_CONFIG.is_cloud: + from owl.routers.projects.cloud import router as cloud_router + + router.include_router(cloud_router) diff --git a/services/api/src/owl/routers/projects/oss.py b/services/api/src/owl/routers/projects/oss.py new file mode 100644 index 0000000..8518e47 --- /dev/null +++ b/services/api/src/owl/routers/projects/oss.py @@ -0,0 +1,927 @@ +import base64 +from io import BytesIO +from os.path import join +from tempfile import TemporaryDirectory +from typing import Annotated, Literal + +import pyarrow as pa +import pyarrow.parquet as pq +from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, Query, Request, UploadFile +from fastapi.responses import FileResponse +from loguru import logger +from pydantic import BaseModel, Field +from sqlmodel import delete, func, select + +from owl.configs import ENV_CONFIG +from owl.db import AsyncSession, async_session, cached_text, yield_async_session +from owl.db.gen_table import ( + ActionTable, + ChatTable, + ColumnMetadata, + KnowledgeTable, + TableMetadata, +) +from owl.db.models import ( + Organization, + Project, + ProjectMember, + User, +) +from owl.types import ( + ListQueryByOrg, + ListQueryByProject, + OkResponse, + OrganizationRead, + Page, + ProjectCreate, + ProjectMemberRead, + ProjectRead, + ProjectUpdate, + Role, + TableMetaResponse, + TableType, + UserAuth, +) +from owl.utils.auth import auth_user, has_permissions +from owl.utils.billing import BillingManager +from owl.utils.dates import now +from owl.utils.exceptions import ( + BadInputError, + ForbiddenError, + ResourceExistsError, + ResourceNotFoundError, + UnexpectedError, + handle_exception, +) +from owl.utils.io import json_dumps, json_loads, open_uri_async, s3_upload +from owl.utils.mcp import MCP_TOOL_TAG + +router = APIRouter() + + +async def _count_project_name( + session: AsyncSession, + organization_id: str, + name: str, +) -> int: + return ( + await session.exec( + select( + func.count(Project.id).filter( + Project.organization_id == organization_id, Project.name == name + ) + ) + ) + ).one() + + +@router.post( + "/v2/projects", + summary="Create a new project under an organization.", + description="Permissions: `organization.ADMIN`.", + tags=[MCP_TOOL_TAG, "organization.ADMIN"], +) +@handle_exception +async def create_project( + request: Request, + user: Annotated[UserAuth, Depends(auth_user)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + body: ProjectCreate, + project_id: str = "", +) -> ProjectRead: + has_permissions(user, ["organization.ADMIN"], organization_id=body.organization_id) + # Check for duplicate project ID + if project_id and await session.get(Project, project_id) is not None: + raise ResourceExistsError(f'Project "{project_id}" already exists.') + # Ensure the organization exists + organization = await session.get(Organization, body.organization_id) + if organization is None: + raise ResourceNotFoundError(f'Organization "{body.organization_id}" is not found.') + # Try assigning a unique name + name_count = await _count_project_name(session, body.organization_id, body.name) + if name_count > 0: + idx = name_count + while ( + await _count_project_name(session, body.organization_id, f"{body.name} ({idx})") + ) > 0: + idx += 1 + body.name = f"{body.name} ({idx})" + + # Create project + project = Project( + **body.model_dump(), + created_by=user.id, + owner=user.id, + ) + if project_id: + project.id = project_id + else: + project_id = project.id + session.add(project) + await session.commit() + await session.refresh(project) + logger.bind(user_id=user.id, org_id=organization.id, proj_id=project_id).info( + f"{request.state.id} - Created project: {project}" + ) + logger.bind(user_id=user.id, org_id=organization.id, proj_id=project_id).success( + f'{user.name} ({user.email}) created a project "{project.name}"' + ) + # Create membership + project_member = ProjectMember( + user_id=user.id, + project_id=project_id, + role=Role.ADMIN, + ) + session.add(project_member) + await session.commit() + await session.refresh(project_member) + logger.bind(user_id=user.id, org_id=organization.id, proj_id=project_id).info( + f"{request.state.id} - Created project member: {project_member}" + ) + logger.bind(user_id=user.id, org_id=organization.id, proj_id=project_id).success( + f'{user.name} ({user.email}) joined project "{project.name}" as as admin.' + ) + # Create Generative Table schemas + for table_type in TableType: + schema_id = f"{project_id}_{table_type}" + await session.exec(cached_text(f'CREATE SCHEMA IF NOT EXISTS "{schema_id}"')) + await session.exec(cached_text(TableMetadata.sql_create(schema_id))) + await session.exec(cached_text(ColumnMetadata.sql_create(schema_id))) + return project + + +class ListProjectQuery(ListQueryByOrg): + search_query: Annotated[ + str, + Field( + max_length=255, + description=( + "_Optional_. A string to search for within project names as a filter. " + 'Defaults to "" (no filter).' + ), + ), + ] = "" + list_chat_agents: Annotated[ + bool, Field(description="_Optional_. List chat agents. Defaults to False.") + ] = False + + +@router.get( + "/v2/projects/list", + summary="List all projects within an organization.", + description="Permissions: `system` OR `organization`.", + tags=[MCP_TOOL_TAG, "system", "organization"], +) +@handle_exception +async def list_projects( + user: Annotated[UserAuth, Depends(auth_user)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + params: Annotated[ListProjectQuery, Query()], +) -> Page[ProjectRead]: + org_id = params.organization_id + has_permissions(user, ["system", "organization"], organization_id=org_id) + # Ensure the organization exists + org_role = next((r.role for r in user.org_memberships if r.organization_id == org_id), None) + if org_role is None: + raise ResourceNotFoundError(f'Organization "{org_id}" is not found.') + # List + response = await Project.list_( + session=session, + return_type=ProjectRead, + offset=params.offset, + limit=params.limit, + order_by=params.order_by, + order_ascending=params.order_ascending, + search_query=params.search_query, + search_columns=params.search_columns, + filters=dict(organization_id=org_id), + after=params.after, + filter_by_user=user.id if org_role == Role.GUEST else "", + ) + if params.list_chat_agents: + for p in response.items: + metas = await ChatTable.list_tables( + project_id=p.id, + limit=None, + offset=0, + order_by="id", + order_ascending=True, + parent_id="_agent_", + ) + p.chat_agents = metas.items + return response + + +@router.get( + "/v2/projects", + summary="Get a project.", + description="Permissions: `system` OR `organization`.", + tags=[MCP_TOOL_TAG, "system", "organization"], +) +@handle_exception +async def get_project( + request: Request, + user: Annotated[UserAuth, Depends(auth_user)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + project_id: Annotated[str, Query(min_length=1, description="Project ID.")], +) -> ProjectRead: + # Fetch the project + project = await session.get(Project, project_id) + if project is None: + raise ResourceNotFoundError(f'Project "{project_id}" is not found.') + has_permissions(user, ["system", "organization"], organization_id=project.organization_id) + # Update billing data if needed + request.state.billing = BillingManager( + organization=OrganizationRead.model_validate(project.organization), + project_id="", # Skip egress charge + user_id=user.id, + request=request, + models=None, + ) + return project + + +@router.patch( + "/v2/projects", + summary="Update a project.", + description="Permissions: `organization.ADMIN` OR `project.ADMIN`.", +) +@handle_exception +async def update_project( + user: Annotated[UserAuth, Depends(auth_user)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + project_id: Annotated[str, Query(min_length=1, description="Project ID.")], + body: ProjectUpdate, +) -> ProjectRead: + # Fetch + project = await session.get(Project, project_id) + if project is None: + raise ResourceNotFoundError(f'Project "{project_id}" is not found.') + # Check permissions + has_permissions( + user, + ["organization.ADMIN", "project.ADMIN"], + organization_id=project.organization_id, + project_id=project_id, + ) + # Update + updates = body.model_dump(exclude={"id"}, exclude_unset=True) + for key, value in updates.items(): + setattr(project, key, value) + project.updated_at = now() + session.add(project) + await session.commit() + await session.refresh(project) + logger.bind(user_id=user.id, proj_id=project.id).success( + ( + f"{user.name} ({user.email}) updated the attributes " + f'{list(updates.keys())} of project "{project.name}".' + ) + ) + return project + + +@router.delete( + "/v2/projects", + summary="Delete a project.", + description="Permissions: `organization.ADMIN`, OR None for the project owner.", +) +@handle_exception +async def delete_project( + request: Request, + user: Annotated[UserAuth, Depends(auth_user)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + project_id: Annotated[str, Query(min_length=1, description="Project ID.")], +) -> OkResponse: + # Fetch + project = await session.get(Project, project_id) + if project is None: + raise ResourceNotFoundError(f'Project "{project_id}" is not found.') + # Check permissions + has_permissions(user, ["organization.ADMIN"], organization_id=project.organization_id) + # Delete Generative Tables + for table_type in TableType: + schema_id = f"{project_id}_{table_type}" + await session.exec(cached_text(f'DROP SCHEMA IF EXISTS "{schema_id}" CASCADE')) + # Delete related resources + await session.exec(delete(ProjectMember).where(ProjectMember.project_id == project_id)) + if ENV_CONFIG.is_cloud: + from owl.db.models.cloud import ProjectKey, VerificationCode + + await session.exec( + delete(VerificationCode).where(VerificationCode.project_id == project_id) + ) + await session.exec(delete(ProjectKey).where(ProjectKey.project_id == project_id)) + await session.delete(project) + await session.commit() + logger.bind(user_id=user.id, org_id=project.id).success( + f'{user.name} ({user.email}) deleted project "{project.name}".' + ) + logger.info(f"{request.state.id} - Deleted project: {project.id}") + return OkResponse() + + +@router.post( + "/v2/projects/members", + summary="Join a project.", + description=( + "Permissions: `organization.ADMIN` OR `project.ADMIN`. " + "Permissions are only needed if adding another user or invite code is not provided." + ), +) +@handle_exception +async def join_project( + user: Annotated[UserAuth, Depends(auth_user)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + user_id: Annotated[ + str, Query(min_length=1, description="ID of the user joining the project.") + ], + invite_code: Annotated[ + str | None, + Query(min_length=1, description="(Optional) Invite code for validation."), + ] = None, + project_id: Annotated[ + str | None, + Query( + min_length=1, + description="(Optional) Project ID. Ignored if invite code is provided.", + ), + ] = None, + role: Annotated[ + Role | None, + Query( + min_length=1, + description="(Optional) Project role. Ignored if invite code is provided.", + ), + ] = None, +) -> ProjectMemberRead: + joining_user = await session.get(User, user_id) + if joining_user is None: + raise ResourceNotFoundError(f'User "{user_id}" is not found.') + if invite_code is None: + if project_id is None or role is None: + raise BadInputError("Missing project ID or role.") + invite = None + else: + if ENV_CONFIG.is_oss: + raise BadInputError("Invite code is not supported in OSS.") + else: + from owl.db.models.cloud import VerificationCode + + # Fetch code + invite = await session.get(VerificationCode, invite_code) + if ( + invite is None + or invite.project_id is None + or invite.purpose not in ("project_invite", None) + or now() > invite.expiry + or invite.revoked_at is not None + or invite.used_at is not None + or invite.user_email != joining_user.preferred_email + ): + raise ResourceNotFoundError(f'Invite code "{invite_code}" is invalid.') + project_id = invite.project_id + role = invite.role + # Fetch + project = await session.get(Project, project_id) + if project is None: + raise ResourceNotFoundError(f'Project "{project_id}" is not found.') + if project.organization_id not in [r.organization_id for r in joining_user.org_memberships]: + raise ForbiddenError("You are not a member of this project's organization.") + if await session.get(ProjectMember, (user_id, project_id)) is not None: + raise ResourceExistsError("You are already in the project.") + # RBAC + if user.id != user_id or invite_code is None: + has_permissions( + user, + ["organization.ADMIN", "project.ADMIN"], + organization_id=project.organization_id, + project_id=project_id, + ) + project_member = ProjectMember(user_id=user_id, project_id=project_id, role=role) + session.add(project_member) + await session.commit() + await session.refresh(project_member) + # Consume invite code + if invite is not None: + invite.used_at = now() + session.add(invite) + await session.commit() + logger.bind(user_id=joining_user.id, proj_id=project.id).success( + ( + f"{joining_user.preferred_name} ({joining_user.preferred_email}) joined " + f'project "{project.name}" as "{role.name}".' + ) + ) + return project_member + + +@router.get( + "/v2/projects/members/list", + summary="List project members.", + description="Permissions: `system` OR `organization` OR `project`.", +) +@handle_exception +async def list_project_members( + user: Annotated[UserAuth, Depends(auth_user)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + params: Annotated[ListQueryByProject[Literal["id", "created_at", "updated_at"]], Query()], +) -> Page[ProjectMemberRead]: + project_id = params.project_id + # Fetch the project + project = await session.get(Project, project_id) + if project is None: + raise ResourceNotFoundError(f'Project "{project_id}" is not found.') + has_permissions( + user, + ["system", "organization", "project"], + organization_id=project.organization_id, + project_id=project_id, + ) + return await ProjectMember.list_( + session=session, + return_type=ProjectMemberRead, + offset=params.offset, + limit=params.limit, + order_by=params.order_by, + order_ascending=params.order_ascending, + search_query=params.search_query, + search_columns=params.search_columns, + filters=dict(project_id=project_id), + after=params.after, + ) + + +@router.get( + "/v2/projects/members", + summary="Get a project member.", + description="Permissions: `system` OR `organization` OR `project`.", +) +@handle_exception +async def get_project_member( + user: Annotated[UserAuth, Depends(auth_user)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + user_id: Annotated[str, Query(min_length=1, description="User ID.")], + project_id: Annotated[str, Query(min_length=1, description="Project ID.")], +) -> ProjectMemberRead: + # Fetch the project + project = await session.get(Project, project_id) + if project is None: + raise ResourceNotFoundError(f'Project "{project_id}" is not found.') + has_permissions( + user, + ["system", "organization", "project"], + organization_id=project.organization_id, + project_id=project_id, + ) + member = await session.get(ProjectMember, (user_id, project_id)) + if member is None: + raise ResourceNotFoundError(f'User "{user_id}" is not a member of project "{project_id}".') + return member + + +@router.patch( + "/v2/projects/members/role", + summary="Update a project member's role.", + description="Permissions: `organization.ADMIN` OR `project.ADMIN`.", +) +@handle_exception +async def update_member_role( + user: Annotated[UserAuth, Depends(auth_user)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + user_id: Annotated[str, Query(min_length=1, description="User ID.")], + project_id: Annotated[str, Query(min_length=1, description="Project ID.")], + role: Annotated[Role, Query(description="New role.")], +) -> ProjectMemberRead: + # Fetch the project + project = await session.get(Project, project_id) + if project is None: + raise ResourceNotFoundError(f'Project "{project_id}" is not found.') + # Check permissions + has_permissions( + user, + ["organization.ADMIN", "project.ADMIN"], + organization_id=project.organization_id, + project_id=project.id, + ) + # Fetch the member + member = await session.get(ProjectMember, (user_id, project_id)) + if member is None: + raise ResourceNotFoundError(f'User "{user_id}" is not a member of project "{project_id}".') + # Update + member.role = role + await session.commit() + return member + + +@router.delete( + "/v2/projects/members", + summary="Leave a project.", + description=( + "Permissions: `organization.ADMIN` OR `project.ADMIN`. " + "Permissions are only needed if deleting other user's membership." + ), +) +@handle_exception +async def leave_project( + user: Annotated[UserAuth, Depends(auth_user)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + user_id: Annotated[str, Query(min_length=1, description="User ID.")], + project_id: Annotated[str, Query(min_length=1, description="Project ID.")], +) -> OkResponse: + leaving_user = await session.get(User, user_id) + if leaving_user is None: + raise ResourceNotFoundError(f'User "{user_id}" is not found.') + project = await session.get(Project, project_id) + if project is None: + raise ResourceNotFoundError(f'Project "{project_id}" is not found.') + if user.id != user_id: + has_permissions( + user, + ["organization.ADMIN", "project.ADMIN"], + organization_id=project.organization_id, + project_id=project_id, + ) + project_member = await session.get(ProjectMember, (user_id, project_id)) + if project_member is None: + raise ResourceNotFoundError(f'User "{user_id}" is not a member of project "{project_id}".') + await session.delete(project_member) + await session.commit() + logger.bind(user_id=leaving_user.id, proj_id=project.id).success( + ( + f"{leaving_user.preferred_name} ({leaving_user.preferred_email}) left " + f'project "{project.name}".' + ) + ) + return OkResponse() + + +TABLE_CLS: dict[TableType, ActionTable | KnowledgeTable | ChatTable] = { + TableType.ACTION: ActionTable, + TableType.KNOWLEDGE: KnowledgeTable, + TableType.CHAT: ChatTable, +} + + +async def _export_project_as_pa_table( + request: Request, + user: UserAuth, + project: Project, +) -> pa.Table: + organization = OrganizationRead.model_validate(project.organization) + # Check quota + billing = BillingManager( + organization=organization, + project_id=project.id, + user_id=user.id, + request=request, + models=None, + ) + billing.has_egress_quota() + # Dump all tables as parquet files + data = [] + table_types = [TableType.ACTION, TableType.KNOWLEDGE, TableType.CHAT] + for table_type in table_types: + metas = ( + await TABLE_CLS[table_type].list_tables( + project_id=project.id, + limit=None, + offset=0, + parent_id=None, + count_rows=False, + ) + ).items + for meta in metas: + table = await TABLE_CLS[table_type].open_table(project_id=project.id, table_id=meta.id) + with BytesIO() as f: + await table.export_table(f) + data.append((table_type, meta, f.getvalue())) + if len(data) == 0: + raise BadInputError(f'Project "{project.id}" is empty with no tables.') + # Download project pictures + project_meta = project.model_dump() + for pic_type in ["profile_picture", "cover_picture"]: + uri: str | None = project_meta.get(f"{pic_type}_url", None) + if uri is None: + continue + async with open_uri_async(uri) as (f, mime): + project_meta[pic_type] = ( + f"data:{mime};base64,{base64.b64encode(await f.read()).decode('utf-8')}" + ) + # Bundle everything into a single PyArrow Table + table_metas = [ + {"table_type": table_type, "table_meta": meta.model_dump(mode="json")} + for table_type, meta, _ in data + ] + data = list(zip(*data, strict=True)) + pa_table = pa.table( + {"table_type": pa.array(data[0], pa.utf8()), "data": pa.array(data[2], pa.binary())}, + metadata={ + "project_meta": json_dumps(project_meta), + "table_metas": json_dumps(table_metas), + }, + ) + return pa_table + + +@router.get("/v2/projects/export") +@handle_exception +async def export_project( + request: Request, + bg_tasks: BackgroundTasks, + user: Annotated[UserAuth, Depends(auth_user)], + project_id: Annotated[str, Query(min_length=1, description="Project ID.")], +) -> FileResponse: + # Fetch the project + async with async_session() as session: + project = await session.get(Project, project_id) + if project is None: + raise ResourceNotFoundError(f'Project "{project_id}" is not found.') + pa_table = await _export_project_as_pa_table( + request=request, + user=user, + project=project, + ) + # Temporary file + tmp_dir = TemporaryDirectory() + filename = f"{project.id}.parquet" + filepath = join(tmp_dir.name, filename) + # Keep a reference to the directory and only delete upon completion + bg_tasks.add_task(tmp_dir.cleanup) + pq.write_table(pa_table, filepath, compression="ZSTD") + logger.bind(user_id=user.id).success( + f'{user.name} ({user.email}) exported project "{project.name}" ({project.id}).' + ) + return FileResponse( + path=filepath, + filename=filename, + media_type="application/octet-stream", + ) + + +async def _import_project_from_pa_table( + request: Request, + user: UserAuth, + *, + organization_id: str, + project_id: str, + pa_table: pa.Table, + keep_original_ids: bool = False, + check_quota: bool = True, + raise_error: bool = True, + verbose: bool = False, +) -> ProjectRead: + has_permissions(user, ["system", "organization"], organization_id=organization_id) + async with async_session() as session: + if project_id: + # Fetch the project + project = await session.get(Project, project_id) + if project is None: + raise ResourceNotFoundError(f'Project "{project_id}" is not found.') + organization = project.organization + else: + if not organization_id: + raise BadInputError("Organization ID is required when project ID is not provided.") + project = None + # Fetch the organization + organization = await session.get(Organization, organization_id) + if organization is None: + raise ResourceNotFoundError(f'Organization "{organization_id}" is not found.') + organization = OrganizationRead.model_validate(organization) + # Check quota + if check_quota: + billing = BillingManager( + organization=organization, + project_id="", # Not needed to check storage quotas + user_id=user.id, + request=request, + models=None, + ) + billing.has_db_storage_quota() + billing.has_file_storage_quota() + # Create the project if needed + if project is None: + try: + project_meta = json_loads(pa_table.schema.metadata[b"project_meta"]) + except KeyError as e: + raise BadInputError("Missing project metadata in the Parquet file.") from e + except Exception as e: + raise BadInputError("Invalid project metadata in the Parquet file.") from e + body = {k: v for k, v in project_meta.items() if k not in ["id", "organization_id"]} + body["organization_id"] = organization_id + async with async_session() as session: + project = await create_project( + request=request, + user=user, + session=session, + body=ProjectCreate.model_validate(body), + project_id=project_meta.get("id", "") if keep_original_ids else "", + ) + # Upload and update project picture URL + project = await session.get(Project, project.id) + if project is None: + raise UnexpectedError(f'Project "{project.id}" is not found.') + for pic_type in ["profile_picture", "cover_picture"]: + data: str | None = project_meta.get(pic_type, None) + uri_ori: str | None = project_meta.get(f"{pic_type}_url", None) + if data is None or uri_ori is None: + uri = None + else: + # f"data:{mime};base64,{base64.b64encode(await f.read()).decode('utf-8')}" + mime_type, b64 = data.replace("data:", "", 1).split(";base64,") + uri = await s3_upload( + organization.id, + project.id, + base64.b64decode(b64.encode("utf-8")), + content_type=mime_type, + filename=uri_ori.split("/")[-1], + ) + setattr(project, f"{pic_type}_url", uri) + await session.commit() + await session.refresh(project) + if verbose: + logger.info( + f'Importing project "{project.name}" ({project.id}): Project metadata imported.' + ) + + # Import Knowledge Tables first + async def _import_table(_data: bytes, _type: str): + with BytesIO(_data) as source: + await TABLE_CLS[_type].import_table( + project_id=project.id, + source=source, + table_id_dst=None, + reupload_files=not keep_original_ids, + verbose=verbose, + ) + + table_metas = json_loads(pa_table.schema.metadata[b"table_metas"]) + rows = pa_table.to_pylist() + i = 1 + for row, meta in zip(rows, table_metas, strict=True): + if row["table_type"] != TableType.KNOWLEDGE: + continue + meta = TableMetaResponse.model_validate(meta["table_meta"]) + if verbose: + logger.info( + ( + f'Importing project "{project.name}" ({project.id}): ' + f'Importing table "{meta.id}" ({i} of {len(rows)}) ...' + ) + ) + try: + await _import_table(row["data"], row["table_type"]) + except ResourceExistsError as e: + logger.info(f'Importing project "{project.name}" ({project.id}): {e}') + if raise_error: + raise + except Exception as e: + logger.exception( + f'Importing project "{project.name}" ({project.id}): Failed to import table "{meta.id}": {e}' + ) + if raise_error: + raise + i += 1 + # Import the rest + for row, meta in zip(rows, table_metas, strict=True): + if row["table_type"] == TableType.KNOWLEDGE: + continue + meta = TableMetaResponse.model_validate(meta["table_meta"]) + if verbose: + logger.info( + ( + f'Importing project "{project.name}" ({project.id}): ' + f'Importing table "{meta.id}" ({i} of {len(rows)}) ...' + ) + ) + try: + await _import_table(row["data"], row["table_type"]) + except ResourceExistsError as e: + logger.info(f'Importing project "{project.name}" ({project.id}): {e}') + if raise_error: + raise + except Exception as e: + logger.exception( + f'Importing project "{project.name}" ({project.id}): Failed to import table "{meta.id}": {e}' + ) + if raise_error: + raise + i += 1 + return project + + +class ProjectImportFormData(BaseModel): + file: Annotated[UploadFile, File(description="The project or template Parquet file.")] + project_id: Annotated[ + str, + Field( + description='If given, import tables into this project. Defaults to "" (create new project).' + ), + ] = "" + organization_id: Annotated[ + str, + Field( + description="Organization ID of the new project. Only required if creating a new project." + ), + ] = "" + + +@router.post("/v2/projects/import/parquet") +@handle_exception +async def import_project( + request: Request, + user: Annotated[UserAuth, Depends(auth_user)], + data: Annotated[ProjectImportFormData, Form()], +) -> ProjectRead: + # Load Parquet file + try: + with BytesIO(await data.file.read()) as source: + # TODO: Perhaps check the metadata with `columns=[]` first and avoid parsing the whole file + pa_table: pa.Table = pq.read_table( + source, columns=None, use_threads=False, memory_map=True + ) + except Exception as e: + raise BadInputError("Failed to parse Parquet file.") from e + return await _import_project_from_pa_table( + request, + user, + organization_id=data.organization_id, + project_id=data.project_id, + pa_table=pa_table, + ) + + +@router.post("/v2/projects/import/parquet/migration") +@handle_exception +async def import_project_migration( + request: Request, + user: Annotated[UserAuth, Depends(auth_user)], + data: Annotated[ProjectImportFormData, Form()], +) -> ProjectRead: + # Load Parquet file + try: + with BytesIO(await data.file.read()) as source: + pa_table: pa.Table = pq.read_table( + source, columns=None, use_threads=False, memory_map=True + ) + except Exception as e: + raise BadInputError("Failed to parse Parquet file.") from e + try: + return await _import_project_from_pa_table( + request, + user, + organization_id=data.organization_id, + project_id=data.project_id, + pa_table=pa_table, + keep_original_ids=True, + check_quota=False, + raise_error=False, + verbose=True, + ) + except Exception as e: + logger.exception(e) + raise + + +class TemplateImportQuery(BaseModel): + template_id: Annotated[str, Field(description="Template ID.")] + project_id: Annotated[ + str, + Field( + description='If given, import tables into this project. Defaults to "" (create new project).' + ), + ] = "" + organization_id: Annotated[ + str, + Field( + description="Organization ID of the new project. Only required if creating a new project." + ), + ] = "" + + +@router.post("/v2/projects/import/template") +@handle_exception +async def import_template( + request: Request, + user: Annotated[UserAuth, Depends(auth_user)], + params: Annotated[TemplateImportQuery, Query()], +) -> ProjectRead: + # Fetch the project + async with async_session() as session: + template = await session.get(Project, params.template_id) + if template is None: + raise ResourceNotFoundError(f'Template "{params.template_id}" is not found.') + # Export template + pa_table = await _export_project_as_pa_table( + request=request, + user=user, + project=template, + ) + # Import + return await _import_project_from_pa_table( + request, + user, + organization_id=params.organization_id, + project_id=params.project_id, + pa_table=pa_table, + ) diff --git a/services/api/src/owl/routers/projects/v1.py b/services/api/src/owl/routers/projects/v1.py new file mode 100644 index 0000000..0bbe0a4 --- /dev/null +++ b/services/api/src/owl/routers/projects/v1.py @@ -0,0 +1,156 @@ +from enum import StrEnum +from typing import Annotated + +from fastapi import ( + APIRouter, + BackgroundTasks, + Depends, + File, + Form, + Path, + Query, + Request, + UploadFile, +) +from fastapi.responses import FileResponse + +from owl.db import AsyncSession, yield_async_session +from owl.routers.projects import oss as v2 +from owl.types import ( + OkResponse, + Page, + ProjectCreate, + ProjectRead, + ProjectUpdate, + UserAuth, +) +from owl.utils.auth import auth_user +from owl.utils.exceptions import handle_exception + +router = APIRouter() + + +@router.post("/v1/projects") +@handle_exception +async def create_project( + request: Request, + user: Annotated[UserAuth, Depends(auth_user)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + body: ProjectCreate, + project_id: str = "", +) -> ProjectRead: + return await v2.create_project(request, user, session, body, project_id=project_id) + + +class AdminOrderBy(StrEnum): + ID = "id" + NAME = "name" + CREATED_AT = "created_at" + UPDATED_AT = "updated_at" + + +@router.get("/v1/projects") +@handle_exception +async def list_projects( + user: Annotated[UserAuth, Depends(auth_user)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + organization_id: Annotated[str, Query(min_length=1, description='Organization ID "org_xxx".')], + search_query: Annotated[ + str, + Query( + max_length=10_000, + description='_Optional_. A string to search for within project names as a filter. Defaults to "" (no filter).', + ), + ] = "", + offset: Annotated[int, Query(ge=0)] = 0, + limit: Annotated[int, Query(gt=0, le=100)] = 100, + order_by: Annotated[ + AdminOrderBy, + Query( + min_length=1, + description='_Optional_. Sort projects by this attribute. Defaults to "updated_at".', + ), + ] = AdminOrderBy.UPDATED_AT, + order_descending: Annotated[ + bool, + Query(description="_Optional_. Whether to sort by descending order. Defaults to True."), + ] = True, +) -> Page[ProjectRead]: + params = v2.ListProjectQuery( + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=not order_descending, + organization_id=organization_id, + search_query=search_query, + ) + return await v2.list_projects(user, session, params) + + +@router.get("/v1/projects/{project_id}") +@handle_exception +async def get_project( + request: Request, + user: Annotated[UserAuth, Depends(auth_user)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + project_id: Annotated[str, Path(min_length=1, description='Project ID "proj_xxx".')], +) -> ProjectRead: + return await v2.get_project(request, user, session, project_id) + + +@router.patch("/v1/projects") +@handle_exception +async def update_project( + user: Annotated[UserAuth, Depends(auth_user)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + project_id: Annotated[str, Query(min_length=1, description="Project ID.")], + body: ProjectUpdate, +) -> ProjectRead: + return await v2.update_project(user, session, project_id, body) + + +@router.delete("/v1/projects/{project_id}") +@handle_exception +async def delete_project( + request: Request, + user: Annotated[UserAuth, Depends(auth_user)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + project_id: Annotated[str, Path(min_length=1, description='Project ID "proj_xxx".')], +) -> OkResponse: + return await v2.delete_project(request, user, session, project_id) + + +@router.get("/v1/projects/{project_id}/export") +@handle_exception +async def export_project( + request: Request, + bg_tasks: BackgroundTasks, + user: Annotated[UserAuth, Depends(auth_user)], + project_id: Annotated[str, Path(min_length=1, description='Project ID "proj_xxx".')], +) -> FileResponse: + return await v2.export_project(request, bg_tasks, user, project_id) + + +@router.post("/v1/projects/import/{organization_id}") +@handle_exception +async def import_project( + request: Request, + user: Annotated[UserAuth, Depends(auth_user)], + organization_id: Annotated[str, Path(min_length=1, description='Organization ID "org_xxx".')], + file: Annotated[UploadFile, File(description="Project or Template Parquet file.")], + project_id_dst: Annotated[ + str, + Form( + description=( + "_Optional_. ID of the project to import tables into. " + "Defaults to creating new project." + ), + ), + ] = "", +) -> ProjectRead: + data = v2.ProjectImportFormData( + file=file, + project_id=project_id_dst, + organization_id=organization_id, + ) + return await v2.import_project(request, user, data) diff --git a/services/api/src/owl/routers/serving.py b/services/api/src/owl/routers/serving.py new file mode 100644 index 0000000..37305a0 --- /dev/null +++ b/services/api/src/owl/routers/serving.py @@ -0,0 +1,270 @@ +from typing import Annotated, Literal + +from fastapi import APIRouter, Depends, Query, Request, Response +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, Field + +from owl.db import AsyncSession, yield_async_session +from owl.db.gen_executor import GenExecutor +from owl.db.models import ModelConfig +from owl.types import ( + EXAMPLE_CHAT_MODEL_IDS, + ChatRequest, + EmbeddingRequest, + EmbeddingResponse, + ModelCapability, + ModelInfoListResponse, + ModelInfoRead, + OrganizationRead, + ProjectRead, + RerankingRequest, + RerankingResponse, + UserAuth, +) +from owl.utils.auth import auth_user_project +from owl.utils.billing import BillingManager +from owl.utils.exceptions import ResourceNotFoundError, handle_exception +from owl.utils.lm import LMEngine +from owl.utils.mcp import MCP_TOOL_TAG + +router = APIRouter() + + +class _ListQuery(BaseModel): + order_by: Literal["id", "name", "created_at", "updated_at"] = Field( + "id", + description='Sort by this attribute. Defaults to "id".', + ) + order_ascending: bool = Field( + True, + description="Whether to sort in ascending order. Defaults to True.", + ) + capabilities: list[ModelCapability] | None = Field( + None, + description=( + "Filter the model info by model's capabilities. Defaults to None (no filter)." + ), + examples=[[ModelCapability.CHAT]], + ) + + +class ModelInfoListQuery(_ListQuery): + model: str = Field( + "", + description="ID of the requested model.", + examples=EXAMPLE_CHAT_MODEL_IDS, + ) + + +@router.get( + "/v1/models", + summary="List the info of models available.", + description="List the info of models available with the specified name and capabilities.", + tags=[MCP_TOOL_TAG, "project"], +) +@handle_exception +async def model_info( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + session: Annotated[AsyncSession, Depends(yield_async_session)], + params: Annotated[ModelInfoListQuery, Query()], +) -> ModelInfoListResponse: + _, _, org = auth_info + try: + models = ( + await ModelConfig.list_( + session=session, + return_type=ModelInfoRead, + organization_id=org.id, + order_by=params.order_by, + order_ascending=params.order_ascending, + capabilities=params.capabilities, + exclude_inactive=True, + ) + ).items + # Filter by name + if params.model != "": + models = [m for m in models if m.id == params.model] + return ModelInfoListResponse(data=models) + except ResourceNotFoundError: + return ModelInfoListResponse(data=[]) + + +class ModelIdListQuery(_ListQuery): + prefer: str = Field( + "", + description="ID of the preferred model.", + examples=EXAMPLE_CHAT_MODEL_IDS, + ) + + +@router.get( + "/v1/models/ids", + summary="List the ID of models available.", + description=( + "List the ID of models available with the specified capabilities with an optional preferred model. " + "If the preferred model is not available, then return the first available model." + ), +) +@router.get( + "/v1/model_names", + deprecated=True, + summary="List the ID of models available.", + description="Deprecated, use `/v1/models/ids` instead. List the ID of models available.", +) +@handle_exception +async def model_ids( + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + session: Annotated[AsyncSession, Depends(yield_async_session)], + params: Annotated[ModelIdListQuery, Query()], +) -> list[str]: + models = await model_info( + auth_info, + session, + ModelInfoListQuery( + order_by=params.order_by, + order_ascending=params.order_ascending, + capabilities=params.capabilities, + model="", + ), + ) + names = [m.id for m in models.data] + if params.prefer in names: + names.remove(params.prefer) + names.insert(0, params.prefer) + return names + + +async def _empty_async_generator(): + """Returns an empty asynchronous generator.""" + return + # This line is never reached, but makes it an async generator + yield + + +@router.post( + "/v1/chat/completions", + summary="Chat completion.", + description="Given a list of messages comprising a conversation, returns a response from the model.", +) +@handle_exception +async def chat_completion( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + body: ChatRequest, +) -> Response: + # Check quota + billing: BillingManager = request.state.billing + billing.has_llm_quota(body.model) + billing.has_egress_quota() + _, project, org = auth_info + body.id = request.state.id + llm = LMEngine(organization=org, project=project, request=request) + body, references = await GenExecutor.setup_rag( + project=project, lm=llm, body=body, request_id=body.id + ) + if body.stream: + agen = llm.chat_completion_stream(messages=body.messages, **body.hyperparams) + try: + chunk = await anext(agen) + except StopAsyncIteration: + return StreamingResponse( + content=_empty_async_generator(), + status_code=200, + media_type="text/event-stream", + headers={"X-Accel-Buffering": "no"}, + ) + + async def _generate(): + content_length = 1 + if references is not None: + sse = f"data: {references.model_dump_json()}\n\n" + content_length += len(sse.encode("utf-8")) + yield sse + nonlocal chunk + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + async for chunk in agen: + sse = f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + content_length += len(sse.encode("utf-8")) + yield sse + sse = "data: [DONE]\n\n" + content_length += len(sse.encode("utf-8")) + yield sse + # NOTE: We must create egress events here as SSE cannot be handled in the middleware + billing.create_egress_events(content_length / (1024**3)) + + response = StreamingResponse( + content=_generate(), + status_code=200, + media_type="text/event-stream", + headers={"X-Accel-Buffering": "no"}, + ) + else: + response = await llm.chat_completion(messages=body.messages, **body.hyperparams) + if references is not None: + response.references = references + # NOTE: Do not create egress events here as it is handled in the middleware + return response + + +@router.post( + "/v1/embeddings", + summary="Embeds texts as vectors.", + description=( + "Get a vector representation of a given input that can be easily consumed by machine learning models and algorithms. " + "Note that the vectors are NOT normalized." + ), +) +@handle_exception +async def generate_embeddings( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + body: EmbeddingRequest, +) -> EmbeddingResponse: + # Check quota + billing: BillingManager = request.state.billing + billing.has_embedding_quota(body.model) + billing.has_egress_quota() + _, project, org = auth_info + embedder = LMEngine(organization=org, project=project, request=request) + if isinstance(body.input, str): + body.input = [body.input] + kwargs = dict( + model=body.model, + texts=body.input, + encoding_format=body.encoding_format, + ) + if body.type == "document": + embeddings = await embedder.embed_documents(**kwargs) + else: + embeddings = await embedder.embed_queries(**kwargs) + return embeddings + + +@router.post( + "/v1/rerank", + summary="Ranks each text input to the query text.", + description="Get the similarity score of each text input to query by giving a query and list of text inputs.", +) +@handle_exception +async def generate_rankings( + request: Request, + auth_info: Annotated[ + tuple[UserAuth, ProjectRead, OrganizationRead], Depends(auth_user_project) + ], + body: RerankingRequest, +) -> RerankingResponse: + # Check quota + billing: BillingManager = request.state.billing + billing.has_reranker_quota(body.model) + billing.has_egress_quota() + _, project, org = auth_info + reranker = LMEngine(organization=org, project=project, request=request) + return await reranker.rerank_documents(**body.model_dump()) diff --git a/services/api/src/owl/routers/tasks.py b/services/api/src/owl/routers/tasks.py new file mode 100644 index 0000000..9587498 --- /dev/null +++ b/services/api/src/owl/routers/tasks.py @@ -0,0 +1,24 @@ +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, Query + +from owl.configs import CACHE +from owl.types import UserAuth +from owl.utils.auth import auth_user_service_key +from owl.utils.exceptions import handle_exception + +router = APIRouter() + + +@router.get( + "/v2/progress", + summary="Get progress data.", + description="Permissions: None as long as signed-in.", +) +@handle_exception +async def get_progress( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + key: Annotated[str, Query(min_length=1, description="Progress key.")], +) -> dict[str, Any]: + del user + return (await CACHE.get_progress(key, None)) or {} diff --git a/services/api/src/owl/routers/template.py b/services/api/src/owl/routers/template.py deleted file mode 100644 index 866a7d2..0000000 --- a/services/api/src/owl/routers/template.py +++ /dev/null @@ -1,417 +0,0 @@ -import os -import pathlib -from io import BytesIO -from shutil import rmtree -from time import perf_counter -from typing import Annotated, Any - -import duckdb -import pyarrow as pa -from fastapi import ( - APIRouter, - Depends, - File, - Form, - Path, - Query, - Request, - UploadFile, -) -from filelock import FileLock, Timeout -from loguru import logger -from pyarrow.parquet import read_table as read_parquet_table -from sqlmodel import Session, select - -from jamaibase.exceptions import ( - BadInputError, - ResourceExistsError, - ResourceNotFoundError, - UnexpectedError, -) -from jamaibase.utils.io import dump_json, json_loads, read_json -from owl.db import create_sql_tables, create_sqlite_engine -from owl.db.gen_table import GenerativeTable -from owl.db.template import Tag, Template, TemplateRead, TemplateSQLModel -from owl.protocol import ( - TABLE_NAME_PATTERN, - ColName, - GenTableOrderBy, - OkResponse, - Page, - TableMetaResponse, - TableType, - TemplateMeta, -) -from owl.utils.auth import auth_internal -from owl.utils.exceptions import handle_exception - -CURR_DIR = pathlib.Path(__file__).resolve().parent -TEMPLATE_DIR = CURR_DIR.parent / "templates" -DB_PATH = TEMPLATE_DIR / "template.db" -TEMPLATE_ID_PATTERN = r"^[A-Za-z0-9]([A-Za-z0-9_-]{0,98}[A-Za-z0-9])?$" - -router = APIRouter(dependencies=[Depends(auth_internal)]) -public_router = APIRouter() - - -@router.on_event("startup") -async def startup(): - global ENGINE - ENGINE = create_sqlite_engine(f"sqlite:///{DB_PATH}") - _populate_template_db() - - -def _populate_template_db(timeout: float = 0.0): - lock = FileLock(TEMPLATE_DIR / "template.lock", timeout=timeout) - try: - with lock: - t0 = perf_counter() - if DB_PATH.exists(): - os.remove(DB_PATH) - create_sql_tables(TemplateSQLModel, ENGINE) - metas = [] - for template_dir in TEMPLATE_DIR.iterdir(): - if not template_dir.is_dir(): - continue - template_filepath = template_dir / "template_meta.json" - if not template_filepath.is_file(): - logger.warning(f"Missing template metadata JSON in {template_dir}") - continue - metas.append((template_dir.name, read_json(template_dir / "template_meta.json"))) - tags = sum([meta["tags"] for _, meta in metas], []) - tags = {t: t for t in tags} - with Session(ENGINE) as session: - for tag in tags: - tag = Tag(id=tag) - session.add(tag) - tags[tag.id] = tag - session.commit() - for template_id, meta in metas: - meta = TemplateMeta.model_validate(meta) - session.add( - Template( - id=template_id, - name=meta.name, - description=meta.description, - created_at=meta.created_at, - tags=[tags[t] for t in meta.tags], - ) - ) - session.commit() - logger.info(f"Populated template DB in {perf_counter() - t0:,.2f} s") - except Timeout: - pass - except Exception as e: - logger.exception(f"Failed to populate template DB due to {e}") - - -def _get_session(): - with Session(ENGINE) as session: - yield session - - -@router.post("/admin/backend/v1/templates/import") -@handle_exception -async def add_template( - *, - request: Request, - file: Annotated[UploadFile, File(description="Template Parquet file.")], - template_id_dst: Annotated[ - str, Form(pattern=TEMPLATE_ID_PATTERN, description="The ID of the new template.") - ], - exist_ok: Annotated[ - bool, Form(description="_Optional_. Whether to overwrite existing template.") - ] = False, -) -> OkResponse: - t0 = perf_counter() - dst_dir = TEMPLATE_DIR / template_id_dst - if exist_ok: - try: - rmtree(dst_dir) - except (NotADirectoryError, FileNotFoundError): - pass - elif dst_dir.is_dir(): - raise ResourceExistsError(f'Template "{template_id_dst}" already exists.') - os.makedirs(dst_dir, exist_ok=True) - try: - with BytesIO(await file.read()) as source: - # Write the template metadata JSON - pa_table = read_parquet_table(source, columns=None, use_threads=False, memory_map=True) - metadata = pa_table.schema.metadata - try: - template_meta = json_loads(metadata[b"template_meta"]) - except KeyError as e: - raise BadInputError("Missing template metadata in the Parquet file.") from e - except Exception as e: - raise BadInputError("Invalid template metadata in the Parquet file.") from e - dump_json(template_meta, dst_dir / "template_meta.json") - # Write the table parquet files - try: - type_metas = json_loads(metadata[b"table_metas"]) - except KeyError as e: - raise BadInputError("Missing table metadata in the Parquet file.") from e - except Exception as e: - raise BadInputError("Invalid table metadata in the Parquet file.") from e - for row, type_meta in zip(pa_table.to_pylist(), type_metas, strict=True): - table_type = type_meta["table_type"] - table_id = type_meta["table_meta"]["id"] - os.makedirs(dst_dir / table_type, exist_ok=True) - with open(dst_dir / table_type / f"{table_id}.parquet", "wb") as f: - f.write(row["data"]) - logger.info( - f'{request.state.id} - Template "{template_id_dst}" imported in {perf_counter() - t0:,.2f} s.' - ) - except pa.ArrowInvalid as e: - raise BadInputError(str(e)) from e - _populate_template_db(30.0) - return OkResponse() - - -@router.post("/admin/backend/v1/templates/populate") -@handle_exception -def populate_templates( - *, - timeout: Annotated[ - float, - Query(ge=0, description="_Optional_. Timeout in seconds, must be >= 0. Defaults to 30.0."), - ] = 30.0, -) -> OkResponse: - _populate_template_db(timeout=timeout) - return OkResponse() - - -@public_router.get("/public/v1/templates") -@handle_exception -def list_templates( - *, - session: Annotated[Session, Depends(_get_session)], - search_query: Annotated[ - str, - Query( - max_length=10_000, - description='_Optional_. A string to search for within template names. Defaults to "" (no filter).', - ), - ] = "", -) -> Page[TemplateRead]: - selection = select(Template) - if search_query != "": - selection = selection.where(Template.name.ilike(f"%{search_query}%")) - items = session.exec(selection).all() - total = len(items) - return Page[TemplateRead](items=items, offset=0, limit=total, total=total) - - -@public_router.get("/public/v1/templates/{template_id}") -@handle_exception -def get_template( - *, - session: Annotated[Session, Depends(_get_session)], - template_id: Annotated[ - str, - Path(max_length=10_000, description="Template ID."), - ], -) -> TemplateRead: - template = session.get(Template, template_id) - if template is None: - raise ResourceNotFoundError(f'Template "{template_id}" is not found.') - return template - - -@public_router.get("/public/v1/templates/{template_id}/gen_tables/{table_type}") -@handle_exception -def list_tables( - *, - template_id: Annotated[ - str, - Path(max_length=10_000, description="Template ID."), - ], - table_type: Annotated[TableType, Path(description="Table type.")], - offset: Annotated[ - int, - Query( - ge=0, - description="_Optional_. Item offset for pagination. Defaults to 0.", - ), - ] = 0, - limit: Annotated[ - int, - Query( - gt=0, - le=100, - description="_Optional_. Number of tables to return (min 1, max 100). Defaults to 100.", - ), - ] = 100, - search_query: Annotated[ - str, - Query( - max_length=100, - description='_Optional_. A string to search for within table IDs as a filter. Defaults to "" (no filter).', - ), - ] = "", - order_by: Annotated[ - GenTableOrderBy, - Query( - min_length=1, - description='_Optional_. Sort tables by this attribute. Defaults to "updated_at".', - ), - ] = GenTableOrderBy.UPDATED_AT, - order_descending: Annotated[ - bool, - Query(description="_Optional_. Whether to sort by descending order. Defaults to True."), - ] = True, -) -> Page[TableMetaResponse]: - template_dir = TEMPLATE_DIR / template_id - if not template_dir.is_dir(): - raise ResourceNotFoundError(f'Template "{template_id}" is not found.') - table_dir = template_dir / table_type - if not table_dir.is_dir(): - return Page[TableMetaResponse](items=[], offset=0, limit=100, total=0) - metas: list[TableMetaResponse] = [] - for table_path in sorted(table_dir.iterdir()): - table = read_parquet_table(table_path, columns=[], use_threads=False, memory_map=True) - try: - table_meta = table.schema.metadata[b"gen_table_meta"] - except KeyError as e: - raise UnexpectedError( - f'Missing table metadata in "templates/{template_id}/gen_tables/{table_type}/{table_path.name}".' - ) from e - except Exception as e: - raise UnexpectedError( - f'Invalid table metadata in "templates/{template_id}/gen_tables/{table_type}/{table_path.name}".' - ) from e - metas.append(TableMetaResponse.model_validate_json(table_meta)) - metas = [ - m - for m in sorted(metas, key=lambda m: getattr(m, order_by), reverse=order_descending) - if search_query.lower() in m.id.lower() - ] - total = len(metas) - return Page[TableMetaResponse]( - items=metas[offset : offset + limit], offset=offset, limit=limit, total=total - ) - - -@public_router.get("/public/v1/templates/{template_id}/gen_tables/{table_type}/{table_id}") -@handle_exception -def get_table( - *, - template_id: Annotated[ - str, - Path(max_length=10_000, description="Template ID."), - ], - table_type: Annotated[TableType, Path(description="Table type.")], - table_id: str = Path(pattern=TABLE_NAME_PATTERN, description="Table ID or name."), -) -> TableMetaResponse: - template_dir = TEMPLATE_DIR / template_id - if not template_dir.is_dir(): - raise ResourceNotFoundError(f'Template "{template_id}" is not found.') - table_path = template_dir / table_type / f"{table_id}.parquet" - if not table_path.is_file(): - raise ResourceNotFoundError(f'Table "{table_id}" is not found.') - table = read_parquet_table(table_path, columns=[], use_threads=False, memory_map=True) - try: - meta = TableMetaResponse.model_validate_json(table.schema.metadata[b"gen_table_meta"]) - except KeyError as e: - raise UnexpectedError( - f'Missing table metadata in "templates/{template_id}/gen_tables/{table_type}/{table_path.name}".' - ) from e - except Exception as e: - raise UnexpectedError( - f'Invalid table metadata in "templates/{template_id}/gen_tables/{table_type}/{table_path.name}".' - ) from e - return meta - - -@public_router.get("/public/v1/templates/{template_id}/gen_tables/{table_type}/{table_id}/rows") -@handle_exception -def list_table_rows( - *, - template_id: Annotated[ - str, - Path(max_length=10_000, description="Template ID."), - ], - table_type: Annotated[TableType, Path(description="Table type.")], - table_id: str = Path(pattern=TABLE_NAME_PATTERN, description="Table ID or name."), - starting_after: Annotated[ - str | None, - Query( - min_length=1, - description=( - "_Optional_. A cursor for use in pagination. Only rows with ID > `starting_after` will be returned. " - 'For instance, if your call receives 100 rows ending with ID "x", ' - 'your subsequent call can include `starting_after="x"` in order to fetch the next page of the list.' - ), - ), - ] = None, - offset: Annotated[ - int, - Query( - ge=0, - description="_Optional_. Item offset. Defaults to 0.", - ), - ] = 0, - limit: Annotated[ - int, - Query( - gt=0, - le=100, - description="_Optional_. Number of rows to return (min 1, max 100). Defaults to 100.", - ), - ] = 100, - order_by: Annotated[ - str, - Query( - min_length=1, - description='_Optional_. Sort rows by this column. Defaults to "Updated at".', - ), - ] = "Updated at", - order_descending: Annotated[ - bool, - Query(description="_Optional_. Whether to sort by descending order. Defaults to True."), - ] = True, - float_decimals: int = Query( - default=0, - ge=0, - description="_Optional_. Number of decimals for float values. Defaults to 0 (no rounding).", - ), - vec_decimals: int = Query( - default=0, - description="_Optional_. Number of decimals for vectors. If its negative, exclude vector columns. Defaults to 0 (no rounding).", - ), -) -> Page[dict[ColName, Any]]: - template_dir = TEMPLATE_DIR / template_id - if not template_dir.is_dir(): - raise ResourceNotFoundError(f'Template "{template_id}" is not found.') - table_path = template_dir / table_type / f"{table_id}.parquet" - if not table_path.is_file(): - raise ResourceNotFoundError(f'Table "{table_id}" is not found.') - - query = GenerativeTable._list_rows_query( - table_name=table_path, - sort_by=order_by, - sort_order="DESC" if order_descending else "ASC", - starting_after=starting_after, - id_column="ID", - offset=offset, - limit=limit, - ) - df = duckdb.sql(query).df() - df = GenerativeTable._post_process_rows_df( - df, - columns=None, - convert_null=True, - remove_state_cols=True, - json_safe=True, - include_original=True, - float_decimals=float_decimals, - vec_decimals=vec_decimals, - ) - rows = df.to_dict("records") - total = duckdb.sql(GenerativeTable._count_rows_query(table_path)).fetchone()[0] - return Page[dict[ColName, Any]]( - items=rows, - offset=offset, - limit=limit, - total=total, - starting_after=starting_after, - ) diff --git a/services/api/src/owl/routers/templates.py b/services/api/src/owl/routers/templates.py new file mode 100644 index 0000000..55eb71f --- /dev/null +++ b/services/api/src/owl/routers/templates.py @@ -0,0 +1,216 @@ +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, Path, Query +from loguru import logger +from pydantic import BaseModel, Field + +from owl.db import TEMPLATE_ORG_ID, AsyncSession, yield_async_session +from owl.db.gen_table import ( + ActionTable, + ChatTable, + KnowledgeTable, +) +from owl.db.models import Organization, Project +from owl.types import ( + GetTableRowQuery, + ListQuery, + ListTableQuery, + ListTableRowQuery, + Page, + ProjectRead, + SanitisedNonEmptyStr, + TableMetaResponse, + TableType, +) +from owl.utils.exceptions import ( + ResourceNotFoundError, + handle_exception, +) + +router = APIRouter() + + +class ListTemplateQuery(ListQuery): + search_query: Annotated[ + str, + Field( + max_length=255, + description='_Optional_. A string to search for within project names as a filter. Defaults to "" (no filter).', + ), + ] = "" + + +@router.get( + "/v2/templates/list", + summary="List templates.", + description="Permissions: None, publicly accessible.", +) +@handle_exception +async def list_templates( + session: Annotated[AsyncSession, Depends(yield_async_session)], + params: Annotated[ListTemplateQuery, Query()], +) -> Page[ProjectRead]: + # Ensure the organization exists + if (await session.get(Organization, TEMPLATE_ORG_ID)) is None: + logger.warning(f'Template organization "{TEMPLATE_ORG_ID}" does not exist.') + return Page[ProjectRead]( + items=[], + offset=params.offset, + limit=params.limit, + total=0, + ) + # List + return await Project.list_( + session=session, + return_type=ProjectRead, + offset=params.offset, + limit=params.limit, + order_by=params.order_by, + order_ascending=params.order_ascending, + search_query=params.search_query, + search_columns=params.search_columns, + filters=dict(organization_id=TEMPLATE_ORG_ID), + after=params.after, + ) + + +@router.get( + "/v2/templates", + summary="Get a specific template.", + description="Permissions: None, publicly accessible.", +) +@handle_exception +async def get_template( + session: Annotated[AsyncSession, Depends(yield_async_session)], + template_id: Annotated[str, Query(min_length=1, description="Template ID.")], +) -> ProjectRead: + # Fetch the template + template = await session.get(Project, template_id) + if template is None: + raise ResourceNotFoundError(f'Template "{template_id}" is not found.') + return template + + +TABLE_CLS: dict[TableType, ActionTable | KnowledgeTable | ChatTable] = { + TableType.ACTION: ActionTable, + TableType.KNOWLEDGE: KnowledgeTable, + TableType.CHAT: ChatTable, +} + + +class _ListTableQuery(ListTableQuery): + template_id: Annotated[str, Field(min_length=1, description="Template ID.")] + + +@router.get( + "/v2/templates/gen_tables/{table_type}/list", + summary="List tables in a template.", + description="Permissions: None, publicly accessible.", +) +@handle_exception +async def list_tables( + table_type: Annotated[TableType, Path(description="Table type.")], + params: Annotated[_ListTableQuery, Query()], +) -> Page[TableMetaResponse]: + metas = await TABLE_CLS[table_type].list_tables( + project_id=params.template_id, + limit=params.limit, + offset=params.offset, + parent_id=params.parent_id, + search_query=params.search_query, + order_by=params.order_by, + order_ascending=params.order_ascending, + count_rows=params.count_rows, + ) + return metas + + +class GetTableQuery(BaseModel): + template_id: Annotated[str, Field(min_length=1, description="Template ID.")] + table_id: Annotated[SanitisedNonEmptyStr, Field(description="The ID of the table to fetch.")] + + +@router.get( + "/v2/templates/gen_tables/{table_type}", + summary="Get a specific table from a template.", + description="Permissions: None, publicly accessible.", +) +@handle_exception +async def get_table( + table_type: Annotated[TableType, Path(description="Table type.")], + params: Annotated[GetTableQuery, Query()], +) -> TableMetaResponse: + table = await TABLE_CLS[table_type].open_table( + project_id=params.template_id, table_id=params.table_id + ) + return table.v1_meta_response + + +class _ListTableRowQuery(ListTableRowQuery): + template_id: Annotated[str, Field(min_length=1, description="Template ID.")] + + +@router.get( + "/v2/templates/gen_tables/{table_type}/rows/list", + summary="List rows in a template table.", + description="Permissions: None, publicly accessible.", +) +@handle_exception +async def list_table_rows( + table_type: Annotated[TableType, Path(description="Table type.")], + params: Annotated[_ListTableRowQuery, Query()], +) -> Page[dict[str, Any]]: + table = await TABLE_CLS[table_type].open_table( + project_id=params.template_id, table_id=params.table_id + ) + rows = await table.list_rows( + limit=params.limit, + offset=params.offset, + order_by=[params.order_by], + order_ascending=params.order_ascending, + columns=params.columns, + where=params.where, + search_query=params.search_query, + search_columns=params.search_columns, + remove_state_cols=False, + ) + return Page[dict[str, Any]]( + items=table.postprocess_rows( + rows.items, + float_decimals=params.float_decimals, + vec_decimals=params.vec_decimals, + ), + offset=params.offset, + limit=params.limit, + total=rows.total, + ) + + +class _GetTableRowQuery(GetTableRowQuery): + template_id: Annotated[str, Field(min_length=1, description="Template ID.")] + + +@router.get( + "/v2/templates/gen_tables/{table_type}/rows", + summary="Get a specific row from a template table.", + description="Permissions: None, publicly accessible.", +) +@handle_exception +async def get_table_row( + table_type: Annotated[TableType, Path(description="Table type.")], + params: Annotated[_GetTableRowQuery, Query()], +) -> dict[str, Any]: + table = await TABLE_CLS[table_type].open_table( + project_id=params.template_id, table_id=params.table_id + ) + row = await table.get_row( + row_id=params.row_id, + columns=params.columns, + remove_state_cols=False, + ) + row = table.postprocess_rows( + [row], + float_decimals=params.float_decimals, + vec_decimals=params.vec_decimals, + )[0] + return row diff --git a/services/api/src/owl/routers/users/__init__.py b/services/api/src/owl/routers/users/__init__.py new file mode 100644 index 0000000..e9dc821 --- /dev/null +++ b/services/api/src/owl/routers/users/__init__.py @@ -0,0 +1,12 @@ +from fastapi import APIRouter + +from owl.configs import ENV_CONFIG +from owl.routers.users.oss import router as oss_router + +router = APIRouter() +router.include_router(oss_router) + +if ENV_CONFIG.is_cloud: + from owl.routers.users.cloud import router as cloud_router + + router.include_router(cloud_router) diff --git a/services/api/src/owl/routers/users/oss.py b/services/api/src/owl/routers/users/oss.py new file mode 100644 index 0000000..8a03ba3 --- /dev/null +++ b/services/api/src/owl/routers/users/oss.py @@ -0,0 +1,207 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Query, Request +from loguru import logger +from sqlmodel import delete, func, select + +from owl.configs import ENV_CONFIG +from owl.db import AsyncSession, yield_async_session +from owl.db.models import ( + Organization, + OrgMember, + ProjectMember, + User, +) +from owl.types import ( + ListQuery, + OkResponse, + Page, + UserAuth, + UserCreate, + UserReadObscured, + UserUpdate, +) +from owl.utils.auth import auth_service_key, auth_user_service_key, has_permissions +from owl.utils.dates import now +from owl.utils.exceptions import ( + ResourceExistsError, + ResourceNotFoundError, + handle_exception, +) + +router = APIRouter() + + +async def _count_email(session: AsyncSession, email: str) -> int: + return (await session.exec(select(func.count(User.id)).where(User.email == email))).one() + + +@router.post( + "/v2/users", + summary="Create a user.", + description="Permissions: None.", +) +@handle_exception +async def create_user( + request: Request, + token: Annotated[str, Depends(auth_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + body: UserCreate, +) -> UserReadObscured: + del token + # Unless explicitly specified, create the first user with ID 0 + if ( + "id" not in body.model_dump(exclude_unset=True) + and (await session.exec(select(func.count(User.id)))).one() == 0 + ): + body.id = "0" + # Check if user already exists + if (await session.get(User, body.id)) is not None: + raise ResourceExistsError(f'User "{body.id}" already exists.') + if await _count_email(session, body.email) > 0: + raise ResourceExistsError(f'User with email "{body.email}" already exists.') + user = User.model_validate(body) + # Auth0 handles email verification + if ENV_CONFIG.auth0_api_key_plain: + user.email_verified = True + session.add(user) + await session.commit() + await session.refresh(user) + logger.info( + f"{request.state.id} - Created user: {user.model_dump(exclude={'password', 'password_hash'})}" + ) + logger.bind(user_id=user.id).success(f"{user.name} ({user.email}) created their account.") + user = await User.get(session, user.id, populate_existing=True) + return user + + +@router.get( + "/v2/users/list", + summary="List users.", + description="Permissions: `system.ADMIN`.", +) +@handle_exception +async def list_users( + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + params: Annotated[ListQuery, Query()], +) -> Page[UserReadObscured]: + has_permissions(user, ["system.ADMIN"]) + return await User.list_( + session=session, + return_type=UserReadObscured, + offset=params.offset, + limit=params.limit, + order_by=params.order_by, + order_ascending=params.order_ascending, + search_query=params.search_query, + search_columns=params.search_columns, + after=params.after, + ) + + +@router.get( + "/v2/users", + summary="Get current user or a specific user.", + description=( + "Permissions: `system.ADMIN`. " + "Permissions are only needed if the queried user is not the current logged-in user." + ), +) +@handle_exception +async def get_user( + _user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + user_id: Annotated[ + str | None, Query(description="User ID. If not provided, the logged-in user is returned.") + ] = None, +) -> UserReadObscured: + if (not user_id) or (user_id == _user.id): + return await User.get(session, _user.id, populate_existing=True) + if _user.id != user_id: + has_permissions(_user, ["system.ADMIN"]) + return await User.get(session, user_id) + + +@router.patch( + "/v2/users", + summary="Update the current logged-in user.", + description="Permissions: None.", +) +@handle_exception +async def update_user( + *, + _user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], + body: UserUpdate, +) -> UserReadObscured: + user = await User.get(session, _user.id) + if user is None: + raise ResourceNotFoundError(f'User "{_user.id}" is not found.') + # Perform update + updates = body.model_dump(exclude={"id"}, exclude_unset=True) + for key, value in updates.items(): + if key == "email" and body.email != user.email: + if await _count_email(session, body.email) > 0: + raise ResourceExistsError(f'User with email "{body.email}" already exists.') + user.email_verified = False + setattr(user, key, value) + user.updated_at = now() + session.add(user) + await session.commit() + await session.refresh(user) + logger.bind(user_id=user.id).success( + ( + f"{user.name} ({user.email}) updated the attributes " + f"{list(updates.keys())} of their user account." + ) + ) + return user + + +@router.delete( + "/v2/users", + summary="Delete a user.", + description="Permissions: None.", +) +@handle_exception +async def delete_user( + request: Request, + user: Annotated[UserAuth, Depends(auth_user_service_key)], + session: Annotated[AsyncSession, Depends(yield_async_session)], +) -> OkResponse: + user = await session.get(User, user.id) + if user is None: + raise ResourceNotFoundError(f'User "{user.id}" is not found.') + org_ids = [m.organization_id for m in user.org_memberships] + # Delete all related resources + logger.info(f'{request.state.id} - Deleting user: "{user.id}"') + await session.exec(delete(OrgMember).where(OrgMember.user_id == user.id)) + await session.exec(delete(ProjectMember).where(ProjectMember.user_id == user.id)) + if ENV_CONFIG.is_cloud: + from owl.db.models.cloud import ProjectKey, VerificationCode + + await session.exec(delete(ProjectKey).where(ProjectKey.user_id == user.id)) + await session.exec( + delete(VerificationCode).where(VerificationCode.user_email == user.email) + ) + await session.delete(user) + await session.commit() + # Delete organizations if the user was the last member + logger.info(f"{request.state.id} - Inspecting organizations: {org_ids}") + for org_id in org_ids: + member_count = ( + await session.exec( + select(func.count(OrgMember.user_id)).where(OrgMember.organization_id == org_id) + ) + ).one() + if member_count > 0: + continue + try: + await session.exec(delete(Organization).where(Organization.id == org_id)) + await session.commit() + logger.info(f'{request.state.id} - Deleting empty organization "{org_id}"') + except Exception as e: + logger.warning(f'Failed to delete organization "{org_id}" due to {repr(e)}') + logger.bind(user_id=user.id).success(f"{user.name} ({user.email}) deleted their account.") + return OkResponse() diff --git a/services/api/src/owl/scripts/backup_db.py b/services/api/src/owl/scripts/backup_db.py index 27fbaaf..d6e73e4 100644 --- a/services/api/src/owl/scripts/backup_db.py +++ b/services/api/src/owl/scripts/backup_db.py @@ -47,7 +47,7 @@ def restore(db_dir: str): ) ) src_path = join(proj_dir, bak_files[0]) - dst_path = join(proj_dir, f'{bak_files[0].split("_")[0]}.db') + dst_path = join(proj_dir, f"{bak_files[0].split('_')[0]}.db") os.remove(dst_path) copy2(src_path, dst_path) @@ -74,5 +74,5 @@ def find_sqlite_files(directory): os.makedirs(backup_dir, exist_ok=False) for j, db_file in enumerate(sqlite_files): - print(f"(DB {j+1:,d}/{len(sqlite_files):,d}): Processing: {db_file}") + print(f"(DB {j + 1:,d}/{len(sqlite_files):,d}): Processing: {db_file}") backup_db(db_file, backup_dir) diff --git a/services/api/src/owl/scripts/update_db.py b/services/api/src/owl/scripts/update_db.py deleted file mode 100644 index 188d030..0000000 --- a/services/api/src/owl/scripts/update_db.py +++ /dev/null @@ -1,95 +0,0 @@ -import os -import sqlite3 -from datetime import datetime, timezone -from os.path import join -from pprint import pprint - -from pydantic_settings import BaseSettings, SettingsConfigDict - -from owl.configs.manager import ENV_CONFIG - - -class EnvConfig(BaseSettings): - model_config = SettingsConfigDict( - env_file=".env", env_file_encoding="utf-8", extra="ignore", cli_parse_args=False - ) - owl_db_dir: str = "db" - - -NOW = datetime.now(tz=timezone.utc).isoformat() -backup_dir = f"{ENV_CONFIG.owl_db_dir}_BAK_{NOW}" -os.makedirs(backup_dir, exist_ok=False) - - -def add_columns(): - with sqlite3.connect(join(ENV_CONFIG.owl_db_dir, "main.db")) as src: - c = src.cursor() - # Add OAuth columns to user table - c.execute("ALTER TABLE user ADD COLUMN username TEXT DEFAULT NULL") - c.execute("ALTER TABLE user ADD COLUMN refresh_counter INTEGER DEFAULT 0") - c.execute("ALTER TABLE user ADD COLUMN google_id TEXT DEFAULT NULL") - c.execute("ALTER TABLE user ADD COLUMN google_name TEXT DEFAULT NULL") - c.execute("ALTER TABLE user ADD COLUMN google_username TEXT DEFAULT NULL") - c.execute("ALTER TABLE user ADD COLUMN google_email TEXT DEFAULT NULL") - c.execute("ALTER TABLE user ADD COLUMN google_picture_url TEXT DEFAULT NULL") - c.execute("ALTER TABLE user ADD COLUMN github_id INTEGER DEFAULT NULL") - c.execute("ALTER TABLE user ADD COLUMN github_name TEXT DEFAULT NULL") - c.execute("ALTER TABLE user ADD COLUMN github_username TEXT DEFAULT NULL") - c.execute("ALTER TABLE user ADD COLUMN github_email TEXT DEFAULT NULL") - c.execute("ALTER TABLE user ADD COLUMN github_picture_url TEXT DEFAULT NULL") - src.commit() - c.execute("CREATE UNIQUE INDEX idx_user_google_id ON user (google_id)") - c.execute("CREATE UNIQUE INDEX idx_user_github_id ON user (github_id)") - # Rename table - c.execute("ALTER TABLE `userorglink` RENAME TO `orgmember`") - # Flatten quota related columns to organization table - c.execute("ALTER TABLE organization ADD COLUMN credit REAL DEFAULT 0.0") - c.execute("ALTER TABLE organization ADD COLUMN credit_grant REAL DEFAULT 0.0") - c.execute("ALTER TABLE organization ADD COLUMN llm_tokens_quota_mtok REAL DEFAULT 0") - c.execute("ALTER TABLE organization ADD COLUMN llm_tokens_usage_mtok REAL DEFAULT 0") - c.execute("ALTER TABLE organization ADD COLUMN embedding_tokens_quota_mtok REAL DEFAULT 0") - c.execute("ALTER TABLE organization ADD COLUMN embedding_tokens_usage_mtok REAL DEFAULT 0") - c.execute("ALTER TABLE organization ADD COLUMN reranker_quota_ksearch REAL DEFAULT 0") - c.execute("ALTER TABLE organization ADD COLUMN reranker_usage_ksearch REAL DEFAULT 0") - c.execute("ALTER TABLE organization ADD COLUMN db_quota_gib REAL DEFAULT 0.0") - c.execute("ALTER TABLE organization ADD COLUMN db_usage_gib REAL DEFAULT 0.0") - c.execute("ALTER TABLE organization ADD COLUMN file_quota_gib REAL DEFAULT 0.0") - c.execute("ALTER TABLE organization ADD COLUMN file_usage_gib REAL DEFAULT 0.0") - c.execute("ALTER TABLE organization ADD COLUMN egress_quota_gib REAL DEFAULT 0.0") - c.execute("ALTER TABLE organization ADD COLUMN egress_usage_gib REAL DEFAULT 0.0") - c.execute("ALTER TABLE organization ADD COLUMN models JSON DEFAULT '{}'") - # Remove nested quota column - c.execute("ALTER TABLE organization DROP COLUMN quotas") - src.commit() - c.execute("PRAGMA table_info(organization)") - pprint(c.fetchall()) - c.close() - - -def update_oauth_info(): - with sqlite3.connect(join(ENV_CONFIG.owl_db_dir, "main.db")) as src: - c = src.cursor() - c.execute("SELECT id FROM User") - for row in c.fetchall(): - user_id = row[0] - if user_id.startswith("github|"): - c.execute( - "UPDATE User SET github_id = ? WHERE id = ?", - (int(user_id.split("|")[1]), user_id), - ) - src.commit() - elif user_id.startswith("google-oauth2|"): - c.execute( - "UPDATE User SET google_id = ? WHERE id = ?", - (user_id.split("|")[1], user_id), - ) - src.commit() - c.close() - - -if __name__ == "__main__": - with sqlite3.connect(join(ENV_CONFIG.owl_db_dir, "main.db")) as src: - with sqlite3.connect(join(backup_dir, "main.db")) as dst: - src.backup(dst) - add_columns() - update_oauth_info() diff --git a/services/api/src/owl/tasks/checks.py b/services/api/src/owl/tasks/checks.py new file mode 100644 index 0000000..de64814 --- /dev/null +++ b/services/api/src/owl/tasks/checks.py @@ -0,0 +1,88 @@ +from loguru import logger + +from jamaibase import JamAI +from jamaibase.types import ChatRequest +from owl.configs import ENV_CONFIG, celery_app + + +@celery_app.task +def test_models(): + client = JamAI( + api_base=f"http://localhost:{ENV_CONFIG.port}/api", + user_id="0", + token=ENV_CONFIG.service_key_plain, + ) + projects = client.projects.list_projects("0", limit=1).items + if len(projects) == 0: + logger.error("No projects found.") + return + project = projects[0] + client = JamAI( + api_base=f"http://localhost:{ENV_CONFIG.port}/api", + user_id="0", + project_id=project.id, + token=ENV_CONFIG.service_key_plain, + ) + + # Test chat completion + models = client.model_info(capabilities=["chat"]).data + status = {model.id: False for model in models} + for model in models: + logger.debug(f"------ {model.id} {model.name} ------") + for stream in [True, False]: + try: + response = client.generate_chat_completions( + ChatRequest( + model=model.id, + messages=[{"role": "user", "content": "Hello"}], + max_tokens=2, + stream=stream, + ), + ) + if stream: + for chunk in response: + logger.debug(chunk) + else: + logger.debug(response) + except Exception as e: + logger.error(f'Model "{model.name}" ({model.id}) failed: {repr(e)}') + status[model.id] = True + logger.info( + f"Chat model test: {sum(status.values()):,d} out of {len(status):,d} models passed." + ) + + # Test embedding + models = client.model_info(capabilities=["embed"]).data + status = {model.id: False for model in models} + for model in models: + logger.debug(f"------ {model.id} {model.name} ------") + for text in ["What is a llama?", ["What is a llama?", "What is an alpaca?"]]: + for encoding in ["float", "base64"]: + try: + response = client.generate_embeddings( + dict(model=model.id, input=text, encoding=encoding), + ) + logger.debug(response) + except Exception as e: + logger.error(f'Model "{model.name}" ({model.id}) failed: {repr(e)}') + status[model.id] = True + logger.info( + f"Embedding model test: {sum(status.values()):,d} out of {len(status):,d} models passed." + ) + + # Test rerank + models = client.model_info(capabilities=["rerank"]).data + status = {model.id: False for model in models} + for model in models: + logger.debug(f"------ {model.id} {model.name} ------") + try: + response = client.rerank( + dict(model=model.id, documents=["Norway", "Sweden"], query="Stockholm"), + ) + logger.debug(response) + except Exception as e: + logger.error(f'Model "{model.name}" ({model.id}) failed: {repr(e)}') + status[model.id] = True + logger.info( + f"Reranking model test: {sum(status.values()):,d} out of {len(status):,d} models passed." + ) diff --git a/services/api/src/owl/tasks/database.py b/services/api/src/owl/tasks/database.py new file mode 100644 index 0000000..0c785db --- /dev/null +++ b/services/api/src/owl/tasks/database.py @@ -0,0 +1,12 @@ +import asyncio + +from owl.configs import celery_app +from owl.utils.billing import CLICKHOUSE_CLIENT + + +@celery_app.task +def run_periodic_flush_buffer(): + """ + Flush redis buffer to clickhouse. + """ + asyncio.get_event_loop().run_until_complete(CLICKHOUSE_CLIENT.flush_buffer()) diff --git a/services/api/src/owl/tasks/gen_table.py b/services/api/src/owl/tasks/gen_table.py new file mode 100644 index 0000000..4ff9738 --- /dev/null +++ b/services/api/src/owl/tasks/gen_table.py @@ -0,0 +1,62 @@ +import asyncio +from io import BytesIO + +from loguru import logger + +from owl.configs import celery_app +from owl.db.gen_table import ActionTable, ChatTable, KnowledgeTable +from owl.types import TableType +from owl.utils.exceptions import JamaiException, ResourceExistsError +from owl.utils.io import open_uri_async + +TABLE_CLS: dict[TableType, ActionTable | KnowledgeTable | ChatTable] = { + TableType.ACTION: ActionTable, + TableType.KNOWLEDGE: KnowledgeTable, + TableType.CHAT: ChatTable, +} + + +@celery_app.task +def import_gen_table( + source: str | bytes, + *, + project_id: str, + table_type: str, + table_id_dst: str | None, + reupload_files: bool = True, + progress_key: str = "", + verbose: bool = False, +) -> str: + async def _task(): + if isinstance(source, str): + async with open_uri_async(source) as (f, _): + data = await f.read() + else: + data = source + with BytesIO(data) as f: + try: + return await TABLE_CLS[table_type].import_table( + project_id=project_id, + source=f, + table_id_dst=table_id_dst, + reupload_files=reupload_files, + progress_key=progress_key, + verbose=verbose, + ) + except ResourceExistsError: + raise + except JamaiException as e: + logger.error( + f'Failed to import table "{table_id_dst}" into project "{project_id}": {repr(e)}' + ) + raise + except Exception as e: + logger.exception( + f'Failed to import table "{table_id_dst}" into project "{project_id}": {repr(e)}' + ) + raise + + logger.info("Generative Table import task started.") + table = asyncio.get_event_loop().run_until_complete(_task()) + logger.info("Generative Table import task completed.") + return table.v1_meta_response.model_dump_json() diff --git a/services/api/src/owl/tasks/genitor.py b/services/api/src/owl/tasks/genitor.py index 880d8b9..20757a1 100644 --- a/services/api/src/owl/tasks/genitor.py +++ b/services/api/src/owl/tasks/genitor.py @@ -1,30 +1,10 @@ -# tasks.py -import os -import pathlib -import tempfile from datetime import datetime, timedelta, timezone import boto3 from botocore.client import Config -from celery import Celery, chord from loguru import logger -from owl.configs.manager import ENV_CONFIG -from owl.db.gen_table import GenerativeTable -from owl.protocol import TableType - -# Set up Celery -app = Celery("tasks", broker=f"redis://{ENV_CONFIG.owl_redis_host}:{ENV_CONFIG.owl_redis_port}/0") - -# Configure Celery -app.conf.update( - result_backend=f"redis://{ENV_CONFIG.owl_redis_host}:{ENV_CONFIG.owl_redis_port}/0", - task_serializer="json", - accept_content=["json"], - result_serializer="json", - timezone="UTC", - enable_utc=True, -) +from owl.configs import ENV_CONFIG, celery_app AWS_DELETE_API_MAX_OBJECT_LIMIT = 1000 @@ -44,7 +24,7 @@ def _get_s3_client(): ) -@app.task +@celery_app.task def s3_cleanup(): s3_client = _get_s3_client() current_date = datetime.utcnow().date() @@ -144,35 +124,7 @@ def s3_cleanup(): logger.error(f"S3 Cleanup failed:\n {e}") -@app.task -def backup_to_s3(): - db_dir = pathlib.Path(ENV_CONFIG.owl_db_dir) - logger.info(f"DB PATH: {db_dir}") - all_chains = [] - - for org_dir in db_dir.iterdir(): - if not org_dir.is_dir() or not org_dir.name.startswith("org_"): - continue - for project_dir in org_dir.iterdir(): - if not project_dir.is_dir(): - continue - table_types = [TableType.ACTION, TableType.KNOWLEDGE, TableType.CHAT] - - lance_chains = [ - backup_gen_table_parquet.s(str(org_dir.name), str(project_dir.name), table_type) - for table_type in table_types - ] - - all_chains.extend(lance_chains) - - if all_chains: - return chord(all_chains)(backup_project_results.s()) - else: - logger.warning("No tasks to execute in the chord.") - return None - - -@app.task +@celery_app.task def backup_project_results(results): failed_project = [] status_dict = {} @@ -189,42 +141,3 @@ def backup_project_results(results): logger.info( f"Total number of successful project backup: {true_count} out of {len(results)}. \n Failed projects: {failed_project}" ) - - -@app.task -def backup_gen_table_parquet(org_id: str, project_id: str, table_type: str): - try: - table = GenerativeTable.from_ids(org_id, project_id, table_type) - table_dir = f"{ENV_CONFIG.owl_db_dir}/{org_id}/{project_id}/{table_type}" - with table.create_session() as session: - offset, total = 0, 1 - while offset < total: - metas, total = table.list_meta( - session, - offset=offset, - limit=50, - remove_state_cols=True, - parent_id=None, - ) - offset += 50 - for meta in metas: - upload_path = ( - f"{get_timestamp()}/db/{org_id}/{project_id}/{table_type}/{meta.id}" - ) - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_file = os.path.join(tmp_dir, f"{meta.id}.parquet") - table.dump_parquet(session=session, table_id=meta.id, dest=tmp_file) - s3_client = _get_s3_client() - s3_client.upload_file( - tmp_file, - ENV_CONFIG.s3_backup_bucket_name, - f"{upload_path}.parquet", - ) - logger.info( - f"Backup to s3://{ENV_CONFIG.s3_backup_bucket_name}/{upload_path}.parquet" - ) - return True, org_id, project_id - - except Exception as e: - logger.error(f"Error backing up Lance table {table_dir}: {e}") - return False, org_id, project_id diff --git a/services/api/src/owl/tasks/restore.py b/services/api/src/owl/tasks/restore.py index 8b50a35..94a7688 100644 --- a/services/api/src/owl/tasks/restore.py +++ b/services/api/src/owl/tasks/restore.py @@ -1,25 +1,17 @@ import multiprocessing import os -import re import sqlite3 import time -from io import BytesIO import boto3 import click -import lance -import pyarrow.parquet as pq from botocore.client import Config from loguru import logger -from tqdm import tqdm -from owl import protocol as p -from owl.configs.manager import ENV_CONFIG -from owl.db.gen_table import GenerativeTable -from owl.protocol import TableMetaResponse +from owl.configs import ENV_CONFIG from owl.utils.logging import setup_logger_sinks -setup_logger_sinks(f"{ENV_CONFIG.owl_log_dir}/restoration.log") +setup_logger_sinks(f"{ENV_CONFIG.log_dir}/restoration.log") logger.info(f"Using configuration: {ENV_CONFIG}") @@ -41,7 +33,7 @@ def _initialize_databases(table_info_list): project_id = item["project_id"] table_type = item["table_type"] - lance_path = os.path.join(ENV_CONFIG.owl_db_dir, org_id, project_id, table_type) + lance_path = os.path.join(ENV_CONFIG.db_dir, org_id, project_id, table_type) sqlite_path = f"{lance_path}.db" if table_type != "file": if sqlite_path not in initialized_dbs: @@ -55,57 +47,57 @@ def get_default_workers(): return max(multiprocessing.cpu_count() * 8, 1) -def restore(item): - import asyncio - - try: - s3_client = _get_s3_client() - org_id = item["org_id"] - project_id = item["project_id"] - table_type = item["table_type"] - table_parquet = item["table_parquet"] - - if table_type == "file": - file_parquet_key = os.path.join( - item["datetime"], "db", org_id, project_id, "file", "file.parquet" - ) - file_lance_dir = os.path.join( - ENV_CONFIG.owl_db_dir, org_id, project_id, "file", "file.lance" - ) - logger.info(f"Processing {org_id}/{project_id}/{table_type}/{table_parquet}") - - if not os.path.exists(file_lance_dir): - response = s3_client.get_object( - Bucket=ENV_CONFIG.s3_backup_bucket_name, Key=file_parquet_key - ) - logger.info(f"Processing {org_id}/{project_id}/file/file.parquet") - body = response["Body"].read() - parquet_table = pq.read_table(BytesIO(body)) - lance.write_dataset(parquet_table, file_lance_dir) - else: - object_key = ( - f"{item['datetime']}/db/{org_id}/{project_id}/{table_type}/{table_parquet}" - ) - logger.info(f"Processing {org_id}/{project_id}/{table_type}/{table_parquet}") - response = s3_client.get_object( - Bucket=ENV_CONFIG.s3_backup_bucket_name, Key=object_key - ) - table_id = re.sub(r"\.parquet$", "", table_parquet, flags=re.IGNORECASE) - table = GenerativeTable.from_ids(org_id, project_id, p.TableType(table_type)) - - body = response["Body"].read() - with table.create_session() as session: - _, meta = asyncio.run( - table.import_parquet( - session=session, - source=BytesIO(body), - table_id_dst=table_id, - ) - ) - meta = TableMetaResponse(**meta.model_dump(), num_rows=table.count_rows(meta.id)) - return meta - except Exception as e: - logger.error(f"Failed to import table from parquet due to {e.__class__.__name__}: {e}") +# def restore(item): +# import asyncio + +# try: +# s3_client = _get_s3_client() +# org_id = item["org_id"] +# project_id = item["project_id"] +# table_type = item["table_type"] +# table_parquet = item["table_parquet"] + +# if table_type == "file": +# file_parquet_key = os.path.join( +# item["datetime"], "db", org_id, project_id, "file", "file.parquet" +# ) +# file_lance_dir = os.path.join( +# ENV_CONFIG.db_dir, org_id, project_id, "file", "file.lance" +# ) +# logger.info(f"Processing {org_id}/{project_id}/{table_type}/{table_parquet}") + +# if not os.path.exists(file_lance_dir): +# response = s3_client.get_object( +# Bucket=ENV_CONFIG.s3_backup_bucket_name, Key=file_parquet_key +# ) +# logger.info(f"Processing {org_id}/{project_id}/file/file.parquet") +# body = response["Body"].read() +# parquet_table = pq.read_table(BytesIO(body)) +# lance.write_dataset(parquet_table, file_lance_dir) +# else: +# object_key = ( +# f"{item['datetime']}/db/{org_id}/{project_id}/{table_type}/{table_parquet}" +# ) +# logger.info(f"Processing {org_id}/{project_id}/{table_type}/{table_parquet}") +# response = s3_client.get_object( +# Bucket=ENV_CONFIG.s3_backup_bucket_name, Key=object_key +# ) +# table_id = re.sub(r"\.parquet$", "", table_parquet, flags=re.IGNORECASE) +# table = GenerativeTable.from_ids(org_id, project_id, p.TableType(table_type)) + +# body = response["Body"].read() +# with table.create_session() as session: +# _, meta = asyncio.get_event_loop().run_until_complete( +# table.import_parquet( +# session=session, +# source=BytesIO(body), +# table_id_dst=table_id, +# ) +# ) +# meta = TableMetaResponse(**meta.model_dump(), num_rows=table.count_rows(meta.id)) +# return meta +# except Exception as e: +# logger.error(f"Failed to import table from parquet due to {e.__class__.__name__}: {e}") @click.command() @@ -118,13 +110,13 @@ def main(): total_objects = 0 fetch_start_time = time.time() - # Ask for the number of workers - max_workers = get_default_workers() - workers = click.prompt( - f"Enter the number of worker processes to use (1-{max_workers}). Default:", - type=click.IntRange(1, max_workers), - default=max_workers, - ) + # # Ask for the number of workers + # max_workers = get_default_workers() + # workers = click.prompt( + # f"Enter the number of worker processes to use (1-{max_workers}). Default:", + # type=click.IntRange(1, max_workers), + # default=max_workers, + # ) click.echo("Fetching S3 objects...") while True: @@ -198,37 +190,37 @@ def main(): logger.error(f"An error occurred: {e}") # Check if database files exist and ask for overwrite confirmation - current_files = os.listdir(ENV_CONFIG.owl_db_dir) + current_files = os.listdir(ENV_CONFIG.db_dir) if current_files: - click.echo(f"Current database path: {ENV_CONFIG.owl_db_dir}") + click.echo(f"Current database path: {ENV_CONFIG.db_dir}") if not click.confirm("Do you want to overwrite the existing files?"): click.echo("Operation cancelled.") return else: - click.echo(f"Current database path: {ENV_CONFIG.owl_db_dir}") + click.echo(f"Current database path: {ENV_CONFIG.db_dir}") if not click.confirm("Confirm restoring to this directory?"): click.echo("Operation cancelled.") return - table_info_list = sorted(table_info_list, key=lambda x: x["org_id"]) - filtered_list = [item for item in table_info_list if item["datetime"] == specific_date] - - # Use this before starting the multiprocessing pool - _initialize_databases(filtered_list) - click.echo(f"Using {workers} worker processes") - tic = time.time() - - with multiprocessing.Pool(workers, maxtasksperchild=2) as pool: - list( - tqdm( - pool.imap_unordered(restore, filtered_list), - total=len(filtered_list), - desc="Importing tables", - unit="table", - ) - ) - - click.echo(f"Import completed successfully! {time.time() - tic:.2f}s") + # table_info_list = sorted(table_info_list, key=lambda x: x["org_id"]) + # filtered_list = [item for item in table_info_list if item["datetime"] == specific_date] + + # # Use this before starting the multiprocessing pool + # _initialize_databases(filtered_list) + # click.echo(f"Using {workers} worker processes") + # tic = time.time() + + # with multiprocessing.Pool(workers, maxtasksperchild=2) as pool: + # list( + # tqdm( + # pool.imap_unordered(restore, filtered_list), + # total=len(filtered_list), + # desc="Importing tables", + # unit="table", + # ) + # ) + + # click.echo(f"Import completed successfully! {time.time() - tic:.2f}s") except Exception as e: logger.error(f"Failed to import table from parquet: {e}") diff --git a/services/api/src/owl/tasks/storage.py b/services/api/src/owl/tasks/storage.py deleted file mode 100644 index 866db6c..0000000 --- a/services/api/src/owl/tasks/storage.py +++ /dev/null @@ -1,192 +0,0 @@ -import asyncio -import pathlib -from datetime import timedelta -from time import perf_counter - -from celery import Celery -from filelock import FileLock, Timeout -from loguru import logger - -from jamaibase import JamAI -from owl.billing import BillingManager -from owl.configs.manager import ENV_CONFIG -from owl.db.gen_table import GenerativeTable -from owl.protocol import TableType -from owl.utils.io import get_file_usage, get_storage_usage - -logger.info(f"Using configuration: {ENV_CONFIG}") -client = JamAI(token=ENV_CONFIG.service_key_plain, timeout=60.0) - -# Set up Celery -app = Celery("tasks", broker=f"redis://{ENV_CONFIG.owl_redis_host}:{ENV_CONFIG.owl_redis_port}/0") - -# Configure Celery -app.conf.update( - result_backend=f"redis://{ENV_CONFIG.owl_redis_host}:{ENV_CONFIG.owl_redis_port}/0", - task_serializer="json", - accept_content=["json"], - result_serializer="json", - timezone="UTC", - enable_utc=True, -) - -logger.info(f"Using configuration: {ENV_CONFIG}") - - -def _iter_all_tables(batch_size: int = 200): - table_types = [TableType.ACTION, TableType.KNOWLEDGE, TableType.CHAT] - db_dir = pathlib.Path(ENV_CONFIG.owl_db_dir) - for org_dir in db_dir.iterdir(): - if not org_dir.is_dir() or not org_dir.name.startswith(("org_", "default")): - continue - for project_dir in org_dir.iterdir(): - if not project_dir.is_dir(): - continue - for table_type in table_types: - table = GenerativeTable.from_ids(org_dir.name, project_dir.name, table_type) - with table.create_session() as session: - offset, total = 0, 1 - while offset < total: - metas, total = table.list_meta( - session, - offset=offset, - limit=batch_size, - remove_state_cols=True, - parent_id=None, - ) - offset += batch_size - for meta in metas: - yield ( - session, - table, - meta, - f"{project_dir}/{table_type}/{meta.id}", - ) - - -@app.task -def periodic_storage_update(): - # Cloud client - if ENV_CONFIG.is_oss: - return - - lock = FileLock(f"{ENV_CONFIG.owl_db_dir}/periodic_storage_update.lock", blocking=False) - try: - t0 = perf_counter() - with lock: - file_usages = get_file_usage(ENV_CONFIG.owl_db_dir) - db_usages = get_storage_usage(ENV_CONFIG.owl_db_dir) - num_ok = num_skipped = num_failed = 0 - for org_id in db_usages: - if not org_id.startswith("org_"): - continue - db_usage_gib = db_usages[org_id] - file_usage_gib = file_usages[org_id] - try: - org = client.admin.backend.get_organization(org_id) - manager = BillingManager( - organization=org, - project_id="", - user_id="", - request=None, - ) - manager.create_storage_events(db_usage_gib, file_usage_gib) - asyncio.get_event_loop().run_until_complete(manager.process_all()) - num_ok += 1 - except Exception as e: - logger.warning((f"Storage usage update failed for {org_id}: {e}")) - num_failed += 1 - t = perf_counter() - t0 - logger.info( - ( - f"Periodic storage usage update completed (t={t:,.3f} s, " - f"{num_ok:,d} OK, {num_skipped:,d} skipped, {num_failed:,d} failed)." - ) - ) - except Timeout: - pass - except Exception as e: - logger.exception(f"Periodic storage usage update failed due to {e}") - - -@app.task -def lance_periodic_reindex(): - lock = FileLock(f"{ENV_CONFIG.owl_db_dir}/periodic_reindex.lock", timeout=0) - try: - with lock: - t0 = perf_counter() - num_ok = num_skipped = num_failed = 0 - for session, table, meta, table_path in _iter_all_tables(): - if session is None: - continue - try: - reindexed = table.create_indexes(session, meta.id) - if reindexed: - num_ok += 1 - else: - num_skipped += 1 - except Timeout: - logger.warning(f"Periodic Lance re-indexing skipped for table: {table_path}") - num_skipped += 1 - except Exception: - logger.exception(f"Periodic Lance re-indexing failed for table: {table_path}") - num_failed += 1 - t = perf_counter() - t0 - logger.info( - ( - f"Periodic Lance re-indexing completed (t={t:,.3f} s, " - f"{num_ok:,d} OK, {num_skipped:,d} skipped, {num_failed:,d} failed)." - ) - ) - except Timeout: - logger.info("Periodic Lance re-indexing skipped due to lock.") - except Exception as e: - logger.exception(f"Periodic Lance re-indexing failed due to {e}") - - -@app.task -def lance_periodic_optimize(): - lock = FileLock(f"{ENV_CONFIG.owl_db_dir}/periodic_optimization.lock", timeout=0) - try: - with lock: - t0 = perf_counter() - num_ok = num_skipped = num_failed = 0 - for _, table, meta, table_path in _iter_all_tables(): - done = True - try: - if meta is None: - done = done and table.compact_files() - done = done and table.cleanup_old_versions( - older_than=timedelta( - minutes=ENV_CONFIG.owl_remove_version_older_than_mins - ), - ) - else: - done = done and table.compact_files(meta.id) - done = done and table.cleanup_old_versions( - meta.id, - older_than=timedelta( - minutes=ENV_CONFIG.owl_remove_version_older_than_mins - ), - ) - if done: - num_ok += 1 - else: - num_skipped += 1 - except Timeout: - logger.warning(f"Periodic Lance optimization skipped for table: {table_path}") - num_skipped += 1 - except Exception: - logger.exception(f"Periodic Lance optimization failed for table: {table_path}") - num_failed += 1 - t = perf_counter() - t0 - logger.info( - ( - f"Periodic Lance optimization completed (t={t:,.3f} s, " - f"{num_ok:,d} OK, {num_skipped:,d} skipped, {num_failed:,d} failed)." - ) - ) - except Timeout: - logger.info("Periodic Lance optimization skipped due to lock.") - except Exception as e: - logger.exception(f"Periodic Lance optimization failed due to {e}") diff --git a/services/api/src/owl/templates/.gitignore b/services/api/src/owl/templates/.gitignore deleted file mode 100644 index 2782799..0000000 --- a/services/api/src/owl/templates/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -# Include Parquet files -!*.parquet \ No newline at end of file diff --git a/services/api/src/owl/templates/f1_due_diligence/action/Due_Diligence_ARM.parquet b/services/api/src/owl/templates/f1_due_diligence/action/Due_Diligence_ARM.parquet deleted file mode 100644 index 541d16b..0000000 --- a/services/api/src/owl/templates/f1_due_diligence/action/Due_Diligence_ARM.parquet +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:06f0f218779d43a88233d8c729b7e5d3701244629fea14eeb52a9088611bdf90 -size 45308 diff --git a/services/api/src/owl/templates/f1_due_diligence/knowledge/Form_F1_ARM.parquet b/services/api/src/owl/templates/f1_due_diligence/knowledge/Form_F1_ARM.parquet deleted file mode 100644 index 4062cdd..0000000 --- a/services/api/src/owl/templates/f1_due_diligence/knowledge/Form_F1_ARM.parquet +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6e62e7f76f7d254948b560afb91800e396ab04a2124811ef563697cf48e4ed2d -size 8949864 diff --git a/services/api/src/owl/templates/f1_due_diligence/template_meta.json b/services/api/src/owl/templates/f1_due_diligence/template_meta.json deleted file mode 100644 index 22973e4..0000000 --- a/services/api/src/owl/templates/f1_due_diligence/template_meta.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "name": "Form F-1 Due Diligence", - "description": "Performs financial due diligence based on Form F-1 filings.", - "tags": ["sector:Finance", "task:Research"], - "created_at": "2024-09-30T15:38:13.747349+00:00" -} diff --git a/services/api/src/owl/types/__init__.py b/services/api/src/owl/types/__init__.py new file mode 100644 index 0000000..10ed78a --- /dev/null +++ b/services/api/src/owl/types/__init__.py @@ -0,0 +1,1032 @@ +from datetime import datetime +from enum import StrEnum +from os.path import splitext +from typing import Annotated, Any, Generic, Literal, Self, Type, TypeVar + +import pandas as pd +import pyarrow as pa +from fastapi import File, UploadFile +from pydantic import ( + AfterValidator, + BaseModel, + BeforeValidator, + Field, + model_validator, +) + +from jamaibase import types as t +from jamaibase.types import ( # noqa: F401 + CITATION_PATTERN, + DEFAULT_MUL_LANGUAGES, + EXAMPLE_CHAT_MODEL_IDS, + EXAMPLE_EMBEDDING_MODEL_IDS, + EXAMPLE_RERANKING_MODEL_IDS, + AgentMetaResponse, + AudioContent, + AudioContentData, + AudioResponse, + CellCompletionResponse, + CellReferencesResponse, + ChatCompletionChoice, + ChatCompletionChunkResponse, + ChatCompletionDelta, + ChatCompletionMessage, + ChatCompletionResponse, + ChatCompletionUsage, + ChatContent, + ChatContentS3, + ChatEntry, + ChatRequest, + ChatRole, + ChatThreadEntry, + ChatThreadResponse, + ChatThreadsResponse, + Chunk, + CloudProvider, + CodeGenConfig, + CodeInterpreterTool, + ColumnDropRequest, + ColumnReorderRequest, + CompletionUsageDetails, + ConversationCreateRequest, + ConversationMetaResponse, + ConversationThreadsResponse, + CSVDelimiter, + DatetimeUTC, + DBStorageUsageData, + Deployment_, + DeploymentCreate, + DeploymentRead, + DeploymentUpdate, + DiscriminatedGenConfig, + EgressUsageData, + EmbeddingModelPrice, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, + EmbeddingUsage, + EmbedGenConfig, + EmbedUsageData, + EmptyIfNoneStr, + FilePath, + FileStorageUsageData, + FileUploadResponse, + Function, + FunctionCall, + FunctionParameters, + GenConfigUpdateRequest, + GetURLRequest, + GetURLResponse, + Host, + ImageContent, + ImageContentData, + JSONInput, + JSONInputBin, + JSONOutput, + JSONOutputBin, + LanguageCodeList, + LLMGenConfig, + LLMModelPrice, + LlmUsageData, + LogProbs, + LogProbToken, + LogQueryResponse, + MessageAddRequest, + MessagesRegenRequest, + MessageUpdateRequest, + Metric, + ModelCapability, + ModelConfig_, + ModelConfigCreate, + ModelConfigRead, + ModelConfigUpdate, + ModelInfo, + ModelInfoListResponse, + ModelInfoRead, + ModelPrice, + ModelProvider, + ModelType, + MultiRowCompletionResponse, + MultiRowDeleteRequest, + NullableStr, + OkResponse, + OnPremProvider, + OrganizationCreate, + OrganizationUpdate, + OrgMember_, + OrgMemberCreate, + OrgMemberRead, + OrgMemberUpdate, + Page, + PasswordChangeRequest, + PasswordLoginRequest, + PaymentState, + PositiveInt, + PositiveNonZeroInt, + PricePlan_, + PricePlanCreate, + PricePlanRead, + PricePlanUpdate, + PriceTier, + Product, + Products, + ProductType, + Progress, + ProgressStage, + ProgressState, + Project_, + ProjectCreate, + ProjectKey_, + ProjectKeyCreate, + ProjectKeyRead, + ProjectKeyUpdate, + ProjectMember_, + ProjectMemberCreate, + ProjectMemberRead, + ProjectMemberUpdate, + ProjectRead, + ProjectUpdate, + PromptUsageDetails, + PythonGenConfig, + RAGParams, + RankedRole, + References, + RerankingApiVersion, + RerankingBilledUnits, + RerankingData, + RerankingMeta, + RerankingMetaUsage, + RerankingModelPrice, + RerankingRequest, + RerankingResponse, + RerankingUsage, + RerankUsageData, + Role, + RowCompletionResponse, + S3Content, + SanitisedMultilineStr, + SanitisedNonEmptyStr, + SanitisedStr, + SearchRequest, + SplitChunksParams, + SplitChunksRequest, + StripeEventData, + StripePaymentInfo, + TableDataImportRequest, + TableImportProgress, + TableImportRequest, + TableMeta, + TableMetaResponse, + TableType, + TextContent, + ToolCall, + ToolChoice, + ToolChoiceFunction, + ToolUsageDetails, + Usage, + UsageData, + UsageResponse, + User_, + UserAgent, + UserAuth, + UserRead, + UserReadObscured, + VerificationCode_, + VerificationCodeCreate, + VerificationCodeRead, + VerificationCodeUpdate, + WebSearchTool, + YAMLInput, + YAMLOutput, + empty_string_to_none, + none_to_empty_string, +) +from jamaibase.utils import uuid7_str +from owl.types.db import ( # noqa: F401 + Organization_, + OrganizationRead, + OrganizationReadDecrypt, + ProjectKeyReadDecrypt, + UserCreate, + UserUpdate, +) +from owl.version import __version__ + + +class StripeEventType(StrEnum): + INVOICE_PAID = "invoice.paid" + INVOICE_PAYMENT_FAILED = "invoice.payment_failed" + INVOICE_MARKED_UNCOLLECTIBLE = "invoice.marked_uncollectible" + INVOICE_VOIDED = "invoice.voided" + # PAYMENT_INTENT_PROCESSING = "payment_intent.processing" + # PAYMENT_INTENT_SUCCEEDED = "payment_intent.succeeded" + # CUSTOMER_SUBSCRIPTION_DELETED = "customer.subscription.deleted" + CHARGE_SUCCEEDED = "charge.succeeded" + CHARGE_REFUNDED = "charge.refunded" + + +TABLE_NAME_PATTERN = r"^[A-Za-z0-9]([A-Za-z0-9.?!@#$%^&*_()\- ]*[A-Za-z0-9.?!()\-])?$" +COLUMN_NAME_PATTERN = TABLE_NAME_PATTERN +GEN_CONFIG_VAR_PATTERN = r"(? str: + return value.replace("\0", "") + + +PostgresSafeStr = Annotated[ + str, + AfterValidator(_str_post_validator), +] + +_MAP_TO_POSTGRES_TYPE = { + "int": "INTEGER", + "int8": "INTEGER", + "float": "FLOAT", + "float32": "FLOAT", + "float16": "FLOAT", + "bool": "BOOL", + "str": "TEXT", + "image": "TEXT", + "audio": "TEXT", + "document": "TEXT", + "date-time": "TIMESTAMPTZ", + "json": "JSONB", +} +_MAP_TO_PYTHON_TYPE = { + "int": int, + "int8": int, + "float": float, + "float32": float, + "float16": float, + "bool": bool, + "str": PostgresSafeStr, + "image": PostgresSafeStr, + "audio": PostgresSafeStr, + "document": PostgresSafeStr, + "date-time": datetime, + "json": dict, +} +_MAP_TO_PANDAS_TYPE = { + "int": pd.Int64Dtype(), + "int8": pd.Int8Dtype(), + "float": pd.Float64Dtype(), + "float32": pd.Float32Dtype(), + "float16": pd.Float32Dtype(), + "bool": pd.BooleanDtype(), + "str": pd.StringDtype(), + "image": pd.StringDtype(), + "audio": pd.StringDtype(), + "document": pd.StringDtype(), + "date-time": pd.StringDtype(), # Convert to ISO format first + "json": pd.StringDtype(), # In general, we should not export JSON +} +_MAP_TO_PYARROW_TYPE = { + "int": pa.int64(), + "int8": pa.int8(), + "float": pa.float64(), + "float32": pa.float32(), + "float16": pa.float16(), + "bool": pa.bool_(), + "str": pa.utf8(), + "image": pa.utf8(), # Store URI + "audio": pa.utf8(), # Store URI + "document": pa.utf8(), # Store URI + "date-time": pa.timestamp("us", "UTC"), + "json": pa.utf8(), +} + + +class DBStorageUsage(BaseModel): + schema_name: str + table_names: list[str] + table_sizes: list[float] + + @property + def total_size(self) -> float: + return sum(self.table_sizes) + + +class ColumnDtype(StrEnum): + INT = "int" + FLOAT = "float" + BOOL = "bool" + STR = "str" + IMAGE = "image" + AUDIO = "audio" + DOCUMENT = "document" + # Internal types + # INT8 = "int8" + # FLOAT32 = "float32" + # FLOAT16 = "float16" + DATE_TIME = "date-time" + JSON = "json" + + def to_postgres_type(self) -> str: + """ + Returns the corresponding PostgreSQL type definition. + """ + return _MAP_TO_POSTGRES_TYPE[self] + + def to_python_type(self) -> Type[int | float | bool | str | datetime | dict]: + """ + Returns the corresponding Python type. + """ + return _MAP_TO_PYTHON_TYPE[self] + + def to_pandas_type( + self, + ) -> ( + pd.Int64Dtype + | pd.Int8Dtype + | pd.Float64Dtype + | pd.Float32Dtype + | pd.BooleanDtype + | pd.StringDtype + ): + """ + Returns the corresponding Python type. + """ + return _MAP_TO_PANDAS_TYPE[self] + + def to_pyarrow_type(self) -> pa.DataType: + """ + Returns the corresponding Python type. + """ + return _MAP_TO_PYARROW_TYPE[self] + + +class ColumnDtypeCreate(StrEnum): + INT = "int" + FLOAT = "float" + BOOL = "bool" + STR = "str" + IMAGE = "image" + AUDIO = "audio" + DOCUMENT = "document" + + def to_column_type(self) -> ColumnDtype: + """ + Returns the corresponding ColumnDtype. + """ + return ColumnDtype(self) + + +class ColumnSchema(t.ColumnSchema): + dtype: ColumnDtype = Field( + ColumnDtype.STR, + description=f"Column data type, one of {list(map(str, ColumnDtype))}.", + ) + + +class ColumnSchemaCreate(t.ColumnSchemaCreate): + id: ColName = Field( + description="Column name.", + ) + dtype: ColumnDtypeCreate = Field( + ColumnDtypeCreate.STR, + description=f"Column data type, one of {list(map(str, ColumnDtypeCreate))}.", + ) + + @model_validator(mode="before") + @classmethod + def map_file_dtype_to_image(cls, data: dict[str, Any]) -> dict[str, Any]: + if data.get("dtype", "") == "file": + data["dtype"] = ColumnDtype.IMAGE + return data + + +class TableSchemaCreate(t.TableSchemaCreate): + id: TableName = Field( + description="Table name.", + ) + version: str = Field( + __version__, + description="Table version, following jamaibase version.", + ) + cols: list[ColumnSchemaCreate] = Field( + description="List of column schema.", + ) + + +class ActionTableSchemaCreate(TableSchemaCreate): + pass + + +class AddActionColumnSchema(ActionTableSchemaCreate): + pass + + +class KnowledgeTableSchemaCreate(TableSchemaCreate): + embedding_model: str + + +class AddKnowledgeColumnSchema(TableSchemaCreate): + pass + + +class ChatTableSchemaCreate(TableSchemaCreate): + pass + + +class AddChatColumnSchema(TableSchemaCreate): + pass + + +class ColumnRenameRequest(t.ColumnRenameRequest): + column_map: dict[str, ColName] = Field( + description="Mapping of old column names to new column names.", + ) + + +IMAGE_FILE_EXTENSIONS = [".jpeg", ".jpg", ".png", ".gif", ".webp"] +AUDIO_FILE_EXTENSIONS = [".mp3", ".wav"] +DOCUMENT_FILE_EXTENSIONS = [ + ".csv", + ".docx", + ".html", + ".json", + ".jsonl", + ".md", + ".pdf", + ".pptx", + ".tsv", + ".txt", + ".xlsx", + ".xml", +] +ALLOWED_FILE_EXTENSIONS = set( + IMAGE_FILE_EXTENSIONS + AUDIO_FILE_EXTENSIONS + DOCUMENT_FILE_EXTENSIONS +) + + +def check_data(value: Any) -> Any: + if isinstance(value, str) and (value.startswith("s3://") or value.startswith("file://")): + extension = splitext(value)[1].lower() + if extension not in ALLOWED_FILE_EXTENSIONS: + raise ValueError( + "Unsupported file type. Make sure the file belongs to " + "one of the following formats: \n" + f"[Image File Types]: \n{IMAGE_FILE_EXTENSIONS} \n" + f"[Audio File Types]: \n{AUDIO_FILE_EXTENSIONS} \n" + f"[Document File Types]: \n{DOCUMENT_FILE_EXTENSIONS}" + ) + return value + + +CellValue = Annotated[Any, AfterValidator(check_data)] + + +class RowAdd(BaseModel): + table_id: str = Field( + description="Table name or ID.", + ) + data: dict[str, CellValue] = Field( + description="Mapping of column names to its value.", + ) + stream: bool = Field( + default=True, + description="Whether or not to stream the LLM generation.", + ) + concurrent: bool = Field( + default=True, + description="_Optional_. Whether or not to concurrently generate the output columns.", + ) + + +class MultiRowAddRequest(t.MultiRowAddRequest): + data: list[dict[str, CellValue]] = Field( + min_length=1, + description=( + "List of mapping of column names to its value. " + "In other words, each item in the list is a row, and each item is a mapping. " + "Minimum 1 row, maximum 100 rows." + ), + ) + + +class MultiRowAddRequestWithLimit(MultiRowAddRequest): + data: list[dict[str, CellValue]] = Field( + min_length=1, + max_length=100, + description=( + "List of mapping of column names to its value. " + "In other words, each item in the list is a row, and each item is a mapping. " + "Minimum 1 row, maximum 100 rows." + ), + ) + + +class MultiRowUpdateRequest(t.MultiRowUpdateRequest): + data: dict[str, dict[str, CellValue]] = Field( + min_length=1, + description="Mapping of row IDs to row data, where each row data is a mapping of column names to its value.", + ) + + +class MultiRowUpdateRequestWithLimit(MultiRowUpdateRequest): + data: dict[str, dict[str, CellValue]] = Field( + min_length=1, + max_length=100, + description="Mapping of row IDs to row data, where each row data is a mapping of column names to its value.", + ) + + +class RowUpdateRequest(t.RowUpdateRequest): + data: dict[str, CellValue] = Field( + description="Mapping of column names to its value.", + ) + + +class RegenStrategy(StrEnum): + """Strategies for selecting columns during row regeneration.""" + + RUN_ALL = "run_all" + RUN_BEFORE = "run_before" + RUN_SELECTED = "run_selected" + RUN_AFTER = "run_after" + + +class RowRegen(t.RowRegen): + regen_strategy: RegenStrategy = Field( + default=RegenStrategy.RUN_ALL, + description=( + "_Optional_. Strategy for selecting columns to regenerate." + "Choose `run_all` to regenerate all columns in the specified row; " + "Choose `run_before` to regenerate columns up to the specified column_id; " + "Choose `run_selected` to regenerate only the specified column_id; " + "Choose `run_after` to regenerate columns starting from the specified column_id; " + ), + ) + + +class MultiRowRegenRequest(t.MultiRowRegenRequest): + regen_strategy: RegenStrategy = Field( + default=RegenStrategy.RUN_ALL, + description=( + "_Optional_. Strategy for selecting columns to regenerate." + "Choose `run_all` to regenerate all columns in the specified row; " + "Choose `run_before` to regenerate columns up to the specified column_id; " + "Choose `run_selected` to regenerate only the specified column_id; " + "Choose `run_after` to regenerate columns starting from the specified column_id; " + ), + ) + + @model_validator(mode="after") + def check_output_column_id_provided(self) -> Self: + if self.regen_strategy != RegenStrategy.RUN_ALL and self.output_column_id is None: + raise ValueError( + "`output_column_id` is required for regen_strategy other than 'run_all'." + ) + return self + + @model_validator(mode="after") + def sort_row_ids(self) -> Self: + self.row_ids = sorted(self.row_ids) + return self + + +class FileEmbedQuery(BaseModel): + table_id: SanitisedNonEmptyStr = Field( + description="Table name or ID.", + ) + file_id: SanitisedNonEmptyStr = Field( + description="ID of the file.", + ) + chunk_size: int = Field( + 1000, + gt=0, + description="Maximum chunk size (number of characters). Must be > 0.", + ) + chunk_overlap: int = Field( + 200, + ge=0, + description="Overlap in characters between chunks. Must be >= 0.", + ) + # stream: Annotated[bool, Field(description="Whether or not to stream the LLM generation.")] = ( + # True + # ) + + +ORDERED_BY = TypeVar("ORDERED_BY", bound=Literal["id", "name", "created_at", "updated_at"]) + + +class ListQuery(BaseModel, Generic[ORDERED_BY]): + offset: Annotated[int, Field(ge=0, description="Items offset.")] = 0 + limit: Annotated[int, Field(gt=0, le=1000, description="Number of items.")] = 1000 + order_by: Annotated[ORDERED_BY, Field(description="Sort by this attribute.")] = "updated_at" + order_ascending: Annotated[bool, Field(description="Whether to sort in ascending order.")] = True # fmt: skip + search_query: Annotated[ + str, + Field( + max_length=10_000, + description=( + "A string to search for as a filter. " + 'The string is interpreted as both POSIX regular expression and literal string. Defaults to "" (no filter). ' + "It will be combined other filters using `AND`." + ), + ), + ] = "" + search_columns: Annotated[ + list[str], + Field( + min_length=1, + description='A list of attribute names to search for `search_query`. Defaults to `["name"]`.', + ), + ] = ["name"] + after: Annotated[ + str | None, + Field( + description=( + "Opaque cursor token to paginate results. " + "If provided, the query will return items after this cursor and `offset` will be ignored. " + "Defaults to `None` (no cursor)." + ), + ), + ] = None + + +class ListQueryByOrg(ListQuery): + organization_id: Annotated[SanitisedNonEmptyStr, Field(description="Organization ID.")] + + +class ListQueryByOrgOptional(ListQuery): + organization_id: Annotated[ + SanitisedNonEmptyStr | None, Field(None, description="Organization ID.") + ] + + +class ListQueryByProject(ListQuery): + project_id: Annotated[SanitisedNonEmptyStr, Field(description="Project ID.")] + + +class OrgModelCatalogueQuery(ListQueryByOrg): + capabilities: list[ModelCapability] | None = Field( + None, + min_length=1, + description="List of capabilities of model.", + ) + + +class DuplicateTableQuery(BaseModel): + table_id_src: Annotated[str, Field(description="Name of the table to be duplicated.")] + table_id_dst: Annotated[ + TableName | None, + Field( + description=( + "Name for the new table. " + "Defaults to None (automatically find the next available table name)." + ) + ), + ] = None + include_data: Annotated[ + bool, + Field(description=("Whether to include data from the source table. Defaults to `True`.")), + ] = True + create_as_child: Annotated[ + bool, + Field( + description=( + "Whether the new table is a child table. Defaults to `False`. " + "If this is `True`, then `include_data` will be set to `True`." + ) + ), + ] = False + + +class RenameTableQuery(BaseModel): + table_id_src: Annotated[str, Field(description="Source table name.")] + table_id_dst: Annotated[TableName, Field(description="Name for the new table.")] + + +class GetTableThreadQuery(BaseModel): + table_id: Annotated[str, Field(description="Table name.")] + column_id: Annotated[str, Field(description="Column to fetch as a conversation thread.")] + row_id: Annotated[ + str, + Field(description='ID of the last row in the thread. Defaults to "" (export all rows).'), + ] = "" + include: Annotated[ + bool, + Field(description="Whether to include the row specified by `row_id`. Defaults to True."), + ] = True + + +class GetTableThreadsQuery(BaseModel): + table_id: Annotated[str, Field(description="Table name.")] + column_ids: Annotated[ + list[str] | None, + Field( + description="Columns to fetch as conversation threads. Defaults to None (fetch all)." + ), + ] = None + row_id: Annotated[ + str, + Field(description='ID of the last row in the thread. Defaults to "" (export all rows).'), + ] = "" + include_row: Annotated[ + bool, + Field(description="Whether to include the row specified by `row_id`. Defaults to True."), + ] = True + + +class GetConversationThreadsQuery(BaseModel): + conversation_id: Annotated[str, Field(description="Conversation ID.")] + column_ids: Annotated[ + list[str] | None, + Field( + description="Columns to fetch as conversation threads. Defaults to None (fetch all)." + ), + ] = None + + +class ListTableQuery(BaseModel): + offset: Annotated[ + int, + Field(ge=0, description="Item offset for pagination. Defaults to 0."), + ] = 0 + limit: Annotated[ + int, + Field( + gt=0, + le=100, + description="Number of tables to return (min 1, max 100). Defaults to 100.", + ), + ] = 100 + order_by: Annotated[ + Literal["id", "table_id", "updated_at"], + Field(description='Sort tables by this attribute. Defaults to "updated_at".'), + ] = "updated_at" + order_ascending: Annotated[ + bool, + Field(description="Whether to sort by ascending order. Defaults to True."), + ] = True + parent_id: Annotated[ + str | None, + Field( + min_length=1, + description=( + "Parent ID of tables to return. Defaults to None (return all tables). " + "Additionally for Chat Table, you can list: " + '(1) all chat agents by passing in "_agent_"; or ' + '(2) all chats by passing in "_chat_".' + ), + ), + ] = None + search_query: Annotated[ + str, + Field( + max_length=255, + description='A string to search for within table IDs as a filter. Defaults to "" (no filter).', + ), + ] = "" + count_rows: Annotated[ + bool, + Field(description="Whether to count the rows of the tables. Defaults to False."), + ] = False + + +class ListRowQuery(BaseModel): + offset: Annotated[ + int, + Field(ge=0, description="Item offset for pagination. Defaults to 0."), + ] = 0 + limit: Annotated[ + int, + Field( + gt=0, + le=100, + description="Number of rows to return (min 1, max 100). Defaults to 100.", + ), + ] = 100 + order_by: Annotated[ + str, + Field(description='Sort rows by this column. Defaults to "ID".'), + ] = "ID" + order_ascending: Annotated[ + bool, + Field(description="Whether to sort by ascending order. Defaults to True."), + ] = True + columns: Annotated[ + list[str] | None, + Field( + description="A list of column names to include in the response. Default is to return all columns.", + ), + ] = None + where: Annotated[ + EmptyIfNoneStr, + Field( + description=( + "SQL where clause. " + "Can be nested ie `x = '1' AND (\"y (1)\" = 2 OR z = '3')`. " + "It will be combined with `row_ids` using `AND`. " + 'Defaults to "" (no filter).' + ), + ), + ] = "" + search_query: Annotated[ + str, + Field( + max_length=10_000, + description=( + "A string to search for within row data as a filter. " + 'The string is interpreted as both POSIX regular expression and literal string. Defaults to "" (no filter). ' + "It will be combined other filters using `AND`." + ), + ), + ] = "" + search_columns: Annotated[ + list[str] | None, + Field( + description="A list of column names to search for `search_query`. Defaults to None (search all columns).", + ), + ] = None + float_decimals: Annotated[ + int, + Field( + ge=0, + description="Number of decimals for float values. Defaults to 0 (no rounding).", + ), + ] = 0 + vec_decimals: Annotated[ + int, + Field( + description="Number of decimals for vectors. If its negative, exclude vector columns. Defaults to 0 (no rounding).", + ), + ] = 0 + + +class ListTableRowQuery(ListRowQuery): + table_id: Annotated[SanitisedNonEmptyStr, Field(description="Table ID or name.")] + + +class ListMessageQuery(ListRowQuery): + conversation_id: Annotated[ + SanitisedNonEmptyStr, Field(description="Conversation ID (Table ID) to fetch.") + ] + + +class GetTableRowQuery(BaseModel): + table_id: Annotated[SanitisedNonEmptyStr, Field(description="Table name.")] + row_id: Annotated[ + SanitisedNonEmptyStr, Field(description="The ID of the specific row to fetch.") + ] + columns: Annotated[ + list[SanitisedNonEmptyStr] | None, + Field( + description="A list of column names to include in the response. Default is to return all columns.", + ), + ] = None + float_decimals: Annotated[ + int, + Field( + ge=0, + description="Number of decimals for float values. Defaults to 0 (no rounding).", + ), + ] = 0 + vec_decimals: Annotated[ + int, + Field( + description="Number of decimals for vectors. If its negative, exclude vector columns. Defaults to 0 (no rounding).", + ), + ] = 0 + + +class FileEmbedFormData(BaseModel): + file: Annotated[UploadFile, File(description="The file.")] + file_name: Annotated[str, Field(description="File name.", deprecated=True)] = "" + table_id: Annotated[SanitisedNonEmptyStr, Field(description="Knowledge Table ID.")] + # overwrite: Annotated[ + # bool, Field(description="Whether to overwrite old file with the same name.") + # ] = False, + chunk_size: Annotated[ + int, Field(gt=0, description="Maximum chunk size (number of characters). Must be > 0.") + ] = 2000 + chunk_overlap: Annotated[ + int, Field(ge=0, description="Overlap in characters between chunks. Must be >= 0.") + ] = 200 + + +class TableDataImportFormData(BaseModel): + file: Annotated[UploadFile, File(description="The CSV or TSV file.")] + file_name: Annotated[str, Field(description="File name.", deprecated=True)] = "" + table_id: Annotated[ + SanitisedNonEmptyStr, + Field(description="ID or name of the table that the data should be imported into."), + ] + stream: Annotated[bool, Field(description="Whether or not to stream the LLM generation.")] = ( + True + ) + # List of inputs is bugged as of 2024-07-14: https://github.com/tiangolo/fastapi/pull/9928/files + # TODO: Maybe we can re-enable these since the bug is for direct `Form` declaration and not Form Model + # column_names: Annotated[ + # list[ColName] | None, + # Field( + # description="_Optional_. A list of columns names if the CSV does not have header row. Defaults to None (read from CSV).", + # ), + # ] = None + # columns: Annotated[ + # list[ColName] | None, + # Field( + # description="_Optional_. A list of columns to be imported. Defaults to None (import all columns except 'ID' and 'Updated at').", + # ), + # ] = None + delimiter: Annotated[ + CSVDelimiter, + Field(description='The delimiter, can be "," or "\\t". Defaults to ",".'), + ] = CSVDelimiter.COMMA + + +class ExportTableDataQuery(BaseModel): + table_id: Annotated[SanitisedNonEmptyStr, Field(description="Table name.")] + delimiter: Annotated[ + CSVDelimiter, + Field(description='The delimiter, can be "," or "\\t". Defaults to ",".'), + ] = CSVDelimiter.COMMA + columns: Annotated[ + list[SanitisedNonEmptyStr] | None, + Field( + min_length=1, + description="_Optional_. A list of columns to be exported. Defaults to None (export all columns).", + ), + ] = None + + +TableImportName = Annotated[ + str, + Field( + pattern=TABLE_NAME_PATTERN, + min_length=1, + max_length=100, # Since we will truncate table IDs that are too long anyway + description=( + "Table name or ID. " + "Must be unique with at least 1 character and up to 46 characters. " + "Must start with an alphabet or number. " + "Characters in the middle can include space and these symbols: `.?!@#$%^&*_()-`. " + "Must end with an alphabet or number or these symbols: `.?!()-`." + ), + ), +] + + +class TableImportFormData(BaseModel): + file: Annotated[UploadFile, File(description="The Parquet file.")] + table_id_dst: Annotated[ + TableImportName | None, + BeforeValidator(empty_string_to_none), + Field(description="The ID or name of the new table."), + ] = None + blocking: Annotated[ + bool, + Field( + description=( + "If True, waits until import finishes. " + "If False, the task is submitted to a task queue and returns immediately." + ), + ), + ] = True + progress_key: Annotated[ + str, + Field( + default_factory=uuid7_str, + description="The key to use to query progress. Defaults to a random string.", + ), + ] + migrate: Annotated[ + bool, + Field(description="Whether to import in migration mode (maybe removed without notice)."), + ] = False diff --git a/services/api/src/owl/types/db.py b/services/api/src/owl/types/db.py new file mode 100644 index 0000000..69f9b07 --- /dev/null +++ b/services/api/src/owl/types/db.py @@ -0,0 +1,69 @@ +from typing import Annotated + +from pwdlib import PasswordHash +from pydantic import BaseModel, BeforeValidator, Field +from sqlmodel import Field as SqlField + +from jamaibase import types as t +from jamaibase.types import PricePlan_, SanitisedNonEmptyStr +from owl.utils.crypt import decrypt + + +def _decrypt(value: str) -> str: + from owl.configs import ENV_CONFIG + + return decrypt(value, ENV_CONFIG.encryption_key_plain) + + +def _decrypt_external_keys(value: dict[str, str] | BaseModel) -> dict[str, str]: + if isinstance(value, BaseModel): + value = value.model_dump(exclude_unset=True) + return {k: _decrypt(v) for k, v in value.items()} + + +class UserUpdate(t.UserUpdate): + password: SanitisedNonEmptyStr = Field( + "", + max_length=72, + description="Password in plain text.", + ) + + @property + def password_hash(self) -> str | None: + if self.password: + hasher = PasswordHash.recommended() + return hasher.hash(self.password) + return None + + +class UserCreate(t.UserCreate): + @property + def password_hash(self) -> str | None: + if self.password: + hasher = PasswordHash.recommended() + return hasher.hash(self.password) + return None + + +class Organization_(t.Organization_): + def get_external_key(self, provider: str) -> str: + api_key = self.external_keys.get(provider.lower(), "").strip() + return _decrypt(api_key) if api_key else "" + + +class OrganizationRead(Organization_): + price_plan: PricePlan_ | None = Field( + description="Subscribed plan.", + ) + + +class OrganizationReadDecrypt(OrganizationRead): + external_keys: Annotated[dict[str, str], BeforeValidator(_decrypt_external_keys)] = SqlField( + description="Mapping of external service provider to its API key.", + ) + + +class ProjectKeyReadDecrypt(t.ProjectKeyRead): + id: Annotated[str, BeforeValidator(_decrypt)] = Field( + description="The token after decryption.", + ) diff --git a/services/api/src/owl/unstructuredio.py b/services/api/src/owl/unstructuredio.py deleted file mode 100644 index bc9951e..0000000 --- a/services/api/src/owl/unstructuredio.py +++ /dev/null @@ -1,206 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Callable, Iterator - -from langchain_community.document_loaders.base import BaseLoader -from langchain_core.documents import Document -from loguru import logger -from unstructured_client import UnstructuredClient -from unstructured_client.models import shared -from unstructured_client.models.errors import SDKError - - -class UnstructuredBaseLoader(BaseLoader, ABC): - """Base Loader that uses `Unstructured`.""" - - def __init__( - self, - mode: str = "single", - post_processors: list[Callable] | None = None, - **unstructured_kwargs: Any, - ): - """Initialize with file path.""" - # try: - # import unstructured # noqa:F401 - # except ImportError: - # raise ValueError( - # "unstructured package not found, please install it with " - # "`pip install unstructured`" - # ) - _valid_modes = {"single", "elements", "paged"} - if mode not in _valid_modes: - raise ValueError(f"Got {mode} for `mode`, but should be one of `{_valid_modes}`") - self.mode = mode - - # if not satisfies_min_unstructured_version("0.5.4"): - # if "strategy" in unstructured_kwargs: - # unstructured_kwargs.pop("strategy") - - self.unstructured_kwargs = unstructured_kwargs - self.post_processors = post_processors or [] - - @abstractmethod - def _get_elements(self) -> list: - """Get elements.""" - - @abstractmethod - def _get_metadata(self) -> dict: - """Get metadata.""" - - def _post_process_elements(self, elements: list) -> list: - """Applies post processing functions to extracted unstructured elements. - Post processing functions are str -> str callables are passed - in using the post_processors kwarg when the loader is instantiated.""" - for element in elements: - for post_processor in self.post_processors: - element.apply(post_processor) - return elements - - def lazy_load(self) -> Iterator[Document]: - """Load file.""" - elements = self._get_elements() - self._post_process_elements(elements) - if self.mode == "elements": - for element in elements: - metadata = element["metadata"] - metadata["page"] = metadata.get("page_number", 1) - # NOTE(MthwRobinson) - the attribute check is for backward compatibility - # with unstructured<0.4.9. The metadata attributed was added in 0.4.9. - if hasattr(element, "metadata"): - metadata.update(element["metadata"]) - if hasattr(element, "type"): - metadata["type"] = element["NarrativeText"] - yield Document(page_content=str(element["text"]), metadata=metadata) - elif self.mode == "paged": - text_dict: dict[int, str] = {} - meta_dict: dict[int, dict] = {} - - for element in elements: - metadata = element["metadata"] - if hasattr(element, "metadata"): - metadata.update(element["metadata"]) - page_number = metadata.get("page_number", 1) - metadata["page"] = page_number - - # Check if this page_number already exists in docs_dict - if page_number not in text_dict: - # If not, create new entry with initial text and metadata - text_dict[page_number] = element["text"] + "\n\n" - meta_dict[page_number] = metadata - else: - # If exists, append to text and update the metadata - text_dict[page_number] += element["text"] + "\n\n" - meta_dict[page_number].update(metadata) - - # Convert the dict to a list of Document objects - for key in text_dict.keys(): - yield Document(page_content=text_dict[key], metadata=meta_dict[key]) - elif self.mode == "single": - metadata = self._get_metadata() - text = "\n\n".join([el["text"] for el in elements]) - yield Document(page_content=text, metadata=metadata) - else: - raise ValueError(f"mode of {self.mode} not supported.") - - -def partition( - filename: str, - unstructuredio_client, - **unstructured_kwargs: Any, -): - languages = unstructured_kwargs.pop("languages", ["en", "cn"]) - - with open(filename, "rb") as f: - # Note that this currently only supports a single file - files = shared.Files( - content=f.read(), - file_name=filename, - ) - - req = shared.PartitionParameters( - files=files, - # Other partition params - languages=languages, - **unstructured_kwargs, - ) - - try: - resp = unstructuredio_client.general.partition(req) - return resp.elements - except SDKError as e: - logger.error(f"UnstructuredIO SDK Error: {str(e)}") - return [] - - -class UnstructuredAPIFileLoader(UnstructuredBaseLoader): - """Load files using `Unstructured`. - - Example: - - UnstructuredAPIFileLoader( - "helloworld.txt", - mode="single", - url="http://unstructuredio:6989/general/v0/general", - api_key="ellm", - languages=["en", "cn"] - ) - - """ - - def __init__( - self, - file_path: str | list[str], - mode: str = "single", - url="https://api.unstructured.io/general/v0/general", - api_key: str = "ellm", - **unstructured_kwargs: Any, - ): - """Initialize with file path.""" - self.file_path = file_path - self.url = url - self.api_key = api_key - super().__init__(mode=mode, **unstructured_kwargs) - - def _get_elements(self) -> list: - s = UnstructuredClient(server_url=self.url, api_key_auth=self.api_key) - - if isinstance(self.file_path, list): - elements = [] - for file in self.file_path: - elements.extend( - partition(filename=file, unstructuredio_client=s, **self.unstructured_kwargs) - ) - return elements - else: - return partition( - filename=self.file_path, unstructuredio_client=s, **self.unstructured_kwargs - ) - - def _get_metadata(self) -> dict: - return {"source": self.file_path} - - -if __name__ == "__main__": - filename = "clients/python/tests/files/docx/Recommendation Letter.docx" - doc_loader = UnstructuredAPIFileLoader( - filename, - mode="single", - url="http://localhost:6989/general/v0/general", - api_key="ellm", - languages=["en", "cn"], - ).load() - - doc_loader = UnstructuredAPIFileLoader( - filename, - mode="paged", - url="http://localhost:6989/general/v0/general", - api_key="ellm", - languages=["en", "cn"], - ).load() - - doc_loader = UnstructuredAPIFileLoader( - filename, - mode="elements", - url="http://localhost:6989/general/v0/general", - api_key="ellm", - languages=["en", "cn"], - ).load() diff --git a/services/api/src/owl/utils/__init__.py b/services/api/src/owl/utils/__init__.py index 7bde827..7caf0e8 100644 --- a/services/api/src/owl/utils/__init__.py +++ b/services/api/src/owl/utils/__init__.py @@ -1,77 +1,77 @@ -from datetime import datetime, timezone -from typing import Any -from uuid import UUID - -from uuid_extensions import uuid7str as _uuid7_draft2_str -from uuid_utils import uuid7 as _uuid7 - -from jamaibase.exceptions import ResourceNotFoundError - - -def get_non_empty(mapping: dict[str, Any], key: str, default: Any): - value = mapping.get(key, None) - return value if value else default - - -def datetime_now_iso() -> str: - return datetime.now(timezone.utc).isoformat() - - -def uuid7_draft2_str(prefix: str = "") -> str: - return f"{prefix}{_uuid7_draft2_str()}" - - -def uuid7_str(prefix: str = "") -> str: - return f"{prefix}{_uuid7()}" - - -def datetime_str_from_uuid7(uuid7_str: str) -> str: - # Extract the timestamp (first 48 bits) - timestamp = UUID(uuid7_str).int >> 80 - dt = datetime.fromtimestamp(timestamp / 1000.0, tz=timezone.utc) - return dt.isoformat() - - -def datetime_str_from_uuid7_draft2(uuid7_str: str) -> str: - # https://www.ietf.org/archive/id/draft-peabody-dispatch-new-uuid-format-02.html#name-uuidv7-layout-and-bit-order - # Parse the UUID string - uuid_obj = UUID(uuid7_str) - # Extract the unix timestamp (first 36 bits) - unix_ts = uuid_obj.int >> 92 - # Extract the fractional seconds (next 24 bits) - frac_secs = (uuid_obj.int >> 68) & 0xFFFFFF - # Combine unix timestamp and fractional seconds - total_secs = unix_ts + (frac_secs / 0x1000000) - # Create a datetime object - dt = datetime.fromtimestamp(total_secs, tz=timezone.utc) - return dt.isoformat() - - -def select_external_api_key(external_api_keys, provider: str) -> str: - if provider == "ellm": - return "DUMMY_KEY" +import sqlparse +from loguru import logger +from sqlparse.sql import Comparison, Function, Identifier, Parenthesis, Where + +from jamaibase.utils import ( # noqa: F401 + get_non_empty, + get_ttl_hash, + mask_content, + mask_dict, + mask_string, + merge_dict, + run, + uuid7_draft2_str, + uuid7_str, +) + + +def validate_where_expr(expr: str, *, id_map: dict[str, str] = None) -> str: + sql = sqlparse.split(expr)[0] + sql = sql.replace("\r", " ").replace("\n", " ").replace("\t", " ").strip().rstrip(";") + if "shutdown" in sql.lower(): + raise ValueError("SQL expression contains shutdown.") + if not sql: + raise ValueError("SQL expression is empty.") + tokens = sqlparse.parse(sql)[0].tokens + if any(isinstance(t, Function) for t in tokens) > 0: + raise ValueError(f"SQL expression contains function: `{expr}`") + # Further breakdown Where + if isinstance(tokens[0], Where): + tokens = tokens[0].tokens[1:] + token_types = [] + + def _breakdown(_tokens): + for t in _tokens: + if t.ttype is None: + _breakdown(t) + else: + token_types.append((str(t), list(t.ttype))) + + _breakdown(tokens) + # logger.info(f"`{''.join(str(t) for t in tokens)}` {token_types=} {[type(t) for t in tokens]}") + dml_tokens = [t for t in token_types if t[1][-1] == "DML"] + ddl_tokens = [t for t in token_types if t[1][-1] == "DDL"] + keyword_tokens = [ + t + for t in token_types + if t[1][0] == "Keyword" and t[0].lower() not in ["and", "or", "null", "true", "false"] + ] + comment_tokens = [t for t in token_types if t[1][0] == "Comment"] + if len(dml_tokens) > 0: + raise ValueError(f"SQL expression contains DML: `{expr}`") + if len(ddl_tokens) > 0: + raise ValueError(f"SQL expression contains DDL: `{expr}`") + if len(keyword_tokens) > 0: + raise ValueError(f"SQL expression contains keyword: `{expr}`") + if len(comment_tokens) > 0 or "/*" in sql or "*/" in sql: + raise ValueError(f"SQL expression contains comment: `{expr}`") + if id_map: + mapped_tokens = [] + + def _map(_tokens): + for t in _tokens: + if isinstance(t, (Parenthesis, Comparison)): + _map(t) + elif isinstance(t, Identifier): + t = str(t).strip('"') + t = id_map.get(t, t).strip('"') + mapped_tokens.append(f'"{id_map.get(t, t)}"') + else: + mapped_tokens.append(t) + + _map(tokens) else: - try: - return getattr(external_api_keys, provider) or "DUMMY_KEY" - except AttributeError: - raise ResourceNotFoundError( - f"External API key not found for provider: {provider}" - ) from None - - -def mask_string(x: str | None) -> str | None: - if x is None: - return None - if x.startswith("[ERROR]"): - return x - return f"len={len(x)} str={x[:5]}***{x[-5:]}" - - -def mask_content(x: str | list[dict[str, str]] | None) -> str | list[dict[str, str]] | None: - if isinstance(x, list): - return [mask_content(v) for v in x] - if isinstance(x, dict): - return {k: mask_content(v) for k, v in x.items()} - if isinstance(x, str): - return mask_string(x) - return None + mapped_tokens = tokens + new_sql = "".join(str(t) for t in mapped_tokens).strip().rstrip(";") + logger.info(f"Validated SQL: `{expr}` -> `{new_sql}`") + return new_sql diff --git a/services/api/src/owl/utils/auth.py b/services/api/src/owl/utils/auth.py deleted file mode 100644 index 06f8e9e..0000000 --- a/services/api/src/owl/utils/auth.py +++ /dev/null @@ -1,471 +0,0 @@ -from functools import lru_cache -from secrets import compare_digest -from typing import Annotated, AsyncGenerator - -from fastapi import BackgroundTasks, Header, Request, Response -from httpx import RequestError -from loguru import logger -from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential - -from jamaibase import JamAIAsync -from jamaibase.exceptions import ( - AuthorizationError, - ForbiddenError, - ResourceNotFoundError, - ServerBusyError, - UnexpectedError, - UpgradeTierError, -) -from jamaibase.protocol import ( - EmbeddingModelConfig, - LLMModelConfig, - ModelDeploymentConfig, - OrganizationRead, - PATRead, - ProjectRead, - RerankingModelConfig, - UserRead, -) -from owl.billing import BillingManager -from owl.configs.manager import CONFIG, ENV_CONFIG -from owl.protocol import ExternalKeys, ModelListConfig -from owl.utils import datetime_now_iso, get_non_empty - -CLIENT = JamAIAsync(token=ENV_CONFIG.service_key_plain, timeout=60.0) -WRITE_METHODS = {"PUT", "PATCH", "POST", "DELETE", "PURGE"} -JAMAI_CLOUD_URL = "https://cloud.jamaibase.com" -NO_PROJECT_ID_MESSAGE = ( - "You didn't provide a project ID. " - 'You need to provide your project ID in an "X-PROJECT-ID" header ' - "(i.e. X-PROJECT-ID: PROJECT_ID). " - f"You can retrieve your project ID via API or from {JAMAI_CLOUD_URL}" -) -NO_TOKEN_MESSAGE = ( - "You didn't provide an authorization token. " - "You need to provide your either your Personal Access Token or organization API key (deprecated) " - 'in an "Authorization" header using Bearer auth (i.e. "Authorization: Bearer TOKEN"). ' - f"You can obtain your token from {JAMAI_CLOUD_URL}" -) -INVALID_TOKEN_MESSAGE = ( - "You provided an invalid authorization token. " - "You need to provide your either your Personal Access Token or organization API key (deprecated) " - 'in an "Authorization" header using Bearer auth (i.e. "Authorization: Bearer TOKEN"). ' - f"You can obtain your token from {JAMAI_CLOUD_URL}" -) -ORG_API_KEY_DEPRECATE_MESSAGE = ( - "Usage of organization API key is deprecated and will be removed soon. " - "Authenticate using your Personal Access Token instead." -) - - -@retry( - retry=retry_if_exception_type(RequestError), - wait=wait_random_exponential(multiplier=1, min=0.1, max=3), - stop=stop_after_attempt(3), - reraise=True, -) -async def _get_project_with_retries(project_id: str) -> ProjectRead: - return await CLIENT.admin.organization.get_project(project_id) - - -async def _get_project(request: Request, project_id: str) -> ProjectRead: - try: - return await _get_project_with_retries(project_id) - except ResourceNotFoundError as e: - raise ResourceNotFoundError(f'Project "{project_id}" is not found.') from e - except RequestError as e: - logger.warning( - f'{request.state.id} - Error fetching project "{project_id}" due to {e.__class__.__name__}: {e}' - ) - raise ServerBusyError(f"{e.__class__.__name__}: {e}") from e - except Exception as e: - raise UnexpectedError(f"{e.__class__.__name__}: {e}") from e - - -@retry( - retry=retry_if_exception_type(RequestError), - wait=wait_random_exponential(multiplier=1, min=0.1, max=3), - stop=stop_after_attempt(3), - reraise=True, -) -async def _get_organization_with_retries(org_id_or_token: str) -> OrganizationRead: - return await CLIENT.admin.backend.get_organization(org_id_or_token) - - -async def _get_organization(request: Request, org_id_or_token: str) -> OrganizationRead: - try: - return await _get_organization_with_retries(org_id_or_token) - except ResourceNotFoundError as e: - raise ResourceNotFoundError(f'Organization "{org_id_or_token}" is not found.') from e - except RequestError as e: - logger.warning( - f'{request.state.id} - Error fetching organization "{org_id_or_token}" due to {e.__class__.__name__}: {e}' - ) - raise ServerBusyError(f"{e.__class__.__name__}: {e}") from e - except Exception as e: - raise UnexpectedError(f"{e.__class__.__name__}: {e}") from e - - -@retry( - retry=retry_if_exception_type(RequestError), - wait=wait_random_exponential(multiplier=1, min=0.1, max=3), - stop=stop_after_attempt(3), - reraise=True, -) -async def _get_user_with_retries(user_id_or_token: str) -> UserRead: - return await CLIENT.admin.backend.get_user(user_id_or_token) - - -async def _get_user(request: Request, user_id_or_token: str) -> UserRead: - try: - return await _get_user_with_retries(user_id_or_token) - except ResourceNotFoundError as e: - raise ResourceNotFoundError(f'User "{user_id_or_token}" is not found.') from e - except RequestError as e: - logger.warning( - f'{request.state.id} - Error fetching user "{user_id_or_token}" due to {e.__class__.__name__}: {e}' - ) - raise ServerBusyError(f"{e.__class__.__name__}: {e}") from e - except Exception as e: - raise UnexpectedError(f"{e.__class__.__name__}: {e}") from e - - -@retry( - retry=retry_if_exception_type(RequestError), - wait=wait_random_exponential(multiplier=1, min=0.1, max=3), - stop=stop_after_attempt(3), - reraise=True, -) -async def _get_pat_with_retries(token: str) -> PATRead: - return await CLIENT.admin.backend.get_pat(token) - - -async def _get_pat(request: Request, token: str) -> PATRead: - try: - return await _get_pat_with_retries(token) - except ResourceNotFoundError as e: - raise ResourceNotFoundError(f'PAT "{token}" is not found.') from e - except RequestError as e: - logger.warning( - f'{request.state.id} - Error fetching PAT "{token}" due to {e.__class__.__name__}: {e}' - ) - raise ServerBusyError(f"{e.__class__.__name__}: {e}") from e - except Exception as e: - raise UnexpectedError(f"{e.__class__.__name__}: {e}") from e - - -def _get_external_keys(organization: OrganizationRead) -> ExternalKeys: - ext_keys = organization.external_keys - return ExternalKeys( - custom=get_non_empty(ext_keys, "custom", ENV_CONFIG.custom_api_key_plain), - openai=get_non_empty(ext_keys, "openai", ENV_CONFIG.openai_api_key_plain), - anthropic=get_non_empty(ext_keys, "anthropic", ENV_CONFIG.anthropic_api_key_plain), - gemini=get_non_empty(ext_keys, "gemini", ENV_CONFIG.gemini_api_key_plain), - cohere=get_non_empty(ext_keys, "cohere", ENV_CONFIG.cohere_api_key_plain), - groq=get_non_empty(ext_keys, "groq", ENV_CONFIG.groq_api_key_plain), - together_ai=get_non_empty(ext_keys, "together_ai", ENV_CONFIG.together_api_key_plain), - jina=get_non_empty(ext_keys, "jina", ENV_CONFIG.jina_api_key_plain), - voyage=get_non_empty(ext_keys, "voyage", ENV_CONFIG.voyage_api_key_plain), - hyperbolic=get_non_empty(ext_keys, "hyperbolic", ENV_CONFIG.hyperbolic_api_key_plain), - cerebras=get_non_empty(ext_keys, "cerebras", ENV_CONFIG.cerebras_api_key_plain), - sambanova=get_non_empty(ext_keys, "sambanova", ENV_CONFIG.sambanova_api_key_plain), - deepseek=get_non_empty(ext_keys, "deepseek", ENV_CONFIG.deepseek_api_key_plain), - ) - - -async def auth_internal_oss() -> str: - return "" - - -async def auth_internal_cloud( - bearer_token: Annotated[str, Header(alias="Authorization", description="Service key.")] = "", -) -> str: - bearer_token = bearer_token.strip().split("Bearer ") - if len(bearer_token) < 2 or bearer_token[1].strip() == "": - raise AuthorizationError(NO_TOKEN_MESSAGE) - token = bearer_token[1].strip() - if not ( - compare_digest(token, ENV_CONFIG.service_key_plain) - or compare_digest(token, ENV_CONFIG.service_key_alt_plain) - ): - raise AuthorizationError(INVALID_TOKEN_MESSAGE) - return token - - -auth_internal = auth_internal_oss if ENV_CONFIG.is_oss else auth_internal_cloud - - -AuthReturn = tuple[UserRead | None, OrganizationRead | None] - - -async def auth_user_oss() -> AuthReturn: - return None, None - - -async def auth_user_cloud( - request: Request, - response: Response, - bearer_token: Annotated[ - str, - Header( - alias="Authorization", - description="One of: Service key, user PAT or organization API key.", - ), - ] = "", - user_id: Annotated[str, Header(alias="X-USER-ID", description="User ID.")] = "", -) -> AuthReturn: - bearer_token = bearer_token.strip() - bearer_token = bearer_token.split("Bearer ") - if len(bearer_token) < 2 or bearer_token[1].strip() == "": - raise AuthorizationError(NO_TOKEN_MESSAGE) - - # Authenticate - user = org = None - token = bearer_token[1].strip() - if ( - compare_digest(token, ENV_CONFIG.service_key_plain) - or compare_digest(token, ENV_CONFIG.service_key_alt_plain) - or token.startswith("jamai_sk_") - ): - if token.startswith("jamai_sk_"): - org = await _get_organization(request, token) - response.headers["Warning"] = f'299 - "{ORG_API_KEY_DEPRECATE_MESSAGE}"' - if user_id: - user = await _get_user(request, user_id) - - elif token.startswith("jamai_pat_"): - user = await _get_user(request, token) - - elif user := request.session.get("user", None) is not None: - user = UserRead(**user) - - else: - raise AuthorizationError(INVALID_TOKEN_MESSAGE) - return user, org - - -auth_user = auth_user_oss if ENV_CONFIG.is_oss else auth_user_cloud - - -def _get_valid_deployments( - model: LLMModelConfig | EmbeddingModelConfig | RerankingModelConfig, - valid_providers: list[str], -) -> list[ModelDeploymentConfig]: - valid_deployments = [] - for deployment in model.deployments: - if deployment.provider in valid_providers: - valid_deployments.append(deployment) - return valid_deployments - - -@lru_cache(maxsize=64) -def _get_valid_modellistconfig(all_models: str, external_keys: str) -> ModelListConfig: - all_models = ModelListConfig.model_validate_json(all_models) - external_keys = ExternalKeys.model_validate_json(external_keys) - # define all possible api providers - available_providers = [ - "openai", - "anthropic", - "together_ai", - "cohere", - "sambanova", - "cerebras", - "hyperbolic", - "deepseek", - ] - # remove providers without credentials - available_providers = [ - provider for provider in available_providers if getattr(external_keys, provider) != "" - ] - # add custom and ellm providers as allow no credentials - available_providers.extend( - [ - "custom", - "ellm", - ] - ) - - # Initialize lists to hold valid models - valid_llm_models = [] - valid_embed_models = [] - valid_rerank_models = [] - - # Iterate over the llm, embed, rerank list - for m in all_models.llm_models: - valid_deployments = _get_valid_deployments(m, available_providers) - if len(valid_deployments) > 0: - m.deployments = valid_deployments - valid_llm_models.append(m) - - for m in all_models.embed_models: - valid_deployments = _get_valid_deployments(m, available_providers) - if len(valid_deployments) > 0: - m.deployments = valid_deployments - valid_embed_models.append(m) - - for m in all_models.rerank_models: - valid_deployments = _get_valid_deployments(m, available_providers) - if len(valid_deployments) > 0: - m.deployments = valid_deployments - valid_rerank_models.append(m) - - # Create a new ModelListConfig with the valid models - valid_model_list_config = ModelListConfig( - llm_models=valid_llm_models, - embed_models=valid_embed_models, - rerank_models=valid_rerank_models, - ) - - return valid_model_list_config - - -async def auth_user_project_oss( - request: Request, - project_id: Annotated[ - str, Header(alias="X-PROJECT-ID", description='Project ID "proj_xxx".') - ] = "default", -) -> AsyncGenerator[ProjectRead, None]: - project_id = project_id.strip() - if project_id == "": - raise AuthorizationError(NO_PROJECT_ID_MESSAGE) - - # Fetch project - project = await _get_project(request, project_id) - organization = project.organization - - # Set some state - request.state.org_id = organization.id - request.state.project_id = project.id - request.state.external_keys = _get_external_keys(organization) - request.state.org_models = ModelListConfig.model_validate(organization.models) - all_models = request.state.org_models + CONFIG.get_model_config() - request.state.all_models = _get_valid_modellistconfig( - all_models.model_dump_json(), request.state.external_keys.model_dump_json() - ) - request.state.billing = BillingManager(request=request) - - yield project - - -async def auth_user_project_cloud( - bg_tasks: BackgroundTasks, - request: Request, - response: Response, - project_id: Annotated[ - str, Header(alias="X-PROJECT-ID", description='Project ID "proj_xxx".') - ] = "", - bearer_token: Annotated[ - str, - Header( - alias="Authorization", - description="One of: Service key, user PAT or organization API key.", - ), - ] = "", - user_id: Annotated[str, Header(alias="X-USER-ID", description="User ID.")] = "", -) -> AsyncGenerator[ProjectRead, None]: - route = request.url.path - project_id = project_id.strip() - bearer_token = bearer_token.strip() - user_id = user_id.strip() - if project_id == "": - raise AuthorizationError(NO_PROJECT_ID_MESSAGE) - - # Fetch project - project = await _get_project(request, project_id) - organization = project.organization - - # Set some state - request.state.org_id = organization.id - request.state.project_id = project.id - request.state.external_keys = _get_external_keys(organization) - request.state.org_models = ModelListConfig.model_validate(organization.models) - all_models = request.state.org_models + CONFIG.get_model_config() - request.state.all_models = _get_valid_modellistconfig( - all_models.model_dump_json(), request.state.external_keys.model_dump_json() - ) - # Check if token is provided - bearer_token = bearer_token.split("Bearer ") - if len(bearer_token) < 2 or bearer_token[1].strip() == "": - raise AuthorizationError(NO_TOKEN_MESSAGE) - - user_roles = {u.user_id: u.role for u in organization.members} - # Non-activated orgs can only perform GET requests - if (not organization.active) and (request.method != "GET"): - raise UpgradeTierError(f'Your organization "{organization.id}" is not activated.') - - # Authenticate - token = bearer_token[1].strip() - if compare_digest(token, ENV_CONFIG.service_key_plain) or compare_digest( - token, ENV_CONFIG.service_key_alt_plain - ): - pass - elif token.startswith("jamai_sk_"): - _org = await _get_organization(request, token) - if project.organization.id != _org.id: - raise AuthorizationError( - f'Your provided project "{project.id}" does not belong to organization "{_org.id}".' - ) - response.headers["Warning"] = f'299 - "{ORG_API_KEY_DEPRECATE_MESSAGE}"' - - elif token.startswith("jamai_pat_"): - pat = await _get_pat(request, token) - if pat.expiry != "" and datetime_now_iso() > pat.expiry: - raise AuthorizationError( - "Your Personal Access Token has expired. Please generate a new token." - ) - user_id = pat.user_id - - elif logged_in_user := request.session.get("user", None) is not None: - logged_in_user = UserRead(**logged_in_user) - user_id = logged_in_user.id - - else: - raise AuthorizationError(INVALID_TOKEN_MESSAGE) - - # Role-based access control - if user_id: - user_role = user_roles.get(user_id, None) - if user_role is None: - raise ForbiddenError(f'You do not have access to organization "{organization.id}".') - if user_role == "guest" and request.method in WRITE_METHODS: - raise ForbiddenError( - f'You do not have write access to organization "{organization.id}".' - ) - if user_role != "admin" and "api/admin/org" in route: - raise ForbiddenError( - f'You do not have admin access to organization "{organization.id}".' - ) - - # Billing - request.state.billing = BillingManager( - organization=organization, - project_id=project.id, - user_id=user_id, - request=request, - ) - - # If quota ran out then allow read access only - if request.method in WRITE_METHODS: - request.state.billing.check_egress_quota() - request.state.billing.check_db_storage_quota() - request.state.billing.check_file_storage_quota() - - yield project - - # NOTE that billing processing is done in middleware where response headers are available - - # Set project updated at datetime - async def _set_project_updated_at() -> None: - if "gen_tables" in route and request.method in WRITE_METHODS: - try: - await CLIENT.admin.organization.set_project_updated_at(project_id) - except Exception as e: - logger.warning( - f'{request.state.id} - Error setting project "{project_id}" last updated time: {e}' - ) - - # This will run AFTER streaming responses are sent - bg_tasks.add_task(_set_project_updated_at) - - -auth_user_project = auth_user_project_oss if ENV_CONFIG.is_oss else auth_user_project_cloud diff --git a/services/api/src/owl/utils/auth/__init__.py b/services/api/src/owl/utils/auth/__init__.py new file mode 100644 index 0000000..2faa534 --- /dev/null +++ b/services/api/src/owl/utils/auth/__init__.py @@ -0,0 +1,18 @@ +from owl.configs import ENV_CONFIG + +if ENV_CONFIG.is_oss: + from owl.utils.auth.oss import ( # noqa: F401 + auth_service_key, + auth_user, + auth_user_project, + auth_user_service_key, + has_permissions, + ) +else: + from owl.utils.auth.cloud import ( # noqa: F401 + auth_service_key, + auth_user, + auth_user_project, + auth_user_service_key, + has_permissions, + ) diff --git a/services/api/src/owl/utils/auth/oss.py b/services/api/src/owl/utils/auth/oss.py new file mode 100644 index 0000000..3043423 --- /dev/null +++ b/services/api/src/owl/utils/auth/oss.py @@ -0,0 +1,156 @@ +from secrets import compare_digest +from time import perf_counter +from typing import Annotated, AsyncGenerator + +from fastapi import BackgroundTasks, Depends, Header, Request +from loguru import logger + +from owl.configs import ENV_CONFIG +from owl.db import async_session +from owl.db.models.oss import ModelConfig, Project, User +from owl.types import ( + ModelConfigRead, + OrganizationRead, + ProjectRead, + UserAuth, +) +from owl.utils.billing import BillingManager +from owl.utils.dates import now +from owl.utils.exceptions import AuthorizationError, ResourceNotFoundError + +WRITE_METHODS = {"PUT", "PATCH", "POST", "DELETE", "PURGE"} +NO_USER_ID_MESSAGE = ( + 'You didn\'t provide a user ID. You need to provide the user ID in an "X-USER-ID" header.' +) +NO_PROJECT_ID_MESSAGE = ( + "You didn't provide a project ID. " + 'You need to provide the project ID in an "X-PROJECT-ID" header.' +) +NO_TOKEN_MESSAGE = ( + "You didn't provide an authorization token. " + 'You need to provide your PAT in an "Authorization" header using Bearer auth (i.e. "Authorization: Bearer TOKEN").' +) +INVALID_TOKEN_MESSAGE = "You provided an invalid authorization token." + + +def is_service_key(token: str) -> bool: + return compare_digest(token, ENV_CONFIG.service_key_plain) or compare_digest( + token, ENV_CONFIG.service_key_alt_plain + ) + + +async def auth_service_key( + bearer_token: Annotated[ + str, Header(alias="Authorization", description="Not needed for OSS.") + ] = "", +) -> str: + return bearer_token + + +async def _bearer_auth( + user_id: Annotated[str, Header(alias="X-USER-ID", description="User ID.")] = "", +) -> tuple[UserAuth, None]: + if user_id == "": + user_id = "0" + async with async_session() as session: + user = await session.get(User, user_id) + if user is None: + raise AuthorizationError(f'User "{user_id}" is not found.') + user = UserAuth.model_validate(user) + return user, None + + +async def auth_user_service_key( + request: Request, + user_project: Annotated[tuple[UserAuth, None], Depends(_bearer_auth)], +) -> AsyncGenerator[UserAuth, None]: + t0 = perf_counter() + user = user_project[0] + t1 = perf_counter() + request.state.timing["Auth"] = t1 - t0 + yield user + request.state.timing["Request"] = perf_counter() - t1 + + +auth_user = auth_user_service_key + + +async def _set_project_updated_at( + request: Request, + project_id: str, +) -> None: + if "gen_tables" in request.url.path and request.method in WRITE_METHODS: + try: + async with async_session() as session: + project = await session.get(Project, project_id) + if project is None: + raise ResourceNotFoundError(f'Project "{project_id}" is not found.') + project.updated_at = now() + session.add(project) + await session.commit() + except Exception as e: + logger.warning( + f'{request.state.id} - Error setting project "{project_id}" last updated time: {e}' + ) + + +async def auth_user_project( + request: Request, + bg_tasks: BackgroundTasks, + user_project: Annotated[tuple[UserAuth, None], Depends(_bearer_auth)], + project_id: Annotated[ + str, Header(alias="X-PROJECT-ID", description="Project ID.") + ] = "default", +) -> AsyncGenerator[tuple[UserAuth, ProjectRead, OrganizationRead], None]: + t0 = perf_counter() + user, project = user_project + ### --- Fetch project --- ### + async with async_session() as session: + proj = await session.get(Project, project_id) + if proj is None: + raise AuthorizationError(f'Project "{project_id}" is not found.') + project = ProjectRead.model_validate(proj) + organization = OrganizationRead.model_validate(proj.organization) + models = ( + await ModelConfig.list_( + session=session, + return_type=ModelConfigRead, + organization_id=organization.id, + ) + ).items + ### --- Billing --- ### + request.state.billing = BillingManager( + organization=organization, + project_id=project.id, + user_id=user.id, + request=request, + models=models, + ) + t1 = perf_counter() + request.state.timing["Auth"] = t1 - t0 + yield user, project, organization + request.state.timing["Request"] = perf_counter() - t1 + # This will run BEFORE any responses are sent + + # Background tasks will run AFTER streaming responses are sent + bg_tasks.add_task( + _set_project_updated_at, + request=request, + project_id=project_id, + ) + + +def has_permissions( + user: UserAuth, + requirements: list[str], + *, + organization_id: str | None = None, + project_id: str | None = None, + raise_error: bool = True, +) -> bool: + del user + del requirements + del organization_id + del project_id + del raise_error + return True diff --git a/services/api/src/owl/utils/billing/__init__.py b/services/api/src/owl/utils/billing/__init__.py new file mode 100644 index 0000000..e9f18f3 --- /dev/null +++ b/services/api/src/owl/utils/billing/__init__.py @@ -0,0 +1,18 @@ +from owl.configs import ENV_CONFIG + +if ENV_CONFIG.is_oss: + from owl.utils.billing.oss import ( # noqa: F401 + CLICKHOUSE_CLIENT, + OPENTELEMETRY_CLIENT, + STRIPE_CLIENT, + BillingManager, + ClickHouseAsyncClient, + ) +else: + from owl.utils.billing.cloud import ( # noqa: F401 + CLICKHOUSE_CLIENT, + OPENTELEMETRY_CLIENT, + STRIPE_CLIENT, + BillingManager, + ClickHouseAsyncClient, + ) diff --git a/services/api/src/owl/utils/billing/oss.py b/services/api/src/owl/utils/billing/oss.py new file mode 100644 index 0000000..22d4865 --- /dev/null +++ b/services/api/src/owl/utils/billing/oss.py @@ -0,0 +1,788 @@ +import asyncio +from collections import defaultdict +from time import perf_counter +from typing import Any, DefaultDict + +import clickhouse_connect +from cloudevents.conversion import to_dict +from cloudevents.http import CloudEvent +from fastapi import Request +from loguru import logger +from opentelemetry import metrics +from opentelemetry.metrics import Counter, Histogram, _Gauge +from tenacity import retry, stop_after_attempt, wait_exponential + +from owl.configs import CACHE, ENV_CONFIG +from owl.db.gen_table import GenerativeTableCore +from owl.types import ( + DBStorageUsageData, + EgressUsageData, + EmbedUsageData, + FileStorageUsageData, + LlmUsageData, + ModelConfigRead, + OrganizationRead, + ProductType, + RerankUsageData, + UsageData, + UserAgent, +) +from owl.utils.exceptions import ResourceNotFoundError, handle_exception + + +class OpenTelemetryClient: + def __init__(self) -> None: + # resource = Resource.create( + # { + # "service.name": "owl-service", + # "service.instance.id": uuid7_str(), + # } + # ) + # reader = PeriodicExportingMetricReader( + # OTLPMetricExporter(endpoint=endpoint), export_interval_millis=math.inf + # ) + # self.provider = MeterProvider(resource=resource, metric_readers=[reader]) + # metrics.set_meter_provider(self.provider) + self.meter = metrics.get_meter(__name__) + self.counters: DefaultDict[str, Counter] = defaultdict( + lambda: self.meter.create_counter(name="default") + ) + self.histograms: DefaultDict[str, Histogram] = defaultdict( + lambda: self.meter.create_histogram(name="default") + ) + self.gauges: DefaultDict[str, _Gauge] = defaultdict( + lambda: self.meter.create_gauge(name="default") + ) + + def get_counter(self, name) -> Counter: + if name not in self.counters: + self.counters[name] = self.meter.create_counter(name=name) + return self.counters[name] + + def get_histogram(self, name) -> Histogram: + if name not in self.histograms: + self.histograms[name] = self.meter.create_histogram(name=name) + return self.histograms[name] + + def get_gauge(self, name) -> _Gauge: + if name not in self.gauges: + self.gauges[name] = self.meter.create_gauge(name=name) + return self.gauges[name] + + def get_meter(self): + return self.meter + + def force_flush(self): + # self.provider.force_flush() + metrics.get_meter_provider().force_flush() + + +class ClickHouseAsyncClient: + def __init__( + self, + host: str, + username: str, + password: str, + database: str, + port: int, + ) -> None: + self.client = asyncio.run( + clickhouse_connect.get_async_client( + host=host, + username=username, + password=password, + database=database, + port=port, + ) + ) + + def _log_debug(self, message: str): + logger.debug(f"{self.__class__.__name__}: {message}") + + def _log_info(self, message: str): + logger.info(f"{self.__class__.__name__}: {message}") + + def _log_error(self, message: str): + logger.error(f"{self.__class__.__name__}: {message}") + + async def query(self, sql: str): + try: + result = await self.client.query(sql) + return result + except Exception as e: + self._log_error(f"Failed to execute query: {sql}. Error: {e}") + raise + + async def insert_llm_usage(self, usages: list[LlmUsageData]): + try: + usages_list = [usage.as_list() for usage in usages] + result = await self.client.insert( + table="llm_usage", + data=usages_list, + column_names=[ + "id", + "org_id", + "proj_id", + "user_id", + "timestamp", + "cost", + "model", + "input_token", + "output_token", + "input_cost", + "output_cost", + ], + settings={ + "async_insert": 1, + "wait_for_async_insert": 1, + "async_insert_busy_timeout_ms": 1000, + "async_insert_use_adaptive_busy_timeout": 1, + }, + ) + return result + except Exception as e: + self._log_error(f"Failed to insert data into table: llm_usage. Error: {e}") + raise + + async def insert_embed_usage(self, usages: list[EmbedUsageData]): + try: + usages_list = [usage.as_list() for usage in usages] + result = await self.client.insert( + table="embed_usage", + data=usages_list, + column_names=[ + "id", + "org_id", + "proj_id", + "user_id", + "timestamp", + "cost", + "model", + "num_token", + ], + settings={ + "async_insert": 1, + "wait_for_async_insert": 1, + "async_insert_busy_timeout_ms": 1000, + "async_insert_use_adaptive_busy_timeout": 1, + }, + ) + return result + except Exception as e: + self._log_error(f"Failed to insert data into table: embed_usage. Error: {e}") + raise + + async def insert_rerank_usage(self, usages: list[RerankUsageData]): + try: + usages_list = [usage.as_list() for usage in usages] + result = await self.client.insert( + table="rerank_usage", + data=usages_list, + column_names=[ + "id", + "org_id", + "proj_id", + "user_id", + "timestamp", + "cost", + "model", + "num_search", + ], + settings={ + "async_insert": 1, + "wait_for_async_insert": 1, + "async_insert_busy_timeout_ms": 1000, + "async_insert_use_adaptive_busy_timeout": 1, + }, + ) + return result + except Exception as e: + self._log_error(f"Failed to insert data into table: rerank_usage. Error: {e}") + raise + + async def insert_egress_usage(self, usages: list[EgressUsageData]): + try: + usages_list = [usage.as_list() for usage in usages] + result = await self.client.insert( + table="egress_usage", + data=usages_list, + column_names=[ + "id", + "org_id", + "proj_id", + "user_id", + "timestamp", + "cost", + "amount_gib", + ], + settings={ + "async_insert": 1, + "wait_for_async_insert": 1, + "async_insert_busy_timeout_ms": 1000, + "async_insert_use_adaptive_busy_timeout": 1, + }, + ) + return result + except Exception as e: + self._log_error(f"Failed to insert data into table: egress_usage. Error: {e}") + raise + + async def insert_file_storage_usage(self, usages: list[FileStorageUsageData]): + try: + usages_list = [usage.as_list() for usage in usages] + result = await self.client.insert( + table="file_storage_usage", + data=usages_list, + column_names=[ + "id", + "org_id", + "proj_id", + "user_id", + "timestamp", + "cost", + "amount_gib", + "snapshot_gib", + ], + settings={ + "async_insert": 1, + "wait_for_async_insert": 1, + "async_insert_busy_timeout_ms": 1000, + "async_insert_use_adaptive_busy_timeout": 1, + }, + ) + return result + except Exception as e: + self._log_error(f"Failed to insert data into table: file_storage_usage. Error: {e}") + raise + + async def insert_db_storage_usage(self, usages: list[DBStorageUsageData]): + try: + usages_list = [usage.as_list() for usage in usages] + result = await self.client.insert( + table="db_storage_usage", + data=usages_list, + column_names=[ + "id", + "org_id", + "proj_id", + "user_id", + "timestamp", + "cost", + "amount_gib", + "snapshot_gib", + ], + settings={ + "async_insert": 1, + "wait_for_async_insert": 1, + "async_insert_busy_timeout_ms": 1000, + "async_insert_use_adaptive_busy_timeout": 1, + }, + ) + return result + except Exception as e: + self._log_error(f"Failed to insert data into table: db_storage_usage. Error: {e}") + raise + + @retry( + wait=wait_exponential(multiplier=1, min=2, max=10), + stop=stop_after_attempt(4), + reraise=True, + ) + async def insert_usage(self, usage: UsageData): + llm_result = await self.insert_llm_usage(usage.llm_usage) + embed_result = await self.insert_embed_usage(usage.embed_usage) + rerank_result = await self.insert_rerank_usage(usage.rerank_usage) + egress_result = await self.insert_egress_usage(usage.egress_usage) + file_storage_result = await self.insert_file_storage_usage(usage.file_storage_usage) + db_storage_result = await self.insert_db_storage_usage(usage.db_storage_usage) + return ( + llm_result, + embed_result, + rerank_result, + egress_result, + file_storage_result, + db_storage_result, + ) + + async def bulk_insert_usage(self, usages: list[UsageData]): + all_usages = sum(usages, start=UsageData()) + results = await self.insert_usage(all_usages) + return results + + async def flush_buffer(self): + buffer_key = ENV_CONFIG.clickhouse_buffer_key + buffer_count_key = buffer_key + "_count" + temp_key = buffer_key + "_temp" + lock_key = buffer_key + ":lock" + + async with CACHE.alock(lock_key, blocking=False, expire=5) as lock_acquired: + if lock_acquired: + self._log_debug("Acquired lock to flush buffer.") + else: + self._log_debug("Could not acquire lock to flush buffer.") + return + + # Exit if buffer key not found + if not await CACHE.exists(buffer_key): + self._log_debug("Buffer key not found, skipping insert operation.") + return + + # Move data from buffer to temp key + # TODO: Maybe use async redis + with CACHE._redis.pipeline() as pipe: + pipe.multi() + pipe.rename(buffer_key, temp_key) + temp_count = pipe.get(buffer_count_key) + pipe.delete(buffer_count_key) + pipe.execute() + + buffer_data = CACHE._redis.lrange(temp_key, 0, -1) + if buffer_data: + _t = perf_counter() + usages = [UsageData.model_validate_json(data) for data in buffer_data] + try: + await self.bulk_insert_usage(usages) + # Delete temp key on success + del CACHE[temp_key] + self._log_info( + ( + f"{sum([usage.total_usage_events for usage in usages]):,d} buffered usage data inserted to DB, " + f"time taken: {perf_counter() - _t:,.3} seconds" + ) + ) + except Exception as e: + self._log_error(f"Failed to insert data. Error: {e}") + # Move data back to buffer on failure + # Append data back to buffer on failure + with CACHE._redis.pipeline() as pipe: + pipe.multi() + for data in buffer_data: + pipe.rpush(buffer_key, data) + pipe.incrby(buffer_count_key, int(temp_count or 0)) + pipe.execute() + # Delete temp key after appending data back to buffer + del CACHE[temp_key] + + +# OPENTELEMETRY_CLIENT = OpenTelemetryClient( +# endpoint=f"http://{ENV_CONFIG.opentelemetry_host}:{ENV_CONFIG.opentelemetry_port}" +# ) +OPENTELEMETRY_CLIENT = OpenTelemetryClient() +CLICKHOUSE_CLIENT = ClickHouseAsyncClient( + host=ENV_CONFIG.clickhouse_host, + username=ENV_CONFIG.clickhouse_user, + password=ENV_CONFIG.clickhouse_password.get_secret_value(), + database=ENV_CONFIG.clickhouse_db, + port=ENV_CONFIG.clickhouse_port, +) +STRIPE_CLIENT = None + + +def _log_exception(e: Exception, *_, **__): + logger.exception(f"Billing event processing encountered an error: {repr(e)}") + + +class BillingManager: + def __init__( + self, + *, + organization: OrganizationRead, + project_id: str = "", + user_id: str = "", + request: Request | None = None, + models: list[ModelConfigRead] | None = None, + ) -> None: + if not isinstance(organization, OrganizationRead): + raise TypeError( + f"`organization` must be an instance of `OrganizationRead`, received: {type(organization)}" + ) + self.org = organization + self.project_id = project_id + self.user_id = user_id + self.request = request + self.id: str = request.state.id if request else "" + if models and not all(isinstance(m, ModelConfigRead) for m in models): + raise TypeError( + f"`models` must be a list of `ModelConfigRead` instances, received: {models}" + ) + self.models = models + self.model_map = {m.id: m for m in models} if models else {} + if request is None: + self._user_agent = UserAgent(is_browser=False, agent="") + else: + self._user_agent: UserAgent = request.state.user_agent + self._price_plan = None + self._events = [] + self._deltas: dict[ProductType, float] = defaultdict(float) + self._values: dict[ProductType, float] = defaultdict(float) + self._llm_usage_events: list[LlmUsageData] = [] + self._embed_usage_events: list[EmbedUsageData] = [] + self._rerank_usage_events: list[RerankUsageData] = [] + self._egress_usage_events: list[EgressUsageData] = [] + self._file_storage_usage_events: list[FileStorageUsageData] = [] + self._db_storage_usage_events: list[DBStorageUsageData] = [] + self._cost = 0.0 + + @property + def cost(self) -> float: + return self._cost + + @property + def total_balance(self) -> float: + return self.org.credit + self.org.credit_grant + + def _log_info(self, message: str): + logger.info(f"{self.id} - {self.__class__.__name__}: {message}") + + def _log_warning(self, message: str): + logger.warning(f"{self.id} - {self.__class__.__name__}: {message}") + + def _model(self, model_id: str) -> ModelConfigRead: + model = self.model_map.get(model_id, None) + if model is None: + raise ResourceNotFoundError( + f'Model "{self._model_id_or_name(model_id)}" is not found.' + ) + return model + + def _check_project_id(self): + if self.project_id.strip() == "": + raise ValueError("Project ID must be provided.") + + def _cloud_event( + self, + attributes: dict[str, Any], + data: dict[str, Any], + ) -> CloudEvent: + if ( + len(data.get("proj_id", "dummy")) == 0 + ): # Update to proj_id to align with Clickhouse Column + raise ValueError('"proj_id" if provided must not be empty.') + # check if request_count + extra_labels = ( + self._user_agent.model_dump() if attributes.get("type", "") == "request_count" else {} + ) + return CloudEvent( + attributes={ + **attributes, + "source": "owl", + "subject": self.org.id, + }, + data={ + **data, + "org_id": self.org.id, + "user_id": self.user_id, + **extra_labels, + }, + ) + + # --- Generative Table Usage --- # + + def has_gen_table_quota(self, table: GenerativeTableCore) -> bool: + return True + + # --- LLM Usage --- # + + def has_llm_quota(self, model_id: str) -> bool: + return True + + def create_llm_events( + self, + model_id: str, + input_tokens: int, + output_tokens: int, + *, + create_usage: bool = True, + ) -> None: + input_tokens = int(input_tokens) + output_tokens = int(output_tokens) + if input_tokens <= 0 and output_tokens <= 0: + return + self._check_project_id() + # Analytics: Token usage + self._events += [ + self._cloud_event( + {"type": ProductType.LLM_TOKENS}, + { + "model": model_id, + "tokens": v, + "type": t, + "proj_id": self.project_id, # Update to proj_id to align with Clickhouse Column + }, + ) + for t, v in [("input", input_tokens), ("output", output_tokens)] + ] + if create_usage: + self._llm_usage_events.append( + LlmUsageData( + org_id=self.org.id, + proj_id=self.project_id, + user_id=self.user_id, + model=model_id, + input_token=input_tokens, + output_token=output_tokens, + input_cost=0.0, + output_cost=0.0, + cost=0.0, + ) + ) + + # --- Embedding Usage --- # + + def has_embedding_quota(self, model_id: str) -> bool: + return True + + def create_embedding_events( + self, + model_id: str, + token_usage: int, + *, + create_usage: bool = True, + ) -> None: + token_usage = int(token_usage) + if token_usage <= 0: + return + self._check_project_id() + # Analytics: Token usage + self._events += [ + self._cloud_event( + {"type": ProductType.EMBEDDING_TOKENS}, + { + "model": model_id, + "tokens": token_usage, + "proj_id": self.project_id, # Update to proj_id to align with Clickhouse Column + }, + ) + ] + if create_usage: + self._embed_usage_events.append( + EmbedUsageData( + org_id=self.org.id, + proj_id=self.project_id, + user_id=self.user_id, + model=model_id, + token=token_usage, + cost=0.0, + ) + ) + + # --- Reranker Usage --- # + + def has_reranker_quota(self, model_id: str) -> bool: + return True + + def create_reranker_events( + self, + model_id: str, + num_searches: int, + *, + create_usage: bool = True, + ) -> None: + num_searches = int(num_searches) + if num_searches <= 0: + return + self._check_project_id() + # Analytics: Rerank usage + self._events += [ + self._cloud_event( + {"type": ProductType.RERANKER_SEARCHES}, + { + "model": model_id, + "searches": num_searches, + "proj_id": self.project_id, # Update to proj_id to align with Clickhouse Column + }, + ) + ] + if create_usage: + self._rerank_usage_events.append( + RerankUsageData( + org_id=self.org.id, + proj_id=self.project_id, + user_id=self.user_id, + model=model_id, + number_of_search=num_searches, + cost=0.0, + ) + ) + + # --- Egress Usage --- # + + def has_egress_quota(self) -> bool: + return True + + def create_egress_events(self, amount_gib: float, *, create_usage: bool = True) -> None: + if amount_gib <= 0 or not self.project_id: + return + # Analytics: Egress usage + self._events += [ + self._cloud_event( + {"type": "bandwidth"}, + { + "amount_gib": amount_gib, + "type": ProductType.EGRESS, + "proj_id": self.project_id, # Update to proj_id to align with Clickhouse Column + }, + ) + ] + if create_usage: + self._egress_usage_events.append( + EgressUsageData( + org_id=self.org.id, + proj_id=self.project_id, + user_id=self.user_id, + amount_gib=amount_gib, + cost=0.0, + ) + ) + + # --- DB Storage Usage --- # + + def has_db_storage_quota(self) -> bool: + return True + + def create_db_storage_events(self, db_usage_gib: float, *, create_usage: bool = True) -> None: + if db_usage_gib <= 0: + return + # Analytics: DB storage usage + self._events += [ + self._cloud_event({"type": "storage"}, {"amount_gib": db_usage_gib, "type": "db"}), + ] + if create_usage: + self._db_storage_usage_events.append( + DBStorageUsageData( + org_id=self.org.id, + proj_id=self.project_id + or "not_applicable", # possible the request is not associated with a project + user_id=self.user_id, + amount_gib=0.0, + cost=0.0, + snapshot_gib=db_usage_gib, + ) + ) + + # --- File Storage Usage --- # + + def has_file_storage_quota(self) -> bool: + return True + + def create_file_storage_events( + self, file_usage_gib: float, *, create_usage: bool = True + ) -> None: + if file_usage_gib <= 0: + return + # Analytics: DB storage usage + self._events += [ + self._cloud_event({"type": "storage"}, {"amount_gib": file_usage_gib, "type": "file"}), + ] + if create_usage: + self._file_storage_usage_events.append( + FileStorageUsageData( + org_id=self.org.id, + proj_id=self.project_id + or "not_applicable", # possible the request is not associated with a project + user_id=self.user_id, + amount_gib=0.0, + cost=0.0, + snapshot_gib=file_usage_gib, + ) + ) + + # --- Process all events --- # + + @handle_exception(handler=_log_exception) + async def process_all(self) -> None: + """ + Process all events. In general, only call this as a BACKGROUND TASK after the response is sent. + """ + + # Push usage to redis for queue if buffer less than 10000 + usage_data = UsageData( + llm_usage=self._llm_usage_events, + embed_usage=self._embed_usage_events, + rerank_usage=self._rerank_usage_events, + egress_usage=self._egress_usage_events, + file_storage_usage=self._file_storage_usage_events, + db_storage_usage=self._db_storage_usage_events, + ) + usage_count = (await CACHE.get_usage_buffer_count()) + usage_data.total_usage_events + if usage_count >= ENV_CONFIG.clickhouse_max_buffer_queue_size: + await CLICKHOUSE_CLIENT.flush_buffer() + # We could use asyncio TaskGroup here if there are other async tasks downstream + # For now there isn't any, so we just simply await it + await CLICKHOUSE_CLIENT.insert_usage(usage_data) + elif usage_data.total_usage_events > 0: + await CACHE.add_usage_to_buffer(usage_data) + + # API request count + req_scope = getattr(self.request, "scope", {}) + req_method: str = getattr(self.request, "method", "") + if req_scope.get("route", None) and req_method and self.project_id: + # https://stackoverflow.com/a/72239186 + path = req_scope.get("root_path", "") + req_scope["route"].path + self._events += [ + self._cloud_event( + {"type": "request_count"}, + { + "method": req_method, + "path": path, + "proj_id": self.project_id, # Update to proj_id to align with Clickhouse Column + }, + ) + ] + # Send OpenTelemetry events + if len(self._events) > 0: + t0 = perf_counter() + for event in self._events: + attributes = to_dict(event) + event_type = attributes["type"] + if event_type == "request_count": + counter = OPENTELEMETRY_CLIENT.get_counter(name="request_count") + counter.add(1, attributes["data"]) + elif event_type == ProductType.LLM_TOKENS: + counter = OPENTELEMETRY_CLIENT.get_counter(name="llm_token_usage") + counter.add( + attributes["data"]["tokens"], + {k: v for k, v in attributes["data"].items() if k != "tokens"}, + ) + elif event_type == ProductType.EMBEDDING_TOKENS: + counter = OPENTELEMETRY_CLIENT.get_counter(name="embedding_token_usage") + counter.add( + attributes["data"]["tokens"], + {k: v for k, v in attributes["data"].items() if k != "tokens"}, + ) + elif event_type == ProductType.RERANKER_SEARCHES: + counter = OPENTELEMETRY_CLIENT.get_counter(name="reranker_search_usage") + counter.add( + attributes["data"]["searches"], + {k: v for k, v in attributes["data"].items() if k != "searches"}, + ) + elif event_type == "bandwidth": + counter = OPENTELEMETRY_CLIENT.get_counter(name="bandwidth_usage") + counter.add( + attributes["data"]["amount_gib"], + {k: v for k, v in attributes["data"].items() if k != "amount_gib"}, + ) + elif event_type == "storage": + gauge = OPENTELEMETRY_CLIENT.get_gauge(name="storage_usage") + gauge.set( + attributes["data"]["amount_gib"], + {k: v for k, v in attributes["data"].items() if k != "amount_gib"}, + ) + elif event_type == "spent": + counter = OPENTELEMETRY_CLIENT.get_counter(name="spent") + counter.add( + attributes["data"]["spent_usd"], + {k: v for k, v in attributes["data"].items() if k != "spent_usd"}, + ) + self._log_info( + ( + f"OpenTelemetry events ingestion: " + f"t={(perf_counter() - t0) * 1e3:,.2f} ms " + f"num_events={len(self._events):,d} " + f"event_types={set(str(e.get_attributes()['type']) for e in self._events)}" + ) + ) + # Force flush + # OPENTELEMETRY_CLIENT.force_flush() + # Clear events + self._events = [] diff --git a/services/api/src/owl/utils/billing_metrics.py b/services/api/src/owl/utils/billing_metrics.py new file mode 100644 index 0000000..4e0b611 --- /dev/null +++ b/services/api/src/owl/utils/billing_metrics.py @@ -0,0 +1,742 @@ +from __future__ import annotations + +import re +from collections import namedtuple +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Any + +from loguru import logger + +from owl.types import ProductType, Usage, UsageResponse +from owl.utils.billing import ClickHouseAsyncClient +from owl.utils.exceptions import BadInputError + + +############################################################################### +# 1. Column-level registry +############################################################################### +@dataclass(frozen=True, slots=True) +class _BaseTable: + org_col: str = "org_id" + proj_col: str = "proj_id" + user_col: str = "user_id" + model_col: str = "model" + ts_col: str = "timestamp" + ts_interval: str = "timestamp_interval" + + def valid_group_by_cols(self) -> list[str]: + return [ + self.org_col, + self.proj_col, + self.user_col, + self.model_col, + ] + + +@dataclass(frozen=True, slots=True) +class LlmTable(_BaseTable): + table_id: str = "llm_usage" + input_col: str = "input_token" + output_col: str = "output_token" + input_cost_col: str = "input_cost" + output_cost_col: str = "output_cost" + result_input_token: str = "input" + result_output_token: str = "output" + result_total_token: str = "total_token" + category_total: str = "total_cost" + + def valid_group_by_cols(self) -> list[str]: + return _BaseTable().valid_group_by_cols() + ["type"] + + +@dataclass(frozen=True, slots=True) +class EmbedTable(_BaseTable): + table_id: str = "embed_usage" + value_col: str = "num_token" + cost_col: str = "cost" + + +@dataclass(frozen=True, slots=True) +class RerankTable(_BaseTable): + table_id: str = "rerank_usage" + value_col: str = "num_search" + cost_col: str = "cost" + + +@dataclass(frozen=True, slots=True) +class EgressTable(_BaseTable): + table_id: str = "egress_usage" + value_col: str = "amount_gib" + cost_col: str = "cost" + model_col: str = "bandwidth" # Egress actually does not have model + type: str = "egress" # to align with vm + + def valid_group_by_cols(self) -> list[str]: + _valid = _BaseTable().valid_group_by_cols() + ["type"] + _valid.remove(_BaseTable().model_col) + return _valid + + +@dataclass(frozen=True, slots=True) +class FileStorageTable(_BaseTable): + table_id: str = "file_storage_usage" + value_col: str = "amount_gib" + cost_col: str = "cost" + snapshot_col: str = "snapshot_gib" + model_col: str = "file_storage" # FileStorage actually does not have model + type: str = "file" # used for grouping + + def valid_group_by_cols(self) -> list[str]: + _valid = _BaseTable().valid_group_by_cols() + ["type"] + _valid.remove(_BaseTable().model_col) + return _valid + + +# For Storage usage type, = file/db +@dataclass(frozen=True, slots=True) +class DBStorageTable(_BaseTable): + table_id: str = "db_storage_usage" + value_col: str = "amount_gib" + cost_col: str = "cost" + snapshot_col: str = "snapshot_gib" + model_col: str = "db_storage" # DBStorage actually does not have model + type: str = "db" # used for grouping + + def valid_group_by_cols(self) -> list[str]: + _valid = _BaseTable().valid_group_by_cols() + ["type"] + _valid.remove(_BaseTable().model_col) + return _valid + + +# HACK: This is not an actual clickhouse table just to make it work with other parts +# For Storage spent, category = file/db (no type) +@dataclass(frozen=True, slots=True) +class CostTable(_BaseTable): + llm_table: LlmTable = LlmTable() + embed_table: EmbedTable = EmbedTable() + rerank_table: RerankTable = RerankTable() + egress_table: EgressTable = EgressTable() + file_storage_table: FileStorageTable = FileStorageTable() + db_storage_table: DBStorageTable = DBStorageTable() + category_total: str = "cost" + category_llm_input: str = "input_cost" + category_llm_output: str = "output_cost" + llm_input_type: str = "input" + llm_output_type: str = "output" + category_llm: str = ProductType.LLM_TOKENS.value + category_embed: str = ProductType.EMBEDDING_TOKENS.value + category_rerank: str = ProductType.RERANKER_SEARCHES.value + category_egress: str = ProductType.EGRESS.value + category_file_storage: str = ProductType.FILE_STORAGE.value + category_db_storage: str = ProductType.DB_STORAGE.value + + # HACK: so the table_id get from with_where_clause + table_id: str = "" + + # HACK: to make this compatible with victoriametrics query + def valid_group_by_cols(self) -> list[str]: + return _BaseTable().valid_group_by_cols() + ["type", "category"] + + def build_table_id(self, where_clause: str = "") -> str: + """Return the table_id with WHERE clause injected into each subquery""" + base_where = f"WHERE {where_clause}" if where_clause else "" + # HACK: the egress table does not have model_col, put model as 'bandwidth' + # HACK: the file_storage and db_storage table does not have model_col, put model as 'file_storage' and 'db_storage' + return f"""( + SELECT {self.llm_table.org_col}, {self.llm_table.proj_col}, {self.llm_table.model_col}, {self.llm_table.ts_col}, {self.llm_table.input_cost_col}, {self.llm_table.output_cost_col}, {self.llm_table.input_cost_col} + {self.llm_table.output_cost_col} as {self.category_llm}, 0 as {self.category_embed}, 0 as {self.category_rerank}, 0 as {self.category_egress}, 0 as {self.category_file_storage}, 0 as {self.category_db_storage} + FROM {self.llm_table.table_id} + {base_where} + UNION ALL + SELECT {self.embed_table.org_col}, {self.embed_table.proj_col}, {self.embed_table.model_col}, {self.embed_table.ts_col}, 0 as {self.llm_table.input_cost_col}, 0 as {self.llm_table.output_cost_col}, 0 as {self.category_llm}, {self.embed_table.cost_col} as {self.category_embed}, 0 as {self.category_rerank}, 0 as {self.category_egress}, 0 as {self.category_file_storage}, 0 as {self.category_db_storage} + FROM {self.embed_table.table_id} + {base_where} + UNION ALL + SELECT {self.rerank_table.org_col}, {self.rerank_table.proj_col}, {self.rerank_table.model_col}, {self.rerank_table.ts_col}, 0 as {self.llm_table.input_cost_col}, 0 as {self.llm_table.output_cost_col}, 0 as {self.category_llm}, 0 as {self.category_embed}, {self.rerank_table.cost_col} as {self.category_rerank}, 0 as {self.category_egress}, 0 as {self.category_file_storage}, 0 as {self.category_db_storage} + FROM {self.rerank_table.table_id} + {base_where} + UNION ALL + SELECT {self.egress_table.org_col}, {self.egress_table.proj_col}, '{self.egress_table.model_col}' as {_BaseTable().model_col}, {self.egress_table.ts_col}, 0 as {self.llm_table.input_cost_col}, 0 as {self.llm_table.output_cost_col}, 0 as {self.category_llm}, 0 as {self.category_embed}, {self.rerank_table.cost_col} as {self.category_rerank}, {self.egress_table.cost_col} as {self.category_egress}, 0 as {self.category_file_storage}, 0 as {self.category_db_storage} + FROM {self.egress_table.table_id} + {base_where} + UNION ALL + SELECT {self.file_storage_table.org_col}, {self.file_storage_table.proj_col}, '{self.file_storage_table.model_col}' as {_BaseTable().model_col}, {self.file_storage_table.ts_col}, 0 as {self.llm_table.input_cost_col}, 0 as {self.llm_table.output_cost_col}, 0 as {self.category_llm}, 0 as {self.category_embed}, {self.rerank_table.cost_col} as {self.category_rerank}, 0 as {self.category_egress}, {self.file_storage_table.cost_col} as {self.category_file_storage}, 0 as {self.category_db_storage} + FROM {self.file_storage_table.table_id} + {base_where} + UNION ALL + SELECT {self.db_storage_table.org_col}, {self.db_storage_table.proj_col}, '{self.db_storage_table.model_col}' as {_BaseTable().model_col}, {self.db_storage_table.ts_col}, 0 as {self.llm_table.input_cost_col}, 0 as {self.llm_table.output_cost_col}, 0 as {self.category_llm}, 0 as {self.category_embed}, {self.rerank_table.cost_col} as {self.category_rerank}, 0 as {self.category_egress}, 0 as {self.category_file_storage}, {self.db_storage_table.cost_col} as {self.category_db_storage} + FROM {self.db_storage_table.table_id} + {base_where} + )""" + + def row_is_llm(self, row: dict[str, Any]) -> bool: + # special handling to remove non llm type (when group by with 'model') + if row.get(self.model_col, "") in [ + self.egress_table.model_col, + self.file_storage_table.model_col, + self.db_storage_table.model_col, + ]: + return False + return True + + +############################################################################### +# 2. Helper utilities +############################################################################### +_duration_units = { + "ms": timedelta(milliseconds=1), + "s": timedelta(seconds=1), + "m": timedelta(minutes=1), + "h": timedelta(hours=1), + "d": timedelta(days=1), + "w": timedelta(weeks=1), + "y": timedelta(days=365), +} + +_interval_map = { + "s": "SECOND", + "m": "MINUTE", + "h": "HOUR", + "d": "DAY", + "w": "WEEK", + "y": "YEAR", +} + +MetricDef = namedtuple("MetricDef", ["name", "value_col", "extra_dims", "gb_mask"]) + +_METRICS: tuple[MetricDef, ...] = ( + MetricDef( + "embed", CostTable().category_embed, {"category": CostTable().category_embed}, "embed" + ), + MetricDef( + "rerank", CostTable().category_rerank, {"category": CostTable().category_rerank}, "rerank" + ), + MetricDef( + "egress", CostTable().category_egress, {"category": CostTable().category_egress}, "egress" + ), + MetricDef( + "file", + CostTable().category_file_storage, + {"category": CostTable().category_file_storage}, + "file", + ), + MetricDef( + "db", CostTable().category_db_storage, {"category": CostTable().category_db_storage}, "db" + ), + MetricDef( + "llm_input", + CostTable().category_llm_input, + {"category": CostTable().category_llm, "type": CostTable().llm_input_type}, + "common", + ), + MetricDef( + "llm_output", + CostTable().category_llm_output, + {"category": CostTable().category_llm, "type": CostTable().llm_output_type}, + "common", + ), + MetricDef("llm", CostTable().category_llm, {"category": CostTable().category_llm}, "common"), + MetricDef("total", CostTable().category_total, {}, "common"), +) + + +def _parse_duration(duration: str) -> timedelta: + delta = timedelta() + for value, unit in re.findall(r"(\d+)([smhdwy])", duration): + delta += int(value) * _duration_units[unit] + return delta + + +def _parse_interval(window_size: str) -> str: + m = re.fullmatch(r"(\d+)([smhdwy])", window_size) + if not m or m.group(2) not in _interval_map: + raise BadInputError(f"Bad window_size {window_size!r}, expected s/m/h/d/w/y") + + number = m.group(1) + unit = _interval_map[m.group(2)] + return f"{number} {unit}" + + +def _in_filter(col: str, values: list[str] | None) -> str: + if not values: + return "1=1" + quoted = ", ".join(f"'{v}'" for v in values) + return f"{col} IN ({quoted})" + + +def _filter_groupby(group_by: list[str], invalids: list[str] | None = None) -> list[str]: + if invalids is None: + invalids = [] + return [g for g in group_by if g not in invalids] + + +def _build_gb_filters(has_category: bool, has_type: bool, has_model: bool) -> dict[str, list[str]]: + base = [] if has_category else ["category"] + filters = {mask: base.copy() for mask in ("common", "embed", "rerank", "egress", "file", "db")} + if has_type: + for m in ("embed", "rerank", "egress", "file", "db"): + filters[m].append("type") + if has_model: + for m in ("file", "db", "egress"): + filters[m].append("model") + return filters + + +def _get_active_metrics(has_category: bool, has_type: bool) -> list[MetricDef]: + if not has_category: + if has_type: + # not has_category and has_type + return [m for m in _METRICS if m.name in {"llm_input", "llm_output", "total"}] + # has_category and has_type + return [m for m in _METRICS if m.name == "total"] + if has_type: + # has_category and has_type + return [ + m + for m in _METRICS + if m.name in {"embed", "rerank", "egress", "llm_input", "llm_output", "file", "db"} + ] + # has_category and not has_type + return [m for m in _METRICS if m.name in {"embed", "rerank", "egress", "file", "db", "llm"}] + + +############################################################################### +# 3. Generic query builder +############################################################################### +def _build_time_bucket_query( + spec: LlmTable + | EmbedTable + | RerankTable + | EgressTable + | FileStorageTable + | DBStorageTable + | CostTable, + org_ids: list[str] | None, + proj_ids: list[str] | None, + from_: datetime, + to: datetime, + group_by: list[str], + window_size: str, +) -> tuple[str, timedelta]: + for group in group_by: + if group not in spec.valid_group_by_cols(): + raise BadInputError( + f"Invalid group_by column: {group}, must be one of {spec.valid_group_by_cols()}" + ) + + org_c = _in_filter(spec.org_col, org_ids) + proj_c = _in_filter(spec.proj_col, proj_ids) + interval = _parse_interval(window_size) + ts_alias = f"toStartOfInterval({spec.ts_col}, INTERVAL {interval}) AS {spec.ts_interval}" + + has_type = "type" in group_by + has_category = "category" in group_by + if has_type: + group_by.remove("type") + if has_category: + group_by.remove("category") + + select_cols = [ts_alias, *group_by] + + # where clause + where_clause = f"""{spec.ts_col} >= '{from_:%Y-%m-%d %H:%M:%S}' + AND {spec.ts_col} < '{to:%Y-%m-%d %H:%M:%S}' + AND {org_c} + AND {proj_c} + """ + # Value expression + if isinstance(spec, LlmTable): + if has_type: + value_expr = f"SUM({spec.input_col}) as {spec.result_input_token}, SUM({spec.output_col}) as {spec.result_output_token}" + else: + value_expr = f"SUM({spec.input_col} + {spec.output_col}) AS {spec.result_total_token}" + elif isinstance(spec, FileStorageTable) or isinstance(spec, DBStorageTable): + value_expr = f"MAX({spec.snapshot_col}) AS {spec.snapshot_col}" + elif isinstance(spec, CostTable): + if has_category: + if has_type: + value_expr = f"SUM({spec.category_llm_input}) AS {spec.category_llm_input}, SUM({spec.category_llm_output}) AS {spec.category_llm_output}, SUM({spec.category_embed}) AS {spec.category_embed}, SUM({spec.category_rerank}) AS {spec.category_rerank}, SUM({spec.category_egress}) AS {spec.category_egress}, SUM({spec.category_file_storage}) AS {spec.category_file_storage}, SUM({spec.category_db_storage}) AS {spec.category_db_storage}" + else: + value_expr = f"SUM({spec.category_llm}) AS {spec.category_llm}, SUM({spec.category_embed}) AS {spec.category_embed}, SUM({spec.category_rerank}) AS {spec.category_rerank}, SUM({spec.category_egress}) AS {spec.category_egress}, SUM({spec.category_file_storage}) AS {spec.category_file_storage}, SUM({spec.category_db_storage}) AS {spec.category_db_storage}" + else: + if has_type: + value_expr = f"SUM({spec.category_llm_input}) as {spec.category_llm_input}, SUM({spec.category_llm_output}) as {spec.category_llm_output}, SUM({spec.category_embed} + {spec.category_rerank} + {spec.category_egress} + {spec.category_file_storage} + {spec.category_db_storage}) AS {spec.category_total}" + else: + value_expr = f"SUM({spec.category_llm} + {spec.category_embed} + {spec.category_rerank} + {spec.category_egress} + {spec.category_file_storage} + {spec.category_db_storage}) AS {spec.category_total}" + else: + value_expr = f"SUM({spec.value_col}) AS {spec.value_col}" + select_cols.append(value_expr) + + group_clause = ", ".join([spec.ts_interval, *group_by]) + sql = f""" + SELECT {", ".join(select_cols)} + FROM {spec.table_id or spec.build_table_id(where_clause)} + WHERE {where_clause} + GROUP BY {group_clause} + ORDER BY {spec.ts_interval} + """ + return sql, _parse_duration(window_size) + + +############################################################################### +# 4. Billing service +############################################################################### +class BillingMetrics: + def __init__(self, clickhouse_client: ClickHouseAsyncClient) -> None: + self.client = clickhouse_client + + async def _query(self, sql: str) -> list[dict[str, Any]]: + try: + res = await self.client.query(sql) + logger.info( + f"Query ID {res.summary.get('query_id')} " + f"rows={res.summary.get('result_rows')} " + f"elapsed={res.summary.get('elapsed_ns')}ns" + ) + if res.summary.get("result_rows") == "0": + return [] + return [ + dict(zip(res.column_names, row, strict=True)) + for row in zip(*res.result_columns, strict=True) + ] + except Exception as e: + logger.error(f"Query failed: {sql} – {e}") + raise + + @staticmethod + def _process_group_by(group_by: list[str]) -> list[str]: + # if "organization_id" in group_by: + # group_by.remove("organization_id") + # if "project_id" in group_by: + # group_by.remove("project_id") + # group_by.append("proj_id") + group_by = list(set([_BaseTable().org_col] + group_by)) + return group_by + + # ------------------------------------------------------------------ + # Public API – unchanged signatures + # ------------------------------------------------------------------ + async def query_llm_usage( + self, + filtered_by_org_id: list[str] | None, + filtered_by_proj_id: list[str] | None, + from_: datetime, + to: datetime | None, + group_by: list[str], + window_size: str, + ) -> UsageResponse: + table = LlmTable() + to = to or datetime.now(timezone.utc) + group_by = self._process_group_by(group_by) + # group_by might be modified + sql, interval = _build_time_bucket_query( + table, filtered_by_org_id, filtered_by_proj_id, from_, to, group_by.copy(), window_size + ) + rows = await self._query(sql) + if "type" in group_by: + usages = [] + for r in rows: + usages.append( + Usage.from_result( + [ + int((r.get(table.ts_interval) + interval).timestamp()), + r.get(table.result_input_token), + ], + {**r, "type": table.result_input_token}, + interval, + group_by, + ) + ) + usages.append( + Usage.from_result( + [ + int((r.get(table.ts_interval) + interval).timestamp()), + r.get(table.result_output_token), + ], + {**r, "type": table.result_output_token}, + interval, + group_by, + ) + ) + else: + usages = [ + Usage.from_result( + [ + int((r.get(table.ts_interval) + interval).timestamp()), + r.get(table.result_total_token), + ], + r, + interval, + group_by, + ) + for r in rows + ] + return UsageResponse( + windowSize=window_size, + data=usages, + start=from_.strftime("%Y-%m-%dT%H:%M:%SZ"), + end=to.strftime("%Y-%m-%dT%H:%M:%SZ"), + ) + + async def query_embedding_usage( + self, + filtered_by_org_id: list[str] | None, + filtered_by_proj_id: list[str] | None, + from_: datetime, + to: datetime | None, + group_by: list[str], + window_size: str, + ) -> UsageResponse: + table = EmbedTable() + to = to or datetime.now(timezone.utc) + group_by = self._process_group_by(group_by) + sql, interval = _build_time_bucket_query( + table, filtered_by_org_id, filtered_by_proj_id, from_, to, group_by.copy(), window_size + ) + rows = await self._query(sql) + return UsageResponse( + windowSize=window_size, + data=[ + Usage.from_result( + [ + int((r.get(table.ts_interval) + interval).timestamp()), + r.get(table.value_col), + ], + r, + interval, + group_by, + ) + for r in rows + ], + start=from_.strftime("%Y-%m-%dT%H:%M:%SZ"), + end=to.strftime("%Y-%m-%dT%H:%M:%SZ"), + ) + + async def query_reranking_usage( + self, + filtered_by_org_id: list[str] | None, + filtered_by_proj_id: list[str] | None, + from_: datetime, + to: datetime | None, + group_by: list[str], + window_size: str, + ) -> UsageResponse: + table = RerankTable() + to = to or datetime.now(timezone.utc) + group_by = self._process_group_by(group_by) + sql, interval = _build_time_bucket_query( + table, + filtered_by_org_id, + filtered_by_proj_id, + from_, + to, + group_by.copy(), + window_size, + ) + rows = await self._query(sql) + return UsageResponse( + windowSize=window_size, + data=[ + Usage.from_result( + [ + int((r.get(table.ts_interval) + interval).timestamp()), + r.get(table.value_col), + ], + r, + interval, + group_by, + ) + for r in rows + ], + start=from_.strftime("%Y-%m-%dT%H:%M:%SZ"), + end=to.strftime("%Y-%m-%dT%H:%M:%SZ"), + ) + + async def query_bandwidth( + self, + filtered_by_org_id: list[str] | None, + filtered_by_proj_id: list[str] | None, + from_: datetime, + to: datetime | None, + group_by: list[str], + window_size: str, + ) -> UsageResponse: + table = EgressTable() + to = to or datetime.now(timezone.utc) + group_by = self._process_group_by(group_by) + has_type = "type" in group_by + sql, interval = _build_time_bucket_query( + table, + filtered_by_org_id, + filtered_by_proj_id, + from_, + to, + group_by.copy(), + window_size, + ) + rows = await self._query(sql) + return UsageResponse( + windowSize=window_size, + data=[ + Usage.from_result( + [ + int((r.get(table.ts_interval) + interval).timestamp()), + r.get(table.value_col), + ], + {**r, "type": table.type} if has_type else r, + interval, + group_by, + ) + for r in rows + ], + start=from_.strftime("%Y-%m-%dT%H:%M:%SZ"), + end=to.strftime("%Y-%m-%dT%H:%M:%SZ"), + ) + + async def query_storage( + self, + filtered_by_org_id: list[str] | None, + filtered_by_proj_id: list[str] | None, + from_: datetime, + to: datetime | None, + group_by: list[str], + window_size: str, + ) -> UsageResponse: + file_table = FileStorageTable() + db_table = DBStorageTable() + to = to or datetime.now(timezone.utc) + group_by = self._process_group_by(group_by) + # group_by might be modified + file_sql, _ = _build_time_bucket_query( + file_table, + filtered_by_org_id, + filtered_by_proj_id, + from_, + to, + group_by.copy(), + window_size, + ) + file_rows = await self._query(file_sql) + db_sql, interval = _build_time_bucket_query( + db_table, + filtered_by_org_id, + filtered_by_proj_id, + from_, + to, + group_by.copy(), + window_size, + ) + db_rows = await self._query(db_sql) + if "type" in group_by: # to be compatible with VM query + usages = [] + for r in file_rows: + usages.append( + Usage.from_result( + [ + int((r.get(file_table.ts_interval) + interval).timestamp()), + r.get(file_table.snapshot_col), + ], + {**r, "type": file_table.type}, + interval, + group_by, + ) + ) + for r in db_rows: + usages.append( + Usage.from_result( + [ + int((r.get(db_table.ts_interval) + interval).timestamp()), + r.get(db_table.snapshot_col), + ], + {**r, "type": db_table.type}, + interval, + group_by, + ) + ) + else: + usages = [ + Usage.from_result( + [ + int((r.get(file_table.ts_interval) + interval).timestamp()), + r.get(file_table.snapshot_col), + ], + r, + interval, + group_by, + ) + for r in file_rows + ] + [ + Usage.from_result( + [ + int((r.get(db_table.ts_interval) + interval).timestamp()), + r.get(db_table.snapshot_col), + ], + r, + interval, + group_by, + ) + for r in db_rows + ] + return UsageResponse( + windowSize=window_size, + data=usages, + start=from_.strftime("%Y-%m-%dT%H:%M:%SZ"), + end=to.strftime("%Y-%m-%dT%H:%M:%SZ"), + ) + + async def query_billing( + self, + filtered_by_org_id: list[str] | None, + filtered_by_proj_id: list[str] | None, + from_: datetime, + to: datetime | None, + group_by: list[str], + window_size: str, + ) -> UsageResponse: + cost_table = CostTable() + to = to or datetime.now(timezone.utc) + group_by = list(set([cost_table.org_col] + group_by)) + has_category = "category" in group_by + has_type = "type" in group_by + has_model = "model" in group_by + sql, interval = _build_time_bucket_query( + cost_table, + filtered_by_org_id, + filtered_by_proj_id, + from_, + to, + group_by.copy(), # group_by might be modified + window_size, + ) + rows = await self._query(sql) + usages = [] + + gb_filters = _build_gb_filters(has_category, has_type, has_model) + active_metrics = _get_active_metrics(has_category, has_type) + + usages: list[Usage] = [] + for row in rows: + ts = int((row.get(cost_table.ts_interval) + interval).timestamp()) + for metric in active_metrics: + value = row.get(metric.value_col) + if value <= 0: + continue + + metrics_dict = { + **row, + **metric.extra_dims, + } + usages.append( + Usage.from_result( + [ts, value], + metrics_dict, + interval, + _filter_groupby(group_by, gb_filters[metric.gb_mask]), + ) + ) + return UsageResponse( + windowSize=window_size, + data=usages, + start=from_.strftime("%Y-%m-%dT%H:%M:%SZ"), + end=to.strftime("%Y-%m-%dT%H:%M:%SZ"), + ) diff --git a/services/api/src/owl/utils/cache.py b/services/api/src/owl/utils/cache.py new file mode 100644 index 0000000..4f028e4 --- /dev/null +++ b/services/api/src/owl/utils/cache.py @@ -0,0 +1,267 @@ +from contextlib import asynccontextmanager, suppress +from random import random +from typing import Any, AsyncGenerator, Type, TypeVar + +from pottery import AIORedlock, ReleaseUnlockedLock +from redis import Redis +from redis.asyncio import Redis as RedisAsync +from redis.backoff import EqualJitterBackoff +from redis.exceptions import ConnectionError, TimeoutError +from redis.retry import Retry +from sqlmodel.ext.asyncio.session import AsyncSession + +from owl.types import Organization_, Progress, UsageData + +ProgressType = TypeVar("ProgressType", bound=Progress) + + +class Cache: + def __init__( + self, + *, + redis_url: str, + clickhouse_buffer_key: str, + cache_expiration: int = 5 * 60, # 5 minutes + ): + self._redis_kwargs = dict( + # url=f"redis://[[username]:[password]]@{ENV_CONFIG.redis_host}:{ENV_CONFIG.redis_port}/1", + url=redis_url, + # https://redis.io/kb/doc/22wxq63j93/how-to-manage-client-reconnections-in-case-of-errors-with-redis-py + retry=Retry(EqualJitterBackoff(cap=10, base=1), 5), + retry_on_error=[ConnectionError, TimeoutError, ConnectionResetError], + health_check_interval=15, + decode_responses=True, + ) + self._redis = Redis.from_url(**self._redis_kwargs) + self._redis_async = RedisAsync.from_url(**self._redis_kwargs) + self.clickhouse_buffer_key = clickhouse_buffer_key + self.cache_expiration = int(cache_expiration) + # try: + # self._redis.ping() + # except ConnectionError as e: + # logger.error(f"Failed to connect to Redis: {repr(e)}") + # raise + + def __getitem__(self, key: str) -> str | None: + """ + Getter method. + ``` + cache = Cache(...) + value = cache["key"] + ``` + + Args: + key (str): Key. + + Returns: + value (str | None): Value. + """ + return self._redis.get(key) + + def __setitem__(self, key: str, value: str) -> None: + """ + Setter method. + ``` + cache = Cache(...) + cache["key"] = value + ``` + + Args: + key (str): Key. + value (str): Value. + """ + if not isinstance(value, str): + raise TypeError(f"`value` must be a str, received: {type(value)}") + self._redis.set(key, value) + + def __delitem__(self, key) -> None: + """ + Delete method. + ``` + cache = Cache(...) + del cache["key"] + ``` + + Args: + key (str): Key. + """ + self._redis.delete(key) + + def __contains__(self, key) -> bool: + self._redis.exists(key) + + def purge(self): + self._redis.flushdb() + + async def aclose(self): + self._redis.close() + await self._redis_async.aclose() + + async def get(self, key: str) -> str | None: + return await self._redis_async.get(key) + + async def set(self, key: str, value: str, **kwargs) -> None: + if not isinstance(value, str): + raise TypeError(f"`value` must be a str, received: {type(value)}") + await self._redis_async.set(key, value, **kwargs) + + async def delete(self, key: str) -> None: + await self._redis_async.delete(key) + + async def exists(self, *keys: str) -> int: + return await self._redis_async.exists(*keys) + + @asynccontextmanager + async def alock( + self, + key: str, + blocking: bool = True, + expire: float = 60.0, + ) -> AsyncGenerator[bool, None]: + lock = AIORedlock( + key=key, + masters={self._redis_async}, + auto_release_time=max(1.0, expire), + ) + lock_acquired = await lock.acquire(blocking=blocking) + try: + yield lock_acquired + finally: + if lock_acquired: + with suppress(ReleaseUnlockedLock): + await lock.release() + + async def add_usage_to_buffer(self, usage: UsageData): + await self._redis_async.rpush(self.clickhouse_buffer_key, usage.model_dump_json()) + await self._redis_async.incrby( + self.clickhouse_buffer_key + "_count", usage.total_usage_events + ) + + # def retrieve_usage_buffer(self) -> list[UsageData]: + # return [ + # UsageData.model_validate_json(data) + # for data in self._redis.lrange(self.clickhouse_buffer_key, 0, -1) + # ] + + async def get_usage_buffer_count(self) -> int: + return int(await self._redis_async.get(self.clickhouse_buffer_key + "_count") or 0) + + # def reset_buffer_and_count(self): + # # Delete the buffer and count keys + # del self[self.clickhouse_buffer_key] + # del self[self.clickhouse_buffer_key + "_count"] + + @staticmethod + def get_capacity_search_keys(deployment_id: str) -> dict[str, str]: + queue_key = f"capacity_search_model_queue:{deployment_id}" + active_key = f"capacity_search_model_active:{deployment_id}" + queue_task_key = f"capacity_search_queue_task:{deployment_id}" + active_task_key = f"capacity_search_active_task:{deployment_id}" + return { + "queue_key": queue_key, + "active_key": active_key, + "queue_task_key": queue_task_key, + "active_task_key": active_task_key, + } + + @staticmethod + def get_capacity_search_cancellation_key(task_id: str) -> str: + return f"capacity_search_cancel:{task_id}" + + async def set_progress( + self, + prog: Progress, + ex: int = 240, + nx: bool = False, + **kwargs, + ) -> bool | None: + """ + Set progress data into Redis at key `prog.key`. + + Args: + prog (Progress): Progress instance. + ex (int, optional): Expiration time in seconds. Defaults to 240. + nx (bool, optional): Set this key only if it does not exist. Defaults to False. + + Returns: + response (bool | None): True if published or key is empty, otherwise None. + """ + if not prog.key: + return True + # Returns True if set, None if not + return await self._redis_async.set( + prog.key, + prog.model_dump_json(), + ex=ex, + nx=nx, + **kwargs, + ) + + async def get_progress( + self, + key: str, + response_model: Type[ProgressType] | None = Progress, + ) -> ProgressType | dict[str, Any] | None: + """ + Get progress data from Redis at key `key`. + + Args: + key (str): Progress key. + response_model (Type[ProgressType], optional): Response model. Defaults to `Progress`. + + Returns: + response (ProgressType | dict[str, Any] | None): The progress data. + """ + from owl.utils.io import json_loads + + prog = await self._redis_async.get(key) + if response_model is None: + return json_loads(prog) if prog else prog + if prog: + return response_model.model_validate_json(prog) + return response_model(key=key) + + def _ex_jitter(self) -> int: + # Jitter to prevent cache stampede + return int(self.cache_expiration * random() / 2) + + async def clear_all_async(self) -> None: + pipe = self._redis_async.pipeline() + for prefix in ["user", "organization", "project", "models"]: + async for key in self._redis_async.scan_iter(match=f"{prefix}:*"): + pipe.delete(key) + await pipe.execute() + + async def cache_organization_async(self, organization: Organization_) -> None: + await self.set( + f"organization:{organization.id}", + Organization_.model_validate(organization).model_dump_json(), + ex=self.cache_expiration + self._ex_jitter(), + ) + + async def get_organization_async( + self, + organization_id: str, + session: AsyncSession, + ) -> Organization_ | None: + from owl.db.models import Organization + + if data := await self.get(f"organization:{organization_id}"): + return Organization_.model_validate_json(data) + organization = await session.get(Organization, organization_id) + if organization is None: + return None + organization = Organization_.model_validate(organization) + await self.cache_organization_async(organization) + return organization + + async def clear_organization_async(self, organization_id: str) -> None: + await self.delete(f"organization:{organization_id}") + + async def refresh_organization_async( + self, + organization_id: str, + session: AsyncSession, + ) -> Organization_ | None: + await self.clear_organization_async(organization_id) + return await self.get_organization_async(organization_id, session) diff --git a/services/api/src/owl/utils/code.py b/services/api/src/owl/utils/code.py index 76a764e..2ca653d 100644 --- a/services/api/src/owl/utils/code.py +++ b/services/api/src/owl/utils/code.py @@ -1,58 +1,151 @@ import base64 +import pickle +import time import uuid +from contextlib import asynccontextmanager +from typing import Any import filetype import httpx from fastapi import Request from loguru import logger -from owl.configs.manager import ENV_CONFIG -from owl.utils.io import upload_file_to_s3 +from owl.configs import ENV_CONFIG +from owl.types import AUDIO_FILE_EXTENSIONS, IMAGE_FILE_EXTENSIONS, ColumnDtype +from owl.utils.billing import OPENTELEMETRY_CLIENT +from owl.utils.io import s3_upload +REQ_COUNTER = OPENTELEMETRY_CLIENT.get_counter("code_executor_requests_total") +REQ_SECONDS = OPENTELEMETRY_CLIENT.get_histogram("code_executor_duration_seconds") +RES_BYTES = OPENTELEMETRY_CLIENT.get_histogram("code_executor_result_bytes") -async def code_executor(source_code: str, dtype: str, request: Request) -> str | None: - response = None +def _status_class(code: int | None) -> str: + if code is None: + return "none" try: - if dtype == "image": - dtype = "file" # for code execution endpoint usage - async with httpx.AsyncClient() as client: - response = await client.post( - f"{ENV_CONFIG.code_executor_endpoint}/execute", - json={"code": source_code}, - ) - response.raise_for_status() - result = response.json() + c = int(code) + except (TypeError, ValueError): + return "none" + return f"{c // 100}xx" if 100 <= c <= 599 else "none" + + +@asynccontextmanager +async def observe_code_execution( + *, + organization_id: str, + project_id: str, + dtype: str, +): + start = time.monotonic() + outcome: str = "ok" + error_type: str | None = None + + rec: dict[str, Any] = { + "result_bytes": 0, + "status_code": None, + } + + class Recorder: + def set_result_bytes(self, n: int) -> None: + rec["result_bytes"] = max(0, int(n)) - if dtype == "file": - if result["type"].startswith("image"): - image_content = base64.b64decode(result["result"]) - content_type = filetype.guess(image_content) - if content_type is None: - raise ValueError("Unable to determine file type") - filename = f"{uuid.uuid4()}.{content_type.extension}" + def set_status_code(self, code: int) -> None: + rec["status_code"] = int(code) + try: + yield Recorder() + except Exception as exc: + outcome = "error" + error_type = exc.__class__.__name__ + raise + finally: + duration = time.monotonic() - start + labels = { + "outcome": outcome, + "error_type": error_type or "", + "status_class": _status_class(rec["status_code"]), + "status_code": rec["status_code"] or 0, + "org_id": organization_id, + "proj_id": project_id, + "dtype": dtype, + } + REQ_COUNTER.add(1, labels) + REQ_SECONDS.record(duration, labels) + RES_BYTES.record(rec["result_bytes"], labels) + + +async def code_executor( + *, + request: Request, + organization_id: str, + project_id: str, + source_code: str, + output_column: str, + row_data: dict | None, + dtype: str, +) -> str: + async with observe_code_execution( + organization_id=organization_id, + project_id=project_id, + dtype=dtype, + ) as rec: + try: + async with httpx.AsyncClient(timeout=ENV_CONFIG.code_timeout_sec) as client: + row_data = base64.b64encode(pickle.dumps(row_data)).decode("utf-8") + response = await client.post( + f"{ENV_CONFIG.code_executor_endpoint}/execute", + json={ + "source_code": source_code, + "output_column": output_column, + "row_data": row_data, + }, + ) + rec.set_status_code(response.status_code) + response.raise_for_status() + result = pickle.loads(base64.b64decode(response.text.strip('"'))) + + # Return early if output column is ColumnDtype.STR + if dtype == ColumnDtype.STR: + rec.set_result_bytes(len(str(result).encode("utf-8"))) + logger.info( + f"Code Executor: {request.state.id} - Python code execution completed for column {output_column}" + ) + return str(result) + + if not isinstance(result, bytes): + raise Exception( + f"Expected type bytes for {dtype}, got {type(result)}:\n\n{str(result)[:100]}" + ) + + rec.set_result_bytes(len(result)) + + content_type = filetype.guess(result) + if not content_type: + raise Exception("Result is bytes but could not determine content type") + + file_extension = f".{content_type.extension}" + + # Handle different data types + if (dtype == ColumnDtype.IMAGE and file_extension in IMAGE_FILE_EXTENSIONS) or ( + dtype == ColumnDtype.AUDIO and file_extension in AUDIO_FILE_EXTENSIONS + ): + filename = f"{uuid.uuid4()}{file_extension}" # Upload the file - uri = await upload_file_to_s3( - organization_id=request.state.org_id, - project_id=request.state.project_id, - content=image_content, + uri = await s3_upload( + organization_id=organization_id, + project_id=project_id, + content=result, content_type=content_type.mime, filename=filename, ) - response = uri - else: - logger.warning( - f"Code Executor: {request.state.id} - Unsupported file type: {result['type']}" + logger.info( + f"Code Executor: {request.state.id} - Python code execution completed for column {output_column}" ) - response = None - else: - response = str(result["result"]) - - logger.info(f"Code Executor: {request.state.id} - Python code execution completed") + return uri - except Exception as e: - logger.error(f"Code Executor: {request.state.id} - An unexpected error occurred: {e}") - response = None - - return response + except Exception as e: + logger.error( + f"Code Executor: {request.state.id} - Python code execution encountered error for column {output_column} : {e}" + ) + raise diff --git a/services/api/src/owl/utils/crypt.py b/services/api/src/owl/utils/crypt.py index bcc48b0..e0d43f6 100644 --- a/services/api/src/owl/utils/crypt.py +++ b/services/api/src/owl/utils/crypt.py @@ -7,7 +7,9 @@ import hashlib import secrets from base64 import b64decode, b64encode +from functools import lru_cache from hashlib import blake2b +from typing import Any # Import Union for type annotations from Cryptodome.Cipher import AES from Cryptodome.Random import get_random_bytes @@ -15,7 +17,11 @@ def _encrypt(message: str, password: str, aes_mode: int) -> str: """ - pass + Encrypts a message using AES encryption with the given password and mode. + :param message: The message to encrypt. + :param password: The password to use for encryption. + :param aes_mode: The AES mode to use (either AES.MODE_SIV or AES.MODE_GCM). + :return: The encrypted message as a string. """ if not (aes_mode == AES.MODE_SIV or aes_mode == AES.MODE_GCM): raise ValueError("`aes_mode` can only be `AES.MODE_SIV` or `AES.MODE_GCM`.") @@ -37,16 +43,20 @@ def _encrypt(message: str, password: str, aes_mode: int) -> str: ) # Create cipher config cipher_config = AES.new(private_key, aes_mode) + # Encrypt the message cipher_text, tag = cipher_config.encrypt_and_digest(message.encode("utf-8")) - cipher_text = b64encode(cipher_text).decode("utf-8") - tag = b64encode(tag).decode("utf-8") + # Encode the cipher_text and tag to base64 + cipher_text_b64 = b64encode(cipher_text).decode("utf-8") + tag_b64 = b64encode(tag).decode("utf-8") + # Create final encrypted text if aes_mode == AES.MODE_SIV: - encrypted = f"{cipher_text}*{tag}" + encrypted = f"{cipher_text_b64}*{tag_b64}" else: - salt = b64encode(salt).decode("utf-8") - nonce = b64encode(cipher_config.nonce).decode("utf-8") - encrypted = f"{cipher_text}*{salt}*{nonce}*{tag}" + salt_b64 = b64encode(salt).decode("utf-8") + nonce_b64 = b64encode(cipher_config.nonce).decode("utf-8") + encrypted = f"{cipher_text_b64}*{salt_b64}*{nonce_b64}*{tag_b64}" + return encrypted @@ -57,6 +67,7 @@ def encrypt_random(message: str, password: str) -> str: return _encrypt(message, password, AES.MODE_GCM) +@lru_cache(maxsize=100000) def encrypt_deterministic(message: str, password: str) -> str: """ Deterministic encryption using AES SIV mode with @@ -65,47 +76,60 @@ def encrypt_deterministic(message: str, password: str) -> str: return _encrypt(message, password, AES.MODE_SIV) +@lru_cache(maxsize=100000) def decrypt(encrypted: str, password: str) -> str: + """ + Decrypts an encrypted message using AES decryption with the given password. + + :param encrypted: The encrypted message as a string. + :param password: The password used for decryption. + :return: The decrypted message as a string. + """ parts = encrypted.split("*") n_parts = len(parts) # Decode the entries from base64 if n_parts == 4: - cipher_text, salt, nonce, tag = parts - salt = b64decode(salt) - nonce = b64decode(nonce) + cipher_text_b64, salt_b64, nonce_b64, tag_b64 = parts + salt = b64decode(salt_b64) # Decode salt to bytes + nonce = b64decode(nonce_b64) # Decode nonce to bytes elif n_parts == 2: - cipher_text, tag = parts - salt = b"" - nonce = None - # elif n_parts == 1: - # logger.warning(f"Attempting to decrypt string that looks unencrypted: {encrypted}") - # return encrypted + cipher_text_b64, tag_b64 = parts + salt = b"" # Use empty salt for AES.MODE_SIV + nonce = None # No nonce for AES.MODE_SIV else: raise ValueError(f"Encrypted string must have either 2 or 4 parts, received: {n_parts}") - cipher_text = b64decode(cipher_text) - tag = b64decode(tag) + + # Decode cipher_text and tag to bytes + cipher_text = b64decode(cipher_text_b64) + tag = b64decode(tag_b64) + # Generate the private key from the password and salt private_key = hashlib.scrypt( - password.encode(), - salt=salt, + password.encode(), # Encode password to bytes + salt=salt, # salt is already bytes n=2**14, r=8, p=1, dklen=32, ) + # Create the cipher config + cipher: Any # Use Any to avoid issues with inaccessible types if n_parts == 4: - cipher = AES.new(private_key, AES.MODE_GCM, nonce=nonce) + cipher = AES.new(private_key, AES.MODE_GCM, nonce=nonce) # Use GCM mode with nonce else: - cipher = AES.new(private_key, AES.MODE_SIV) + cipher = AES.new(private_key, AES.MODE_SIV) # Use SIV mode + # Decrypt the cipher text - decrypted = cipher.decrypt_and_verify(cipher_text, tag) - return decrypted.decode("UTF-8") + decrypted = cipher.decrypt_and_verify(cipher_text, tag) # Both inputs are bytes + return decrypted.decode("UTF-8") # Decode the decrypted bytes to a string -def hash_string_blake2b(string: str, digest_size: int = 8) -> str: - hasher = blake2b(digest_size=digest_size) +def hash_string_blake2b(string: str, key_length: int = 8) -> str: + if key_length % 2 != 0: + raise ValueError("Key length must be a multiple of 2.") + hasher = blake2b(digest_size=key_length // 2) # 2 characters per byte hasher.update(string.encode()) return hasher.hexdigest() @@ -132,13 +156,13 @@ def generate_key(key_length: int = 48, prefix: str = "") -> str: prefix (str, optional): Prefix of the key. Defaults to "". Raises: - ValueError: If `key_length` is < 16 or not a multiple of 2. + ValueError: If `key_length` is < 8 or not a multiple of 2. Returns: api_key (str): A random key. """ - if key_length < 16: - raise ValueError("Key length must be at least 16 characters.") + if key_length < 8: + raise ValueError("Key length must be at least 8 characters.") if key_length % 2 != 0: raise ValueError("Key length must be a multiple of 2.") api_key = blake2b(secrets.token_bytes(key_length), digest_size=key_length // 2).hexdigest() diff --git a/services/api/src/owl/utils/dates.py b/services/api/src/owl/utils/dates.py new file mode 100644 index 0000000..4996017 --- /dev/null +++ b/services/api/src/owl/utils/dates.py @@ -0,0 +1,14 @@ +from jamaibase.utils.dates import ( # noqa: F401 + date_to_utc, + date_to_utc_iso, + earliest, + ensure_utc_timezone, + now, + now_iso, + now_tz_naive, + utc_datetime_from_iso, + utc_iso_from_datetime, + utc_iso_from_string, + utc_iso_from_uuid7, + utc_iso_from_uuid7_draft2, +) diff --git a/services/api/src/owl/utils/exceptions.py b/services/api/src/owl/utils/exceptions.py index 991fedd..8844185 100644 --- a/services/api/src/owl/utils/exceptions.py +++ b/services/api/src/owl/utils/exceptions.py @@ -1,21 +1,37 @@ -from functools import wraps +from functools import partial, wraps from inspect import iscoroutinefunction -from typing import Any, Callable, Type, TypeVar, overload +from typing import Any, Callable, TypeVar, overload from fastapi import Request from fastapi.exceptions import RequestValidationError -from filelock import Timeout from loguru import logger -from pydantic import ValidationError from sqlalchemy.exc import IntegrityError -from jamaibase.exceptions import JamaiException, ResourceExistsError, UnexpectedError - - -def check_type(obj: Any, clss: tuple[Type] | Type, mssg: str) -> None: - if not isinstance(obj, clss): - raise TypeError(f"{mssg} Received: {type(obj)}") - +# Import from jamaibase for use within owl +from jamaibase.utils.exceptions import ( # noqa: F401 + AuthorizationError, + BadInputError, + BaseTierCountError, + ContextOverflowError, + ExternalAuthError, + ForbiddenError, + InsufficientCreditsError, + JamaiException, + MethodNotAllowedError, + ModelCapabilityError, + ModelOverloadError, + NoTierError, + RateLimitExceedError, + ResourceExistsError, + ResourceNotFoundError, + ServerBusyError, + UnavailableError, + UnexpectedError, + UnsupportedMediaTypeError, + UpgradeTierError, + UpStreamError, + docstring_message, +) F = TypeVar("F", bound=Callable[..., Any]) @@ -24,29 +40,27 @@ def check_type(obj: Any, clss: tuple[Type] | Type, mssg: str) -> None: def handle_exception( func: F, *, - failure_message: str = "", + handler: Callable[..., Any] | None = None, ) -> F: ... @overload def handle_exception( *, - failure_message: str = "", + handler: Callable[..., Any] | None = None, ) -> Callable[[F], F]: ... def handle_exception( func: F | None = None, *, - failure_message: str = "", handler: Callable[..., Any] | None = None, ) -> Callable[[F], F] | F: - # TODO: Add support for callable as "failure_message" """ A decorator to handle exceptions for both synchronous and asynchronous functions. Its main purpose is to: - - Provide more meaningful error messages for logging. - - Transform certain error classes, for example `RequestValidationError` -> `ValidationError`. + - Produce shorter traceback (160 vs 500 lines) upon unexpected errors (such as `ValueError`). + - Transform certain error classes, for example `IntegrityError` -> `ResourceExistsError`. It also allows you to specify a custom exception handler function. The handler function should accept a single positional argument (the exception instance) @@ -57,84 +71,64 @@ def handle_exception( Args: func (F | None): The function to be decorated. This can be either a synchronous or asynchronous function. When used as a decorator, leave this unset. Defaults to `None`. - failure_message (str): Optional message to be logged for timeout and unexpected exceptions. Defaults to "". handler (Callable[..., None] | None): A custom exception handler function. - The handler function should accept a single positional argument (the exception instance) - and all keyword arguments passed to the decorated function. + The handler function should accept a positional argument (the exception instance) + followed by all arguments passed to the decorated function. Returns: func (Callable[[F], F] | F): The decorated function with exception handling applied. Raises: - JamaiException: If the exception is of type JamaiException. - RequestValidationError: If the exception is a FastAPI RequestValidationError. - ValidationError: Wraps Pydantic ValidationError as RequestValidationError. - ResourceExistsError: If an IntegrityError indicates a unique constraint violation in the database. - UnexpectedError: For any other unhandled exceptions. + JamaiException: If `JamaiException` is raised. + RequestValidationError: If `fastapi.exceptions.RequestValidationError` is raised. + ResourceExistsError: If `sqlalchemy.exc.IntegrityError` indicates a unique constraint violation in the database. + UnexpectedError: For all other exception. """ - def decorator(fn: F) -> F: - def _handle_exception(e: Exception, kwargs): - try: - if handler is not None: - return handler(e, **kwargs) - except e.__class__: - pass - except Exception: - logger.warning(f"Exception handler failed for exception: {e}") - - if isinstance(e, JamaiException): - raise - elif isinstance(e, RequestValidationError): - raise - elif isinstance(e, ValidationError): - # Sometimes ValidationError is raised from additional checking code - raise RequestValidationError(errors=e.errors()) from e - elif isinstance(e, IntegrityError): - err_mssg: str = e.args[0] - err_mssg = err_mssg.split("UNIQUE constraint failed:") - if len(err_mssg) > 1: - constraint = err_mssg[1].strip() - raise ResourceExistsError(f'DB item "{constraint}" already exists.') from e - else: - raise UnexpectedError(f"{e.__class__.__name__}: {e}") from e - elif isinstance(e, Timeout): - request: Request | None = kwargs.get("request", None) - mssg = failure_message if failure_message else "Could not acquire lock" - mssg = f"{e.__class__.__name__}: {e} - {mssg} - kwargs={kwargs}" - if request: - logger.warning(f"{request.state.id} - {mssg}") - else: - logger.warning(mssg) - raise + def _default_handler(e: Exception, *args, **kwargs): + if isinstance(e, JamaiException): + raise + elif isinstance(e, RequestValidationError): + raise + # elif isinstance(e, ValidationError): + # raise RequestValidationError(errors=e.errors()) from e + elif isinstance(e, IntegrityError): + err_mssg: str = e.args[0] + err_mssgs = err_mssg.split("UNIQUE constraint failed:") + if len(err_mssgs) > 1: + constraint = err_mssgs[1].strip() + raise ResourceExistsError(f'DB item "{constraint}" already exists.') from e else: - request: Request | None = kwargs.get("request", None) - mssg = failure_message if failure_message else f"Failed to run {fn.__name__}" - mssg = f"{e.__class__.__name__}: {e} - {mssg} - kwargs={kwargs}" - if request: - logger.error(f"{request.state.id} - {mssg}") - else: - logger.error(mssg) raise UnexpectedError(f"{e.__class__.__name__}: {e}") from e + else: + request: Request | None = kwargs.get("request", None) + mssg = f"Failed to run {func.__name__}" + mssg = f"{e.__class__.__name__}: {e} - {mssg} - kwargs={kwargs}" + if request: + logger.error(f"{request.state.id} - {mssg}") + else: + logger.error(mssg) + raise UnexpectedError(f"{e.__class__.__name__}: {e}") from e - if iscoroutinefunction(fn): + if handler is None: + handler = _default_handler - @wraps(fn) - async def wrapper(**kwargs): - try: - return await fn(**kwargs) - except Exception as e: - return _handle_exception(e, kwargs) + if iscoroutinefunction(func): - else: + @wraps(func) + async def wrapper(*args, **kwargs): + try: + return await func(*args, **kwargs) + except Exception as e: + return handler(e, *args, **kwargs) - @wraps(fn) - def wrapper(**kwargs): - try: - return fn(**kwargs) - except Exception as e: - return _handle_exception(e, kwargs) + else: - return wrapper + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + return handler(e, *args, **kwargs) - return decorator if func is None else decorator(func) + return partial(handle_exception, handler=handler) if func is None else wrapper diff --git a/services/api/src/owl/utils/handlers.py b/services/api/src/owl/utils/handlers.py new file mode 100644 index 0000000..07cd952 --- /dev/null +++ b/services/api/src/owl/utils/handlers.py @@ -0,0 +1,391 @@ +from typing import Any, Mapping + +import orjson +from fastapi import Request, status +from fastapi.exceptions import RequestValidationError +from fastapi.responses import ORJSONResponse +from loguru import logger +from pydantic import BaseModel +from sqlalchemy.exc import IntegrityError +from starlette.exceptions import HTTPException + +from owl.utils import mask_string +from owl.utils.exceptions import ( + AuthorizationError, + BadInputError, + ContextOverflowError, + ExternalAuthError, + ForbiddenError, + InsufficientCreditsError, + JamaiException, + MethodNotAllowedError, + ModelOverloadError, + RateLimitExceedError, + ResourceExistsError, + ResourceNotFoundError, + ServerBusyError, + UnavailableError, + UnsupportedMediaTypeError, + UpgradeTierError, +) + +INTERNAL_ERROR_MESSAGE = "Oops sorry we ran into an unexpected error. Please try again later." + + +def make_request_log_str(request: Request, status_code: int | None = None) -> str: + """ + Generate a string for logging, given a request object and an HTTP status code. + + Args: + request (Request): Starlette request object. + status_code (int): HTTP error code. + + Returns: + str: A string in the format + ' - " " ' + """ + query = request.url.query + query = f"?{query}" if query else "" + msg = f'{request.state.id} - "{request.method} {request.url.path}{query}"' + if status_code is not None: + msg = f"{msg} {status_code}" + return msg + + +def make_response( + request: Request, + message: str, + error: str, + status_code: int, + *, + detail: str | None = None, + exception: Exception | None = None, + headers: Mapping[str, str] | None = None, + log: bool = True, +) -> ORJSONResponse: + """ + Create a Response object. + + Args: + request (Request): Starlette request object. + message (str): User-friendly error message to be displayed by frontend or SDK. + error (str): Short error name. + status_code (int): HTTP error code. + detail (str | None, optional): Error message with potentially more details. + Defaults to None (message + headers). + exception (Exception | None, optional): Exception that occurred. Defaults to None. + headers (Mapping[str, str] | None, optional): Response headers. Defaults to None. + log (bool, optional): Whether to log the response. Defaults to True. + + Returns: + response (ORJSONResponse): Response object. + """ + if detail is None: + detail = f"{message}\nException:{repr(exception)}" + if headers is None: + headers = {} + headers["x-request-id"] = request.headers.get("x-request-id", "") + request_headers = {k.lower(): v for k, v in request.headers.items()} + token = request_headers.get("authorization", "") + if token.startswith("Bearer "): + request_headers["authorization"] = f"Bearer {mask_string(token[7:], include_len=False)}" + else: + request_headers["authorization"] = mask_string(token, include_len=False) + response = ORJSONResponse( + status_code=status_code, + content={ + "object": "error", + "error": error, + "message": message, + "detail": detail, + "request_id": request.state.id, + "exception": exception.__class__.__name__ if exception else None, + "request_headers": request_headers, + }, + headers=headers, + ) + mssg = make_request_log_str(request, response.status_code) + if not log: + return response + if status_code == 500: + log_fn = logger.exception + elif status_code > 500: + log_fn = logger.warning + elif exception is None: + log_fn = logger.info + elif isinstance(exception, (JamaiException, HTTPException)): + log_fn = logger.info + else: + log_fn = logger.warning + if exception: + log_fn(f"{mssg} - {exception.__class__.__name__}: {exception}") + else: + log_fn(mssg) + return response + + +class Wrapper(BaseModel): + body: Any + + +async def _request_validation_exc_handler(request: Request, exc: RequestValidationError): + content = None + try: + logger.info( + f"{make_request_log_str(request, 422)} - RequestValidationError: {exc.errors()}" + ) + errors, messages = [], [] + for i, e in enumerate(exc.errors()): + try: + msg = str(e["ctx"]["error"]).strip() + except Exception: + msg = e["msg"].strip() + if not msg.endswith("."): + msg = f"{msg}." + + path = "" + for j, x in enumerate(e.get("loc", [])): + if isinstance(x, str): + if j > 0: + path += "." + path += x + elif isinstance(x, int): + path += f"[{x}]" + else: + raise TypeError("Unexpected type") + if path: + path += " : " + messages.append(f"{i + 1}. {path}{msg}") + error = {k: v for k, v in e.items() if k != "ctx"} + if "ctx" in e: + error["ctx"] = {k: repr(v) if k == "error" else v for k, v in e["ctx"].items()} + if "input" in e: + error["input"] = repr(e["input"]) + errors.append(error) + message = "\n".join(messages) + message = f"Your request contains errors:\n{message}" + content = { + "object": "error", + "error": "validation_error", + "message": message, + "detail": errors, + "request_id": request.state.id, + "exception": "", + **Wrapper(body=exc.body).model_dump(), + } + return ORJSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content=content, + ) + except Exception: + if content is None: + content = repr(exc) + logger.exception(f"{request.state.id} - Failed to parse error data: {content}") + message = str(exc) + return ORJSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content={ + "object": "error", + "error": "validation_error", + "message": message, + "detail": message, + "request_id": request.state.id, + "exception": exc.__class__.__name__, + }, + ) + + +async def path_not_found_handler(request: Request, e: HTTPException): + return make_response( + request=request, + message=f"The path '{request.url.path}' was not found.", + error="http_error", + status_code=e.status_code, + exception=e, + log=False, + ) + + +async def exception_handler(request: Request, e: Exception): + if isinstance(e, RequestValidationError): + return await _request_validation_exc_handler(request, e) + # elif isinstance(e, ValidationError): + # raise RequestValidationError(errors=e.errors()) from e + elif isinstance(e, AuthorizationError): + return make_response( + request=request, + message=str(e), + error="unauthorized", + status_code=status.HTTP_401_UNAUTHORIZED, + exception=e, + ) + elif isinstance(e, ExternalAuthError): + return make_response( + request=request, + message=str(e), + error="external_authentication_failed", + status_code=status.HTTP_401_UNAUTHORIZED, + exception=e, + ) + elif isinstance(e, PermissionError): + return make_response( + request=request, + message=str(e), + error="resource_protected", + status_code=status.HTTP_403_FORBIDDEN, + exception=e, + ) + elif isinstance(e, ForbiddenError): + return make_response( + request=request, + message=str(e), + error="forbidden", + status_code=status.HTTP_403_FORBIDDEN, + exception=e, + ) + elif isinstance(e, UpgradeTierError): + return make_response( + request=request, + message=str(e), + error="upgrade_tier", + status_code=status.HTTP_403_FORBIDDEN, + exception=e, + ) + elif isinstance(e, InsufficientCreditsError): + return make_response( + request=request, + message=str(e), + error="insufficient_credits", + status_code=status.HTTP_403_FORBIDDEN, + exception=e, + ) + elif isinstance(e, (ResourceNotFoundError, FileNotFoundError)): + return make_response( + request=request, + message=str(e), + error="resource_not_found", + status_code=status.HTTP_404_NOT_FOUND, + exception=e, + ) + elif isinstance(e, (ResourceExistsError, FileExistsError)): + return make_response( + request=request, + message=str(e), + error="resource_exists", + status_code=status.HTTP_409_CONFLICT, + exception=e, + ) + elif isinstance(e, UnsupportedMediaTypeError): + return make_response( + request=request, + message=str(e), + error="unsupported_media_type", + status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, + exception=e, + ) + elif isinstance(e, BadInputError): + return make_response( + request=request, + message=str(e), + error="bad_input", + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + exception=e, + ) + elif isinstance(e, ContextOverflowError): + return make_response( + request=request, + message=str(e), + error="context_overflow", + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + exception=e, + ) + elif isinstance(e, RateLimitExceedError): + retry_after = "30" if e.retry_after is None else str(e.retry_after) + used = str(e.limit) if e.used is None else str(e.used) + meta = "{}" if e.meta is None else orjson.dumps(e.meta).decode("utf-8") + return make_response( + request=request, + message=str(e), + error="rate_limit_exceeded", + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + exception=e, + headers={ + "X-RateLimit-Limit": str(e.limit), + "X-RateLimit-Remaining": str(e.remaining), + "X-RateLimit-Reset": str(e.reset_at), + "Retry-After": retry_after, + "X-RateLimit-Used": used, + "X-RateLimit-Meta": meta, + }, + ) + elif isinstance(e, UnavailableError): + return make_response( + request=request, + message=str(e), + error="not_implemented", + status_code=status.HTTP_501_NOT_IMPLEMENTED, + exception=e, + ) + elif isinstance(e, ServerBusyError): + return make_response( + request=request, + message="The server is currently busy. Please try again later.", + error="busy", + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + exception=e, + headers={"Retry-After": "30"}, + ) + elif isinstance(e, ModelOverloadError): + return make_response( + request=request, + message="The model is overloaded. Please try again later.", + error="busy", + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + exception=e, + headers={"Retry-After": "30"}, + ) + elif isinstance(e, HTTPException): + return make_response( + request=request, + message=e.detail, + error="http_error", + status_code=e.status_code, + exception=e, + log=e.status_code != 404, + ) + elif isinstance(e, IntegrityError): + err_mssg: str = e.args[0] + err_mssgs = err_mssg.split("UNIQUE constraint failed:") + if len(err_mssgs) > 1: + constraint = err_mssgs[1].strip() + return make_response( + request=request, + message=f'DB item "{constraint}" already exists.', + error="resource_exists", + status_code=status.HTTP_409_CONFLICT, + exception=e, + ) + else: + return make_response( + request=request, + message=INTERNAL_ERROR_MESSAGE, + error="unexpected_error", + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + exception=e, + ) + elif isinstance(e, MethodNotAllowedError): + return make_response( + request=request, + message=str(e), + error="method_not_allowed", + status_code=status.HTTP_405_METHOD_NOT_ALLOWED, + exception=e, + ) + else: + return make_response( + request=request, + message=INTERNAL_ERROR_MESSAGE, + error="unexpected_error", + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + exception=e, + ) diff --git a/services/api/src/owl/utils/io.py b/services/api/src/owl/utils/io.py index 91fef1e..8f5149b 100644 --- a/services/api/src/owl/utils/io.py +++ b/services/api/src/owl/utils/io.py @@ -1,78 +1,63 @@ -import asyncio -import contextlib import os -import pathlib -import zipfile +from contextlib import asynccontextmanager +from hashlib import blake2b from io import BytesIO -from os import listdir, walk -from os.path import abspath, dirname, getsize, isdir, islink, join, relpath -from typing import AsyncGenerator, BinaryIO, Generator +from os.path import join, splitext +from pathlib import Path +from typing import AsyncGenerator, BinaryIO import aioboto3 -import aiofiles -import boto3 from botocore.exceptions import ClientError from loguru import logger - -from jamaibase.exceptions import BadInputError, ResourceNotFoundError -from jamaibase.utils.io import generate_audio_thumbnail, generate_image_thumbnail -from owl.configs.manager import ENV_CONFIG +from PIL import Image, ImageDraw, ImageFont +from sqlmodel import select + +from jamaibase.utils.io import ( # noqa: F401 + AUDIO_WHITE_LIST, + DOC_WHITE_LIST, + EMBED_WHITE_LIST, + IMAGE_WHITE_LIST, + csv_to_df, + df_to_csv, + dump_json, + dump_pickle, + dump_toml, + dump_yaml, + guess_mime, + json_dumps, + json_loads, + load_pickle, + read_image, + read_json, + read_toml, + read_yaml, +) +from owl.configs import ENV_CONFIG +from owl.types import DBStorageUsage, TableType from owl.utils import uuid7_str +from owl.utils.exceptions import BadInputError, ResourceNotFoundError -if ENV_CONFIG.owl_file_dir.startswith("s3://"): - S3_CLIENT = boto3.client( - "s3", - aws_access_key_id=ENV_CONFIG.s3_access_key_id, - aws_secret_access_key=ENV_CONFIG.s3_secret_access_key_plain, - endpoint_url=ENV_CONFIG.s3_endpoint, - ) - S3_BUCKET_NAME = ENV_CONFIG.owl_file_dir.replace("s3://", "") - LOCAL_FILE_DIR = "" - logger.info(f"Starting with S3 File Storage: {S3_BUCKET_NAME}") +S3_BUCKET_NAME = ENV_CONFIG.file_dir.replace("s3://", "") +ASSET_DIRPATH = Path(__file__).resolve().parent.parent / "assets" +ICON_DIRPATH = ASSET_DIRPATH / "icons" +if ICON_DIRPATH.is_dir() and (ICON_DIRPATH / "csv.webp").is_file(): + logger.info(f'Documents icons will be loaded from "{ICON_DIRPATH}".') else: - S3_CLIENT = None - S3_BUCKET_NAME = "" - LOCAL_FILE_DIR = ENV_CONFIG.owl_file_dir.replace("file://", "") - logger.info(f"Starting with Local File Storage: {LOCAL_FILE_DIR}") - -EMBED_WHITE_LIST = { - "application/pdf": [".pdf"], - "application/xml": [".xml"], - "application/json": [".json"], - "application/jsonl": [".jsonl"], - "application/x-ndjson": [".jsonl"], - "application/json-lines": [".jsonl"], - "application/vnd.openxmlformats-officedocument.wordprocessingml.document": [".docx"], - "application/msword": [".doc"], - "application/vnd.openxmlformats-officedocument.presentationml.presentation": [".pptx"], - "application/vnd.ms-powerpoint": [".ppt"], - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": [".xlsx"], - "application/vnd.ms-excel": [".xls"], - "text/markdown": [".md"], - "text/plain": [".txt"], - "text/html": [".html"], - "text/tab-separated-values": [".tsv"], - "text/csv": [".csv"], - "text/xml": [".xml"], -} -IMAGE_WHITE_LIST = { - "image/jpeg": [".jpg", ".jpeg"], - "image/png": [".png"], - "image/gif": [".gif"], - "image/webp": [".webp"], -} -AUDIO_WHITE_LIST = { - "audio/mpeg": [".mp3"], - "audio/vnd.wav": [".wav"], - "audio/x-wav": [".wav"], - "audio/x-pn-wav": [".wav"], - "audio/wave": [".wav"], - "audio/vnd.wave": [".wav"], -} + ICON_DIRPATH = None + logger.warning( + f'Documents icons not found in "{ICON_DIRPATH}". Falling back to generating text-based thumbnails.' + ) +GiB = 1024**3 + UPLOAD_WHITE_LIST = {**EMBED_WHITE_LIST, **IMAGE_WHITE_LIST, **AUDIO_WHITE_LIST} EMBED_WHITE_LIST_MIME = set(EMBED_WHITE_LIST.keys()) EMBED_WHITE_LIST_EXT = set(ext for exts in EMBED_WHITE_LIST.values() for ext in exts) +DOC_WHITE_LIST_MIME = set(DOC_WHITE_LIST.keys()) +DOC_WHITE_LIST_EXT = set(ext for exts in DOC_WHITE_LIST.values() for ext in exts) +NON_PDF_DOC_WHITE_LIST_EXT = set( + ext for exts in DOC_WHITE_LIST.values() for ext in exts if ext != ".pdf" +) IMAGE_WHITE_LIST_MIME = set(IMAGE_WHITE_LIST.keys()) IMAGE_WHITE_LIST_EXT = set(ext for exts in IMAGE_WHITE_LIST.values() for ext in exts) AUDIO_WHITE_LIST_MIME = set(AUDIO_WHITE_LIST.keys()) @@ -81,86 +66,7 @@ UPLOAD_WHITE_LIST_EXT = set(ext for exts in UPLOAD_WHITE_LIST.values() for ext in exts) -def get_db_usage(db_dir: str) -> float: - """Returns the DB storage used in bytes (B).""" - db_usage = 0.0 - for root, dirs, filenames in walk(abspath(db_dir), topdown=True): - # Don't visit Lance version directories - if root.endswith(".lance") and "_versions" in dirs: - dirs.remove("_versions") - for f in filenames: - fp = join(root, f) - if islink(fp): - continue - db_usage += getsize(fp) - return db_usage - - -def get_storage_usage(db_dir: str) -> dict[str, float]: - """Returns the DB storage used by each organisation in GiB.""" - db_usage = {} - for org_id in listdir(db_dir): - org_dir = join(db_dir, org_id) - if not (isdir(org_dir) and org_id.startswith("org_")): - continue - db_usage[org_id] = get_db_usage(org_dir) - db_usage = {k: v / (1024**3) for k, v in db_usage.items()} - return db_usage - - -def get_file_usage(db_dir: str) -> dict[str, float]: - """Returns the File storage used by each organisation in GiB.""" - file_usage = {} - if S3_CLIENT: - paginator = S3_CLIENT.get_paginator("list_objects_v2") - for org_id in listdir(db_dir): - org_dir = join(db_dir, org_id) - if not (isdir(org_dir) and org_id.startswith("org_")): - continue - - total_size = 0 - for prefix in [f"raw/{org_id}/", f"thumb/{org_id}/"]: - for page in paginator.paginate(Bucket=S3_BUCKET_NAME, Prefix=prefix): - for obj in page.get("Contents", []): - total_size += obj["Size"] - - file_usage[org_id] = total_size / (1024**3) # Convert to GiB - else: - for org_id in listdir(db_dir): - org_dir = join(db_dir, org_id) - print(org_id) - if not (isdir(org_dir) and org_id.startswith(("org_", "default"))): - continue - total_size = 0 - for subdir in ["raw", "thumb"]: - file_dir = join(LOCAL_FILE_DIR, subdir, org_id) - print(LOCAL_FILE_DIR) - if os.path.exists(file_dir): - for root, _, files in os.walk(file_dir): - for file in files: - file_path = join(root, file) - total_size += os.path.getsize(file_path) - - file_usage[org_id] = total_size / (1024**3) # Convert to GiB - - return file_usage - - -def zip_directory_content(root_dir: str, output_filepath: str) -> None: - root_dir = abspath(root_dir) - output_filepath = abspath(output_filepath) - if dirname(output_filepath) == root_dir: - raise ValueError("Output directory cannot be the zipped directory.") - with zipfile.ZipFile(output_filepath, "w", zipfile.ZIP_DEFLATED) as f: - for dir_name, _, filenames in walk(root_dir): - for filename in filenames: - filepath = join(dir_name, filename) - # Create a relative path for the file in the zip archive - arcname = relpath(filepath, root_dir) - f.write(filepath, arcname) - - -@contextlib.asynccontextmanager +@asynccontextmanager async def get_s3_aclient(): async with aioboto3.Session().client( "s3", @@ -171,74 +77,224 @@ async def get_s3_aclient(): yield aclient -# Synchronous version -@contextlib.contextmanager -def open_uri_sync(uri: str) -> Generator[BinaryIO | BytesIO, None, None]: - if S3_CLIENT: - if uri.startswith("s3://"): - try: - bucket_name, key = uri[5:].split("/", 1) - response = S3_CLIENT.get_object(Bucket=bucket_name, Key=key) - yield response["Body"] - except ClientError as e: - logger.warning(f'Failed to open "{uri}" due to {e.__class__.__name__}: {e}') - raise ResourceNotFoundError(f'File "{uri}" is not found.') from e - except Exception as e: - logger.exception(f'Failed to open "{uri}" due to {e.__class__.__name__}: {e}') - raise ResourceNotFoundError(f'File "{uri}" is not found.') from e - else: - raise ResourceNotFoundError(f'File "{uri}" is not found.') - else: - if uri.startswith("file://"): - try: - local_path = os.path.abspath(uri[7:]) - with open(local_path, "rb") as file: - yield file - except FileNotFoundError as e: - logger.warning(f'Failed to open "{uri}" due to {e.__class__.__name__}: {e}') - raise ResourceNotFoundError(f'File "{uri}" is not found.') from e - except Exception as e: - logger.exception(f'Failed to open "{uri}" due to {e.__class__.__name__}: {e}') - raise ResourceNotFoundError(f'File "{uri}" is not found.') from e - else: - raise ResourceNotFoundError(f'File "{uri}" is not found.') - - # Asynchronous version -@contextlib.asynccontextmanager -async def open_uri_async(uri: str) -> AsyncGenerator[BinaryIO | BytesIO, None]: - if S3_CLIENT: - if uri.startswith("s3://"): - try: - bucket_name, key = uri[5:].split("/", 1) - async with get_s3_aclient() as aclient: - response = await aclient.get_object(Bucket=bucket_name, Key=key) - yield response["Body"] - except ClientError as e: - logger.warning(f'Failed to open "{uri}" due to {e.__class__.__name__}: {e}') - raise ResourceNotFoundError(f'File "{uri}" is not found.') from e - except Exception as e: - logger.exception(f'Failed to open "{uri}" due to {e.__class__.__name__}: {e}') +@asynccontextmanager +async def open_uri_async(uri: str) -> AsyncGenerator[tuple[BinaryIO | BytesIO, str], None]: + if isinstance(uri, str) and uri.startswith("s3://"): + try: + bucket_name, key = uri[5:].split("/", 1) + async with get_s3_aclient() as aclient: + response = await aclient.get_object(Bucket=bucket_name, Key=key) + yield response["Body"], str(response["ContentType"]) + except ClientError as e: + if "NoSuchKey" in str(e): raise ResourceNotFoundError(f'File "{uri}" is not found.') from e - else: - raise ResourceNotFoundError(f'File "{uri}" is not found.') + logger.warning(f'Failed to open "{uri}" due to {e.__class__.__name__}: {e}') + raise ResourceNotFoundError(f'File "{uri}" cannot be opened.') from e + except Exception as e: + logger.exception(f'Failed to open "{uri}" due to {e.__class__.__name__}: {e}') + raise ResourceNotFoundError(f'File "{uri}" cannot be opened.') from e else: - if uri.startswith("file://"): + raise ResourceNotFoundError(f'File "{uri}" cannot be opened.') + + +def get_bytes_size_mb(bytes_content: bytes, decimal_places: int = 3) -> float: + """ + Convert bytes to megabytes (MB). + + Args: + bytes_content (bytes): The content in bytes to be calculated. + decimal_places (int, optional): Number of decimal places to round to. Defaults to 3. + + Returns: + float: The converted value in megabytes (MB) + """ + mb_value = len(bytes_content) / (1024 * 1024) # 1 MB = 1024 KB = 1024 * 1024 bytes + return round(mb_value, decimal_places) + + +def _image_to_webp_bytes(image: Image.Image) -> bytes: + """ + Converts an image to bytes. + + Args: + image (Image.Image): The image. + + Returns: + bytes: The image as bytes (WebP format). + """ + with BytesIO() as f: + image.save( + f, + format="webp", + lossless=False, + quality=60, + alpha_quality=50, + method=6, + exact=False, + ) + return f.getvalue() + + +def generate_image_thumbnail( + file_content: bytes, + size: tuple[float, float] = (450.0, 450.0), +) -> bytes | None: + """ + Generates an image thumbnail. + + Args: + file_content (bytes): The image file content. + size (tuple[float, float]): The desired size of the thumbnail (width, height). + Defaults to (450.0, 450.0). + + Returns: + thumbnail (bytes | None): The thumbnail image as bytes, or None if generation fails. + """ + try: + with Image.open(BytesIO(file_content)) as img: + # Check image mode + if img.mode not in ("RGB", "RGBA"): + img = img.convert("RGB") + # Resize and save + img.thumbnail(size=size) + return _image_to_webp_bytes(img) + except Exception as e: + logger.exception(f"Failed to generate image thumbnail due to {e.__class__.__name__}: {e}") + return None + + +def generate_audio_thumbnail( + file_content: bytes, + duration_ms: int = 30000, +) -> bytes | None: + """ + Generates an audio thumbnail by extracting a segment from the original audio. + + Args: + file_content (bytes): The audio file content. + duration_ms (int): Duration of the thumbnail in milliseconds. + Defaults to 30000 (30 seconds). + + Returns: + thumbnail (bytes | None): The thumbnail audio as bytes, or None if generation fails. + """ + from pydub import AudioSegment + + try: + # Extract the first `duration_ms` milliseconds + audio = AudioSegment.from_file(BytesIO(file_content)) + thumbnail = audio[:duration_ms] + # Export the thumbnail to a bytes object + with BytesIO() as output: + thumbnail.export(output, format="mp3") + return output.getvalue() + except Exception as e: + logger.exception(f"Failed to generate audio thumbnail due to {e.__class__.__name__}: {e}") + return None + + +def generate_pdf_thumbnail( + file_content: bytes, + size: tuple[int, int] = (950, 950), +) -> bytes | None: + """ + Generates a PDF thumbnail image. + + Args: + file_content (bytes): The PDF file content. + size (tuple[int, int]): The desired size of the thumbnail (width, height). + Defaults to (950, 950). + + Returns: + thumbnail (bytes | None): The thumbnail image as bytes, or None if generation fails. + """ + from pdf2image import convert_from_bytes + + try: + images = convert_from_bytes( + file_content, + dpi=200, + first_page=1, + last_page=1, # process only the first page + ) + if not images: + return b"" + img = images[0] + img.thumbnail(size=size) + thumbnail_bytes = _image_to_webp_bytes(img) + for image in images: + image.close() # release resources + return thumbnail_bytes + + except Exception as e: + logger.exception(f"Failed to generate PDF thumbnail: {e.__class__.__name__}: {e}") + return None + + +def _generate_text_thumbnail(file_extension: str, size: tuple[int, int]) -> bytes: + """Generates a text-based thumbnail (as a fallback).""" + try: + img = Image.new("RGB", size, color=(255, 255, 255)) + draw = ImageDraw.Draw(img) + + text = file_extension + font_size = min(size) // 2 + while font_size > 1: try: - local_path = os.path.abspath(uri[7:]) - async with aiofiles.open(local_path, "rb") as file: - yield file - except FileNotFoundError as e: - logger.warning(f'Failed to open "{uri}" due to {e.__class__.__name__}: {e}') - raise ResourceNotFoundError(f'File "{uri}" is not found.') from e - except Exception as e: - logger.exception(f'Failed to open "{uri}" due to {e.__class__.__name__}: {e}') - raise ResourceNotFoundError(f'File "{uri}" is not found.') from e - else: - raise ResourceNotFoundError(f'File "{uri}" is not found.') - - -def os_path_to_s3_key(path: pathlib.Path | str) -> str: + font_ttf = ASSET_DIRPATH / "Roboto-Regular.ttf" + font = ImageFont.truetype(font_ttf, font_size) + except OSError: + logger.warning("Roboto font not found. Using default fallback font.") + font = ImageFont.load_default() + break + + text_bbox = draw.textbbox((0, 0), text, font=font) + if ( + text_bbox[2] - text_bbox[0] < size[0] * 0.9 + and text_bbox[3] - text_bbox[1] < size[1] * 0.9 + ): + break + font_size -= 1 + + text_bbox = draw.textbbox((0, 0), text, font=font) + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] + text_x = (size[0] - text_width) // 2 + text_y = (size[1] - text_height) // 2 + + draw.text((text_x, text_y), text, fill=(0, 0, 0), font=font) + + return _image_to_webp_bytes(img) + + except Exception as e: + logger.exception(f"Failed to generate text thumbnail: {e.__class__.__name__}: {e}") + return b"" + + +def generate_extension_name_thumbnail( + file_extension: str, + size: tuple[int, int] = (512, 512), +) -> bytes: + """ + Loads a pre-generated thumbnail based on the file extension. + If no icon is found, falls back to generating a text-based thumbnail. + """ + if ICON_DIRPATH: + icon_path = ICON_DIRPATH / f"{file_extension[1:]}.webp" + try: + with open(icon_path, "rb") as f: + img = Image.open(f) + if img.size != size: + img.thumbnail(size) + return _image_to_webp_bytes(img) + except Exception as e: + logger.exception(f"Error loading pre-generated icon: {repr(e)}") + # Fallback: Generate a text-based thumbnail if the icon is not found or there's an error. + return _generate_text_thumbnail(file_extension, size) + + +def _os_path_to_s3_key(path: Path | str) -> str: # Convert path to string if it's a PathLike object path_str = str(path) # Replace backslashes with forward slashes @@ -247,97 +303,384 @@ def os_path_to_s3_key(path: pathlib.Path | str) -> str: return s3_key.lstrip("/") -async def upload_file_to_s3( +async def s3_upload( organization_id: str, project_id: str, content: bytes, + *, content_type: str, filename: str, + generate_thumbnail: bool = True, + key: str = "", ) -> str: if content_type not in UPLOAD_WHITE_LIST_MIME: raise BadInputError( - f"Unsupported file MIME type: {content_type}. Allowed types are: {', '.join(UPLOAD_WHITE_LIST_MIME)}" + f'Unsupported MIME type "{content_type}" for file "{filename}". Allowed types are: {", ".join(UPLOAD_WHITE_LIST_MIME)}' ) - file_extension = os.path.splitext(filename)[1].lower() + file_extension = splitext(filename)[1].lower() if file_extension not in UPLOAD_WHITE_LIST_EXT: raise BadInputError( - f"Unsupported file extension: {file_extension}. Allowed types are: {', '.join(UPLOAD_WHITE_LIST_EXT)}" + f'Unsupported extension "{file_extension}" for file "{filename}". Allowed types are: {", ".join(UPLOAD_WHITE_LIST_EXT)}' ) else: if ( file_extension in EMBED_WHITE_LIST_EXT - and len(content) > ENV_CONFIG.owl_embed_file_upload_max_bytes + and len(content) > ENV_CONFIG.embed_file_upload_max_bytes ): raise BadInputError( - f"File size exceeds {ENV_CONFIG.owl_embed_file_upload_max_bytes / 1024**2} MB limit: {len(content) / 1024**2} MB" + f"File size exceeds {ENV_CONFIG.embed_file_upload_max_bytes / 1024**2} MB limit: {len(content) / 1024**2} MB" ) elif ( file_extension in AUDIO_WHITE_LIST_EXT - and len(content) > ENV_CONFIG.owl_audio_file_upload_max_bytes + and len(content) > ENV_CONFIG.audio_file_upload_max_bytes ): raise BadInputError( - f"File size exceeds {ENV_CONFIG.owl_audio_file_upload_max_bytes / 1024**2} MB limit: {len(content) / 1024**2} MB" + f"File size exceeds {ENV_CONFIG.audio_file_upload_max_bytes / 1024**2} MB limit: {len(content) / 1024**2} MB" ) elif ( file_extension in IMAGE_WHITE_LIST_EXT - and len(content) > ENV_CONFIG.owl_image_file_upload_max_bytes + and len(content) > ENV_CONFIG.image_file_upload_max_bytes ): raise BadInputError( - f"File size exceeds {ENV_CONFIG.owl_image_file_upload_max_bytes / 1024**2} MB limit: {len(content) / 1024**2} MB" + f"File size exceeds {ENV_CONFIG.image_file_upload_max_bytes / 1024**2} MB limit: {len(content) / 1024**2} MB" ) - - uuid = uuid7_str() - raw_path = os.path.join("raw", organization_id, project_id, uuid, filename) - raw_key = os_path_to_s3_key(raw_path) + # Process key + if key: + key = key.removeprefix(f"s3://{S3_BUCKET_NAME}/").lstrip("/") + if not key.startswith("raw/"): + raise BadInputError( + f'Invalid S3 key "{key}". Must start with one of ["raw/", "s3:///raw/"].' + ) + else: + key = join("raw", organization_id, project_id, uuid7_str(), filename) + raw_key = _os_path_to_s3_key(key) thumb_ext = "mp3" if file_extension in AUDIO_WHITE_LIST_EXT else "webp" - thumb_filename = f"{os.path.splitext(filename)[0]}.{thumb_ext}" - thumb_path = os.path.join("thumb", organization_id, project_id, uuid, thumb_filename) - thumb_key = os_path_to_s3_key(thumb_path) - if file_extension in AUDIO_WHITE_LIST_EXT: - thumbnail_task = asyncio.create_task(asyncio.to_thread(generate_audio_thumbnail, content)) + thumb_key = f"{splitext(raw_key.replace('raw/', 'thumb/', 1))[0]}.{thumb_ext}" + if generate_thumbnail: + if file_extension == ".pdf": + thumbnail = generate_pdf_thumbnail(content) + elif file_extension in NON_PDF_DOC_WHITE_LIST_EXT: + thumbnail = await generate_document_thumbnail(file_extension) + elif file_extension in AUDIO_WHITE_LIST_EXT: + thumbnail = generate_audio_thumbnail(content) + else: + thumbnail = generate_image_thumbnail(content) else: - thumbnail_task = asyncio.create_task(asyncio.to_thread(generate_image_thumbnail, content)) - thumbnail = await thumbnail_task - - if S3_CLIENT: + thumbnail = None + + async with get_s3_aclient() as aclient: + # Upload raw file + await aclient.put_object( + Body=content, + Bucket=S3_BUCKET_NAME, + Key=raw_key, + ContentType=content_type, + ) + if thumbnail is not None: + await aclient.put_object( + Body=thumbnail, + Bucket=S3_BUCKET_NAME, + Key=thumb_key, + ContentType=f"{content_type.split('/')[0]}/{'mpeg' if thumb_ext == 'mp3' else thumb_ext}", + ) + logger.info( + f"File uploaded: [{organization_id}/{project_id}] " + f"Location: s3://{S3_BUCKET_NAME}/{raw_key} " + f"File name: {filename}, MIME type: {content_type}. " + ) + return f"s3://{S3_BUCKET_NAME}/{raw_key}" + + +# async def s3_cache_file( +# content: bytes, +# content_type: str, +# ) -> str: +# content_len = len(content) +# content_hash = blake2b(content).hexdigest() +# s3_key = f"temp/{content_hash}-{content_len}" +# uri = f"s3://{S3_BUCKET_NAME}/{s3_key}" +# # Upload file +# async with get_s3_aclient() as aclient: +# # If file already exists, skip +# try: +# await aclient.head_object(Bucket=S3_BUCKET_NAME, Key=s3_key) +# return uri +# except Exception: +# pass +# # Upload +# await aclient.put_object( +# Body=content, +# Bucket=S3_BUCKET_NAME, +# Key=s3_key, +# ContentType=content_type, +# ) +# logger.info(f"S3 file created: {uri}") +# return uri + + +# async def s3_delete( +# *, +# organization_id: str = "", +# project_id: str = "", +# filename: str = "", +# key: str = "", +# delete_thumbnail: bool = True, +# ) -> str: +# # Process key +# if key: +# key = key.removeprefix(f"s3://{S3_BUCKET_NAME}/").lstrip("/") +# if not key.startswith(("raw/", "temp/")): +# raise BadInputError( +# ( +# f'Invalid S3 key "{key}". Must start with one of ' +# '["raw/", "temp/", "s3:///raw/", "s3:///temp/"].' +# ) +# ) +# else: +# key = join("raw", organization_id, project_id, uuid7_str(), filename) +# raw_key = _os_path_to_s3_key(key) +# file_extension = splitext(filename)[1].lower() +# thumb_ext = "mp3" if file_extension in AUDIO_WHITE_LIST_EXT else "webp" +# thumb_key = f"{splitext(raw_key.replace('raw/', 'thumb/', 1))[0]}.{thumb_ext}" + +# async with get_s3_aclient() as aclient: +# # Delete raw file +# await aclient.delete_object(Bucket=S3_BUCKET_NAME, Key=raw_key) +# # Delete thumbnail +# if delete_thumbnail: +# try: +# await aclient.delete_object(Bucket=S3_BUCKET_NAME, Key=thumb_key) +# except Exception as e: +# logger.warning(f'Failed to delete thumbnail "{thumb_key}": {repr(e)}') +# logger.info(f"File deleted: s3://{S3_BUCKET_NAME}/{raw_key}") +# return raw_key + + +@asynccontextmanager +async def s3_temporary_file( + content: bytes, + content_type: str, +) -> AsyncGenerator[str, None]: + from owl.configs import CACHE + + content_len = len(content) + content_hash = blake2b(content).hexdigest() + cache_key = f"temp:{content_hash}-{content_len}" + s3_key = cache_key.replace(":", "/") + # This lock is so that we don't upload the same file twice + async with CACHE.alock(f"{cache_key}:lock", blocking=True, expire=180) as lock_acquired: + if not lock_acquired: + raise BadInputError("Another upload of this file is in progress.") + # Upload file async with get_s3_aclient() as aclient: - # Upload raw file await aclient.put_object( Body=content, Bucket=S3_BUCKET_NAME, - Key=raw_key, + Key=s3_key, ContentType=content_type, ) - if len(thumbnail) > 0: - await aclient.put_object( - Body=thumbnail, - Bucket=S3_BUCKET_NAME, - Key=thumb_key, - ContentType=f"{content_type.split('/')[0]}/{"mpeg" if thumb_ext == "mp3" else thumb_ext}", - ) - logger.info( - f"File Uploaded: [{organization_id}/{project_id}] " - f"Location: s3://{S3_BUCKET_NAME}/{raw_key} " - f"File name: {filename}, MIME type: {content_type}. " + uri = f"s3://{S3_BUCKET_NAME}/{s3_key}" + logger.info(f"Temporary S3 file created: {uri}") + try: + yield uri + finally: + # Delete file + try: + async with get_s3_aclient() as aclient: + await aclient.delete_object(Bucket=S3_BUCKET_NAME, Key=s3_key) + logger.info(f"Temporary S3 file deleted: {uri}") + except Exception as e: + logger.warning(f'Failed to delete temporary S3 file "{uri}": {repr(e)}') + + +def get_global_thumbnail_path(extension: str) -> str: + """Returns the path for a global thumbnail based on file extension.""" + return join("thumb", "global", f"{extension[1:]}.webp") + + +async def get_global_thumbnail(extension: str) -> bytes | None: + """Retrieves a global thumbnail if it exists.""" + + try: + thumbnail_path = get_global_thumbnail_path(extension) + async with get_s3_aclient() as aclient: + try: + response = await aclient.get_object(Bucket=S3_BUCKET_NAME, Key=thumbnail_path) + return await response["Body"].read() + except ClientError: + return None + except Exception as e: + logger.warning(f"Failed to get global thumbnail: {e}") + return None + + +async def save_global_thumbnail(extension: str, thumbnail: bytes) -> None: + """Saves a global thumbnail for future use.""" + + try: + thumbnail_path = get_global_thumbnail_path(extension) + async with get_s3_aclient() as aclient: + await aclient.put_object( + Body=thumbnail, + Bucket=S3_BUCKET_NAME, + Key=thumbnail_path, + ContentType="image/webp", + ) + except Exception as e: + logger.warning(f"Failed to save global thumbnail: {e}") + + +async def generate_document_thumbnail( + file_extension: str, + size: tuple[int, int] = (512, 512), +) -> None: + """ + Generates a thumbnail based on the given file extension with global cache. + > if doc and non-pdf, generate global thumbnail, no local thumbnail + > when get thumbnail url, check raw url for extension, get global thumbnail url + + Args: + file_extension (str): The file extension (e.g., ".xlsx"). + size (tuple[int, int]): The desired size (width, height) of the thumbnail. + """ + file_extension = file_extension.lower() + if file_extension not in NON_PDF_DOC_WHITE_LIST_EXT: + raise ValueError(f"Unsupported file extension: {file_extension}") + try: + # Check global cache first + if (await get_global_thumbnail(file_extension)) is not None: + return + # Generate and cache new thumbnail + thumbnail_path = get_global_thumbnail_path(file_extension) + async with get_s3_aclient() as aclient: + await aclient.put_object( + Body=generate_extension_name_thumbnail(file_extension, size), + Bucket=S3_BUCKET_NAME, + Key=thumbnail_path, + ContentType="image/webp", + ) + return + except Exception as e: + logger.exception(f"Failed to generate file thumbnail due to {e.__class__.__name__}: {e}") + + +async def get_file_storage_usage(org_id: str) -> float | None: + """ + Calculates the total file storage used by an organization in the S3 bucket. + + This function iterates through the S3 objects under the standard 'raw/{org_id}/' + and 'thumb/{org_id}/' prefixes, summing their sizes. It includes error + handling to prevent task failure if S3 is unavailable. + + Args: + org_id (str): The ID of the organization to measure. + + Returns: + usage_gib (float | None): The total storage used in GiB. Returns None on error. + """ + try: + async with get_s3_aclient() as aclient: + paginator = aclient.get_paginator("list_objects_v2") + total_size = 0 + for prefix in [f"raw/{org_id}/", f"thumb/{org_id}/"]: + prefix_size = 0 + async for page in paginator.paginate(Bucket=S3_BUCKET_NAME, Prefix=prefix): + for obj in page.get("Contents", []): + prefix_size += obj["Size"] + total_size += prefix_size + return total_size / GiB + except Exception as e: + logger.exception( + f'Failed to compute file storage usage for organization "{org_id}": {repr(e)}' ) - return f"s3://{S3_BUCKET_NAME}/{raw_key}" - else: - raw_file_path = os.path.join(LOCAL_FILE_DIR, raw_path) - thumb_file_path = os.path.join(LOCAL_FILE_DIR, thumb_path) + return None + - os.makedirs(os.path.dirname(raw_file_path)) - os.makedirs(os.path.dirname(thumb_file_path)) +async def get_schema_storage_usage_postgres(schema_name: str) -> DBStorageUsage: + """ + Calculates detailed storage usage for a given schema in PostgreSQL. - async with aiofiles.open(raw_file_path, "wb") as out_file: - await out_file.write(content) + This function queries PostgreSQL system tables to get the total size of all + tables and their associated indexes within a specific schema. - if len(thumbnail) > 0: - async with aiofiles.open(thumb_file_path, "wb") as thumb_file: - await thumb_file.write(thumbnail) + Args: + session: The SQLAlchemy session to use for the query. + schema_name: The name of the database schema to measure. - logger.info( - f"File Uploaded: [{organization_id}/{project_id}] " - f"Location: file://{raw_file_path} " - f"File name: {filename}, MIME type: {content_type}. " + Returns: + The total size in GiB. + """ + from owl.db import async_session, cached_text + + usage = DBStorageUsage( + schema_name=schema_name, + table_names=[], + table_sizes=[], + ) + try: + query = cached_text( + """ + SELECT + nspname AS schema_name, + array_agg(c.relname) AS names, + array_agg(pg_total_relation_size(c.oid)::bigint) AS total_relation_sizes, + array_agg(pg_total_relation_size(c.reltoastrelid)::bigint) AS total_toast_sizes + FROM + pg_class c + LEFT JOIN + pg_namespace n ON (n.oid = c.relnamespace) + WHERE + n.nspname = :schema_name + AND c.relkind IN ('r', 'm') -- r = table, m = materialized view + GROUP BY nspname; + """ + ) + async with async_session() as session: + stats = (await session.exec(query, params={"schema_name": schema_name})).one_or_none() + if not stats: + return usage + return DBStorageUsage( + schema_name=stats.schema_name, + table_names=stats.names, + table_sizes=[ + float(rs or 0.0) + float(ts or 0.0) + for rs, ts in zip(stats.total_relation_sizes, stats.total_toast_sizes, strict=True) + ], + ) + except Exception as e: + logger.exception( + f'Failed to compute DB storage usage for schema "{schema_name}": {repr(e)}' + ) + return usage + + +async def get_db_storage_usage(org_id: str) -> float | None: + """ + Calculates the total DB storage used by an organization. + + Args: + org_id (str): The ID of the organization to measure. + + Returns: + usage_gib (float | None): The total storage used in GiB. Returns None on error. + """ + from owl.db import async_session + from owl.db.models import Project + + try: + db_usage = 0.0 + async with async_session() as session: + projects_in_org = ( + await session.exec(select(Project.id).where(Project.organization_id == org_id)) + ).all() + for project_id in projects_in_org: + for table_type in TableType: + schema_name = f"{project_id}_{table_type}" + usage = await get_schema_storage_usage_postgres(schema_name) + db_usage += usage.total_size + return db_usage / GiB + except Exception as e: + logger.exception( + f'Failed to compute DB storage usage for organization "{org_id}": {repr(e)}' ) - return f"file://{raw_file_path}" + return None diff --git a/services/api/src/owl/utils/ip_address.py b/services/api/src/owl/utils/ip_address.py deleted file mode 100644 index 5a34424..0000000 --- a/services/api/src/owl/utils/ip_address.py +++ /dev/null @@ -1,136 +0,0 @@ -import re - - -def is_valid_ipv4(ip): - pattern = re.compile( - r"^(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$" - ) - return pattern.match(ip) is not None - - -def is_valid_port(port): - return 0 <= int(port) <= 65535 - - -def expand_port_ranges(port_ranges): - ports = [] - for part in port_ranges.split(","): - if "-" in part: - start, end = part.split("-") - ports.extend(range(int(start), int(end) + 1)) - else: - ports.append(int(part)) - return ports - - -def validate_and_process_ip_address(input_string): - urls = input_string.split("|") - result = [] - - for url in urls: - if not url: - continue - match = re.match(r"http://(\d+\.\d+\.\d+\.\d+):(.+)", url) - if not match: - return "Input is invalid" - - ip, port_ranges = match.groups() - if not is_valid_ipv4(ip): - return "Input is invalid" - - try: - ports = expand_port_ranges(port_ranges) - for port in ports: - if not is_valid_port(port): - return "Input is invalid" - result.append(f"http://{ip}:{port}") - except ValueError: - return "Input is invalid" - - return result - - -if __name__ == "__main__": - - def test_input_vllm_url(): - test_case = [] - - test_case.append( - {"input": "http://192.168.1.1:1234", "expected": ["http://192.168.1.1:1234"]} - ) - - test_case.append( - { - "input": "http://192.168.171.1:1234,1235,1237", - "expected": [ - "http://192.168.171.1:1234", - "http://192.168.171.1:1235", - "http://192.168.171.1:1237", - ], - } - ) - - test_case.append( - { - "input": "http://192.6.171.1:1234-1237", - "expected": [ - "http://192.6.171.1:1234", - "http://192.6.171.1:1235", - "http://192.6.171.1:1236", - "http://192.6.171.1:1237", - ], - } - ) - - test_case.append( - { - "input": "http://10.168.171.1:1234|http://192.168.171.1:1256", - "expected": ["http://10.168.171.1:1234", "http://192.168.171.1:1256"], - } - ) - - test_case.append( - { - "input": "http://192.168.171.6:2345|http://192.168.171.1:1234,1235,1237|", - "expected": [ - "http://192.168.171.6:2345", - "http://192.168.171.1:1234", - "http://192.168.171.1:1235", - "http://192.168.171.1:1237", - ], - } - ) - - test_case.append( - { - "input": "http://192.168.171.1:1234-1237|http://192.168.171.6:2345", - "expected": [ - "http://192.168.171.1:1234", - "http://192.168.171.1:1235", - "http://192.168.171.1:1236", - "http://192.168.171.1:1237", - "http://192.168.171.6:2345", - ], - } - ) - - test_case.append( - { - "input": "http://192.168.171.1:1234-1237|https://192.168.171.6:2345", - "expected": [ - "http://192.168.171.1:1234", - "http://192.168.171.1:1235", - "http://192.168.171.1:1236", - "http://192.168.171.1:1237", - "http://192.168.171.6:2345", - ], - } - ) - - return test_case - - valid_test_cases = test_input_vllm_url() - - for test_case in valid_test_cases: - output = validate_and_process_ip_address(test_case["input"]) - assert output == test_case["expected"], f"{output} \n {test_case['expected']}" diff --git a/services/api/src/owl/utils/jwt.py b/services/api/src/owl/utils/jwt.py index b443e57..9551a8c 100644 --- a/services/api/src/owl/utils/jwt.py +++ b/services/api/src/owl/utils/jwt.py @@ -4,13 +4,13 @@ import jwt from loguru import logger -from jamaibase.exceptions import AuthorizationError -from owl.configs.manager import ENV_CONFIG +from owl.configs import ENV_CONFIG +from owl.utils.exceptions import AuthorizationError def encode_jwt(data: dict[str, Any], expiry: datetime) -> str: data.update({"iat": datetime.now(tz=timezone.utc), "exp": expiry}) - token = jwt.encode(data, f"{ENV_CONFIG.owl_encryption_key_plain}_secret", algorithm="HS256") + token = jwt.encode(data, f"{ENV_CONFIG.encryption_key_plain}_secret", algorithm="HS256") return token @@ -23,7 +23,7 @@ def decode_jwt( try: data = jwt.decode( token, - f"{ENV_CONFIG.owl_encryption_key_plain}_secret", + f"{ENV_CONFIG.encryption_key_plain}_secret", algorithms=["HS256"], ) return data diff --git a/services/api/src/owl/utils/kb.py b/services/api/src/owl/utils/kb.py index 8d342a1..8897325 100644 --- a/services/api/src/owl/utils/kb.py +++ b/services/api/src/owl/utils/kb.py @@ -2,7 +2,7 @@ from itertools import chain, pairwise from typing import Any -from owl.protocol import Chunk +from owl.types import Chunk def detect_consecutive_segments(lst: list[tuple[Any, Any]]) -> list[tuple[Any, Any]]: diff --git a/services/api/src/owl/utils/lm.py b/services/api/src/owl/utils/lm.py new file mode 100644 index 0000000..5aa1862 --- /dev/null +++ b/services/api/src/owl/utils/lm.py @@ -0,0 +1,1843 @@ +import asyncio +import itertools +import random +from base64 import b64encode +from contextlib import asynccontextmanager +from copy import deepcopy +from dataclasses import dataclass +from datetime import timedelta +from textwrap import dedent +from time import perf_counter, time +from typing import Any, AsyncGenerator + +import httpx +import litellm +import numpy as np +import openai +from fastapi import Request +from litellm import acompletion, aembedding, arerank +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.types.rerank import RerankResponse +from litellm.types.utils import ( + Choices, + Delta, + Message, + ModelResponse, + ModelResponseStream, + StreamingChoices, + Usage, +) +from litellm.types.utils import ( + EmbeddingResponse as LiteLLMEmbeddingResponse, +) +from loguru import logger +from natsort import natsorted +from openai import AsyncOpenAI +from openai.types.responses import ( + Response, + ResponseCodeInterpreterToolCall, + ResponseCompletedEvent, + ResponseFunctionWebSearch, + ResponseOutputItemDoneEvent, + ResponseOutputMessage, + ResponseReasoningItem, + ResponseReasoningSummaryTextDeltaEvent, + ResponseTextDeltaEvent, +) +from tenacity import ( + AsyncRetrying, + before_sleep_log, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) + +from jamaibase.types.common import SanitisedStr +from owl.configs import CACHE, ENV_CONFIG +from owl.db import SCHEMA, async_session, cached_text +from owl.db.models import Deployment, ModelConfig +from owl.types import ( + AudioContent, + ChatCompletionChunkResponse, + ChatCompletionResponse, + ChatCompletionUsage, + ChatEntry, + ChatRole, + CloudProvider, + CodeInterpreterTool, + CompletionUsageDetails, + Deployment_, + EmbeddingResponse, + EmbeddingResponseData, + EmbeddingUsage, + ImageContent, + ModelCapability, + ModelConfig_, + ModelConfigRead, + ModelProvider, + ModelType, + OnPremProvider, + OrganizationRead, + Project_, + PromptUsageDetails, + RAGParams, + References, + RerankingResponse, + RerankingUsage, + TextContent, + ToolUsageDetails, + WebSearchTool, +) +from owl.utils import mask_content, mask_string +from owl.utils.billing import BillingManager +from owl.utils.dates import now +from owl.utils.exceptions import ( + BadInputError, + ExternalAuthError, + JamaiException, + ModelCapabilityError, + ModelOverloadError, + RateLimitExceedError, + ResourceNotFoundError, + UnavailableError, + UnexpectedError, +) + +litellm.drop_params = True +litellm.set_verbose = False +litellm.suppress_debug_info = True + +WEB_SEARCH_TOOL = WebSearchTool() +CODE_INTERPRETER_TOOL = CodeInterpreterTool() +OPENAI_HOSTED_TOOLS = (WEB_SEARCH_TOOL.type, CODE_INTERPRETER_TOOL.type) + + +class _Logger: + @staticmethod + def log( + log_level: int, + message: str, + **kwargs, + ): + logger.bind(**kwargs).log(log_level, message) + + +@dataclass(slots=True) +class DeploymentContext: + deployment: Deployment + api_key: str + routing_id: str + inference_provider: str + is_reasoning_model: bool + use_openai_responses: bool = False + + +class DeploymentRouter: + def __init__( + self, + *, + request: Request, + config: ModelConfigRead, + organization: OrganizationRead, + cooldown: float = 0.0, # No cooldown by default + is_browser: bool = False, + ) -> None: + self.request = request + self.id: str = request.state.id + if not isinstance(config, ModelConfigRead): + raise TypeError(f"Expected ModelConfigRead, got {type(config)}.") + self.config = config + if not isinstance(organization, OrganizationRead): + raise TypeError(f"Expected OrganizationRead, got {type(organization)}.") + self.organization = organization + self.cooldown = cooldown + self.is_browser = is_browser + self.retry_policy = dict( + retry=retry_if_exception_type((RateLimitExceedError, ModelOverloadError)), + wait=wait_exponential_jitter(initial=0.5, exp_base=1.2, max=5, jitter=0.5), + stop=stop_after_attempt(3), + reraise=True, + before_sleep=before_sleep_log(_Logger(), "WARNING"), + ) + self._model_display_id = self.config.name if is_browser else self.config.id + + @staticmethod + def batch(seq, n): + if n < 1: + raise ValueError("`n` must be > 0") + for i in range(0, len(seq), n): + yield seq[i : i + n] + + def _inference_provider(self, provider: str) -> str: + if provider == CloudProvider.ELLM: + return CloudProvider.ELLM + if provider in ModelProvider: + return ModelProvider(provider) + if provider in OnPremProvider: + return OnPremProvider(provider) + owned_by = self.config.owned_by or "" + return next((p for p in ModelProvider if owned_by.lower() == p.value.lower()), "") + + def _litellm_model_id(self, deployment: Deployment_): + """ + Chat and embedding: + - Known cloud providers: provider/model + - Unknown cloud providers and on-prem: openai/model + + Reranking: + - Known cloud providers and on-prem: provider/model + - Unknown cloud providers: cohere/model + """ + provider = deployment.provider + routing_id = self.config.id if deployment.routing_id == "" else deployment.routing_id + if provider in CloudProvider and provider not in ( + CloudProvider.INFINITY_CLOUD, + CloudProvider.ELLM, + ): + # Standard cloud providers + prefix = "hosted_vllm" if provider == CloudProvider.VLLM_CLOUD else provider + return routing_id if routing_id.startswith(f"{prefix}/") else f"{prefix}/{routing_id}" + if self.config.type != ModelType.RERANK: + # Non-standard providers including ELLM + prefix = "openai" + else: + # Reranking + if provider in ( + CloudProvider.INFINITY_CLOUD, + CloudProvider.ELLM, + ): + prefix = "infinity" + elif provider in OnPremProvider: + prefix = provider.split("_")[0] # infinity_cpu -> infinity + else: + prefix = "cohere" + return f"{prefix}/{routing_id}" + + def _log_completion_masked( + self, + messages: list[dict], + **hyperparams, + ): + body = dict( + model=self.config.id, + messages=[mask_content(m) for m in messages], + **hyperparams, + ) + logger.info(f"{self.id} - Generating chat completions: {body}") + + def _map_and_log_exception( + self, + e: Exception, + api_key: str, + *, + messages: list[dict], + **hyperparams, + ) -> Exception: + messages = [mask_content(m) for m in messages] + err_mssg = getattr(e, "message", str(e)) + logger.warning( + f'{self.id} - LLM request to model "{self.config.id}" failed. Exception: {e.__class__}: {err_mssg}' + ) + if isinstance(e, JamaiException): + return e + elif isinstance(e, openai.BadRequestError): + logger.info( + ( + f'{self.id} - LLM request to model "{self.config.id}" failed due to bad request. ' + f"Hyperparameters: {hyperparams} Messages: {messages}" + ) + ) + return BadInputError(err_mssg) + elif isinstance(e, openai.AuthenticationError): + return ExternalAuthError(f"Invalid API key: {mask_string(api_key)}") + elif isinstance(e, openai.RateLimitError): + _header = e.response.headers + limit = int(_header.get("X-RateLimit-Limit", 0)) + remaining = int(_header.get("X-RateLimit-Remaining", 0)) + reset_at = int(_header.get("X-RateLimit-Reset", time() + 30)) + return RateLimitExceedError( + err_mssg, + limit=limit, + remaining=remaining, + reset_at=reset_at, + used=int(_header.get("X-RateLimit-Used", limit - remaining)), + retry_after=int(_header.get("Retry-After", int(reset_at - time()) + 1)), + meta=None, + ) + elif isinstance( + e, + ( + openai.APITimeoutError, + openai.APIError, + httpx.HTTPStatusError, + httpx.TimeoutException, # ReadTimeout, ConnectTimeout, etc + ), + ): + return ModelOverloadError( + f'Model provider for "{self._model_display_id}" is overloaded. Please try again later.' + ) + elif isinstance(e, (BaseLLMException, openai.OpenAIError)): + return BadInputError(err_mssg) + else: + body = dict( + model=self.config.id, + api_key=mask_string(api_key), + messages=messages, + **hyperparams, + ) + logger.exception( + f"{self.id} - {self.__class__.__name__} - Unexpected error !!! {body}" + ) + return UnexpectedError(err_mssg) + + async def _cooldown_deployment(self, deployment: Deployment_, cooldown_time: timedelta): + if cooldown_time.total_seconds() <= 0: + logger.warning( + f"{self.id} - Cooldown time is zero or negative for deployment {deployment.id}. Skipping cooldown." + ) + return + cooldown_until = now() + cooldown_time + logger.warning( + ( + f'{self.id} - Cooling down deployment "{deployment.id}" ' + f"until {cooldown_until} ({cooldown_time.total_seconds()} seconds)." + ) + ) + try: + async with async_session() as session: + await session.exec( + cached_text( + f'UPDATE {SCHEMA}."Deployment" SET cooldown_until = :cooldown_until WHERE id = :deployment_id;' + ), + params={ + "cooldown_until": cooldown_until, + "deployment_id": deployment.id, + }, + ) + await session.commit() + except Exception as exc: + logger.warning(f"{self.id} - Failed to cooldown deployment: {repr(exc)}") + + @asynccontextmanager + async def _get_deployment( + self, + **hyperparams, + ) -> AsyncGenerator[DeploymentContext, None]: + name = self.config.name + # Get deployment + if len(self.config.deployments) == 0: + logger.warning( + f"{self.id} - No deployments attached to model config. Fetching from database." + ) + async with async_session() as session: + deployments = ( + await Deployment.list_( + session=session, + return_type=Deployment_, + filters=dict(model_id=self.config.id), + ) + ).items + if len(deployments) == 0: + raise UnavailableError(f'No deployments found for model "{name}".') + else: + deployments = self.config.deployments + deployments = [d for d in deployments if d.cooldown_until <= now()] + if len(deployments) == 0: + raise UnavailableError(f'All deployments are on cooldown for model "{name}".') + deployment = random.choices(deployments, weights=[d.weight for d in deployments], k=1)[0] + # Get API key + provider = deployment.provider.lower() + api_key = "" + if self.organization.id == "0" or ( + ENV_CONFIG.enable_byok and provider not in OnPremProvider + ): + # Use Organization keys + api_key = self.organization.get_external_key(provider) + if (not api_key) and self.organization.id != "0": + # Use TSP keys + async with async_session() as session: + tsp_org = await CACHE.get_organization_async("0", session) + api_key = "" if tsp_org is None else tsp_org.get_external_key(provider) + if not api_key: + # Use System keys + api_key = ENV_CONFIG.get_api_key(provider) + if not api_key: + api_key = "DUMMY_KEY" + # Get model routing ID + routing_id = self._litellm_model_id(deployment) + # Check if its a reasoning model + can_reason = ModelCapability.REASONING in self.config.capabilities + is_reasoning_model = can_reason or litellm.supports_reasoning(routing_id) + if is_reasoning_model and not can_reason: + logger.warning( + f'Model "{self.config.id}" by provider "{provider}" seems to support reasoning, but it is not labelled as such.' + ) + try: + logger.info( + f'{self.id} - Request started for model "{self.config.id}" ({provider=}, {routing_id=}).' + ) + t0 = perf_counter() + self.request.state.model_start_time = t0 + yield DeploymentContext( + deployment=deployment, + api_key=api_key, + routing_id=routing_id, + inference_provider=self._inference_provider(provider), + is_reasoning_model=is_reasoning_model, + ) + self.request.state.timing["external_call"] = perf_counter() - t0 + logger.info(f'{self.id} - Request completed for model "{self.config.id}".') + except Exception as e: + mapped_e = self._map_and_log_exception(e, api_key, **hyperparams) + if isinstance(mapped_e, (ModelOverloadError, RateLimitExceedError)): + # Cooldown deployment + if len(deployments) > 1: + cooldown_time = timedelta( + seconds=getattr(mapped_e, "retry_after", self.cooldown) + ) + await self._cooldown_deployment(deployment, cooldown_time) + logger.warning( + f"{self.id} - LLM request failed. Mapped exception: {mapped_e.__class__}: {str(mapped_e)}" + ) + raise mapped_e from e + + ### --- Chat Completion --- ### + + async def _prepare_chat( + self, + *, + messages: list[ChatEntry], + hyperparams, + **kwargs, + ) -> tuple[list[dict[str, Any]], dict]: + # Prepare messages + if len(messages) == 0: + raise ValueError("`messages` is an empty list.") + elif len(messages) == 1: + # [user] + if messages[0].role == ChatRole.USER: + pass + # [system] + elif messages[0].role == ChatRole.SYSTEM: + messages.append(ChatEntry.user(content=".")) + # [assistant] + else: + messages = [ + ChatEntry.system(content="."), + ChatEntry.user(content="."), + ] + messages + else: + # [user, ...] + if messages[0].role == ChatRole.USER: + pass + # [system, ...] + elif messages[0].role == ChatRole.SYSTEM: + # [system, assistant, ...] + if messages[1].role == ChatRole.ASSISTANT: + messages.insert(1, ChatEntry.user(content=".")) + # [assistant, ...] + else: + messages = [ + ChatEntry.system(content="."), + ChatEntry.user(content="."), + ] + messages + if messages[0].role == ChatRole.SYSTEM and messages[0].content == "": + messages[0].content = "." + messages = [m.model_dump(mode="json", exclude_none=True) for m in messages] + # Prepare hyperparams + if isinstance(hyperparams.get("stop", None), list) and len(hyperparams["stop"]) == 0: + hyperparams["stop"] = None + hyperparams.update(kwargs) + # if self.config.id.startswith("anthropic"): + # hyperparams["extra_headers"] = {"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"} + # Log + # self._log_completion_masked(messages, **hyperparams) + return messages, hyperparams + + def _prepare_hyperparams( + self, + ctx: DeploymentContext, + hyperparams: dict[str, Any], + ): + # Handle max_tokens + max_tokens = hyperparams.pop("max_tokens", None) + max_completion_tokens = hyperparams.pop("max_completion_tokens", None) + hyperparams["max_tokens"] = max_completion_tokens or max_tokens + tools: list[dict] = hyperparams.pop("tools", None) or [] + # OpenAI specific + if ctx.inference_provider == CloudProvider.OPENAI: + if ctx.is_reasoning_model or any( + t.get("type", "") in OPENAI_HOSTED_TOOLS for t in tools + ): + ctx.use_openai_responses = True + if ctx.is_reasoning_model: + hyperparams.pop("temperature", None) + hyperparams.pop("top_p", None) + hyperparams["max_output_tokens"] = hyperparams.pop("max_tokens", None) + hyperparams.pop("id", None) + hyperparams.pop("n", None) + hyperparams.pop("presence_penalty", None) + hyperparams.pop("frequency_penalty", None) + hyperparams.pop("logit_bias", None) + hyperparams.pop("stop", None) + else: + hyperparams["max_completion_tokens"] = hyperparams.pop("max_tokens", None) + else: + tools = [t for t in tools if t.get("type", "") not in OPENAI_HOSTED_TOOLS] + + # Anthropic specific + if ctx.inference_provider == CloudProvider.ANTHROPIC: + # Sonnet 4.5 cannot specify both `temperature` and `top_p` + if "sonnet-4-5" in ctx.routing_id: + t = hyperparams.get("temperature", None) + p = hyperparams.get("top_p", None) + if t is not None and p is not None: + hyperparams.pop("top_p", None) # Prioritise temperature + + if tools: + hyperparams["tools"] = tools + + # Handle reasoning params + reasoning_effort: str | None = hyperparams.pop("reasoning_effort", None) + thinking_budget: int | None = hyperparams.pop("thinking_budget", None) + reasoning_summary: str = hyperparams.pop("reasoning_summary", "auto") + if thinking_budget is not None: + thinking_budget = max(thinking_budget, 0) + # Non-reasoning model does not require further processing + if not ctx.is_reasoning_model: + return + # Disable reasoning if requested + if ( + reasoning_effort in ("disable", "minimal") + or thinking_budget == 0 + or (reasoning_effort is None and thinking_budget is None) + ): + if ctx.inference_provider == CloudProvider.ELLM: + hyperparams["reasoning_effort"] = "disable" + return + elif ctx.inference_provider == CloudProvider.GEMINI: + # 2.5 Pro cannot disable thinking + if "2.5-pro" in ctx.routing_id: + hyperparams["thinking"] = {"type": "enabled", "budget_tokens": 128} + else: + hyperparams["reasoning_effort"] = "disable" + return + elif ctx.inference_provider == CloudProvider.ANTHROPIC: + hyperparams["thinking"] = {"type": "disabled"} + return + elif ctx.inference_provider == CloudProvider.OPENAI: + if "gpt-5" in ctx.routing_id: + hyperparams["reasoning"] = { + "effort": "minimal", + "summary": reasoning_summary, + } + return + elif "o1" in ctx.routing_id or "o3" in ctx.routing_id or "o4" in ctx.routing_id: + hyperparams["reasoning"] = { + "effort": "low", + "summary": reasoning_summary, + } + return + else: + hyperparams["reasoning"] = { + "effort": "low", + "summary": reasoning_summary, + } + return + elif ctx.inference_provider == OnPremProvider.VLLM: + hyperparams["extra_body"] = {"chat_template_kwargs": {"enable_thinking": False}} + return + logger.warning( + ( + f'Disabling reasoning is not supported for model "{self.config.id}" ' + f'by provider "{ctx.inference_provider}". ' + f"(owned_by={self.config.owned_by}, deployment.provider={ctx.deployment.provider})" + ) + ) + return + # Configure reasoning effort + if reasoning_effort not in ("disable", "minimal") or thinking_budget: + if reasoning_effort not in ("low", "medium", "high"): + if thinking_budget <= 1024: + reasoning_effort = "low" + elif thinking_budget <= 4096: + reasoning_effort = "medium" + else: + reasoning_effort = "high" + if ctx.inference_provider == CloudProvider.ELLM: + hyperparams["reasoning_effort"] = reasoning_effort + return + elif ctx.inference_provider in [CloudProvider.GEMINI, CloudProvider.ANTHROPIC]: + if not thinking_budget: + if reasoning_effort == "low": + thinking_budget = 1024 + elif reasoning_effort == "medium": + thinking_budget = 4096 + else: + thinking_budget = 8192 + if ctx.inference_provider == CloudProvider.ANTHROPIC: + hyperparams["temperature"] = 1 + hyperparams["top_p"] = min(max(0.95, hyperparams.pop("top_p", 1.0)), 1.0) + thinking_budget = max(thinking_budget, 1024) + hyperparams["thinking"] = { + "type": "enabled", + "budget_tokens": thinking_budget, + } + elif ctx.inference_provider == CloudProvider.OPENAI: + hyperparams["reasoning"] = { + "effort": reasoning_effort, + "summary": reasoning_summary, + } + else: + logger.warning( + ( + f'Thinking budget is not supported for model "{self.config.id}" ' + f'by provider "{ctx.inference_provider}". ' + f"(owned_by={self.config.owned_by}, deployment.provider={ctx.deployment.provider})" + ) + ) + + def _stream_delta(self, delta: Delta, finish_reason: Any | None = None) -> ModelResponseStream: + return ModelResponseStream( + id=self.id, + model=self.config.id, + choices=[StreamingChoices(index=0, delta=delta, finish_reason=finish_reason)], + ) + + def _prepare_responses_messages(self, messages: list[dict]) -> list[dict]: + for m in messages: + content: str | list[dict[str, str]] = m["content"] + if not isinstance(content, list): + continue + for c in content: + if c.get("type", None) == "text": + c["type"] = "input_text" + elif c.get("type", None) == "image_url": + c["type"] = "input_image" + c["image_url"] = c["image_url"]["url"] + elif c.get("type", None) == "input_audio": + pass + else: + pass + + async def _openai_responses_stream( + self, + ctx: DeploymentContext, + messages: list[dict], + **hyperparams, + ) -> AsyncGenerator[ModelResponseStream, None]: + self._prepare_responses_messages(messages) + openai_client = AsyncOpenAI(api_key=ctx.api_key) + response_stream = await openai_client.responses.create( + model=ctx.routing_id.split("openai/")[-1], + input=messages, + stream=True, + **hyperparams, + ) + usage_stats = {"web_search_calls": 0, "code_interpreter_calls": 0} + final_usage = None + async for chunk in response_stream: + if isinstance(chunk, ResponseReasoningSummaryTextDeltaEvent): + yield self._stream_delta(Delta(role="assistant", reasoning_content=chunk.delta)) + elif isinstance(chunk, ResponseOutputItemDoneEvent): + if isinstance(chunk.item, ResponseFunctionWebSearch): + usage_stats["web_search_calls"] += 1 + if ( + chunk.item.action + and hasattr(chunk.item.action, "query") + and chunk.item.action.query + ): + yield self._stream_delta( + Delta( + role="assistant", + reasoning_content=f'Searched the web for "{chunk.item.action.query}".', + ) + ) + yield self._stream_delta(Delta(role="assistant", reasoning_content="\n\n")) + elif isinstance(chunk.item, ResponseCodeInterpreterToolCall): + usage_stats["code_interpreter_calls"] += 1 + code_snippet = chunk.item.code + yield self._stream_delta( + Delta( + role="assistant", + reasoning_content=f"Ran Python code:\n\n```python\n{code_snippet}\n```", + ) + ) + yield self._stream_delta(Delta(role="assistant", reasoning_content="\n\n")) + elif isinstance(chunk, ResponseTextDeltaEvent): + yield self._stream_delta(Delta(role="assistant", content=chunk.delta)) + elif isinstance(chunk, ResponseCompletedEvent): + if chunk.response.usage: + final_usage = chunk.response.usage + + if final_usage: + usage = ChatCompletionUsage( + prompt_tokens=final_usage.input_tokens, + completion_tokens=final_usage.output_tokens, + total_tokens=final_usage.total_tokens, + prompt_tokens_details=PromptUsageDetails( + cached_tokens=final_usage.input_tokens_details.cached_tokens + if final_usage.input_tokens_details + else 0 + ), + completion_tokens_details=CompletionUsageDetails( + reasoning_tokens=final_usage.output_tokens_details.reasoning_tokens + if final_usage.output_tokens_details + else 0 + ), + tool_usage_details=ToolUsageDetails(**usage_stats), + ) + else: + # Fallback if usage is not in the final chunk for some reason + usage = ChatCompletionUsage(tool_usage_details=ToolUsageDetails(**usage_stats)) + + final_chunk = self._stream_delta(delta=Delta(), finish_reason="stop") + final_chunk.usage = Usage(**usage.model_dump()) + yield final_chunk + + async def _openai_responses( + self, + ctx: DeploymentContext, + messages: list[dict], + **hyperparams, + ) -> ModelResponse: + self._prepare_responses_messages(messages) + openai_client = AsyncOpenAI(api_key=ctx.api_key) + response: Response = await openai_client.responses.create( + model=ctx.routing_id.split("openai/")[-1], + input=messages, + stream=False, + **hyperparams, + ) + reasoning_parts = [] + result_parts = [] + usage_stats = {"web_search_calls": 0, "code_interpreter_calls": 0} + for item in response.output: + if isinstance(item, ResponseReasoningItem): + if item.summary: + summary_text = "\n".join( + part.text for part in item.summary if hasattr(part, "text") + ) + if summary_text: + reasoning_parts.append(summary_text) + elif isinstance(item, ResponseFunctionWebSearch) and item.status == "completed": + usage_stats["web_search_calls"] += 1 + if item.action and hasattr(item.action, "query") and item.action.query: + reasoning_parts.append(f'Searched the web for "{item.action.query}".') + elif isinstance(item, ResponseCodeInterpreterToolCall) and item.status == "completed": + usage_stats["code_interpreter_calls"] += 1 + code_snippet = item.code + reasoning_parts.append(f"Ran Python code:\n\n```python\n{code_snippet}\n```") + elif isinstance(item, ResponseOutputMessage) and item.status == "completed": + text_content = item.content[0].text if item.content else "" + result_parts.append(text_content) + + reasoning_result = "\n\n".join(part for part in reasoning_parts if part) + final_result = "\n\n".join(part for part in result_parts if part) + + if response.usage: + usage = ChatCompletionUsage( + prompt_tokens=response.usage.input_tokens, + completion_tokens=response.usage.output_tokens, + total_tokens=response.usage.total_tokens, + prompt_tokens_details=PromptUsageDetails( + cached_tokens=response.usage.input_tokens_details.cached_tokens + if response.usage.input_tokens_details + else 0 + ), + completion_tokens_details=CompletionUsageDetails( + reasoning_tokens=response.usage.output_tokens_details.reasoning_tokens + if response.usage.output_tokens_details + else 0 + ), + tool_usage_details=ToolUsageDetails(**usage_stats), + ) + else: + usage = ChatCompletionUsage(tool_usage_details=ToolUsageDetails(**usage_stats)) + + return ModelResponse( + id=self.id, + model=self.config.id, + choices=[ + Choices( + index=0, + message=Message( + role="assistant", + content=final_result, + reasoning_content=reasoning_result.strip(), + ), + finish_reason="stop", + ) + ], + usage=Usage(**usage.model_dump()), + created=int(time()), + ) + + async def _completion_stream( + self, + messages: list[dict], + **hyperparams, + ) -> AsyncGenerator[ModelResponseStream, None]: + async for attempt in AsyncRetrying(**self.retry_policy): + with attempt: + async with self._get_deployment(messages=messages, **hyperparams) as ctx: + self._prepare_hyperparams(ctx, hyperparams) + # logger.warning(f"{hyperparams=}") + if ctx.use_openai_responses: + async for chunk in self._openai_responses_stream( + ctx, messages, **hyperparams + ): + yield chunk + else: + response: AsyncGenerator[ModelResponseStream, None] = await acompletion( + timeout=self.config.timeout, + api_key=ctx.api_key, + base_url=ctx.deployment.api_base, + model=ctx.routing_id, + messages=messages, + stream=True, + stream_options={"include_usage": True}, + **hyperparams, + ) + if response is None: + raise ModelOverloadError( + f'Model provider for "{self._model_display_id}" is overloaded. Please try again later.' + ) + # TODO: Investigate why litellm yields role chunks at the end of the stream + # i = 0 + async for chunk in response: + chunk.model = self.config.id + yield chunk + # i += 1 + + async def _completion( + self, + messages: list[dict], + **hyperparams, + ) -> ModelResponse: + async for attempt in AsyncRetrying(**self.retry_policy): + with attempt: + async with self._get_deployment(messages=messages, **hyperparams) as ctx: + self._prepare_hyperparams(ctx, hyperparams) + if ctx.use_openai_responses: + return await self._openai_responses(ctx, messages, **hyperparams) + response = await acompletion( + timeout=self.config.timeout, + api_key=ctx.api_key, + base_url=ctx.deployment.api_base, + model=ctx.routing_id, + messages=messages, + stream=False, + **hyperparams, + ) + if response is None: + raise ModelOverloadError( + f'Model provider for "{self._model_display_id}" is overloaded. Please try again later.' + ) + response.model = self.config.id + return response + + async def chat_completion( + self, + *, + messages: list[ChatEntry], + stream: bool, + **hyperparams, + ) -> ModelResponse | AsyncGenerator[ModelResponse, None]: + if not (isinstance(messages, list) and all(isinstance(m, ChatEntry) for m in messages)): + # We raise TypeError here since this is a programming error + raise TypeError("`messages` must be a list of `ChatEntry`.") + hyperparams.pop("stream_options", None) + messages, hyperparams = await self._prepare_chat( + messages=messages, + hyperparams=hyperparams, + ) + if stream: + return self._completion_stream(messages, **hyperparams) + else: + return await self._completion(messages, **hyperparams) + + ### --- Embedding --- ### + + async def embedding( + self, + *, + texts: list[str], + is_query: bool = True, + encoding_format: str | None = None, + **hyperparams, + ) -> EmbeddingResponse: + async for attempt in AsyncRetrying(**self.retry_policy): + with attempt: + async with self._get_deployment( + texts=texts, encoding_format=encoding_format, **hyperparams + ) as ctx: + # Get output dimensions + dimensions = ( + hyperparams.get("dimensions", None) or self.config.embedding_dimensions + ) + # Maybe transform texts + if self.config.embedding_transform_query is not None: + texts = [self.config.embedding_transform_query + text for text in texts] + # Set batch size and hyperparams + batch_size = 2048 + if ctx.deployment.provider == CloudProvider.COHERE: + if is_query: + hyperparams["input_type"] = "search_query" + else: + hyperparams["input_type"] = "search_document" + batch_size = 96 # limit on cohere server + elif ctx.deployment.provider == CloudProvider.JINA_AI: + batch_size = 128 # don't know limit, but too large will timeout + elif ctx.deployment.provider == CloudProvider.VOYAGE: + batch_size = 128 # limit on voyage server + elif ctx.deployment.provider == CloudProvider.OPENAI: + batch_size = 256 # limited by token per min (10,000,000) + + # self._billing.has_embedding_quota(model_id=self.embedder_config["id"]) + # Call + responses: list[LiteLLMEmbeddingResponse] = await asyncio.gather( + *[ + aembedding( + timeout=self.config.timeout, + api_key=ctx.api_key, + api_base=ctx.deployment.api_base, + model=ctx.routing_id, + input=txt, + dimensions=dimensions, + encoding_format=encoding_format, + **hyperparams, + ) + for txt in self.batch(texts, batch_size) + ] + ) + # Compile from batches + vectors = [ + e["embedding"] for e in itertools.chain(*[r.data for r in responses]) + ] + usage = EmbeddingUsage( + prompt_tokens=sum(getattr(r.usage, "prompt_tokens", 1) for r in responses), + total_tokens=sum(getattr(r.usage, "total_tokens", 1) for r in responses), + ) + # Might need to encode into base64 + if encoding_format == "base64" and isinstance(vectors[0], list): + logger.warning( + "`encoding_format` is `base64` but vectors are not base64 encoded." + ) + vectors = [ + b64encode(np.asarray(v, dtype=np.float32).tobytes()).decode("ascii") + for v in vectors + ] + embeddings = EmbeddingResponse( + data=[ + EmbeddingResponseData(embedding=v, index=i) + for i, v in enumerate(vectors) + ], + model=self.config.id, + usage=usage, + ) + return embeddings + + ### --- Reranking --- ### + + async def reranking( + self, + *, + query: str, + documents: list[str], + top_n: int | None = None, + **hyperparams, + ) -> RerankingResponse: + if len(documents) == 0: + raise ValueError("There are no documents to rerank.") + async for attempt in AsyncRetrying(**self.retry_policy): + with attempt: + async with self._get_deployment( + query=query, documents=documents, **hyperparams + ) as ctx: + batch_size = 100 + # self._billing.has_embedding_quota(model_id=self.embedder_config["id"]) + # Call + batches = list(self.batch(documents, batch_size)) + responses: list[RerankResponse] = await asyncio.gather( + *[ + arerank( + timeout=self.config.timeout, + api_key=ctx.api_key, + api_base=ctx.deployment.api_base, + model=ctx.routing_id, + query=query, + documents=docs, + top_n=top_n, + return_documents=False, + **hyperparams, + ) + for docs in batches + ] + ) + responses = [r.model_dump(exclude_unset=True) for r in responses] + # Compile results from batches + results = [ + { + "index": res["index"] + if i == 0 + else res["index"] + i * len(batches[i - 1]), + "relevance_score": res["relevance_score"], + } + for i, response in enumerate(responses) + for res in response["results"] + ] + results = sorted(results, key=lambda x: x["relevance_score"], reverse=True) + # Compile usage from batches + metas = [r.get("meta", {}) for r in responses] + billed_units = [m.get("billed_units", {}) for m in metas] + tokens = [m.get("tokens", {}) for m in metas] + billed_units = { + k: sum(d.get(k, 0) or 0 for d in billed_units) + for k in set().union(*billed_units) + } + tokens = { + k: sum(d.get(k, 0) or 0 for d in tokens) for k in set().union(*tokens) + } + usage = deepcopy(tokens) + usage["documents"] = len(documents) + # Generate final response + try: + response = responses[0] + except IndexError: + logger.error( + f"No responses from reranking!!! {batches=} {documents=} {batch_size=}" + ) + raise + response["results"] = results + response["usage"] = usage + response["meta"]["model"] = self.config.id + if len(billed_units) > 0: + response["meta"]["billed_units"] = billed_units + if len(tokens) > 0: + response["meta"]["tokens"] = tokens + return RerankingResponse.model_validate(response) + + +class LMEngine: + def __init__( + self, + *, + organization: OrganizationRead, + project: Project_, + request: Request, + ) -> None: + self.organization = organization + self.project = project + self.request = request + self.id: str = request.state.id + self.is_browser: bool = request.state.user_agent.is_browser + self.billing: BillingManager | None = getattr(request.state, "billing", None) + self._models: list[ModelConfigRead] | None = getattr(self.billing, "models", None) + self._chat_usage = ChatCompletionUsage() + self._embed_usage = EmbeddingUsage() + self._rerank_usage = RerankingUsage(documents=0) + + async def _get_models(self, capabilities: list[str] | None = None) -> list[ModelConfigRead]: + if self._models is None: + logger.warning( + f"{self.id} - No models found in BillingManager. Fetching from database." + ) + async with async_session() as session: + models = ( + await ModelConfig.list_( + session=session, + return_type=ModelConfigRead, + organization_id=self.organization.id, + capabilities=capabilities, + exclude_inactive=True, + ) + ).items + self._models = models + else: + models = [m for m in self._models if m.is_active] + # Filter by capability + if capabilities is not None: + for capability in capabilities: + models = [m for m in models if capability in m.capabilities] + if len(models) == 0: + raise ResourceNotFoundError( + f"No model found with capabilities: {list(map(str, capabilities))}." + ) + return models + + async def _get_model(self, model: str) -> ModelConfigRead: + model = model.strip() + model_configs = await self._get_models() + model_config = next((m for m in model_configs if m.id == model), None) + if model_config is None: + raise ResourceNotFoundError(f'Model "{model}" is not found.') + return model_config + + @staticmethod + def pick_best_model( + model_configs: list[ModelConfig_], + capabilities: list[ModelCapability], + ) -> ModelConfig_: + def _sort_key_with_priority(m: ModelConfig_) -> tuple[int, int, str]: + return ( + int(not m.id.startswith("ellm")), + int(ModelCapability.AUDIO in m.capabilities), # De-prioritise audio models + len(m.capabilities_set - set(capabilities)), + -m.priority, + m.name, + ) + + model_configs = natsorted(model_configs, key=_sort_key_with_priority) + return model_configs[0] + + ### --- Chat Completion --- ### + + @staticmethod + def _check_messages_type(messages: list[ChatEntry]): + if not (isinstance(messages, list) and all(isinstance(m, ChatEntry) for m in messages)): + # We raise TypeError here since this is a programming error + raise TypeError("`messages` must be a list of `ChatEntry`.") + + async def _get_default_model( + self, + model: str, + capabilities: list[ModelCapability], + ) -> ModelConfigRead: + capabilities_set = set(capabilities) + # If model is empty string, we try to get a suitable model + if model == "": + # Error will be raised if no suitable model is found + model_configs = await self._get_models(capabilities) + model_config = self.pick_best_model(model_configs, capabilities) + else: + model_config = await self._get_model(model) + if len(lack := (capabilities_set - model_config.capabilities_set)) > 0: + raise ModelCapabilityError( + f'Model "{model_config.name if self.is_browser else model}" lack these capabilities: {", ".join(lack)}' + ) + return model_config + + @asynccontextmanager + async def _setup_chat( + self, + model: str, + messages: list[ChatEntry], + ): + # Validate model capability + self._check_messages_type(messages) + capabilities = [str(ModelCapability.CHAT)] + if any(m.has_image for m in messages): + capabilities.append(str(ModelCapability.IMAGE)) + if any(m.has_audio for m in messages): + capabilities.append(str(ModelCapability.AUDIO)) + # If model is empty string, we try to get a suitable model + model_config = await self._get_default_model(model, capabilities) + model = model_config.id + # Setup rate limiting + # rpm_limiter = CascadeRateLimiter( + # org_hpm=ENV_CONFIG.llm_requests_per_minute, + # proj_hpm=ENV_CONFIG.llm_requests_per_minute, + # organization_id=self.organization.id, + # project_id=self.project.id, + # key=f"{model}:rpm", + # name="RPM", + # ) + # tpm_limiter = CascadeRateLimiter( + # org_hpm=ENV_CONFIG.llm_tokens_per_minute, + # proj_hpm=ENV_CONFIG.llm_tokens_per_minute, + # organization_id=self.organization.id, + # project_id=self.project.id, + # key=f"{model}:tpm", + # name="TPM", + # ) + # # Test rate limits + # await asyncio.gather(rpm_limiter.test(), tpm_limiter.test(max_tokens)) + router = DeploymentRouter( + request=self.request, + config=model_config, + organization=self.organization, + is_browser=self.is_browser, + ) + try: + yield router + finally: + # # Consume rate limits + # await asyncio.gather(rpm_limiter.hit(), tpm_limiter.hit(self._chat_usage.total_tokens)) + if self.billing is not None: + try: + self.billing.create_llm_events( + model_id=model, + input_tokens=self._chat_usage.prompt_tokens, + output_tokens=self._chat_usage.completion_tokens, + ) + except Exception as e: + logger.warning(f"Failed to create LLM events due to error: {repr(e)}") + + async def chat_completion_stream( + self, + *, + model: str, + messages: list[ChatEntry], + **hyperparams, + ) -> AsyncGenerator[ChatCompletionChunkResponse, None]: + """ + Generate streaming chat completions. + + Args: + model (str): Model ID. Can be empty in which case we try to get a suitable model based on message content. + messages (list[ChatEntry]): List of messages. + **hyperparams (Any): Keyword arguments. + + Yields: + chunk (ChatCompletionChunkResponse): A chat chunk. + """ + hyperparams.pop("stream", None) + async with self._setup_chat(model, messages) as router: + completion: AsyncGenerator[ModelResponse, None] = await router.chat_completion( + messages=messages, + stream=True, + **hyperparams, + ) + async for chunk in completion: + if hasattr(chunk, "usage"): + self._chat_usage = ChatCompletionUsage.model_validate(chunk.usage.model_dump()) + yield ChatCompletionChunkResponse( + **chunk.model_dump(exclude_unset=True, exclude_none=True) + ) + + async def chat_completion( + self, + *, + model: str, + messages: list[ChatEntry], + **hyperparams, + ) -> ChatCompletionResponse: + """ + Generate chat completions. + + Args: + model (str): Model ID. Can be empty in which case we try to get a suitable model based on message content. + messages (list[ChatEntry]): List of messages. + **hyperparams (Any): Keyword arguments. + + Returns: + response (ChatCompletionResponse): The chat response. + """ + hyperparams.pop("stream", None) + async with self._setup_chat(model, messages) as router: + completion: ModelResponse = await router.chat_completion( + messages=messages, + stream=False, + **hyperparams, + ) + completion = ChatCompletionResponse.model_validate( + completion.model_dump(exclude_unset=True, exclude_none=True) + ) + self._chat_usage = completion.usage + return completion + + async def generate_title( + self, + *, + excerpt: str, + model: str = "", + **hyperparams, + ) -> str: + system_prompt = dedent("""\ + You are a professional document analyst. Your primary goal is to extract the most accurate and complete title from the document's first page. + + Analyze the page using the following prioritized steps: + + 1. **PRIORITY 1: EXTRACT VERBATIM TITLE:** + - First, attempt to identify and extract the main, verbatim title. This is typically the most prominent text block (e.g., largest font, bold, centered) + at the top of the page, common in academic papers, reports, or articles. This is the preferred method. + - Prominent text block may not always represent the title, so read the entire page to understand the context. Append the suitable subtitle if it exists. + - Append the purpose of the document based on the page content. + + 2. **PRIORITY 2: ASSEMBLE FROM COMPONENTS:** + If no single, clear verbatim title exists (common in forms or structured plans), + then construct a title by extracting and combining these components: + - Primary Entity: The main company/organization. + - Document Type: The official name of the document (e.g., Insurance Plan, Agreement). + - Key Identifiers: Extract ALL unique codes and levels. This includes a master identifier (like a Policy or Group Number) + AND specific sub-identifier (like a Plan Name, Plan Level, Tier, Date and/or Year). + + 3. **UNIVERSAL RULE: INCLUDE IDENTIFIERS:** + Regardless of whether the title is extracted verbatim (Priority 1) or assembled (Priority 2), + append both the master identifier and the specific plan level if both are present. + Append the date or year if it is part of the title or relevant to the document's context. + + 4. **OUTPUT:** Output only the final, single-line title. (Max 20 words). + """) + prompt = dedent(f"""\ + Analyze the Page content below and output the most representative title based on your core instructions. + - DO NOT THINK, OUTPUT ONLY THE FINAL TITLE + + **Page Context:** + {excerpt} + """) + # Override hyperparams + hyperparams.update( + temperature=0.01, + top_p=0.01, + max_tokens=500, + stream=False, + reasoning_effort="minimal", + ) + try: + completion = ( + await self.chat_completion( + model=model, + messages=[ChatEntry.system(system_prompt), ChatEntry.user(prompt)], + **hyperparams, + ) + ).content + title = completion.strip().strip('"') + except Exception as e: + logger.warning( + f"{hyperparams.get('id', '')} - Title extraction failed for excerpt: \n{excerpt}\n, error: {e}" + ) + title = "" + if not title: + title = "Document" + return title + + async def generate_chat_title( + self, + *, + user_content: str, + assistant_content: str, + model: str = "", + **hyperparams, + ) -> SanitisedStr: + system_prompt = "Generate a concise, descriptive title for a chat message." + prompt = dedent(f"""\ + + {user_content} + + + + {assistant_content} + + + Do not think. Generate a short, concise title of no more than 5 words for the conversation. + """) + # Override hyperparams + hyperparams.update( + temperature=0.01, + top_p=0.01, + max_tokens=500, + stream=False, + reasoning_effort="minimal", + ) + default_title = "New Chat" + try: + completion = ( + await self.chat_completion( + model=model, + messages=[ChatEntry.system(system_prompt), ChatEntry.user(prompt)], + **hyperparams, + ) + ).content + title = completion.strip().strip('"') + if not title: + title = default_title + except Exception as e: + logger.warning( + f"{hyperparams.get('id', '')} - Title generation failed for the chat message: {user_content}, error: {e}" + ) + title = default_title + + # Replace non-printable characters with space + return " ".join("".join(c if c.isprintable() else " " for c in title).split()) + + def _rewrite_prompts_for_fts_query(self, input_prompt: str) -> str: + system_prompt = dedent("""\ + You are an advanced search query generation system. Your purpose is to translate user questions and conversational context into precise query components optimized for an information retrieval system using both keyword-based Full-Text Search (FTS) with pgroonga. + + Your primary tasks are: + 1. **Analyze Intent:** Deeply understand the user's information need expressed in their query and any relevant conversation history (if provided). + 2. **Extract Key Information:** Identify critical keywords, named entities (people, places, organizations, dates, etc.), specific technical terms, and core concepts. + 3. **Disambiguate:** Resolve ambiguities based on context. + 4. **Generate Direct Query Output:** Produce a direct answer containing the distinct query strings: + *Optimized for keyword precision and recall in pgroonga. Focus on essential nouns, verbs, entities, and specific identifiers. Should be concise. + + Focus on generating queries that, when used together in their respective search engines, will yield the most relevant results. Accuracy, relevance, and appropriate optimization for each search type are paramount. + """) + prompt = dedent(f"""\ + "user_query": "{input_prompt}", + "current_datetime": "{now().isoformat()}" + + Instructions: + Analyze the user_query, considering the current_datetime for temporal references. Generate a direct query string containing the rewritten query optimized for pgroonga FTS, keeping in mind that **stemming is active (at least for English)**. Follow these steps precisely: + + 1. **Identify Core Concepts:** Extract the most important terms representing the subject, action/intent, and key context from the user_query. Include essential nouns, verbs, entities, codes, and specific identifiers. Since stemming is active, focus on the root concepts. + 2. **Handle Phrases:** Identify multi-word terms crucial to the meaning (e.g., "machine learning", "API key", "user acceptance testing"). Enclose these exact phrases in double quotes (`"`). Stemming does not preserve word order, making phrase matching critical. + 3. **Use Synonyms/Alternatives (OR - Strategically):** + * Use `OR` *only* for genuinely distinct synonyms or alternative concepts that **will likely not stem to the same root** (e.g., `bug OR defect`, `UI OR "user interface"`). + * **Do NOT** use `OR` for simple word variations handled by stemming (e.g., do not write `database OR databases`, `configure OR configuration`, `run OR running` - the stemmer handles these). + * Use OR sparingly, focusing on high-value alternatives to improve recall for distinct concepts. + 4. **Convert Dates:** Use the `current_datetime` to resolve relative temporal references (e.g., "last year", "yesterday") into absolute numeric formats (YYYY or YYYY-MM-DD). For ranges like "last 2 years", list the specific years space-separated (e.g., based on 2025-04-16, "last 2 years" -> `2023 2024`). + 5. **Combine Terms:** Join individual keywords (prefer base/stemmed forms where natural), quoted phrases, and `OR` groups primarily with spaces (implying an AND relationship between distinct concepts). + 6. **Filter Noise but Preserve Meaning:** Remove generic filler words (like "the", "a", "is", "how to") UNLESS they are part of an essential quoted phrase. Prioritize terms likely to appear verbatim (or their stems) in relevant documents, but do not discard terms crucial for understanding the query's specific intent (e.g., keep words like "compare", "impact", "migrate" if central). + 7. **Conciseness and Completeness:** Aim for a query that is concise yet captures the full essential meaning of the original user query, leveraging the stemmer's capabilities. + 8. **Multi-Word Terms:** Use double quotes for terms composed of multiple words that should be treated as a single unit. Example: United Kingdom -> "United Kingdom" + + **Examples:** + + * **User Query:** What's the meaning of USG? + **FTS Query:** USG meaning OR definition + + * **User Query:** In 2024 how many database outage happened? + **FTS Query:** database outage OR failure 2024 count + + * **User Query:** How can I configure the connection pool for the main transaction database? + **FTS Query:** configure OR setup "connection pool" "main transaction database" + + * **User Query:** Any issues reported for the payment gateway integration last month? (Given Datetime: 2025-04-16) + **FTS Query:** issue OR problem OR error "payment gateway integration" 2025-03 + + * **User Query:** Compare performance impact of Redis vs Memcached deployment in production last year. (Given Datetime: 2025-04-16) + **FTS Query:** compare performance impact Redis Memcached deployment production 2024 + + * **User Query:** What's the weather in Japan 3 months ago? (Given Datetime: 2025-04-16) + **FTS Query:** weather Japan 2025-01 + + * **User Query:** ãƒã‚«ãƒ¯ã¯ä½•å¹´ã«ç”Ÿã¾ã‚Œã¾ã™ã‹ï¼Ÿ + **FTS Query:** ãƒã‚«ãƒ¯ 生ã¾ã‚Œã‚‹ OR 誕生 + + Reply ONLY with the generated FTS query string. Do not think. Do not include explanations, reasoning, markdown formatting, no need to use quotes to encapsulate the entire results, or any text outside the final FTS Query in the original query language. + + Now generate the query: + """) + return system_prompt, prompt + + def _rewrite_prompts_for_vs_query(self, input_prompt: str) -> str: + system_prompt = dedent("""\ + You are an advanced search query generation system. Your purpose is to translate user questions and conversational context into precise query components optimized for an information retrieval system using semantic Vector Search (VS). + + Your primary tasks are: + 1. **Analyze Intent:** Deeply understand the user's information need expressed in their query and any relevant conversation history (if provided). + 2. **Extract Key Information:** Identify critical keywords, named entities (people, places, organizations, dates, etc.), specific technical terms, and core concepts. + 3. **Disambiguate:** Resolve ambiguities based on context. + 4. **Generate Direct Query Output:** Produce a direct answer containing the distinct query strings: + *Optimized for capturing semantic meaning and nuance for vector embedding similarity search. Should be a well-formed natural language sentence or question reflecting the user's core intent. + + Focus on generating queries that, when used together in their respective search engines, will yield the most relevant results. Accuracy, relevance, and appropriate optimization for each search type are paramount. + """) + prompt = dedent(f"""\ + "user_query": "{input_prompt}", + "current_datetime": "{now().isoformat()}" + + Instructions: + Analyze the user_query, considering the current_datetime for temporal references. Generate a direct query string containing vector query for vector search. + + 1. **vector_query**: + * Create a natural language sentence or question that captures the core semantic meaning and intent of the user_query. + * This query should be suitable for generating an embedding for vector similarity search. + * Retain natural language phrasing for concepts, including relative time expressions (e.g., "last year", "next quarter") if they better represent the user's intent semantically. + * Example style: How to fix database connection timeout errors when configuring pgroonga, especially issues seen recently? + + Reply ONLY with the generated VS query. Do not think. Do not include explanations, reasoning, markdown formatting, or any text outside the final VS Query. + + Now generate the query: + """) + return system_prompt, prompt + + @staticmethod + def _extract_text_prompt( + messages: list[ChatEntry], + ) -> tuple[list[ChatEntry], str, list[ImageContent | AudioContent] | None]: + # Make a deep copy to avoid side effects + messages = deepcopy(messages) + # The message list should end with user message + if messages[-1].role == ChatRole.USER: + pass + elif messages[-2].role == ChatRole.USER: + messages = messages[:-1] + else: + raise BadInputError("The message list should end with user or assistant message.") + content = messages[-1].content + if isinstance(content, str): + prompt = content + multimodal_contents = None + else: + prompt = messages[-1].text_content + multimodal_contents = [c for c in content if not isinstance(c, TextContent)] + return messages, prompt, multimodal_contents + + async def _generate_search_query( + self, + *, + model: str, + messages: list[ChatEntry], + type: str, + **hyperparams, + ) -> str: + messages, prompt, multimodal_contents = self._extract_text_prompt(messages) + # Retrieved system and user prompt, updated as of 2025-04-17 + if type == "fts": + system_prompt, new_prompt = self._rewrite_prompts_for_fts_query(prompt) + elif type == "vs": + system_prompt, new_prompt = self._rewrite_prompts_for_vs_query(prompt) + else: + raise BadInputError( + f"Rewrite prompt only works for type: FTS or VS. Invalid type: {type}" + ) + + if messages[0].role == ChatRole.SYSTEM: + # Suggest to just override system prompt, 2025-04-17 + messages[0].content = system_prompt + else: + messages.insert(0, ChatEntry.system(system_prompt)) + # Override hyperparams + hyperparams.update( + temperature=0.01, + top_p=0.01, + max_tokens=1000, + stream=False, + reasoning_effort="minimal", + ) + if multimodal_contents is not None: + new_prompt = multimodal_contents + [TextContent(text=new_prompt)] + messages[-1] = ChatEntry.user(new_prompt) + completion = ( + await self.chat_completion( + model=model, + messages=messages, + **hyperparams, + ) + ).content + if completion is None: + new_prompt = prompt + else: + new_prompt = completion.strip() + if new_prompt.startswith('"') and new_prompt.endswith('"'): + new_prompt = new_prompt[1:-1] + return new_prompt + + async def generate_search_query( + self, + *, + model: str, + messages: list[ChatEntry], + rag_params: RAGParams, + **hyperparams, + ) -> tuple[str, str]: + """ + Generate search query for RAG. + + Args: + model (str): Model ID. Can be empty in which case we try to get a suitable model based on message content. + messages (list[ChatEntry]): List of messages. + rag_params (RAGParams): RAG parameters. + **hyperparams (Any): Keyword arguments. + + Raises: + TypeError: If `rag_params` is not an instance of `RAGParams`. + BadInputError: If the message list does not end with user or assistant message. + + Returns: + fts_query (str): The fts search query. + vs_query (str): The vs search query. + """ + self._check_messages_type(messages) + if not isinstance(rag_params, RAGParams): + raise TypeError("`rag_params` must be an instance of `RAGParams`.") + # Generate missing queries in parallel + queries = { + "fts": rag_params.search_query.strip(), + "vs": rag_params.search_query.strip(), + } + to_generate = [q_type for q_type, query in queries.items() if not query] + if to_generate: + generated = await asyncio.gather( + *[ + self._generate_search_query( + model=model, + messages=messages, + type=q_type, + **hyperparams, + ) + for q_type in to_generate + ] + ) + # Update the queries dict with generated values + for q_type, generated_query in zip(to_generate, generated, strict=True): + queries[q_type] = generated_query + return queries["fts"], queries["vs"] + + async def generate_rag_prompt( + self, + *, + messages: list[ChatEntry], + references: References, + inline_citations: bool = False, + ) -> str | list[TextContent | ImageContent | AudioContent]: + _, prompt, multimodal_contents = self._extract_text_prompt(messages) + documents = "\n\n".join( + dedent(f"""\ + + + {chunk.title} + {i} + {chunk.page} + + {"\n".join(f"## {k}: {v}" for k, v in chunk.context.items())} + + ## Text:\n{chunk.text} + + + + + """) + for i, chunk in enumerate(references.chunks) + ) + context_prompt = f"\n\n{documents}\n\n\n\n" + if inline_citations: + prompt += ( + "\n" + "When any sentence in your answer is supported by or refers to one or more documents inside , " + "append inline citations using Pandoc-style `[@]` for each supporting document at the end of that sentence, " + "immediately before the sentence-ending punctuation. " + "Use the exact from each and never invent IDs. " + "Arrange the citations from most to least relevant. " + "If multiple documents support the sentence, include multiple citations delimited by semicolons `[@; @]`. " + "Always separate the text and citations with one space, ie ` [@]`. " + "Do not cite for general knowledge, your own reasoning, or content not found in the provided documents. " + "\n" + "For example:" + "\n" + '- "London is the capital of England."\n' + '- "The merger was completed in Q3 [@4]."\n' + '- "Revenue was $8.2 million [@7; @1]."\n' + ) + if multimodal_contents is None: + multimodal_contents = [] + prompt = ( + [TextContent(text=context_prompt)] + multimodal_contents + [TextContent(text=prompt)] + ) + return prompt + + ### --- Embedding --- ### + + @asynccontextmanager + async def _setup_embedding(self, model: str): + # Validate model capability + capabilities = [str(ModelCapability.EMBED)] + model_config = await self._get_default_model(model, capabilities) + model = model_config.id + # Setup rate limiting + # rpm_limiter = CascadeRateLimiter( + # org_hpm=ENV_CONFIG.embed_requests_per_minute, + # proj_hpm=ENV_CONFIG.embed_requests_per_minute, + # organization_id=self.organization.id, + # project_id=self.project.id, + # key=f"{model}:rpm", + # name="RPM", + # ) + # tpm_limiter = CascadeRateLimiter( + # org_hpm=ENV_CONFIG.embed_tokens_per_minute, + # proj_hpm=ENV_CONFIG.embed_tokens_per_minute, + # organization_id=self.organization.id, + # project_id=self.project.id, + # key=f"{model}:tpm", + # name="TPM", + # ) + # # Test rate limits + # await asyncio.gather(rpm_limiter.test(), tpm_limiter.test()) + router = DeploymentRouter( + request=self.request, + config=model_config, + organization=self.organization, + is_browser=self.is_browser, + ) + try: + yield router + finally: + # # Consume rate limits + # await asyncio.gather( + # rpm_limiter.hit(), tpm_limiter.hit(self._embed_usage.total_tokens) + # ) + if self.billing is not None: + try: + self.billing.create_embedding_events( + model_id=model, + token_usage=self._embed_usage.total_tokens, + ) + except Exception as e: + logger.warning(f"Failed to create embedding events due to error: {repr(e)}") + + async def embed_documents( + self, + *, + model: str, + texts: list[str], + encoding_format: str | None = None, + **hyperparams, + ) -> EmbeddingResponse: + """ + Embed documents. + + Args: + model (str): Model ID. Can be empty in which case we try to get a suitable model. + texts (list[str]): List of strings to embed as documents. + encoding_format (str | None, optional): Vector encoding format. Defaults to None. + + Returns: + response (EmbeddingResponse): The embedding response. + """ + if len(texts) == 0: + raise BadInputError("There is no text or content to embed.") + # TODO: Do we need to truncate based on context length? + # encoding = tiktoken.get_encoding(encoding_name) + # encoded_text = encoding.encode(text) + # if len(encoded_text) <= max_context_length: + # return text + # truncated_encoded = encoded_text[:max_context_length] + # truncated_text = encoding.decode(truncated_encoded) + async with self._setup_embedding(model) as router: + embeddings = await router.embedding( + texts=texts, + is_query=False, + encoding_format=encoding_format, + **hyperparams, + ) + self._embed_usage = embeddings.usage + return embeddings + + async def embed_queries( + self, + model: str, + texts: list[str], + encoding_format: str | None = None, + **hyperparams, + ) -> EmbeddingResponse: + """ + Embed documents. + + Args: + model (str): Model ID. Can be empty in which case we try to get a suitable model. + texts (list[str]): List of strings to embed as queries. + encoding_format (str | None, optional): Vector encoding format. Defaults to None. + + Returns: + response (EmbeddingResponse): The embedding response. + """ + # TODO: Do we need to truncate based on context length? + async with self._setup_embedding(model) as router: + embeddings = await router.embedding( + texts=texts, + is_query=True, + encoding_format=encoding_format, + **hyperparams, + ) + self._embed_usage = embeddings.usage + return embeddings + + async def embed_query_as_vector( + self, + model: str, + text: str, + **hyperparams, + ) -> list[float]: + """ + Embed documents. + + Args: + model (str): Model ID. Can be empty in which case we try to get a suitable model. + text (str): A string to embed as query. + + Returns: + vector (list[float]): The embedding vector. + """ + response = await self.embed_queries( + model=model, + texts=[text], + encoding_format="float", + **hyperparams, + ) + return response.data[0].embedding + + ### --- Reranking --- ### + + @asynccontextmanager + async def _setup_reranking(self, model: str): + # Validate model capability + capabilities = [str(ModelCapability.RERANK)] + model_config = await self._get_default_model(model, capabilities) + model = model_config.id + # Setup rate limiting + # rpm_limiter = CascadeRateLimiter( + # org_hpm=ENV_CONFIG.rerank_requests_per_minute, + # proj_hpm=ENV_CONFIG.rerank_requests_per_minute, + # organization_id=self.organization.id, + # project_id=self.project.id, + # key=f"{model}:rpm", + # name="RPM", + # ) + # spm_limiter = CascadeRateLimiter( + # org_hpm=ENV_CONFIG.rerank_searches_per_minute, + # proj_hpm=ENV_CONFIG.rerank_searches_per_minute, + # organization_id=self.organization.id, + # project_id=self.project.id, + # key=f"{model}:spm", + # name="SPM", + # ) + # # Test rate limits + # await asyncio.gather(rpm_limiter.test(), spm_limiter.test()) + router = DeploymentRouter( + request=self.request, + config=model_config, + organization=self.organization, + is_browser=self.is_browser, + ) + try: + yield router + finally: + # # Consume rate limits + # await asyncio.gather(rpm_limiter.hit(), spm_limiter.hit(self._rerank_usage.documents)) + if self.billing is not None: + try: + self.billing.create_reranker_events( + model_id=model, + num_searches=self._rerank_usage.documents, + ) + except Exception as e: + logger.warning(f"Failed to create reranker events due to error: {repr(e)}") + + async def rerank_documents( + self, + *, + model: str, + query: str, + documents: list[str], + top_n: int | None = None, + **hyperparams, + ) -> RerankingResponse: + """ + Rerank documents. + + Args: + model (str): Model ID. Can be empty in which case we try to get a suitable model. + query (str): Query string. + documents (list[str]): List of strings to rerank. + top_n (int | None, optional): Only return `top_n` results. Defaults to None. + + Returns: + response (RerankingResponse): The rerank response. + """ + if len(query.strip()) == 0: + raise BadInputError("Query cannot be empty.") + if len(documents) == 0: + raise BadInputError("There are no documents to rerank.") + async with self._setup_reranking(model) as router: + rerankings = await router.reranking( + query=query, + documents=documents, + top_n=top_n, + **hyperparams, + ) + self._rerank_usage = rerankings.usage + return rerankings diff --git a/services/api/src/owl/utils/logging.py b/services/api/src/owl/utils/logging.py index 6812766..3d2a998 100644 --- a/services/api/src/owl/utils/logging.py +++ b/services/api/src/owl/utils/logging.py @@ -7,10 +7,15 @@ import inspect import logging import sys +from typing import Any +import httpx from loguru import logger -from owl.configs.manager import ENV_CONFIG +from owl.client import VictoriaMetricsAsync +from owl.configs import ENV_CONFIG +from owl.types import LogQueryResponse +from owl.utils.io import json_loads class InterceptHandler(logging.Handler): @@ -75,20 +80,22 @@ def suppress_logging_handlers(names: list[str], include_submodules: bool = True) lgg.setLevel("ERROR") -def setup_logger_sinks(log_filepath: str = f"{ENV_CONFIG.owl_log_dir}/owl.log"): +def setup_logger_sinks(log_filepath: str | None = f"{ENV_CONFIG.log_dir}/owl.log"): logger.remove() logger.level("INFO", color="") - logger.configure( - handlers=[ - { - "sink": sys.stderr, - "level": "INFO", - "serialize": False, - "backtrace": False, - "diagnose": True, - "enqueue": True, - "catch": True, - }, + handlers = [ + { + "sink": sys.stderr, + "level": "INFO", + "serialize": False, + "backtrace": False, + "diagnose": True, + "enqueue": True, + "catch": True, + }, + ] + if log_filepath is not None: + handlers.append( { "sink": log_filepath, "level": "INFO", @@ -101,5 +108,92 @@ def setup_logger_sinks(log_filepath: str = f"{ENV_CONFIG.owl_log_dir}/owl.log"): "delay": False, "watch": False, }, - ], - ) + ) + logger.configure(handlers=handlers) + + +class VictoriaLogClient(VictoriaMetricsAsync): + __QUERY_ENDPOINT = "/select/logsql/query" + + def _construct_query( + self, + time: str = None, + severity: str = None, + org_ids: list[str] = None, + proj_ids: list[str] = None, + user_ids: list[str] = None, + ) -> str: + """ + Constructs a query string for the VictoriaMetrics log query. + + Args: + time (str, optional): The time range for the query defaults to 5m. + severity (str, optional): The severity level of the logs. + org_ids (list[str], optional): organization IDs. + proj_ids (list[str], optional): project IDs. + user_ids (list[str], optional): user IDs. + + Returns: + str: A query string starting with '_time:5m' if no parameters are provided, + otherwise a string of key:value pairs joined by ' AND '. + """ + + query_params = { + "_time": time or "5m", + "severity": severity.upper() if severity else None, + } + + query_parts = [ + f"{key}:{value}" for key, value in query_params.items() if value is not None + ] + + if org_ids: + org_values = " OR ".join(org_ids) + query_parts.append(f"org_id:({org_values})") + + if proj_ids: + proj_values = " OR ".join(proj_ids) + query_parts.append(f"proj_id:({proj_values})") + + if user_ids: + user_values = " OR ".join(user_ids) + query_parts.append(f"user_id:({user_values})") + + return " AND ".join(query_parts) + + async def query_logs( + self, + time: str = None, + severity: str = None, + org_ids: list[str] = None, + proj_ids: list[str] = None, + user_ids: list[str] = None, + ) -> LogQueryResponse: + """ + Queries logs from VictoriaMetrics using the constructed query parameters. + + Args: + time (str, optional): The time range for the query. + severity (str, optional): The severity level of the logs. + org_ids (list[str], optional): organization IDs. + proj_ids (list[str], optional): project IDs. + user_ids (list[str], optional): user IDs. + + Returns: + LogQueryResponse: A list of JSON objects representing the logs. + """ + params = {"query": self._construct_query(time, severity, org_ids, proj_ids, user_ids)} + response = await self._fetch_victoria_metrics(self.__QUERY_ENDPOINT, params) + return LogQueryResponse(logs=self._process_logs(response)) + + def _process_logs(self, response: httpx.Response) -> list[dict[str, Any]]: + """ + Processes the HTTP response from VictoriaMetrics and extracts log entries. + + Args: + response (httpx.Response): The HTTP response object. + + Returns: + list: A list of JSON objects parsed from the response. + """ + return [json_loads(line) for line in response.iter_lines() if line] diff --git a/services/api/src/owl/utils/loguru_otlp_handler.py b/services/api/src/owl/utils/loguru_otlp_handler.py new file mode 100644 index 0000000..20bf75b --- /dev/null +++ b/services/api/src/owl/utils/loguru_otlp_handler.py @@ -0,0 +1,290 @@ +""" +Adapted from: https://github.com/s71m/opentelemetry-loguru-telegram/blob/master/utils/loguru_otlp_handler.py +""" + +import atexit +import queue +import signal +import sys +import threading +import time +import traceback +from time import time_ns +from typing import Any, ClassVar, Dict + +from loguru import logger +from opentelemetry import trace +from opentelemetry._logs import SeverityNumber +from opentelemetry.sdk._logs._internal import LoggerProvider, LogRecord +from opentelemetry.sdk._logs._internal.export import BatchLogRecordProcessor, LogExporter +from opentelemetry.sdk.resources import Resource + +# Constants +MAX_QUEUE_SIZE = 10000 + +# Simplified severity mapping +SEVERITY_MAPPING = { + 10: SeverityNumber.DEBUG, + 20: SeverityNumber.INFO, + 30: SeverityNumber.WARN, + 40: SeverityNumber.ERROR, + 50: SeverityNumber.FATAL, +} + + +class OTLPHandler: + _instances: ClassVar[list["OTLPHandler"]] = [] # Changed from set to list for safe iteration + _shutdown_lock: ClassVar[threading.Lock] = threading.Lock() + _is_shutting_down: ClassVar[bool] = False + + def __init__( + self, + service_name: str, + exporter: LogExporter, + max_queue_size: int = MAX_QUEUE_SIZE, + batch_size: int = 100, + export_interval_ms: int = 1000, + ): + self._resource = Resource( + { + "service.name": service_name, + # "service.instance.id": uuid7_str(), + } + ) + self._queue: queue.Queue[Dict[str, Any]] = queue.Queue(maxsize=max_queue_size) + self._shutdown_event = threading.Event() + # self._flush_complete = threading.Event() + + # Initialize logger provider with resource + self._logger_provider = LoggerProvider(resource=self._resource) + self._logger_provider.add_log_record_processor( + BatchLogRecordProcessor( + exporter, + max_export_batch_size=batch_size, + schedule_delay_millis=export_interval_ms, + export_timeout_millis=5000, + ) + ) + self._logger = self._logger_provider.get_logger(service_name) + + # Start worker thread + self._worker = threading.Thread(target=self._process_queue, name="loguru_otlp_worker") + self._worker.daemon = True + self._worker.start() + + # Register this instance + with self._shutdown_lock: + self.__class__._instances.append(self) + + # Register shutdown handlers only once + if len(self._instances) == 1: + atexit.register(self._shutdown_all_handlers) + signal.signal(signal.SIGINT, self._signal_handler) + signal.signal(signal.SIGTERM, self._signal_handler) + + def _get_trace_context(self) -> tuple: + """Get the current trace context.""" + span_context = trace.get_current_span().get_span_context() + return ( + span_context.trace_id if span_context.is_valid else 0, + span_context.span_id if span_context.is_valid else 0, + span_context.trace_flags if span_context.is_valid else 0, + ) + + def _get_severity(self, level_no: int) -> tuple: + """Map Loguru level to OpenTelemetry severity.""" + base_level = (level_no // 10) * 10 + return ( + SEVERITY_MAPPING.get(base_level, SeverityNumber.UNSPECIFIED), + "CRITICAL" + if level_no >= 50 + else "ERROR" + if level_no >= 40 + else "WARNING" + if level_no >= 30 + else "INFO" + if level_no >= 20 + else "DEBUG", + ) + + def _extract_attributes(self, record: Dict[str, Any]) -> Dict[str, Any]: + """Extract attributes from the record.""" + attributes = { + "code.filepath": record["file"].path, + "code.function": record["function"], + "code.lineno": record["line"], + "filename": record["file"].name, + } + + # Add extra attributes if present + extra = record.get("extra", {}) + if isinstance(extra, dict): + for k, v in extra.items(): + if isinstance(v, (str, int, float, bool)): + attributes[k] = v + else: + attributes[k] = repr(v) + else: + pass + + # Handle exception information + if "exception" in record and record["exception"]: + exc_type, exc_value, exc_tb = record["exception"] + if exc_type: + attributes.update( + { + "exception.type": exc_type.__name__, + "exception.message": str(exc_value) if exc_value else "No message", + "exception.stacktrace": "".join( + traceback.format_exception(exc_type, exc_value, exc_tb) + ) + if exc_tb + else "No stacktrace", + } + ) + + return attributes + + def _create_log_record(self, record: Dict[str, Any]) -> LogRecord: + """Create an OpenTelemetry LogRecord.""" + severity_number, severity_text = self._get_severity(record["level"].no) + trace_id, span_id, trace_flags = self._get_trace_context() + + if "exception" in record and record["exception"]: + severity_number = SeverityNumber.FATAL + severity_text = "CRITICAL" + + return LogRecord( + timestamp=int(record["time"].timestamp() * 1e9), + observed_timestamp=time_ns(), + trace_id=trace_id, + span_id=span_id, + trace_flags=trace_flags, + severity_text=severity_text, + severity_number=severity_number, + body=record["message"], + resource=self._logger.resource, + attributes=self._extract_attributes(record), + ) + + @classmethod + def _shutdown_all_handlers(cls): + """Shutdown all handler instances safely.""" + with cls._shutdown_lock: + if cls._is_shutting_down: + return + cls._is_shutting_down = True + + # Create a copy of instances for safe iteration + handlers = cls._instances.copy() + + # Shutdown each handler + for handler in handlers: + try: + handler.shutdown() + except Exception as e: + logger.warning(f"Error shutting down handler: {e}", file=sys.stderr) + + # Clear the instances list + with cls._shutdown_lock: + cls._instances.clear() + + @classmethod + def _signal_handler(cls, signum, frame): + """Handle termination signals.""" + logger.info("\nShutting down logger...", file=sys.stderr) + cls._shutdown_all_handlers() + sys.exit(0) + + def _process_queue(self) -> None: + """Process logs from the queue until shutdown.""" + while not self._shutdown_event.is_set() or not self._queue.empty(): + try: + try: + record = self._queue.get(timeout=0.1) + except queue.Empty: + continue + + if record is None: + self._queue.task_done() + continue + + log_record = self._create_log_record(record) + self._logger.emit(log_record) + self._queue.task_done() + + except Exception as e: + logger.warning(f"Error processing log record: {e}", file=sys.stderr) + + def sink(self, message) -> None: + """Add log message to queue.""" + if self._shutdown_event.is_set(): + return + + try: + self._queue.put_nowait(message.record) + except queue.Full: + logger.warning("Warning: Log queue full, dropping message", file=sys.stderr) + + def shutdown(self) -> None: + """Graceful shutdown of the handler.""" + if self._shutdown_event.is_set(): + return + + try: + # Signal shutdown + self._shutdown_event.set() + # Wait for queue to empty + try: + # Wait with timeout + if not self._queue.empty(): + # Give some time for the queue to process + timeout = 5.0 # 5 seconds timeout + start_time = time.time() + + while not self._queue.empty() and (time.time() - start_time) < timeout: + time.sleep(0.1) + + if not self._queue.empty(): + logger.warning("Warning: Queue not empty after timeout", file=sys.stderr) + + # Force flush remaining logs + # self._logger_provider.force_flush(timeout_millis=5000) + + # Wait for flush completion + # self._flush_complete.wait(timeout=1.0) + + except Exception as e: + logger.warning(f"Error during queue processing: {e}", file=sys.stderr) + + # Final shutdown of logger provider + self._logger_provider.shutdown() + + except Exception as e: + logger.warning(f"Error during shutdown: {e}", file=sys.stderr) + + @classmethod + def create( + cls, + service_name: str, + exporter: LogExporter, + development_mode: bool = False, + export_interval_ms: int = 1000, + ) -> "OTLPHandler": + """Factory method with environment-specific configurations.""" + if development_mode: + return cls( + service_name=service_name, + exporter=exporter, + max_queue_size=1000, + batch_size=50, + export_interval_ms=500, + ) + + return cls( + service_name=service_name, + exporter=exporter, + max_queue_size=MAX_QUEUE_SIZE, + batch_size=100, + export_interval_ms=export_interval_ms, + ) diff --git a/services/api/src/owl/utils/mcp/__init__.py b/services/api/src/owl/utils/mcp/__init__.py new file mode 100644 index 0000000..3ade0f3 --- /dev/null +++ b/services/api/src/owl/utils/mcp/__init__.py @@ -0,0 +1,35 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, FastAPI, Request +from fastapi.responses import ORJSONResponse + +from owl.types import UserAuth +from owl.utils.auth import auth_user_service_key +from owl.utils.exceptions import handle_exception +from owl.utils.mcp.helpers import _get_mcp_server +from owl.utils.mcp.server import MCP_TOOL_TAG # noqa: F401 + +router = APIRouter() + + +def get_mcp_router(app: FastAPI) -> APIRouter: + """Get the MCP router.""" + mcp_server = _get_mcp_server(app) + import owl.utils.mcp.custom_tools # noqa: F401 + + @router.get("/v1/mcp/http", summary="MCP Streamable HTTP endpoint") + @router.post("/v1/mcp/http", summary="MCP Streamable HTTP endpoint") + @handle_exception + async def mcp_streamable( + request: Request, + user: Annotated[UserAuth, Depends(auth_user_service_key)], + ) -> ORJSONResponse: + if request.method == "GET": + return await mcp_server.get() + return await mcp_server.post( + user=user, + body=await request.json(), + headers=dict(request.headers), + ) + + return router diff --git a/services/api/src/owl/utils/mcp/custom_tools.py b/services/api/src/owl/utils/mcp/custom_tools.py new file mode 100644 index 0000000..44c6a33 --- /dev/null +++ b/services/api/src/owl/utils/mcp/custom_tools.py @@ -0,0 +1,6 @@ +from owl.utils.mcp.helpers import mcp_tool + + +@mcp_tool +async def sum(a: int, b: int) -> int: + return a + b diff --git a/services/api/src/owl/utils/mcp/helpers.py b/services/api/src/owl/utils/mcp/helpers.py new file mode 100644 index 0000000..34f161f --- /dev/null +++ b/services/api/src/owl/utils/mcp/helpers.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING, Callable + +from owl.utils.mcp.server import MCPServer + +if TYPE_CHECKING: + from fastapi import FastAPI + +_mcp_singleton: MCPServer | None = None + + +def _get_mcp_server(app: "FastAPI") -> MCPServer: + global _mcp_singleton + if _mcp_singleton is None: + _mcp_singleton = MCPServer(app) + return _mcp_singleton + + +def mcp_tool(fn: Callable[..., object]) -> Callable[..., object]: + """ + Module-level decorator that forwards to MCPServer.tool + Must be used *after* `get_mcp_router` has been called once. + """ + if _mcp_singleton is None: + raise RuntimeError( + "MCP server not initialized yet. " + "Make sure get_mcp_router(...) is called before decorating." + ) + return _mcp_singleton.tool(fn) diff --git a/services/api/src/owl/utils/mcp/server.py b/services/api/src/owl/utils/mcp/server.py new file mode 100644 index 0000000..cdbc621 --- /dev/null +++ b/services/api/src/owl/utils/mcp/server.py @@ -0,0 +1,483 @@ +import asyncio +import inspect +from collections import defaultdict +from functools import cached_property +from typing import Any, Callable, get_type_hints +from urllib.parse import quote + +from fastapi import FastAPI +from fastapi.openapi.utils import get_openapi +from fastapi.responses import ORJSONResponse +from loguru import logger +from pydantic import BaseModel, ValidationError, create_model + +from jamaibase.types.db import RankedRole, UserAuth +from jamaibase.types.mcp import ( + CallToolRequest, + CallToolResult, + ErrorData, + Implementation, + InitializeResult, + JSONRPCEmptyResponse, + JSONRPCError, + JSONRPCErrorCode, + JSONRPCResponse, + ListToolsResult, + ServerCapabilities, + TextContent, + ToolAPI, + ToolAPIInfo, + ToolInputSchema, +) +from owl.client import JamaiASGIAsync +from owl.utils.auth import has_permissions +from owl.utils.exceptions import ( + BadInputError, + ForbiddenError, + JamaiException, + MethodNotAllowedError, + ResourceNotFoundError, +) +from owl.utils.handlers import INTERNAL_ERROR_MESSAGE + +MCP_TOOL_TAG = "mcp_tool" + + +class MCPServer: + _custom_tools: list[ToolAPI] = [] + _custom_callables: dict[str, Callable[..., Any]] = {} + _custom_models: dict[str, BaseModel] = {} + + def __init__( + self, + app: FastAPI, + *, + include_headers_in_input: bool = False, + ): + self.app = app + self.include_headers_in_input = include_headers_in_input + self.openapi_schema = get_openapi( + title=self.app.title, + version=self.app.version, + description=self.app.description, + routes=self.app.routes, + ) + self.init_result = InitializeResult( + capabilities=ServerCapabilities(), + serverInfo=Implementation( + name=self.app.title, + version=self.app.version, + ), + ) + self.client = JamaiASGIAsync(app=self.app) + _ = self.tools + + def tool(self, fn: Callable[..., Any]) -> Callable[..., Any]: + """ + Decorator that turns any sync/async function into an MCP tool. + Arguments are validated through a dynamically created Pydantic model. + """ + sig = inspect.signature(fn) + type_hints = get_type_hints(fn) + + # Build Pydantic model from signature + fields = {} + required = [] + for name, param in sig.parameters.items(): + annotation = type_hints.get(name, Any) + default = ... if param.default is inspect.Parameter.empty else param.default + fields[name] = (annotation, default) + if param.default is inspect.Parameter.empty: + required.append(name) + + Model = create_model(f"{fn.__name__}Schema", **fields) + + # Register + tool = ToolAPI( + name=fn.__name__, + description=fn.__doc__ or "", + inputSchema=ToolInputSchema( + properties=Model.model_json_schema()["properties"], + required=required, + ), + api_info=None, + ) + + self._custom_tools.append(tool) + self._custom_callables[fn.__name__] = fn + self._custom_models[fn.__name__] = Model + return fn + + @cached_property + def tools(self) -> list[ToolAPI]: + tools = [] + operation_ids = set() # Track operation IDs to detect duplicates + + # dump_json(openapi_schema, "openapi_schema.json") + paths: dict[str, Any] = self.openapi_schema.get("paths", {}) + if len(paths) == 0: + logger.warning("Failed to extract paths from OpenAPI schema.") + schemas: dict[str, Any] = self.openapi_schema.get("components", {}).get("schemas", {}) + if len(schemas) == 0: + logger.warning("Failed to extract schemas from OpenAPI schema.") + # Extract tools + for path, methods in paths.items(): + args_types: dict[str, str] = {} + for method, method_info in methods.items(): + tags = method_info.get("tags", []) + if MCP_TOOL_TAG not in tags: + continue + + # Check for duplicate operation IDs + operation_id = method_info.get("operationId", method_info.get("summary", "")) + assert operation_id not in operation_ids, ( + f"Duplicate operation ID found: '{operation_id}' in {method.upper()} {path}" + ) + operation_ids.add(operation_id) + + schema = { + "title": operation_id, + "type": "object", + "properties": {}, + "required": [], + } + # Process path and query parameters (optional headers) + parameters: dict[str, Any] = method_info.get("parameters", {}) + for param in parameters: + if param["in"] == "header" and not self.include_headers_in_input: + continue + param_name = param["name"] + schema["properties"][param_name] = param["schema"] + # self._add_schema_to_properties( + # properties=schema["properties"], + # schema=param["schema"], + # name=param_name, + # ) + if param.get("required", False): + schema["required"].append(param_name) + args_types[param_name] = param["in"] + # Process body + body_schema_ref: str = ( + method_info.get("requestBody", {}) + .get("content", {}) + .get("application/json", {}) + .get("schema", {}) + .get("$ref", "") + ) + body = schemas.get(body_schema_ref.replace("#/components/schemas/", ""), {}) + for param_name, param_schema in body.get("properties", {}).items(): + # Maybe need to resolve reference + if "$ref" in param_schema: + param_schema = schemas.get( + body_schema_ref.replace("#/components/schemas/", ""), {} + ) + schema["properties"][param_name] = param_schema + # self._add_schema_to_properties( + # properties=schema["properties"], + # schema=param_schema, + # name=param_name, + # ) + args_types[param_name] = "body" + schema["required"] += body.get("required", []) + # Create the tool definition + summary = method_info.get("summary", schema.get("title", "")).strip() + if not summary.endswith("."): + summary += "." + description = method_info.get("description", None) + description = summary if description is None else f"{summary}\n{description}" + if method_info.get("deprecated", False): + description += " (Deprecated)" + tool = ToolAPI( + name=schema["title"], + description=description, + inputSchema=ToolInputSchema( + properties=schema["properties"], + required=schema["required"], + ), + api_info=ToolAPIInfo( + path=path, + method=method, + args_types=args_types, + method_info=method_info, + ), + ) + # logger.info(f"{tool=}") + tools.append(tool) + return tools + + @cached_property + def tools_map(self) -> dict[str, ToolAPI]: + return {tool.name: tool for tool in self.tools} + + @cached_property + def permission_tool_map(self) -> dict[frozenset[str], list[ToolAPI]]: + # {frozenset(["system.models", "organization.models"]): [ToolAPI(name="list_models", ...)]} + tool_map = defaultdict(list) + for t in self.tools: + permissions = [ + permission + for permission in t.api_info.method_info["tags"] + if permission.startswith(("system", "organization", "project")) + ] + key = frozenset(permissions) + tool_map[key].append(t) + return tool_map + + def list_tools( + self, + *, + user: UserAuth, + ) -> ListToolsResult: + has_sys_membership = has_permissions(user, ["system"], raise_error=False) + has_org_membership = len(user.org_memberships) > 0 + has_proj_membership = len(user.proj_memberships) > 0 + + org_permission = ( + max([r.role.rank for r in user.org_memberships]) + if has_org_membership + else RankedRole.GUEST + ) # Guest has basically no permissions + proj_permission = ( + max([r.role.rank for r in user.proj_memberships]) + if has_proj_membership + else RankedRole.GUEST + ) # Guest has basically no permissions + tool_list: list[ToolAPI] = [] + for permissions, tools in self.permission_tool_map.items(): + if has_sys_membership and "system" in permissions: + tool_list.extend(tools) + elif has_org_membership and "organization" in permissions: + tool_list.extend(tools) + elif has_proj_membership and "project" in permissions: + tool_list.extend(tools) + else: + for permission in permissions: + if ( + permission.startswith(("system.", "organization.")) + and RankedRole[permission.split(".")[1]] <= org_permission + ): + tool_list.extend(tools) + break + elif ( + permission.startswith("project.") + and RankedRole[permission.split(".")[1]] <= proj_permission + ): + tool_list.extend(tools) + break + # include all custom tools + tool_list.extend(self._custom_tools) + return ListToolsResult(tools=tool_list) + + async def call_tool( + self, + body: CallToolRequest, + *, + headers: dict[str, Any] | None = None, + ) -> CallToolResult: + # Call custom tools + if body.params.name in self._custom_models: + return await self._call_custom_tool( + tool_name=body.params.name, + tool_args=body.params.arguments, + headers=headers, + ) + tool = self.tools_map.get(body.params.name, None) + if tool is None: + raise ResourceNotFoundError(f'Tool "{body.params.name}" is not found.') + # Call the tool + path = tool.api_info.path + args_types = tool.api_info.args_types + args = body.params.arguments + # Process parameters + query_params = None + body_params = None + if args is not None: + if headers is None: + headers = {} + query_params = {} + body_params = {} + for arg_name, arg_value in args.items(): + args_type = args_types.get(arg_name, "") + # Path parameters + if args_type == "path" and f"{{{arg_name}}}" in path: + path = path.replace(f"{{{arg_name}}}", quote(arg_value)) + # Headers + elif args_type == "header": + headers[arg_name] = arg_value + # Query parameters + elif args_type == "query": + query_params[arg_name] = arg_value + # Body parameters + elif args_type == "body": + body_params[arg_name] = arg_value + if len(headers) == 0: + headers = None + if len(query_params) == 0: + query_params = None + if len(body_params) == 0: + body_params = None + if body.params.name == "chat_completion" and len(body_params) != 0: + body_params["stream"] = False + + response = await self.client.request( + tool.api_info.method, + path, + headers=headers, + params=query_params, + body=body_params, + ) + return CallToolResult(content=[TextContent(text=response.text)]) + + async def _call_custom_tool( + self, + tool_name: str, + *, + tool_args: dict[str, Any] | None = None, + headers: dict[str, Any] | None = None, + ) -> CallToolResult: + Model = self._custom_models[tool_name] + fn = self._custom_callables[tool_name] + + try: + validated = Model.model_validate(tool_args) + except ValidationError as e: + raise BadInputError(errors=e.errors()) from e + + kwargs = validated.model_dump() + if asyncio.iscoroutinefunction(fn): + out = await fn(**kwargs) + else: + out = fn(**kwargs) + return CallToolResult(content=[TextContent(text=str(out))]) + + async def _handle_request( + self, + user: UserAuth, + body: dict[str, Any], + *, + headers: dict[str, Any] | None = None, + ) -> JSONRPCResponse | JSONRPCEmptyResponse | JSONRPCError | None: + request_id = body.get("id", "") + method = body.get("method", "") + try: + if method.startswith("notifications/"): + return None + elif method == "ping": + response = JSONRPCEmptyResponse( + id=request_id, + ) + elif method == "initialize": + response = JSONRPCResponse[InitializeResult]( + id=request_id, + result=self.init_result, + ) + elif method == "tools/list": + response = JSONRPCResponse[ListToolsResult]( + id=request_id, + result=self.list_tools(user=user), + ) + elif method == "tools/call": + body = CallToolRequest.model_validate(body) + response = JSONRPCResponse[CallToolResult]( + id=request_id, + result=await self.call_tool(body, headers=headers), + ) + else: + response = JSONRPCError( + id=request_id, + error=ErrorData( + code=JSONRPCErrorCode.METHOD_NOT_FOUND, + message=f'Method "{method}" is not supported.', + data=None, + ), + ) + except BadInputError as e: + response = JSONRPCError( + id=request_id, + error=ErrorData( + code=JSONRPCErrorCode.INVALID_PARAMS, + message=str(e), + data=None, + ), + ) + except ForbiddenError as e: + response = JSONRPCError( + id=request_id, + error=ErrorData( + code=JSONRPCErrorCode.FORBIDDEN, + message=str(e), + data=None, + ), + ) + except JamaiException as e: + response = JSONRPCError( + id=request_id, + error=ErrorData( + code=JSONRPCErrorCode.INVALID_REQUEST, + message=str(e), + data=None, + ), + ) + except ValidationError as e: + logger.error(f"Failed to parse JSON-RPC body: {repr(e)}") + response = JSONRPCError( + id=request_id, + error=ErrorData( + code=JSONRPCErrorCode.PARSE_ERROR, + message=str(e), + data=None, + ), + ) + except Exception as e: + logger.exception(f"Unexpected error: {repr(e)}") + response = JSONRPCError( + id=request_id, + error=ErrorData( + code=JSONRPCErrorCode.INTERNAL_ERROR, + message=INTERNAL_ERROR_MESSAGE, + data=None, + ), + ) + return response + + async def get(self): + """Return 405 for GET requests to /mcp""" + raise MethodNotAllowedError("SSE is not supported.") + + async def post( + self, + user: UserAuth, + body: dict[str, Any] | list[dict[str, Any]], + *, + headers: dict[str, Any] | None = None, + ) -> ORJSONResponse: + logger.debug("MCP request: {body}", body=body) + if isinstance(body, list): + response = [ + await self._handle_request(user=user, body=req, headers=headers) for req in body + ] + if any(r is None for r in response): + return ORJSONResponse( + status_code=202, + content={}, + media_type="application/json", + ) + else: + return ORJSONResponse( + status_code=200, + content=[ + r.model_dump(mode="json", by_alias=True, exclude_none=True) + for r in response + ], + media_type="application/json", + ) + else: + response = await self._handle_request(user=user, body=body, headers=headers) + if response is None: + return ORJSONResponse(status_code=202, content={}) + else: + return ORJSONResponse( + status_code=200, + content=response.model_dump(mode="json", by_alias=True, exclude_none=True), + media_type="application/json", + ) diff --git a/services/api/src/owl/utils/metrics.py b/services/api/src/owl/utils/metrics.py new file mode 100644 index 0000000..2fc7298 --- /dev/null +++ b/services/api/src/owl/utils/metrics.py @@ -0,0 +1,424 @@ +import re +from collections import defaultdict +from datetime import datetime, timedelta +from typing import Any, Sequence + +import httpx +from loguru import logger + +from owl.client import VictoriaMetricsAsync +from owl.types import Host, Metric, Usage, UsageResponse + +http_client = httpx.Client(timeout=5) + + +def filter_hostnames( + metric_sequence: Sequence[Metric], host_name_sequence: Sequence[str] +) -> list[Metric]: + return [metric for metric in metric_sequence if metric.hostname in host_name_sequence] + + +def group_metrics_by_hostname( + metrics: list[Metric], hostnames: list[str] | None = None +) -> list[Host]: + # If hostnames filter is provided, filter the metrics + if hostnames: + metrics = filter_hostnames(metrics, hostnames) + + # Group metrics by hostname + hostname_dict = defaultdict(list) + for metric in metrics: + hostname_dict[metric.hostname].append(metric) + + # Create list of hosts + hosts = [ + Host(name=hostname, metrics=metric_list) for hostname, metric_list in hostname_dict.items() + ] + + return hosts + + +class Telemetry(VictoriaMetricsAsync): + __QUERY_ENDPOINT = "/vm/prometheus/api/v1/query" + __QUERY_RANGE_ENDPOINT = "/vm/prometheus/api/v1/query_range" + + # __METRIC_QUERY_TUPLE = ( + # "label_set(sum(rate(container_cpu_usage_seconds_total{container!='',}[1m])) by (instance,job) / sum(machine_cpu_cores{}) by (instance,job) * 100, '__name__', 'cpu_util')", + # "label_set(sum(container_memory_working_set_bytes{container!='',}) by (instance,job) / sum(container_spec_memory_limit_bytes{container!='',}) by (instance,job) * 100,'__name__','memory_util')", + # "label_set(sum(rate(container_fs_reads_bytes_total{container!='',}[15s])) by (instance,job),'__name__','disk_read_bytes')", + # "label_set(sum(rate(container_fs_writes_bytes_total{container!='',}[15s])) by (instance,job),'__name__','disk_write_bytes')", + # "label_set(sum(rate(container_network_receive_bytes_total{container!='',}[15s])) by (instance,job),'__name__','network_receive_bytes')", + # "label_set(sum(rate(container_network_transmit_bytes_total{container!='',}[15s])) by (instance,job),'__name__','network_transmit_bytes')", + # "sort(topk(1, gpu_clock{clock_type='GPU_CLOCK_TYPE_SYSTEM',}))", + # "gpu_clock{clock_type='GPU_CLOCK_TYPE_MEMORY',}", + # "gpu_edge_temperature{}", + # "gpu_memory_temperature{}", + # "gpu_power_usage{}", + # "gpu_gfx_activity{}", + # "gpu_umc_activity{}", + # "gpu_free_vram{}", + # "used_memory{}", + # "DCGM_FI_DEV_SM_CLOCK{}", + # "DCGM_FI_DEV_MEM_CLOCK{}", + # "DCGM_FI_DEV_GPU_TEMP{}", + # "DCGM_FI_DEV_MEMORY_TEMP{}", + # "DCGM_FI_DEV_POWER_USAGE{}", + # "DCGM_FI_DEV_GPU_UTIL{}", + # "DCGM_FI_DEV_MEM_COPY_UTIL{}", + # "DCGM_FI_DEV_FB_FREE{}", + # "DCGM_FI_DEV_FB_USED{}", + # ) + + def _construct_metrics_query(self, queries: list[str]) -> str: + """Construct a metrics retrieval query string. + + Args: + queries (list[str]): A list of fields to query. + + Returns: + str: The constructed query string. + """ + return "union(" + ", ".join(f"{i}" for i in queries) + ")" + + def _construct_usage_query( + self, + range_func: str, + aggregate_func: str, + subject_id: str, + query_filter: list[str], + group_by: list[str], + window_size: str, + ) -> str: + """Construct a usage retrieval query string. + + Args: + range_func (str): The range function to use for the query, ex: max_over_time, increase_pure, etc.. + aggregate_func (str): The aggregate function to use for the query, ex: max, sum, etc.. + subject_id (str): The metric ID to query from, ex: owl_spent_total, owl_llm_token_usage_total, etc.. + group_by (list[str]): The group by fields for the query, ex: ["org_id", "proj_id"], etc.. + window_size (str): The window size to use for the query. + + Returns: + str: The constructed query string. + """ + return f"{aggregate_func}({range_func}({subject_id}{{{','.join(query_filter)}}}[{window_size}])) by ({', '.join(group_by)})" + + def _process_metrics(self, response: list[dict[str, Any]]) -> list[Metric]: + """Process the metrics received from response. + + Args: + response (list[dict[str, Any]]): JSON data from metrics provider. + + Returns: + list[Metric]: A list of processed metrics. + """ + metrics = [] + for metric in response: + try: + # Ensure "metric" key exists + if "metric" not in metric or not isinstance(metric["metric"], dict): + raise KeyError('"metric" key is missing or not a dictionary') + + # Safely retrieve the "__name__" field + metric_name = metric["metric"].get("__name__") + if metric_name == "gpu_clock": + # Safely retrieve the "clock_type" field + clock_type = metric["metric"].get("clock_type") + if clock_type == "GPU_CLOCK_TYPE_MEMORY": + metric["metric"]["__name__"] = "gpu_memory_clock" + + # Process the metric + metrics.append(Metric.from_response(metric)) + except (KeyError, TypeError) as e: + # Log the error and skip the problematic metric + logger.warning( + f"Skipping metric due to missing fields or invalid structure: {metric}. Error: {e}" + ) + continue + return metrics + + def _parse_duration(self, duration_str: str) -> timedelta: + """Parse a duration string into a timedelta object. + + The duration string is expected to be in the format of a sequence of + decimal numbers followed by a unit character. The unit characters + supported are 'ms', 's', 'm', 'h', 'd', 'w', 'y', which represent + milliseconds, seconds, minutes, hours, days, weeks, and years, + respectively. + + Args: + duration_str (str): The duration string to parse. + + Returns: + timedelta: The parsed timedelta object. + """ + pattern = r"(?P\d+)(?P[smhdwy])" + matches = re.findall(pattern, duration_str) + + delta = timedelta() + unit_multipliers = { + "ms": timedelta(milliseconds=1), + "s": timedelta(seconds=1), + "m": timedelta(minutes=1), + "h": timedelta(hours=1), + "d": timedelta(days=1), + "w": timedelta(weeks=1), + "y": timedelta(days=365), + } + + for value, unit in matches: + if unit == "ms": + delta += int(value) * unit_multipliers[unit] + else: + delta += int(value) * unit_multipliers[unit] + + return delta + + async def query_metrics( + self, + queries: list[str] | None = None, + hostnames: list[str] | None = None, + ) -> list[Host]: + """Retrieve the latest metrics from VictoriaMetrics. + + Args: + queries (list[str] | None, optional): A list of fields to query. Defaults to None (which means self.__METRIC_QUERY_TUPLE will be used). + hostnames (list[str] | None, optional): A list of hostnames to filter the results. If None, no filtering will be applied. + + Returns: + list[Host]: A list of Host(s) each contains a name and list[Metric]. + """ + queries = queries or self.__METRIC_QUERY_TUPLE + logger.info(self._construct_metrics_query(queries)) + params = {"query": self._construct_metrics_query(queries)} + response = await self._fetch_victoria_metrics(self.__QUERY_ENDPOINT, params) + response = response.json()["data"]["result"] + + if not response: + return [] + + metrics = self._process_metrics(response) + return group_metrics_by_hostname(metrics, hostnames) + + def _process_usage( + self, + usage: list[dict[str, Any]], + data_interval: timedelta, + group_by: list[str], + ) -> list[Usage]: + """Process usage data into a list of Usage objects. + + Args: + usage (list[dict[str, Any]]): The raw usage data from the query. + data_interval (timedelta): The data interval to adjust the window range. + group_by (list[str]): The group-by fields for the query. + + Returns: + list[Usage]: a list of the usage metrics. + """ + return [ + Usage.from_result(value, result["metric"], data_interval, group_by) + for result in usage + for value in result["values"] + ] + + async def query_usage( + self, + range_func: str, + aggregate_func: str, + subject_id: str, + filtered_by_org_id: list[str] | None, + filtered_by_proj_id: list[str] | None, + from_: datetime, + to: datetime | None, + group_by: list[str], + window_size: str, + timeout_value: int = 5, + ) -> UsageResponse: + """ + Query VictoriaMetrics/Prometheus for usage metrics. + + Args: + range_func (str): The range function to use for the query, e.g., max_over_time, increase_pure, etc. + aggregate_func (str): The aggregate function to use for the query, e.g., max, sum, etc. + subject_id (str): The metric ID to query from, e.g., owl_spent_total, owl_llm_token_usage_total, etc. + filtered_by_org_id (list[str] | None): The organization IDs to filter by. None means no filtering. + filtered_by_proj_id (list[str] | None): The project IDs to filter by. None means no filtering. + from_ (datetime): The start time of the query. + to (datetime | None): The end time of the query. + group_by (list[str]): The group-by fields for the query, e.g., ["org_id", "proj_id"], etc. + window_size (str): The window size to use for the query. + timeout_value (int, optional): The timeout value in seconds. Defaults to 5. + + Returns: + UsageResponse: A response containing windowSize and a list of the usage metrics. + """ + # if "organization_id" in group_by: + # group_by.remove("organization_id") + # if "project_id" in group_by: + # group_by.remove("project_id") + # group_by.append("proj_id") + group_by = list(set(["org_id"] + group_by)) + query_filter = [ + "service.name=~'(owl|starling)'" + ] # always filter service by owl or starling + if filtered_by_org_id: + query_filter.append(f"org_id=~'{'|'.join(filtered_by_org_id)}'") + if filtered_by_proj_id: + query_filter.append( + f"proj_id=~'{'|'.join(filtered_by_proj_id)}'" + ) # Update to proj_id to align with Clickhouse Column + + # Convert datetime to Prometheus timestamp format + data_interval = self._parse_duration(window_size) + # Query VictoriaMetrics/Prometheus + # In VictoriaMetrics/Prometheus max_over_time and increase are rollup functions, + # which calculate the value over raw samples on the given lookbehind window d per each time series returned from the given series_selector. + # Example: start time 2024-12-01 with step 1d means data is from 2024-11-30 to 2024-12-01. + # Thus, the window_start is 2024-11-30 and window_end is 2024-12-01. + # During the query, we add data_interval to start_time (so that the first datapoint is [2024-12-01, 2024-12-02]). + # Otherwise, the first datetime will be [2024-11-30, 2024-12-01], which is not what we want. + params = { + "query": self._construct_usage_query( + range_func, aggregate_func, subject_id, query_filter, group_by, window_size + ), + "start": (from_ + data_interval).timestamp(), + "end": to.timestamp() if to else None, + "step": window_size, + "timeout": timeout_value, + } + response = await self._fetch_victoria_metrics(self.__QUERY_RANGE_ENDPOINT, params) + response = response.json()["data"]["result"] + + return UsageResponse( + windowSize=window_size, + data=self._process_usage(response, data_interval, group_by), + start=(from_ + data_interval).strftime("%Y-%m-%dT%H:%M:%SZ") if from_ else {}, + end=to.strftime("%Y-%m-%dT%H:%M:%SZ") if to else {}, + ) + + async def query_llm_usage( + self, + filtered_by_org_id: list[str] | None, + filtered_by_proj_id: list[str] | None, + from_: datetime, + to: datetime | None, + group_by: list[str], + window_size: str, + ) -> UsageResponse: + return await self.query_usage( + "increase_pure", + "sum", + "llm_token_usage", + filtered_by_org_id, + filtered_by_proj_id, + from_, + to, + group_by, + window_size, + ) + + async def query_embedding_usage( + self, + filtered_by_org_id: list[str] | None, + filtered_by_proj_id: list[str] | None, + from_: datetime, + to: datetime | None, + group_by: list[str], + window_size: str, + ) -> UsageResponse: + return await self.query_usage( + "increase_pure", + "sum", + "embedding_token_usage", + filtered_by_org_id, + filtered_by_proj_id, + from_, + to, + group_by, + window_size, + ) + + async def query_reranking_usage( + self, + filtered_by_org_id: list[str] | None, + filtered_by_proj_id: list[str] | None, + from_: datetime, + to: datetime | None, + group_by: list[str], + window_size: str, + ) -> UsageResponse: + return await self.query_usage( + "increase_pure", + "sum", + "reranker_search_usage", + filtered_by_org_id, + filtered_by_proj_id, + from_, + to, + group_by, + window_size, + ) + + async def query_billing( + self, + filtered_by_org_id: list[str] | None, + filtered_by_proj_id: list[str] | None, + from_: datetime, + to: datetime | None, + group_by: list[str], + window_size: str, + ) -> UsageResponse: + return await self.query_usage( + "increase_pure", + "sum", + "spent", + filtered_by_org_id, + filtered_by_proj_id, + from_, + to, + group_by, + window_size, + ) + + def query_bandwidth( + self, + filtered_by_org_id: list[str] | None, + filtered_by_proj_id: list[str] | None, + from_: datetime, + to: datetime | None, + group_by: list[str], + window_size: str, + ) -> UsageResponse: + return self.query_usage( + "increase_pure", + "sum", + "bandwidth_usage", + filtered_by_org_id, + filtered_by_proj_id, + from_, + to, + group_by, + window_size, + ) + + def query_storage( + self, + filtered_by_org_id: list[str] | None, + filtered_by_proj_id: list[str] | None, + from_: datetime, + to: datetime | None, + group_by: list[str], + window_size: str, + ) -> UsageResponse: + return self.query_usage( + "max_over_time", + "max", + "storage_usage", + filtered_by_org_id, + filtered_by_proj_id, + from_, + to, + group_by, + window_size, + ) diff --git a/services/api/src/owl/utils/openapi.py b/services/api/src/owl/utils/openapi.py deleted file mode 100644 index 10707c9..0000000 --- a/services/api/src/owl/utils/openapi.py +++ /dev/null @@ -1,6 +0,0 @@ -from fastapi.routing import APIRoute - - -def custom_generate_unique_id(route: APIRoute): - # return f"{route.tags[0]}-{route.name}" - return f"{route.name}" diff --git a/services/api/src/owl/utils/responses.py b/services/api/src/owl/utils/responses.py deleted file mode 100644 index b4a95ca..0000000 --- a/services/api/src/owl/utils/responses.py +++ /dev/null @@ -1,360 +0,0 @@ -from typing import Mapping - -from fastapi import Request, status -from fastapi.responses import ORJSONResponse -from loguru import logger -from starlette.exceptions import HTTPException - -from jamaibase.exceptions import JamaiException - -INTERNAL_ERROR_MESSAGE = "Opss sorry we ran into an unexpected error. Please try again later." - - -def make_request_log_str(request: Request, status_code: int) -> str: - """ - Generate a string for logging, given a request object and an HTTP status code. - - Args: - request (Request): Starlette request object. - status_code (int): HTTP error code. - - Returns: - str: A string in the format - ' - " " ' - """ - query = request.url.query - query = f"?{query}" if query else "" - org_id = "" - project_id = "" - try: - org_id = request.state.org_id - project_id = request.state.project_id - except Exception: - pass - return ( - f"{request.state.id} - " - f'"{request.method} {request.url.path}{query}" {status_code} - ' - f"org_id={org_id} project_id={project_id}" - ) - - -def make_response( - request: Request, - message: str, - error: str, - status_code: int, - *, - detail: str | None = None, - exception: Exception | None = None, - headers: Mapping[str, str] | None = None, - log: bool = True, -) -> ORJSONResponse: - """ - Create a Response object. - - Args: - request (Request): Starlette request object. - message (str): User-friendly error message to be displayed by frontend or SDK. - error (str): Short error name. - status_code (int): HTTP error code. - detail (str | None, optional): Error message with potentially more details. - Defaults to None (message + headers). - exception (Exception | None, optional): Exception that occurred. Defaults to None. - headers (Mapping[str, str] | None, optional): Response headers. Defaults to None. - log (bool, optional): Whether to log the response. Defaults to True. - - Returns: - response (ORJSONResponse): Response object. - """ - if detail is None: - detail = f"{message}\nException:{repr(exception)}" - request_headers = dict(request.headers) - if "authorization" in request_headers: - request_headers["authorization"] = ( - f'{request_headers["authorization"][:2]}*****{request_headers["authorization"][-1:]}' - ) - response = ORJSONResponse( - status_code=status_code, - content={ - "object": "error", - "error": error, - "message": message, - "detail": detail, - "request_id": request.state.id, - "exception": exception.__class__.__name__ if exception else None, - "headers": request_headers, - }, - headers=headers, - ) - mssg = make_request_log_str(request, response.status_code) - if not log: - return response - if status_code == 500: - log_fn = logger.exception - elif status_code > 500: - log_fn = logger.warning - elif exception is None: - log_fn = logger.info - elif isinstance(exception, (JamaiException, HTTPException)): - log_fn = logger.info - else: - log_fn = logger.warning - if exception: - log_fn(f"{mssg} - {exception.__class__.__name__}: {exception}") - else: - log_fn(mssg) - return response - - -def unauthorized_response( - request: Request, - message: str, - *, - detail: str | None = None, - error: str = "unauthorized", - exception: Exception | None = None, - headers: Mapping[str, str] | None = None, -) -> ORJSONResponse: - """ - HTTP 401. - The client should provide or correct their authentication information. - Often used when a user is not logged in or their session has expired. - - Args: - request (Request): Starlette request object. - message (str): User-friendly error message to be displayed by frontend or SDK. - detail (str | None, optional): Error message with potentially more details. - Defaults to None (message + headers). - error (str, optional): Short error name. Defaults to "unauthorized". - exception (Exception | None, optional): Exception that occurred. Defaults to None. - headers (Mapping[str, str] | None, optional): Response headers. Defaults to None. - - Returns: - response (ORJSONResponse): Response object. - """ - return make_response( - request=request, - message=message, - error=error, - status_code=status.HTTP_401_UNAUTHORIZED, - detail=detail, - exception=exception, - headers=headers, - ) - - -def forbidden_response( - request: Request, - message: str, - *, - detail: str | None = None, - error: str = "forbidden", - exception: Exception | None = None, - headers: Mapping[str, str] | None = None, -) -> ORJSONResponse: - """ - HTTP 403. - The client does not have access rights to the content. - Authentication will not help, as the client is not allowed to perform the requested action. - - Args: - request (Request): Starlette request object. - message (str): User-friendly error message to be displayed by frontend or SDK. - detail (str | None, optional): Error message with potentially more details. - Defaults to None (message + headers). - error (str, optional): Short error name. Defaults to "forbidden". - exception (Exception | None, optional): Exception that occurred. Defaults to None. - headers (Mapping[str, str] | None, optional): Response headers. Defaults to None. - - Returns: - response (ORJSONResponse): Response object. - """ - return make_response( - request=request, - message=message, - error=error, - status_code=status.HTTP_403_FORBIDDEN, - detail=detail, - exception=exception, - headers=headers, - ) - - -def resource_not_found_response( - request: Request, - message: str, - *, - detail: str | None = None, - error: str = "resource_not_found", - exception: Exception | None = None, - headers: Mapping[str, str] | None = None, -) -> ORJSONResponse: - """ - HTTP 404. - The server can not find the requested resource. - - Args: - request (Request): Starlette request object. - message (str): User-friendly error message to be displayed by frontend or SDK. - detail (str | None, optional): Error message with potentially more details. - Defaults to None (message + headers). - error (str, optional): Short error name. Defaults to "resource_not_found". - exception (Exception | None, optional): Exception that occurred. Defaults to None. - headers (Mapping[str, str] | None, optional): Response headers. Defaults to None. - - Returns: - response (ORJSONResponse): Response object. - """ - return make_response( - request=request, - message=message, - error=error, - status_code=status.HTTP_404_NOT_FOUND, - detail=detail, - exception=exception, - headers=headers, - ) - - -def resource_exists_response( - request: Request, - message: str, - *, - detail: str | None = None, - error: str = "resource_exists", - exception: Exception | None = None, - headers: Mapping[str, str] | None = None, -) -> ORJSONResponse: - """ - HTTP 409. - The request cannot be processed because it conflicts with the current state of the resource. - - Args: - request (Request): Starlette request object. - message (str): User-friendly error message to be displayed by frontend or SDK. - detail (str | None, optional): Error message with potentially more details. - Defaults to None (message + headers). - error (str, optional): Short error name. Defaults to "resource_exists". - exception (Exception | None, optional): Exception that occurred. Defaults to None. - headers (Mapping[str, str] | None, optional): Response headers. Defaults to None. - - Returns: - response (ORJSONResponse): Response object. - """ - return make_response( - request=request, - message=message, - error=error, - status_code=status.HTTP_409_CONFLICT, - detail=detail, - exception=exception, - headers=headers, - ) - - -def bad_input_response( - request: Request, - message: str, - *, - detail: str | None = None, - error: str = "bad_input", - exception: Exception | None = None, - headers: Mapping[str, str] | None = None, -) -> ORJSONResponse: - """ - HTTP 422. - The request contains errors and cannot be processed. - - Args: - request (Request): Starlette request object. - message (str): User-friendly error message to be displayed by frontend or SDK. - detail (str | None, optional): Error message with potentially more details. - Defaults to None (message + headers). - error (str, optional): Short error name. Defaults to "bad_input". - exception (Exception | None, optional): Exception that occurred. Defaults to None. - headers (Mapping[str, str] | None, optional): Response headers. Defaults to None. - - Returns: - response (ORJSONResponse): Response object. - """ - return make_response( - request=request, - message=message, - error=error, - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=detail, - exception=exception, - headers=headers, - ) - - -def internal_server_error_response( - request: Request, - message: str = INTERNAL_ERROR_MESSAGE, - *, - detail: str | None = None, - error: str = "unexpected_error", - exception: Exception | None = None, - headers: Mapping[str, str] | None = None, -) -> ORJSONResponse: - """ - HTTP 500. - The server encountered an unexpected condition that prevented it from fulfilling the request. - - Args: - request (Request): Starlette request object. - message (str): User-friendly error message to be displayed by frontend or SDK. - detail (str | None, optional): Error message with potentially more details. - Defaults to None (message + headers). - error (str, optional): Short error name. Defaults to "unexpected_error". - exception (Exception | None, optional): Exception that occurred. Defaults to None. - headers (Mapping[str, str] | None, optional): Response headers. Defaults to None. - - Returns: - response (ORJSONResponse): Response object. - """ - return make_response( - request=request, - message=message, - error=error, - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=detail, - exception=exception, - headers=headers, - ) - - -def server_busy_response( - request: Request, - message: str, - *, - detail: str | None = None, - error: str = "busy", - exception: Exception | None = None, - headers: Mapping[str, str] | None = None, -) -> ORJSONResponse: - """ - HTTP 503. - The server is currently unable to handle the request due to a temporary overloading or maintenance. - - Args: - request (Request): Starlette request object. - message (str): User-friendly error message to be displayed by frontend or SDK. - detail (str | None, optional): Error message with potentially more details. - Defaults to None (message + headers). - error (str, optional): Short error name. Defaults to "busy". - exception (Exception | None, optional): Exception that occurred. Defaults to None. - headers (Mapping[str, str] | None, optional): Response headers. Defaults to None. - - Returns: - response (ORJSONResponse): Response object. - """ - return make_response( - request=request, - message=message, - error=error, - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=detail, - exception=exception, - headers=headers, - ) diff --git a/services/api/src/owl/utils/tasks.py b/services/api/src/owl/utils/tasks.py deleted file mode 100644 index 4df67bc..0000000 --- a/services/api/src/owl/utils/tasks.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -Repeated tasks implemented by dmontagu - -https://github.com/dmontagu/fastapi-utils/blob/3ef27a6f67ac10fae6a8b4816549c0c44567a451/fastapi_utils/tasks.py -""" - -from __future__ import annotations - -import asyncio -import logging -import time -from asyncio import ensure_future -from functools import wraps -from time import perf_counter -from traceback import format_exception -from typing import Any, Callable, Coroutine, Union - -from starlette.concurrency import run_in_threadpool - -NoArgsNoReturnFuncT = Callable[[], None] -NoArgsNoReturnAsyncFuncT = Callable[[], Coroutine[Any, Any, None]] -NoArgsNoReturnDecorator = Callable[ - [Union[NoArgsNoReturnFuncT, NoArgsNoReturnAsyncFuncT]], NoArgsNoReturnAsyncFuncT -] - - -def repeat_every( - *, - seconds: float, - wait_first: bool = False, - logger: logging.Logger | None = None, - raise_exceptions: bool = False, - max_repetitions: int | None = None, -) -> NoArgsNoReturnDecorator: - """ - This function returns a decorator that modifies a function so it is periodically re-executed after its first call. - - The function it decorates should accept no arguments and return nothing. If necessary, this can be accomplished - by using `functools.partial` or otherwise wrapping the target function prior to decoration. - - Parameters - ---------- - seconds: float - The number of seconds to wait between repeated calls - wait_first: bool (default False) - If True, the function will wait for a single period before the first call - logger: Optional[logging.Logger] (default None) - The logger to use to log any exceptions raised by calls to the decorated function. - If not provided, exceptions will not be logged by this function (though they may be handled by the event loop). - raise_exceptions: bool (default False) - If True, errors raised by the decorated function will be raised to the event loop's exception handler. - Note that if an error is raised, the repeated execution will stop. - Otherwise, exceptions are just logged and the execution continues to repeat. - See https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.set_exception_handler for more info. - max_repetitions: Optional[int] (default None) - The maximum number of times to call the repeated function. If `None`, the function is repeated forever. - """ - - def decorator( - func: NoArgsNoReturnAsyncFuncT | NoArgsNoReturnFuncT, - ) -> NoArgsNoReturnAsyncFuncT: - """ - Converts the decorated function into a repeated, periodically-called version of itself. - """ - is_coroutine = asyncio.iscoroutinefunction(func) - - @wraps(func) - async def wrapped() -> None: - repetitions = 0 - - async def loop() -> None: - nonlocal repetitions - if wait_first: - await asyncio.sleep(seconds) - while max_repetitions is None or repetitions < max_repetitions: - try: - if is_coroutine: - await func() # type: ignore - else: - await run_in_threadpool(func) - repetitions += 1 - except Exception as exc: - if logger is not None: - formatted_exception = "".join( - format_exception(type(exc), exc, exc.__traceback__) - ) - logger.error(formatted_exception) - if raise_exceptions: - raise exc - await asyncio.sleep(seconds) - - ensure_future(loop()) - - return wrapped - - return decorator - - -def repeat_every_blocking( - *, - seconds: float, - wait_first: bool = False, -): - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - if wait_first: - time.sleep(seconds) - while True: - t0 = perf_counter() - func(*args, **kwargs) - t = perf_counter() - t0 - time.sleep(max(0, seconds - t)) - - return wrapper - - return decorator diff --git a/services/api/src/owl/utils/test.py b/services/api/src/owl/utils/test.py new file mode 100644 index 0000000..997a66a --- /dev/null +++ b/services/api/src/owl/utils/test.py @@ -0,0 +1,1084 @@ +import os +from collections import defaultdict +from contextlib import contextmanager +from datetime import datetime +from functools import lru_cache +from os.path import basename, join +from typing import Any, Generator, Self, TypeVar + +from loguru import logger +from pydantic import BaseModel, model_validator + +from jamaibase import JamAI +from jamaibase.types import ( + ActionTableSchemaCreate, + CellCompletionResponse, + CellReferencesResponse, + ChatCompletionChunkResponse, + ChatCompletionResponse, + ChatCompletionUsage, + ChatTableSchemaCreate, + ColumnSchemaCreate, + ConversationCreateRequest, + ConversationMetaResponse, + DeploymentCreate, + DeploymentRead, + FileUploadResponse, + KnowledgeTableSchemaCreate, + LLMGenConfig, + ModelConfigCreate, + ModelConfigRead, + MultiRowAddRequest, + MultiRowCompletionResponse, + MultiRowRegenRequest, + OkResponse, + OrganizationCreate, + OrganizationRead, + Page, + PasswordLoginRequest, + PricePlanCreate, + PricePlanRead, + Products, + ProjectCreate, + ProjectRead, + References, + RowCompletionResponse, + StripePaymentInfo, + TableDataImportRequest, + TableMetaResponse, + UserCreate, + UserRead, +) +from owl.configs import ENV_CONFIG +from owl.db.models import BASE_PLAN_ID +from owl.types import CloudProvider, ModelCapability, ModelType, TableType +from owl.utils.crypt import generate_key +from owl.utils.dates import utc_iso_from_uuid7_draft2 + +EMAIL = "carl@up.com" +DS_PARAMS = dict(argvalues=["clickhouse", "victoriametrics"], ids=["ch", "vm"]) + + +def get_file_map(test_file_dir: str) -> dict[str, str]: + _files = [join(root, f) for root, _, files in os.walk(test_file_dir) for f in files] + file_map = {basename(f): f for f in _files} + if not len(_files) == len(file_map): + raise ValueError(f'There are duplicate file names in "{test_file_dir}"') + return file_map + + +@contextmanager +def register_password( + body: dict[str, Any], + *, + token: str = ENV_CONFIG.service_key_plain, +): + user = JamAI(token=token).auth.register_password(UserCreate(**body)) + try: + assert isinstance(user, UserRead) + assert user.email == body["email"] + assert user.name == body["name"] + if "password" in body: + assert user.password_hash == "***" + else: + assert user.password_hash is None + yield user + finally: + try: + JamAI(user_id=user.id, token=token).users.delete_user() + except Exception as e: + logger.error(f"User cleanup failed: {repr(e)}") + + +@contextmanager +def create_plan( + body: dict[str, Any], + *, + user_id: str, + token: str = ENV_CONFIG.service_key_plain, +): + client = JamAI(user_id=user_id, token=token) + plan = client.prices.create_price_plan(body) + try: + yield plan + finally: + client.prices.delete_price_plan(plan.id, missing_ok=True) + + +@contextmanager +def create_user( + body: dict[str, Any] | None = None, + *, + token: str = ENV_CONFIG.service_key_plain, +): + if body is None: + body = dict(email=EMAIL, name="System Admin") + user = JamAI(token=token).users.create_user(UserCreate(**body)) + try: + assert isinstance(user, UserRead) + assert user.email == body["email"] + assert user.name == body["name"] + if "password" in body: + assert user.password_hash == "***", f"{user.password_hash=}" + # Test password login + user = JamAI(token=token).auth.login_password( + PasswordLoginRequest(email=body["email"], password=body["password"]) + ) + assert isinstance(user, UserRead) + else: + assert user.password_hash is None + yield user + finally: + try: + JamAI(user_id=user.id, token=token).users.delete_user() + except Exception as e: + logger.error(f"User cleanup failed: {repr(e)}") + + +@contextmanager +def create_organization( + body: OrganizationCreate | dict | None = None, + *, + user_id: str, + token: str = ENV_CONFIG.service_key_plain, + subscribe_plan: bool = True, +): + client = JamAI(user_id=user_id, token=token) + if body is None: + body = OrganizationCreate(name="Clubhouse") + # Create org + org = client.organizations.create_organization(body) + try: + assert isinstance(org, OrganizationRead) + assert org.created_by == user_id, f"{org.created_by=}, {user_id=}" + # Try to create price plan + if ENV_CONFIG.is_cloud: + plans = client.prices.list_price_plans() + if plans.total <= 1: + client.prices.create_price_plan( + PricePlanCreate( + id="pro", + name="Pro plan", + stripe_price_id_live="price_223", + stripe_price_id_test="price_1RT2EdCcpbd72IcYeAFWrbxw", + flat_cost=25.0, + credit_grant=15.0, + max_users=None, + products=Products.unlimited(), + ) + ) + client.prices.create_price_plan( + PricePlanCreate( + id="team", + name="Team plan", + stripe_price_id_live="price_323", + stripe_price_id_test="price_1RT2FfCcpbd72IcYPGIGyXmj", + flat_cost=250.0, + credit_grant=150.0, + max_users=None, + products=Products.unlimited(), + ) + ) + base_plan = next((p for p in plans.items if p.id == BASE_PLAN_ID), None) + assert isinstance(base_plan, PricePlanRead) + assert base_plan.flat_cost == 0.0 + if subscribe_plan and org.price_plan_id is None: + response = client.organizations.subscribe_plan(org.id, base_plan.id) + assert isinstance(response, StripePaymentInfo) + org.price_plan_id = base_plan.id + org.price_plan = base_plan + response = JamAI(user_id="0", token=token).organizations.set_credit_grant( + organization_id=org.id, amount=150 + ) + assert isinstance(response, OkResponse) + if isinstance(body, BaseModel): + body = body.model_dump() + assert org.name == body["name"] + yield org + finally: + try: + client.organizations.delete_organization(org.id) + except Exception as e: + logger.error(f"Organization cleanup failed: {repr(e)}") + + +@contextmanager +def create_project( + body: dict[str, Any] | None = None, + *, + user_id: str = "0", + organization_id: str = "0", + token: str = ENV_CONFIG.service_key_plain, +): + client = JamAI(user_id=user_id, token=token) + if body is None: + body = dict(name="Mickey 17") + body["organization_id"] = organization_id + project = client.projects.create_project(ProjectCreate(**body)) + try: + assert isinstance(project, ProjectRead) + assert project.created_by == user_id, f"{project.created_by=}, {user_id=}" + assert project.name.startswith(body["name"]) + yield project + finally: + try: + client.projects.delete_project(project.id) + except Exception as e: + logger.error(f"Project cleanup failed: {repr(e)}") + + +@contextmanager +def create_model_config( + body: ModelConfigCreate, + *, + user_id: str = "0", + token: str = ENV_CONFIG.service_key_plain, +): + client = JamAI(user_id=user_id, token=token) + model = client.models.create_model_config(body) + try: + assert isinstance(model, ModelConfigRead) + yield model + finally: + try: + client.models.delete_model_config(model.id) + except Exception as e: + logger.error(f"Model cleanup failed: {repr(e)}") + + +@contextmanager +def create_deployment( + body: DeploymentCreate | dict, + *, + user_id: str = "0", + token: str = ENV_CONFIG.service_key_plain, +): + client = JamAI(user_id=user_id, token=token) + deployment = client.models.create_deployment(body) + try: + assert isinstance(deployment, DeploymentRead) + yield deployment + finally: + try: + client.models.delete_deployment(deployment.id) + except Exception as e: + logger.error(f"Deployment cleanup failed: {repr(e)}") + + +class OrgContext(BaseModel): + superuser: UserRead + user: UserRead + superorg: OrganizationRead + org: OrganizationRead + + +@contextmanager +def setup_organizations(): + with ( + create_user() as superuser, + create_user(dict(email=f"russell-{generate_key(8)}@up.com", name="User")) as user, + ): + assert user.id != "0" + with ( + create_organization( + OrganizationCreate(name="System"), user_id=superuser.id + ) as superorg, + create_organization(OrganizationCreate(name="Clubhouse"), user_id=user.id) as org, + ): + assert superorg.id == "0" + assert org.id != "0" + yield OrgContext(superuser=superuser, user=user, superorg=superorg, org=org) + + +class ProjectContext(OrgContext): + projects: list[ProjectRead] + + +@contextmanager +def setup_projects(): + with setup_organizations() as ctx: + with ( + create_project(user_id=ctx.superuser.id, organization_id=ctx.superorg.id) as p0, + create_project(user_id=ctx.user.id, organization_id=ctx.org.id) as p1, + ): + assert p0.organization_id == ctx.superorg.id + assert p1.organization_id == ctx.org.id + # Using `**model_dump()` leads to serialization warnings + yield ProjectContext( + projects=[p0, p1], + superuser=ctx.superuser, + user=ctx.user, + superorg=ctx.superorg, + org=ctx.org, + ) + + +SMOL_LM2_CONFIG = ModelConfigCreate( + id="ellm/smollm2:135m", + name="ELLM SmolLM2 135M", + type=ModelType.LLM, + capabilities=[ModelCapability.CHAT], + context_length=4096, + owned_by="ellm", +) +CLAUDE_HAIKU_CONFIG = ModelConfigCreate( + id="anthropic/claude-3-5-haiku-latest", + name="Anthropic Claude 3.5 Haiku", + type=ModelType.LLM, + capabilities=[ + ModelCapability.CHAT, + ModelCapability.IMAGE, + ModelCapability.TOOL, + ], + context_length=128000, + languages=["en"], +) +GPT_41_MINI_CONFIG = ModelConfigCreate( + id="openai/gpt-4.1-mini", + name="OpenAI GPT-4.1 mini", + type=ModelType.LLM, + capabilities=[ + ModelCapability.CHAT, + ModelCapability.IMAGE, + ModelCapability.TOOL, + ], + context_length=1047576, + languages=["en"], +) +GPT_41_NANO_CONFIG = ModelConfigCreate( + id="openai/gpt-4.1-nano", + name="OpenAI GPT-4.1 nano", + type=ModelType.LLM, + capabilities=[ + ModelCapability.CHAT, + ModelCapability.IMAGE, + ModelCapability.TOOL, + ], + context_length=1047576, + languages=["en"], +) +GPT_4O_MINI_CONFIG = ModelConfigCreate( + id="openai/gpt-4o-mini", + name="OpenAI GPT-4o mini", + type=ModelType.LLM, + capabilities=[ + ModelCapability.CHAT, + ModelCapability.IMAGE, + ModelCapability.TOOL, + ], + context_length=128000, + languages=["en"], +) +GPT_5_MINI_CONFIG = ModelConfigCreate( + id="openai/gpt-5-mini", + name="OpenAI GPT-5 mini", + type=ModelType.LLM, + capabilities=[ + ModelCapability.CHAT, + ModelCapability.REASONING, + ModelCapability.IMAGE, + ModelCapability.TOOL, + ], + context_length=1280000, + languages=["en"], +) +OPENAI_O4_MINI_CONFIG = ModelConfigCreate( + id="openai/o4-mini", + name="OpenAI o4 mini", + type=ModelType.LLM, + capabilities=[ + ModelCapability.CHAT, + ModelCapability.REASONING, + ModelCapability.IMAGE, + ModelCapability.TOOL, + ], + context_length=1280000, + languages=["en"], +) +ELLM_DESCRIBE_CONFIG = ModelConfigCreate( + id="ellm/describe", + name="Describe Message", + type=ModelType.LLM, + capabilities=[ + ModelCapability.CHAT, + ModelCapability.IMAGE, + ModelCapability.AUDIO, + ], + context_length=128000, + languages=["en"], + owned_by="ellm", +) +TEXT_EMBEDDING_3_SMALL_CONFIG = ModelConfigCreate( + id="openai/text-embedding-3-small", + name="OpenAI Text Embedding 3 Small", + type=ModelType.EMBED, + capabilities=[ModelCapability.EMBED], + context_length=8192, + embedding_size=1536, + embedding_dimensions=256, + languages=["en"], +) +ELLM_EMBEDDING_CONFIG = ModelConfigCreate( + id="ellm/embed-dim-256", + name="Mock Embedding (256-dim)", + type=ModelType.EMBED, + capabilities=[ModelCapability.EMBED], + context_length=8192, + embedding_size=256, + embedding_dimensions=256, + languages=["en"], + owned_by="ellm", +) +RERANK_ENGLISH_v3_SMALL_CONFIG = ModelConfigCreate( + id="cohere/rerank-english-v3.0", + name="Cohere Rerank English v3.0", + type=ModelType.RERANK, + capabilities=[ModelCapability.RERANK], + context_length=512, + languages=["en"], +) + +CLAUDE_HAIKU_DEPLOYMENT = DeploymentCreate( + model_id=CLAUDE_HAIKU_CONFIG.id, + name=f"{CLAUDE_HAIKU_CONFIG.name} Deployment", + provider=CloudProvider.ANTHROPIC, + routing_id=CLAUDE_HAIKU_CONFIG.id, + api_base="", +) +GPT_41_MINI_DEPLOYMENT = DeploymentCreate( + model_id=GPT_41_MINI_CONFIG.id, + name=f"{GPT_41_MINI_CONFIG.name} Deployment", + provider=CloudProvider.OPENAI, + routing_id=GPT_41_MINI_CONFIG.id, + api_base="", +) +GPT_41_NANO_DEPLOYMENT = DeploymentCreate( + model_id=GPT_41_NANO_CONFIG.id, + name=f"{GPT_41_NANO_CONFIG.name} Deployment", + provider=CloudProvider.OPENAI, + routing_id=GPT_41_NANO_CONFIG.id, + api_base="", +) +GPT_4O_MINI_DEPLOYMENT = DeploymentCreate( + model_id=GPT_4O_MINI_CONFIG.id, + name=f"{GPT_4O_MINI_CONFIG.name} Deployment", + provider=CloudProvider.OPENAI, + routing_id=GPT_4O_MINI_CONFIG.id, + api_base="", +) +GPT_5_MINI_DEPLOYMENT = DeploymentCreate( + model_id=GPT_5_MINI_CONFIG.id, + name=f"{GPT_5_MINI_CONFIG.name} Deployment", + provider=CloudProvider.OPENAI, + routing_id=GPT_5_MINI_CONFIG.id, + api_base="", +) +OPENAI_O4_MINI_DEPLOYMENT = DeploymentCreate( + model_id=OPENAI_O4_MINI_CONFIG.id, + name=f"{OPENAI_O4_MINI_CONFIG.name} Deployment", + provider=CloudProvider.OPENAI, + routing_id=OPENAI_O4_MINI_CONFIG.id, + api_base="", +) +ELLM_DESCRIBE_DEPLOYMENT = DeploymentCreate( + model_id=ELLM_DESCRIBE_CONFIG.id, + name=f"{ELLM_DESCRIBE_CONFIG.name} Deployment", + provider="custom", + routing_id=ELLM_DESCRIBE_CONFIG.id, + api_base=ENV_CONFIG.test_llm_api_base, +) +TEXT_EMBEDDING_3_SMALL_DEPLOYMENT = DeploymentCreate( + model_id=TEXT_EMBEDDING_3_SMALL_CONFIG.id, + name=f"{TEXT_EMBEDDING_3_SMALL_CONFIG.name} Deployment", + provider=CloudProvider.OPENAI, + routing_id=TEXT_EMBEDDING_3_SMALL_CONFIG.id, + api_base="", +) +ELLM_EMBEDDING_DEPLOYMENT = DeploymentCreate( + model_id=ELLM_EMBEDDING_CONFIG.id, + name=f"{ELLM_EMBEDDING_CONFIG.name} Deployment", + provider=CloudProvider.VLLM_CLOUD, + routing_id=ELLM_EMBEDDING_CONFIG.id, + api_base=ENV_CONFIG.test_llm_api_base, +) +RERANK_ENGLISH_v3_SMALL_DEPLOYMENT = DeploymentCreate( + model_id=RERANK_ENGLISH_v3_SMALL_CONFIG.id, + name=f"{RERANK_ENGLISH_v3_SMALL_CONFIG.name} Deployment", + provider=CloudProvider.COHERE, + routing_id=RERANK_ENGLISH_v3_SMALL_CONFIG.id, + api_base="", +) + + +@lru_cache(maxsize=1000) +def upload_file_cached(user_id: str, project_id: str, file_path: str) -> FileUploadResponse: + return JamAI(user_id=user_id, project_id=project_id).file.upload_file(file_path) + + +def upload_file(client: JamAI, file_path: str) -> FileUploadResponse: + return upload_file_cached( + user_id=client.user_id, + project_id=client.project_id, + file_path=file_path, + ) + + +STREAM_PARAMS = dict(argvalues=[True, False], ids=["stream", "non-stream"]) +TABLE_TYPES = list(TableType) +TEXTS = { + "EN": '"Arrival" is a 2016 film.', + "ZH-CN": "《é™ä¸´ã€‹æ˜¯ä¸€éƒ¨ 2016 年科幻片。", + "ZH-TW": "《異星入境》是2016年的電影。", + "JA": "「メッセージã€ã¯2016å¹´ã®æ˜ ç”»ã§ã™ã€‚", + "KR": '"컨íƒíЏ"는 2016ë…„ ì˜í™”입니다.', + "ES": '"La llegada" es una película de 2016.', + "IT": '"Arrival" è un film del 2016.', + "IS": '"Arrival" er kvikmynd frá 2016.', + "AR": '"الوصول" هو Ùيلم من عام 2016.', +} + + +@contextmanager +def create_table( + client: JamAI, + table_type: TableType, + table_id: str = "", + *, + cols: list[ColumnSchemaCreate] | None = None, + chat_cols: list[ColumnSchemaCreate] | None = None, + chat_model: str = "", + embedding_model: str = "", +): + try: + if cols is None: + dtypes = ["int", "float", "bool", "str", "image", "audio", "document"] + cols = [ColumnSchemaCreate(id=dtype, dtype=dtype) for dtype in dtypes] + cols += [ + ColumnSchemaCreate( + id="summary", + dtype="str", + gen_config=LLMGenConfig( + model=chat_model, + system_prompt="", + prompt="", + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), + ] + if table_type == TableType.CHAT: + if chat_cols is None: + cols = [ + ColumnSchemaCreate(id="User", dtype="str"), + ColumnSchemaCreate( + id="AI", + dtype="str", + gen_config=LLMGenConfig( + model="", + system_prompt="You are a wacky assistant.", + temperature=0.001, + top_p=0.001, + max_tokens=5, + ), + ), + ] + cols + else: + cols = chat_cols + cols + + # info_col_ids = ["ID", "Updated at"] + # input_col_ids = [ + # col.id for col in cols if col.gen_config is None and col.id not in info_col_ids + # ] + # # output_col_ids = [col.id for col in cols if col.gen_config is not None] + # default_sys_prompt_col_ids = [ + # col.id + # for col in cols + # if isinstance(col.gen_config, LLMGenConfig) and col.gen_config.system_prompt == "" + # ] + # default_prompt_col_ids = [ + # col.id + # for col in cols + # if isinstance(col.gen_config, LLMGenConfig) and col.gen_config.prompt == "" + # ] + + if not table_id: + table_id = generate_key(80, "table-") + if table_type == TableType.ACTION: + table = client.table.create_action_table( + ActionTableSchemaCreate(id=table_id, cols=cols), + ) + elif table_type == TableType.KNOWLEDGE: + table = client.table.create_knowledge_table( + KnowledgeTableSchemaCreate(id=table_id, cols=cols, embedding_model=embedding_model) + ) + elif table_type == TableType.CHAT: + table = client.table.create_chat_table( + ChatTableSchemaCreate(id=table_id, cols=cols), + ) + else: + raise ValueError(f"Invalid table type: {table_type}") + assert isinstance(table, TableMetaResponse) + assert table.id == table_id + # col_map = {col.id: col for col in table.cols} + # # Check default system prompt + # default_sys_phrase = ( + # "You are a versatile data generator. " + # "Your task is to process information from input data and generate appropriate responses based on the specified column name and input data." + # ) + # for col_id in default_sys_prompt_col_ids: + # gen_config = col_map[col_id].gen_config + # assert default_sys_phrase in gen_config.system_prompt + # # Check default prompt + # input_col_refs = ["${" + col + "}" for col in input_col_ids] + # for col_id in default_prompt_col_ids: + # gen_config = col_map[col_id].gen_config + # for ref in input_col_refs: + # assert ref in gen_config.prompt, f"Missing '{ref}' in '{gen_config.prompt}'" + # assert "${ID}" not in gen_config.prompt # Info columns + # assert "${Updated at}" not in gen_config.prompt # Info columns + # if table_type == TableType.KNOWLEDGE: + # assert "${Title Embed}" not in gen_config.prompt # Vector columns + # assert "${Text Embed}" not in gen_config.prompt # Vector columns + # elif table_type == TableType.CHAT: + # assert "${User}" in gen_config.prompt + yield table + finally: + try: + client.table.delete_table(table_type, table_id, missing_ok=True) + except Exception as e: + logger.error(f"Table cleanup failed: {repr(e)}") + + +def list_tables( + client: JamAI, + table_type: TableType, + *, + offset: int = 0, + limit: int = 100, + order_by: str = "updated_at", + order_ascending: bool = True, + created_by: str | None = None, + parent_id: str | None = None, + search_query: str = "", + count_rows: bool = False, + **kwargs, +) -> Page[TableMetaResponse]: + tables = client.table.list_tables( + table_type, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + created_by=created_by, + parent_id=parent_id, + search_query=search_query, + count_rows=count_rows, + **kwargs, + ) + assert isinstance(tables, Page) + assert isinstance(tables.items, list) + assert all(isinstance(t, TableMetaResponse) for t in tables.items) + return tables + + +def compile_and_check_row_responses( + response: ( + MultiRowCompletionResponse + | Generator[CellReferencesResponse | CellCompletionResponse, None, None] + ), + *, + table_type: TableType, + stream: bool, + regen: bool, + check_usage: bool = True, +) -> MultiRowCompletionResponse: + if stream: + responses: list[CellReferencesResponse | CellCompletionResponse] = [r for r in response] + # dump_json( + # [r.model_dump(mode="json") for r in responses], f"stream-{table_type.value}.json" + # ) + for r in responses: + if isinstance(r, CellReferencesResponse): + assert r.object == "gen_table.references" + elif isinstance(r, CellCompletionResponse): + assert r.object == "gen_table.completion.chunk" + assert r.usage is None or isinstance(r.usage, ChatCompletionUsage) + assert isinstance(r.prompt_tokens, int) + assert isinstance(r.completion_tokens, int) + assert isinstance(r.total_tokens, int) + else: + raise ValueError(f"Unexpected response type: {type(r)}") + # Construct MultiRowCompletionResponse + row_chunks_map: dict[str, list[CellCompletionResponse]] = defaultdict(list) + refs_map: dict[tuple[str, str], CellReferencesResponse] = {} + for r in responses: + if isinstance(r, CellReferencesResponse): + refs_map[(r.row_id, r.output_column_name)] = r + continue + row_chunks_map[r.row_id].append(r) + rows = [] + for row_id, row_chunks in row_chunks_map.items(): + col_chunks_map: dict[str, list[CellCompletionResponse]] = defaultdict(list) + for c in row_chunks: + col_chunks_map[c.output_column_name].append(c) + columns = {col_id: chunks[0] for col_id, chunks in col_chunks_map.items()} + for col_id, chunks in col_chunks_map.items(): + content = "".join( + getattr(c.choices[0].message, "content", "") or "" for c in chunks + ) + reasoning_content = "".join( + getattr(c.choices[0].message, "reasoning_content", "") or "" for c in chunks + ) + columns[col_id].choices[0].message.content = content + columns[col_id].choices[0].message.reasoning_content = reasoning_content + columns[col_id].choices[0].delta = None + columns[col_id].usage = chunks[-1].usage # Last chunk should have usage data + columns[col_id].references = refs_map.get((row_id, col_id), None) + # columns[col_id] = ChatCompletionResponse.model_validate( + # columns[col_id].model_dump(exclude={"object", "references.object"}) + # ) + rows.append(RowCompletionResponse(columns=columns, row_id=row_id)) + response = MultiRowCompletionResponse(rows=rows) + # dump_json(response.model_dump(mode="json"), f"stream-{table_type.value}-converted.json") + # else: + # dump_json(response.model_dump(mode="json"), f"nonstream-{table_type.value}-converted.json") + assert isinstance(response, MultiRowCompletionResponse) + assert response.object == "gen_table.completion.rows" + for row in response.rows: + assert isinstance(row, RowCompletionResponse) + assert row.object == "gen_table.completion.chunks" + # if table_type == TableType.CHAT: + # assert "AI" in row.columns + # Check completion lengths + for completion in row.columns.values(): + assert isinstance(completion, (ChatCompletionChunkResponse, ChatCompletionResponse)) + # assert len(completion.content) > 0, f"{completion=}" + # Check usage + if check_usage and not completion.content.startswith("[ERROR] "): + assert isinstance(completion.usage, ChatCompletionUsage), f"{completion.usage=}" + assert isinstance(completion.prompt_tokens, int) + assert isinstance(completion.completion_tokens, int) + assert isinstance(completion.total_tokens, int) + # Regen will return zero usage for "RUN_BEFORE", "RUN_AFTER", "RUN_SELECTED" + min_value = 0 if regen else 1 + assert completion.prompt_tokens >= min_value, f"{completion.content=} {completion.usage=}" # fmt: off + assert completion.completion_tokens >= min_value, f"{completion.content=} {completion.usage=}" # fmt: off + assert completion.usage.total_tokens >= min_value, f"{completion.content=} {completion.usage=}" # fmt: off + # Check references + if isinstance(completion.references, References): + assert isinstance(completion.references.chunks, list) + else: + assert completion.references is None, ( + f"Unexpected type: {type(completion.references)=}" + ) + return response + + +def add_table_rows( + client: JamAI, + table_type: TableType, + table_name: str, + data: list[dict, Any], + *, + stream: bool, + check_usage: bool = True, +) -> MultiRowCompletionResponse: + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest(table_id=table_name, data=data, stream=stream), + ) + return compile_and_check_row_responses( + response, + table_type=table_type, + stream=stream, + regen=False, + check_usage=check_usage, + ) + + +def regen_table_rows( + client: JamAI, + table_type: TableType, + table_name: str, + row_ids: list[str], + *, + stream: bool, + check_usage: bool = True, + **kwargs: Any, +) -> MultiRowCompletionResponse: + response = client.table.regen_table_rows( + table_type, + MultiRowRegenRequest(table_id=table_name, row_ids=row_ids, stream=stream, **kwargs), + ) + return compile_and_check_row_responses( + response, + table_type=table_type, + stream=stream, + regen=True, + check_usage=check_usage, + ) + + +def import_table_data( + client: JamAI, + table_type: TableType, + table_name: str, + file_path: str, + *, + stream: bool, + delimiter: str = ",", + check_usage: bool = True, + **kwargs: Any, +) -> MultiRowCompletionResponse: + response = client.table.import_table_data( + table_type, + TableDataImportRequest( + file_path=file_path, + table_id=table_name, + stream=stream, + delimiter=delimiter, + **kwargs, + ), + ) + return compile_and_check_row_responses( + response, + table_type=table_type, + stream=stream, + regen=False, + check_usage=check_usage, + ) + + +def assert_is_vector_or_none(x: Any, *, allow_none: bool): + if allow_none and x is None: + return + assert isinstance(x, list), f"Not a list: {x}" + assert len(x) > 0, f"List is empty: {x}" + assert all(isinstance(v, float) for v in x), f"Not a list of floats: {x}" + + +T = TypeVar("T") + + +class RowPage(Page[T]): + # For easier testing + values: list[dict[str, Any]] = [] + originals: list[dict[str, Any]] = [] + references: list[dict[str, References | Any]] = [] + + @model_validator(mode="after") + def flatten_row_data(self) -> Self: + rows: list[dict[str, Any]] = self.items + self.values = [ + # `value` key must be present + {c: v["value"] if isinstance(v, dict) else v for c, v in r.items()} + for r in rows + ] + self.originals = [ + # `original` key may be absent + {c: v.get("original", None) if isinstance(v, dict) else None for c, v in r.items()} + for r in rows + ] + references = [ + # `references` key may be absent + {c: v.get("references", None) if isinstance(v, dict) else None for c, v in r.items()} + for r in rows + ] + self.references = [ + {c: References.model_validate(v) if v else None for c, v in r.items()} + for r in references + ] + return self + + +def _check_fetched_row( + row: dict[str, Any], + *, + table_type: TableType, + vec_decimals: int = 0, + columns: list[str] | None = None, +): + assert isinstance(row, dict) + # Check info columns + assert isinstance(row["ID"], str) + assert isinstance(row["Updated at"], str) + id_datetime = datetime.fromisoformat(utc_iso_from_uuid7_draft2(row["ID"])) + updated_at = datetime.fromisoformat(row["Updated at"]) + time_diff = abs( + (id_datetime.replace(tzinfo=None) - updated_at.replace(tzinfo=None)).total_seconds() + ) + assert time_diff < (60 * 60), ( + f"ID datetime: {id_datetime}, Updated at: {updated_at}, Diff: {time_diff}" + ) + # Check vector columns + if table_type == TableType.KNOWLEDGE: + if vec_decimals < 0: + # Vector columns should be removed + assert "Text Embed" not in row + assert "Title Embed" not in row + else: + if columns is None or "Text Embed" in columns: + assert_is_vector_or_none(row["Text Embed"]["value"], allow_none=True) + if columns is None or "Title Embed" in columns: + assert_is_vector_or_none(row["Title Embed"]["value"], allow_none=True) + + +def list_table_rows( + client: JamAI, + table_type: TableType, + table_name: str, + *, + offset: int = 0, + limit: int = 100, + order_by: str = "ID", + order_ascending: bool = True, + columns: list[str] | None = None, + where: str = "", + search_query: str = "", + search_columns: list[str] | None = None, + float_decimals: int = 0, + vec_decimals: int = 0, + **kwargs, +) -> RowPage[dict[str, Any]]: + rows = client.table.list_table_rows( + table_type, + table_name, + offset=offset, + limit=limit, + order_by=order_by, + order_ascending=order_ascending, + columns=columns, + where=where, + search_query=search_query, + search_columns=search_columns, + float_decimals=float_decimals, + vec_decimals=vec_decimals, + **kwargs, + ) + assert isinstance(rows, Page) + assert isinstance(rows.items, list) + assert rows.offset == offset + assert rows.limit == limit + if len(rows.items) > 0: + row = rows.items[0] + _check_fetched_row( + row, + table_type=table_type, + vec_decimals=vec_decimals, + columns=columns, + ) + rows = RowPage[dict[str, Any]].model_validate(rows.model_dump()) + return rows + + +def get_table_row( + client: JamAI, + table_type: TableType, + table_name: str, + row_id: str, + *, + columns: list[str] | None = None, + float_decimals: int = 0, + vec_decimals: int = 0, + **kwargs, +) -> dict[str, Any]: + row = client.table.get_table_row( + table_type, + table_name, + row_id, + columns=columns, + float_decimals=float_decimals, + vec_decimals=vec_decimals, + **kwargs, + ) + _check_fetched_row( + row, + table_type=table_type, + vec_decimals=vec_decimals, + columns=columns, + ) + return row + + +def check_rows( + rows: list[dict[str, Any]], + data: list[dict[str, Any]], + *, + info_cols_equal: bool = True, +): + assert len(rows) == len(data), f"Row count mismatch: {len(rows)=} != {len(data)=}" + for row, d in zip(rows, data, strict=True): + for col in d: + if col in ["ID", "Updated at"] and not info_cols_equal: + assert row[col] != d[col], f'Column "{col}" is not regenerated: {d[col]=}' + continue + if d[col] is None or d[col] == "": + assert row[col] is None, f'Column "{col}" mismatch: {row[col]=} != {d[col]=}' + else: + assert row[col] == d[col], f'Column "{col}" mismatch: {row[col]=} != {d[col]=}' + + +def create_conversation( + client: JamAI, + agent_id: str, + data: dict[str, Any], + title: str | None = None, +) -> list[ConversationMetaResponse | CellReferencesResponse | CellCompletionResponse]: + chunks = client.conversations.create_conversation( + ConversationCreateRequest(agent_id=agent_id, data=data, title=title) + ) + return list(chunks) + + +# class ModelContext(ProjectContext): +# tier: ModelTierRead +# model: ModelConfigRead +# deployment: DeploymentRead + + +# @contextmanager +# def setup_model(): +# async with setup_projects() as ctx: +# async with ( +# create_model_tier( +# dict( +# id="test-tier", +# name="Test PriceTier", +# llm_requests_per_minute=100, +# llm_tokens_per_minute=1000, +# ) +# ) as tier, +# create_model_config( +# ModelConfigCreate( +# id="openai/gpt-4o-mini", +# name="OpenAI GPT-4o mini", +# capabilities=["chat", "image"], +# context_length=128000, +# type=ModelType.LLM, +# languages=["en"], +# ) +# ) as model, +# create_deployment( +# DeploymentCreate( +# model_id=model.id, +# name="Test Deployment", +# provider=CloudProvider.OPENAI, +# routing_id="openai/gpt-4o-mini", +# ) +# ) as deployment, +# ): +# assert tier.id == "test-tier" +# assert model.id == "openai/gpt-4o-mini" +# assert deployment.model_id == "openai/gpt-4o-mini" +# # Using `**model_dump()` leads to serialization warnings +# yield ModelContext( +# tier=tier, +# model=model, +# deployment=deployment, +# projects=ctx.projects, +# superuser=ctx.superuser, +# user=ctx.user, +# superorg=ctx.superorg, +# org=ctx.org, +# ) diff --git a/services/api/src/owl/utils/types.py b/services/api/src/owl/utils/types.py new file mode 100644 index 0000000..8f0481b --- /dev/null +++ b/services/api/src/owl/utils/types.py @@ -0,0 +1,44 @@ +import base64 + +import orjson +from sqlalchemy import TypeDecorator +from sqlmodel import JSON + +from jamaibase.utils.types import ( # noqa: F401 + CLI, + get_enum_validator, +) +from owl.configs import ENV_CONFIG + + +class RqliteJSON(TypeDecorator): + impl = JSON + + def process_bind_param(self, value, dialect): + if value is not None: + # Encode JSON data as Base64 before storing it + return base64.b64encode(orjson.dumps(value)).decode("utf-8") + return value + + def process_result_value(self, value, dialect): + if value is not None: + # Handle empty strings explicitly + if value == "": + return None # or return an empty dict {} depending on your use case + # If the value is already a dictionary, return it directly + if isinstance(value, (list, dict)): + return value + # Ensure the value is a string before decoding + if isinstance(value, bytes): + value = value.decode("utf-8") + # Decode Base64 data back to JSON + return orjson.loads(base64.b64decode(value.encode("utf-8"))) + return value + + +if ENV_CONFIG.db_dialect == "rqlite": + JSON = RqliteJSON +elif ENV_CONFIG.db_dialect == "postgresql": + from sqlalchemy.dialects.postgresql import JSONB + + JSON = JSONB diff --git a/services/api/src/owl/utils/victoriametrics.py b/services/api/src/owl/utils/victoriametrics.py new file mode 100644 index 0000000..bde9d45 --- /dev/null +++ b/services/api/src/owl/utils/victoriametrics.py @@ -0,0 +1,45 @@ +import httpx +from loguru import logger + +http_client = httpx.Client(timeout=5) + + +class VictoriaMetricsClient: + def __init__(self, host: str, port: int, user: str = None, password: str = None): + """Initialize a class for communicating with Victoria Metrics server. + + Args: + host (str): The hostname or IP address of the VictoriaMetrics server. + port (int): The port number of the VictoriaMetrics server. + user (str | None, optional): The username for authentication. + password (str | None, optional): The password for authentication. + """ + self.endpoint = f"http://{host}:{port}" + self.user = user or "" + self.password = password or "" + + def _fetch_victoria_metrics( + self, endpoint: str, params: dict | None = None + ) -> httpx.Response | None: + """Send a GET request to the specified VictoriaMetrics API endpoint. + + Args: + endpoint (str): The API endpoint to send the request to. + params (dict | None, optional): Query parameters to include in the request. + + Returns: + httpx.Response | None: The HTTP response object if the request is successful, or None if the request fails. + + Raises: + httpx.HTTPError: If the HTTP request returns an error status code. + + """ + try: + response = http_client.get( + f"{self.endpoint}{endpoint}", params=params, auth=(self.user, self.password) + ) + response.raise_for_status() + return response + except httpx.HTTPError as e: + logger.warning(f"Error querying VictoriaMetrics: {e}") + return None diff --git a/services/api/src/owl/version.py b/services/api/src/owl/version.py index 6a9beea..3d18726 100644 --- a/services/api/src/owl/version.py +++ b/services/api/src/owl/version.py @@ -1 +1 @@ -__version__ = "0.4.0" +__version__ = "0.5.0" diff --git a/services/api/tests/README.md b/services/api/tests/README.md new file mode 100644 index 0000000..6de8e1a --- /dev/null +++ b/services/api/tests/README.md @@ -0,0 +1,6 @@ +# API Server Tests + +Some tests are split into two files: + +- `test_.py` for OSS and Cloud tests. Cloud-only tests must be marked with `pytest.mark.cloud`. +- `test__cloud.py` for Cloud-only tests. Usually Cloud-only modules are imported. These files will be removed when running OSS tests. diff --git "a/services/api/tests/docling_ground_truth/GitHub \350\241\250\345\215\225\346\236\266\346\236\204\350\257\255\346\263\225 - GitHub \346\226\207\346\241\243.json" "b/services/api/tests/docling_ground_truth/GitHub \350\241\250\345\215\225\346\236\266\346\236\204\350\257\255\346\263\225 - GitHub \346\226\207\346\241\243.json" new file mode 100644 index 0000000..5dad414 --- /dev/null +++ "b/services/api/tests/docling_ground_truth/GitHub \350\241\250\345\215\225\346\236\266\346\236\204\350\257\255\346\263\225 - GitHub \346\226\207\346\241\243.json" @@ -0,0 +1,14 @@ +{ + "document": { + "filename": "GitHub è¡¨å•æž¶æž„语法 - GitHub 文档.pdf", + "md_content": "\n\n建设社区 / 问题和 PR æ¨¡æ¿ / GitHub è¡¨å•æž¶æž„的语法\n\n## GitHub è¡¨å•æž¶æž„的语法\n\n您å¯ä»¥ä½¿â½¤ GitHub çš„è¡¨å•æž¶æž„æ¥é…置⽀æŒçš„功能。\n\n本⽂内容\n\n关于 GitHub çš„è¡¨å•æž¶æž„\n\n密钥\n\n延伸阅读\n\n## 注æ„\n\nGitHub çš„è¡¨å•æž¶æž„⽬å‰ä¸ºå…¬å…±é¢„览版,å¯èƒ½ä¼šæ›´æ”¹ã€‚\n\n## 关于 GitHub çš„è¡¨å•æž¶æž„\n\n您å¯ä»¥ä½¿â½¤ GitHub çš„è¡¨å•æž¶æž„æ¥é…置⽀æŒçš„功能。有关详细信æ¯ï¼Œè¯·å‚阅 为仓库é…ç½®è®®é¢˜æ¨¡æ¿ ã€‚ ' '\n\nè¡¨å•æ˜¯è¯·æ±‚⽤⼾输⼊的⼀组元素。您å¯ä»¥é€šè¿‡åˆ›å»º YAML 表å•定义(这是⼀个表å•元素阵列)æ¥é…置表 å•。æ¯ä¸ªè¡¨å•元素是⼀组确定元素类型ã€å…ƒç´ å±žæ€§ä»¥åŠè¦åº”⽤于元素的约æŸçš„键值对。对于æŸäº›é”®ï¼Œå€¼æ˜¯ å¦â¼€ç»„键值对。\n\n例如,以下表å•定义包括四ç§è¡¨å•元素:⽤于æä¾›â½¤â¼¾æ“作系统的⽂本区域ã€â½¤äºŽé€‰æ‹©â½¤â¼¾è¿â¾çš„软件版 本的下拉èœå•ã€â½¤äºŽç¡®è®¤â¾ä¸ºå‡†åˆ™çš„å¤é€‰æ¡†ä»¥åŠæ„Ÿè°¢â½¤â¼¾å®Œæˆè¡¨å•çš„ Markdown 。\n\n\n\n```\nplaceholder: \"Example: macOS Big Sur\" value: operating system validations: required: true - type: dropdown attributes: label: Version description: What version of our software are you running? multiple: false options: - 1.0.2 (Default) - 1.0.3 (Edge) default: 0 validations: required: true - type: checkboxes attributes: label: Code of Conduct description: The Code of Conduct helps create a safe space for everyone. We require that everyone agrees to it. options: - label: I agree to follow this project's [Code of Conduct](link/to/coc) required: true - type: markdown attributes: value: \"Thanks for completing our form!\"\n```\n\n## 密钥\n\n## 对于æ¯ä¸ªè¡¨å•元素,您å¯ä»¥è®¾ç½®ä»¥ä¸‹é”®ã€‚\n\n| 密钥 | 说明 | 必需 | 类型 | 默认 | 有效值 |\n|------|---------------------------------|------|--------|------|----------------------------------------------|\n| type | 您想è¦å®šä¹‰çš„å…ƒ 素类型。 | | String | | checkboxe s dropdown input markdown textarea |\n| id | 元素的标识符, 除⾮ type 设置 为 markdown 。 | | String | | |\n\n| 密钥 | 说明 | 必需 | 类型 | 默认 | 有效值 |\n|-------------|----------------------------------------------------------------------|------|------|------|-------|\n| | åªèƒ½ä½¿â½¤å­—⺟数 字字符〠- å’Œ _ 。在表å•定义 中必须是唯⼀ 的。如果æä¾›ï¼Œ id 是 URL 查询 傿•°é¢„填中字段 的规范标识符。 | | | | |\n| attributes | 定义元素属性的 ⼀组键值对。 | | 映射 | | |\n| validations | 设置元素约æŸçš„ ⼀组键值对。 | | 映射 | | |\n\n您å¯ä»¥ä»Žä»¥ä¸‹ç±»åž‹çš„表å•元素中选择。æ¯ä¸ªç±»åž‹éƒ½æœ‰å”¯â¼€çš„属性和验è¯ã€‚\n\n| 类型 | 说明 |\n|----------|---------------------------------------|\n| markdown | Markdown ⽂本显⽰在表å•中,为⽤⼾æä¾›é¢å¤–的上下 ⽂,但并未æäº¤ã€‚ |\n| textarea | 多â¾â½‚本字段。 |\n| input | å•â¾â½‚本字段。 |\n| dropdown | 下拉èœå•。 |\n\n## checkboxes\n\n⼀组å¤é€‰æ¡†ã€‚\n\n## markdown\n\nå¯ä»¥ä½¿â½¤ markdown 元素在表å•中显⽰ Markdown ,为⽤⼾æä¾›é¢å¤–çš„ä¸Šä¸‹â½‚ï¼Œä½†ä¸æäº¤ã€‚\n\nmarkdown 的属性\n\n对于 attributes 键的值,å¯ä»¥è®¾ç½®ä»¥ä¸‹å¯†é’¥ã€‚\n\n| 密钥 | 说明 | 必需 | 类型 | 默认 | 有效值 |\n|-------|--------------------|------|--------|------|-------|\n| value | 渲染的⽂本。⽀ æŒ Markdown | | String | | |\n\n## æâ½°\n\nYAML 处ç†å°†å“ˆå¸Œç¬¦å·è§†ä¸ºæ³¨é‡Šã€‚è¦æ’⼊ Markdown æ ‡é¢˜ï¼Œè¯·â½¤å¼•å·æ‹¬ä½â½‚本。\n\n对于多â¾â½‚本,您å¯ä»¥ä½¿â½¤ç«–线è¿ç®—符。\n\nmarkdown 的⽰例\n\n\n\n## textarea\n\nå¯ä»¥ä½¿â½¤ textarea 元素å‘è¡¨å•æ·»åŠ å¤šâ¾â½‚本字段。å‚与者还å¯ä»¥åœ¨ textarea 字段中附加⽂件。\n\n## textarea 的属性\n\n对于 attributes 键的值,å¯ä»¥è®¾ç½®ä»¥ä¸‹å¯†é’¥ã€‚\n\n| 密钥 | 说明 | 必需 | 类型 | 默认 | 有效值 |\n|-------|-------------------------|------|--------|------|-------|\n| label | 预期⽤⼾输⼊的 简短æè¿°ï¼Œä¹Ÿä»¥ 表å•形弿˜¾â½°ã€‚ | | String | | |\n\n| 密钥 | 说明 | 必需 | 类型 | 默认 | 有效值 |\n|-------------|------------------------------------------------------------|------|--------|------|--------------------------------------|\n| description | æä¾›ä¸Šä¸‹â½‚或指 导的⽂本区域的 æè¿°ï¼Œä»¥è¡¨å•å½¢ 弿˜¾â½°ã€‚ | | String | 空字符串 | |\n| placeholder | åŠé€æ˜Žçš„å ä½ 符,在⽂本区域 空⽩时呈现。 | | String | 空字符串 | |\n| value | 在⽂本区域中预 填充的⽂本。 | | String | | |\n| render | 如果æä¾›äº†å€¼ï¼Œ æäº¤çš„⽂本将格 å¼åŒ–为代ç å—。 æä¾›æ­¤é”®æ—¶ï¼Œâ½‚ 本区域将ä¸ä¼šæ‰© 展到⽂件附件或 Markdown ç¼– | | String | | GitHub 已知的语 ⾔。有关详细信 æ¯ï¼Œè¯·å‚阅语⾔ YAML ⽂件。 |\n\n## textarea 的验è¯\n\n对于 validations 键的值,å¯ä»¥è®¾ç½®ä»¥ä¸‹å¯†é’¥ã€‚\n\n| 密钥 | 说明 | 必需 | 类型 | 默认 | 有效值 |\n|----------|-----------------------------|------|------|-------|-------|\n| required | é˜²â½Œåœ¨å…ƒç´ å®Œæˆ ä¹‹å‰æäº¤è¡¨å•。 仅适⽤于公共存 储库。 | | 布尔 | false | |\n\n## textarea 的⽰例\n\n\n\n```\nattributes: label: Reproduction steps description: \"How do you trigger this bug? Please walk us through it step by step.\" value: | 1. 2. 3. ... render: bash validations: required: true\n```\n\n## input\n\nå¯ä»¥ä½¿â½¤ input 元素å‘è¡¨å•æ·»åŠ å•â¾â½‚本字段。\n\n## input 的属性\n\n对于 attributes 键的值,å¯ä»¥è®¾ç½®ä»¥ä¸‹å¯†é’¥ã€‚\n\n| 密钥 | 说明 | 必需 | 类型 | 默认 | 有效值 |\n|-------------|----------------------------|------|--------|------|-------|\n| label | 预期⽤⼾输⼊的 简短æè¿°ï¼Œä¹Ÿä»¥ 表å•形弿˜¾â½°ã€‚ | | String | | |\n| description | æä¾›ä¸Šä¸‹â½‚或指 导的字段的æ 述,以表å•å½¢å¼ æ˜¾â½°ã€‚ | | String | 空字符串 | |\n| placeholder | åŠé€æ˜Žçš„å ä½ 符,在字段空⽩ 时呈现。 | | String | 空字符串 | |\n| value | 字段中预填的⽂ | | String | | |\n\n本。\n\n## input 的验è¯\n\n对于 validations 键的值,å¯ä»¥è®¾ç½®ä»¥ä¸‹å¯†é’¥ã€‚\n\n| 密钥 | 说明 | 必需 | 类型 | 默认 | 有效值 |\n|----------|-----------------------------|------|------|-------|-------|\n| required | é˜²â½Œåœ¨å…ƒç´ å®Œæˆ ä¹‹å‰æäº¤è¡¨å•。 仅适⽤于公共存 储库。 | | 布尔 | false | |\n\n## input 的⽰例\n\n```\nYAML body: - type: input id: prevalence attributes: label: Bug prevalence description: \"How often do you or others encounter this bug?\" placeholder: \"Example: Whenever I visit the personal account page (1-2 times a week)\" validations: required: true\n```\n\n## dropdown\n\nå¯ä»¥ä½¿â½¤ dropdown 元素在表å•中添加下拉èœå•。\n\ndropdown 的属性\n\n对于 attributes 键的值,å¯ä»¥è®¾ç½®ä»¥ä¸‹å¯†é’¥ã€‚\n\n| 密钥 | 说明 | 必需 | 类型 | 默认 | 有效值 |\n|-------------|------------------------------|------|--------|------|-------|\n| label | 预期⽤⼾输⼊的 简短æè¿°ï¼Œä»¥è¡¨ å•形弿˜¾â½°ã€‚ | | String | | |\n| description | æä¾›ä¸Šä¸‹â½‚或指 导的下拉列表的 æè¿°ï¼Œä»¥è¡¨å•å½¢ 弿˜¾â½°ã€‚ | | String | 空字符串 | |\n\n| 密钥 | 说明 | 必需 | 类型 | 默认 | 有效值 |\n|----------|---------------------------------------------------------------|------|---------|-------|-------|\n| multiple | 确定⽤⼾是å¦å¯ 以选择多个选 项。 | | Boolean | false | |\n| options | ⽤⼾å¯ä»¥é€‰æ‹©çš„ 选项阵列。ä¸èƒ½ 为空,所有选择 必须是ä¸åŒçš„。 | | 字符串数组 | | |\n| default | options 数组 中预选选项的索 引。指定了默认 选项时,ä¸èƒ½åŒ… å« ' None ' 或 ' n/a ' 作为选项。 | | Integer | | |\n\n## dropdown 的验è¯\n\n## 对于 validations 键的值,å¯ä»¥è®¾ç½®ä»¥ä¸‹å¯†é’¥ã€‚\n\n| 密钥 | 说明 | 必需 | 类型 | 默认 | 有效值 |\n|----------|-----------------------------|------|------|-------|-------|\n| required | é˜²â½Œåœ¨å…ƒç´ å®Œæˆ ä¹‹å‰æäº¤è¡¨å•。 仅适⽤于公共存 储库。 | | 布尔 | false | |\n\n## dropdown 的⽰例\n\n\n\n```\n- MacPorts - apt-get default: 0 validations: required: true\n```\n\n## checkboxes\n\nå¯ä»¥ä½¿â½¤ checkboxes 元素å‘è¡¨å•æ·»åŠ â¼€ç»„å¤é€‰æ¡†ã€‚\n\ncheckboxes 的属性\n\n对于 attributes 键的值,å¯ä»¥è®¾ç½®ä»¥ä¸‹å¯†é’¥ã€‚\n\n| 密钥 | 说明 | 必需 | 类型 | 默认 | 有效值 |\n|-------------|------------------------------------|------|--------|------|-------|\n| label | 预期⽤⼾输⼊的 简短æè¿°ï¼Œä»¥è¡¨ å•形弿˜¾â½°ã€‚ | | String | | |\n| description | å¤é€‰æ¡†é›†çš„æ è¿°ï¼Œä»¥è¡¨å•å½¢å¼ æ˜¾â½°ã€‚â½€æŒ Markdown æ ¼ å¼ã€‚ | | String | 空字符串 | |\n| options | ⽤⼾å¯ä»¥é€‰æ‹©çš„ å¤é€‰æ¡†é˜µåˆ—。有 关语法,请å‚阅 下⽂。 | | Array | | |\n\n对于 options 数组中的æ¯ä¸ªå€¼ï¼Œå¯ä»¥è®¾ç½®ä»¥ä¸‹é”®ã€‚\n\n| 密钥 | 说明 | 必需 | 类型 | 默认 | 选项 |\n|-------|---------------------------------------------------|------|--------|------|------|\n| label | 选项的标识符, 显⽰在表å•中。 â½€æŒ Markdown ⽤于粗体或斜体 ⽂本格å¼åŒ–和超 ⽂本链接。 | | String | | |\n\n| 密钥 | 说明 | 必需 | 类型 | 默认 | 选项 |\n|----------|-----------------------------|------|------|-------|------|\n| required | é˜²â½Œåœ¨å…ƒç´ å®Œæˆ ä¹‹å‰æäº¤è¡¨å•。 仅适⽤于公共存 储库。 | | 布尔 | false | |\n\n## checkboxes 的验è¯\n\n对于 validations 键的值,å¯ä»¥è®¾ç½®ä»¥ä¸‹å¯†é’¥ã€‚\n\n| 密钥 | 说明 | 必需 | 类型 | 默认 | 有效值 |\n|----------|-----------------------------|------|------|-------|-------|\n| required | é˜²â½Œåœ¨å…ƒç´ å®Œæˆ ä¹‹å‰æäº¤è¡¨å•。 仅适⽤于公共存 储库。 | | 布尔 | false | |\n\n## checkboxes 的⽰例\n\n\n\n| YAML |\n|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| body: |\n| - type: checkboxes id: operating-systems attributes: label: Which operating systems have you used? description: You may select more than one. options: - label: macOS - label: Windows - label: Linux |\n\n## 延伸阅读\n\n## YAML\n\n此内容中的⼀些内容å¯èƒ½æ˜¯æœºå™¨ç¿»è¯‘的或 AI 翻译的内容。\n\n© 2025 GitHub, Inc. 术语 éšç§ çŠ¶æ€ å®šä»· 专家æœåŠ¡ åšå®¢", + "json_content": null, + "html_content": null, + "text_content": null, + "doctags_content": null + }, + "status": "success", + "errors": [], + "processing_time": 9.824623378925025, + "timings": {} +} diff --git a/services/api/tests/docling_ground_truth/Swire_AR22_e_230406_sample.json b/services/api/tests/docling_ground_truth/Swire_AR22_e_230406_sample.json new file mode 100644 index 0000000..0f20836 --- /dev/null +++ b/services/api/tests/docling_ground_truth/Swire_AR22_e_230406_sample.json @@ -0,0 +1,14 @@ +{ + "document": { + "filename": "Swire_AR22_e_230406_sample.pdf", + "md_content": "\n\n\n\n2 0 2 2 ANNUAL REPORT\n\n## CONTENTS\n\n- 1 Corporate Statement\n- 3 2022 Performance Highlights\n- 4 Chairman's Statement\n\n## MANAGEMENT DISCUSSION AND ANALYSIS\n\n- 10 2022 Performance Review and Outlook\n- 59 Financial Review\n- 69 Financing\n\n## CORPORATE GOVERNANCE & SUSTAINABILITY\n\n- 79 Corporate Governance Report\n- 94 Risk Management\n- 98 Directors and Officers\n- 100 Directors' Report\n- 109 Sustainable Development Review\n\n## FINANCIAL STATEMENTS\n\n| 117 | Independent Auditor's Report |\n|-------|----------------------------------------------------------------|\n| 125 | Consolidated Statement of Profit or Loss |\n| 126 | Consolidated Statement of Other Comprehensive Income |\n| 127 | Consolidated Statement of Financial Position |\n| 128 | Consolidated Statement of Cash Flows |\n| 129 | Consolidated Statement of Changes in Equity |\n| 130 | Notes to the Financial Statements |\n| 205 | Principal Accounting Policies |\n| 208 | Principal Subsidiary, Joint Venture and Associated Companies |\n| 218 | Cathay Pacific Airways Limited - Abridged Financial Statements |\n\n## SUPPLEMENTARY INFORMATION\n\n- 220 Summary of Past Performance\n- 222 Schedule of Principal Group Properties\n- 232 Group Structure Chart\n- 234 Glossary\n- 236 Financial Calendar and Information for Investors\n- 236 Disclaimer\n\nNote: Definitions of the terms and ratios used in this report can be found in the Glossary.\n\n## CORPORATE STATEMENT\n\n## SUSTAINABLE GROWTH\n\nSwire Pacific is a Hong Kong-based international conglomerate with a diversified portfolio of market leading businesses. The Company has a long history in Greater China, where the name Swire or å¤ªå¤ has been established for over 150 years.\n\nOur aims are to deliver sustainable growth in shareholder value, achieved through sound returns on equity over the long term, and to return value to shareholders through sustainable growth in ordinary dividends. Our strategy is focused on Greater China and South East Asia, where we seek to grow our core Property, Beverages and Aviation divisions. New areas of growth, such as healthcare and sustainable foods, are being targeted.\n\n## Our Values\n\nIntegrity, endeavour, excellence, humility, teamwork, continuity.\n\n## Our Core Principles\n\n- - We focus on Asia, principally Greater China, because of its strong growth potential and because it is where the Group has long experience, deep knowledge and strong relationships.\n- - We mobilise capital, talent and ideas across the Group. Our scale and diversity increase our access to investment opportunities.\n- - We are prudent financial managers. This enables us to execute long-term investment plans irrespective of shortterm financial market volatility.\n- - We recruit the best people and invest heavily in their training and development. The welfare of our people is critical to our operations.\n- - We build strong and lasting relationships, based on mutual benefit, with those with whom we do business.\n- - We invest in sustainable development, because it is the right thing to do and because it supports long-term growth through innovation and improved efficiency.\n- - We are committed to the highest standards of corporate governance and to the preservation and development of the Swire brand and reputation.\n\n## Our Investment Principles\n\n- - We aim to build a portfolio of businesses that collectively deliver a steady dividend stream over time.\n- - We are long-term investors. We prefer to have controlling interests in our businesses and to manage them for longterm growth. We do not rule out minority investments in appropriate circumstances.\n- - We concentrate on businesses where we can contribute expertise, and where our expertise can add value.\n- - We invest in businesses that provide high-quality products and services and that are leaders in their markets.\n- - We divest from businesses which have reached their full potential under our ownership, and recycle the capital released into existing or new businesses.\n\n## Fleet profile*\n\n| | Number at 31st December 2022 | Number at 31st December 2022 | Number at 31st December 2022 | | | Orders | Orders | Orders | | | | | | | |\n|----------------------|--------------------------------|--------------------------------|--------------------------------|-------|-------------|----------|----------|----------------|-------|-----|-----|-----|-----|-----|----------------|\n| Aircraft type | Owned | Finance | Operating | Total | Average age | '23 | '24 | '25 and beyond | Total | '23 | '24 | '25 | '26 | '27 | '28 and beyond |\n| Cathay Pacific: | | | | | | | | | | | | | | | |\n| A320-200 | 4 | | | 4 | 19.3 | | | | | | | | | | |\n| A321-200 | 2 | | 1 | 3 | 19.8 | | | | | 1 | | | | | |\n| A321-200neo | | 2 | 5 | 7 | 1.4 | 5 (a) | 4 | | 9 | | | | | | 5 |\n| A330-300 | 31 | 8 | 4 | 43 | 14.3 | | | | | | | 2 | 2 | | |\n| A350-900 | 19 | 7 | 2 | 28 | 5.1 | 2 | | | 2 | | | | | | 2 |\n| A350-1000 | 11 | 7 | | 18 | 3.1 | | | | | | | | | | |\n| 747-400ERF | 6 | | | 6 | 14.0 | | | | | | | | | | |\n| 747-8F | 3 | 11 | | 14 | 9.9 | | | | | | | | | | |\n| 777-300 | 17 | | | 17 | 21.2 | | | | | | | | | | |\n| 777-300ER | 28 | 2 | 11 | 41 | 10.2 | | | | | 2 | 3 | 2 | 4 | | |\n| 777-9 | | | | | | | | 21 | 21 | | | | | | |\n| Total | 121 | 37 | 23 | 181 | 10.8 | 7 | 4 | 21 | 32 | 3 | 3 | 4 | 6 | | 7 |\n| HK Express: | | | | | | | | | | | | | | | |\n| A320-200 | | | 5 | 5 | 10.5 | | | | | 1 | 4 | | | | |\n| A320-200neo | | | 10 | 10 | 3.8 | | | | | | | | | | 10 |\n| A321-200 | | | 11 | 11 | 5.2 | | | | | | | 1 | 2 | | 8 |\n| A321-200neo | | | | | | 4 | 8 | 4 | 16 | | | | | | |\n| Total | | | 26 | 26 | 5.7 | 4 | 8 | 4 | 16 | 1 | 4 | 1 | 2 | | 18 |\n| Air Hong Kong*** (b) | : | | | | | | | | | | | | | | |\n| A300-600F | | | 9 | 9 | 18.6 | | | | | 7 | 2 | | | | |\n| A330-243F | | | 2 | 2 | 11.0 | | | | | | | | 2 | | |\n| A330-300P2F | | | 4 | 4 | 13.7 | | | | | | | | 3 | | 1 |\n| Total | | | 15 | 15 | 16.3 | | | | | 7 | 2 | | 5 | | 1 |\n| Grand total | 121 | 37 | 64 | 222 | 10.6 | 11 | 12 | 25 | 48 | 11 | 9 | 5 | 13 | | 26 |\n\n* The table does not reflect aircraft movements after 31st December 2022.\n\n- ** Leases previously classified as operating leases are accounted for in a similar manner to finance leases under accounting standards. The majority of operating leases in the above table are within the scope of HKFRS 16.\n\n*** The contractual arrangements relating to the freighters operated by Air Hong Kong do not constitute leases in accordance with HKFRS 16.\n\n(a) Two Airbus A321-200neo aircraft were delivered in February 2023.\n\n(b) The plan is to return the nine A300-600F aircraft between 2023 and 2024 and to replace them with nine second-hand A330F aircraft. This allows the Air Hong Kong fleet to remain the same (at 15), at least until 2024.\n\n## Responsibilities of Directors\n\nOn appointment, the Directors receive information about the Group including:\n\n- - the role of the Board and the matters reserved for its attention\n- - the role and terms of reference of Board Committees\n- - the Group's corporate governance practices and procedures\n- - the powers delegated to management and\n- - the latest financial information.\n\nDirectors update their skills, knowledge and understanding of the Company's businesses through their participation at meetings of the Board and its committees and through regular meetings with management at the head office and in the divisions. Directors are regularly updated by the Company Secretary on their legal and other duties as Directors of a listed company.\n\nThrough the Company Secretary, Directors are able to obtain appropriate professional training and advice.\n\nEach Director ensures that he/she can give sufficient time and attention to the affairs of the Group. All Directors disclose to the Board on their first appointment their interests as a Director or otherwise in other companies or organisations and such declarations of interests are updated regularly. No Director was a director of more than five other listed companies (excluding the Company) at 31st December 2022.\n\n\n\nDetails of Directors' other appointments are shown in their biographies in the section of this annual report headed Directors and Officers.\n\nAgendas and accompanying Board papers are circulated with sufficient time to allow the Directors to prepare before meetings.\n\n## Board Processes\n\nAll committees of the Board follow the same processes as the full Board.\n\nThe dates of the 2022 Board meetings were determined in 2021 and any amendments to this schedule were notified to Directors at least 14 days before regular meetings. Appropriate arrangements are in place to allow Directors to include items in the agenda for regular Board meetings.\n\nThe Board met seven times in 2022, including two strategy sessions. The attendance of individual Directors at meetings of the Board and its committees is set out in the table on page 83. Attendance at Board meetings was 100%. All Directors attended Board meetings in person or through electronic means of communication during the year.\n\nThe Chairman takes the lead to ensure that the Board acts in the best interests of the Company, that there is effective communication with the shareholders and that their views are communicated to the Board as a whole.\n\nBoard decisions are made by vote at Board meetings and supplemented by the circulation of written resolutions between Board meetings.\n\nMinutes of Board meetings are taken by the Company Secretary and, together with any supporting papers, are made available to all Directors. The minutes record the matters considered by the Board, the decisions reached, and any concerns raised or dissenting views expressed by Directors. Draft and final versions of the minutes are sent to all Directors for their comment and records respectively.", + "json_content": null, + "html_content": null, + "text_content": null, + "doctags_content": null + }, + "status": "success", + "errors": [], + "processing_time": 6.495824097888544, + "timings": {} +} diff --git a/clients/python/tests/files/bmp/cifar10-deer.bmp b/services/api/tests/files/bmp/cifar10-deer.bmp similarity index 100% rename from clients/python/tests/files/bmp/cifar10-deer.bmp rename to services/api/tests/files/bmp/cifar10-deer.bmp diff --git a/clients/python/tests/files/csv/company-profile.csv b/services/api/tests/files/csv/company-profile.csv similarity index 100% rename from clients/python/tests/files/csv/company-profile.csv rename to services/api/tests/files/csv/company-profile.csv diff --git a/clients/python/tests/files/csv/empty.csv b/services/api/tests/files/csv/empty.csv similarity index 100% rename from clients/python/tests/files/csv/empty.csv rename to services/api/tests/files/csv/empty.csv diff --git a/clients/python/tests/files/csv/weather_observations_long.csv b/services/api/tests/files/csv/weather_observations_long.csv similarity index 100% rename from clients/python/tests/files/csv/weather_observations_long.csv rename to services/api/tests/files/csv/weather_observations_long.csv diff --git a/clients/python/tests/files/doc/Recommendation Letter.doc b/services/api/tests/files/doc/Recommendation Letter.doc similarity index 100% rename from clients/python/tests/files/doc/Recommendation Letter.doc rename to services/api/tests/files/doc/Recommendation Letter.doc diff --git a/clients/python/tests/files/docx/Recommendation Letter.docx b/services/api/tests/files/docx/Recommendation Letter.docx similarity index 100% rename from clients/python/tests/files/docx/Recommendation Letter.docx rename to services/api/tests/files/docx/Recommendation Letter.docx diff --git a/services/api/tests/files/exports/export-v0.4-action.parquet b/services/api/tests/files/exports/export-v0.4-action.parquet new file mode 100644 index 0000000..ef076fd Binary files /dev/null and b/services/api/tests/files/exports/export-v0.4-action.parquet differ diff --git a/services/api/tests/files/exports/export-v0.4-chat-agent-1.parquet b/services/api/tests/files/exports/export-v0.4-chat-agent-1.parquet new file mode 100644 index 0000000..9b00c71 Binary files /dev/null and b/services/api/tests/files/exports/export-v0.4-chat-agent-1.parquet differ diff --git a/services/api/tests/files/exports/export-v0.4-chat-agent.parquet b/services/api/tests/files/exports/export-v0.4-chat-agent.parquet new file mode 100644 index 0000000..86e2ff2 Binary files /dev/null and b/services/api/tests/files/exports/export-v0.4-chat-agent.parquet differ diff --git a/services/api/tests/files/exports/export-v0.4-knowledge.parquet b/services/api/tests/files/exports/export-v0.4-knowledge.parquet new file mode 100644 index 0000000..b878b71 Binary files /dev/null and b/services/api/tests/files/exports/export-v0.4-knowledge.parquet differ diff --git a/services/api/tests/files/exports/export-v0.4-project-long-name.parquet b/services/api/tests/files/exports/export-v0.4-project-long-name.parquet new file mode 100644 index 0000000..2ab3e14 Binary files /dev/null and b/services/api/tests/files/exports/export-v0.4-project-long-name.parquet differ diff --git a/services/api/tests/files/exports/export-v0.4-project.parquet b/services/api/tests/files/exports/export-v0.4-project.parquet new file mode 100644 index 0000000..be3e8e9 Binary files /dev/null and b/services/api/tests/files/exports/export-v0.4-project.parquet differ diff --git a/clients/python/tests/files/gif/rabbit_cifar10-deer.gif b/services/api/tests/files/gif/rabbit_cifar10-deer.gif similarity index 100% rename from clients/python/tests/files/gif/rabbit_cifar10-deer.gif rename to services/api/tests/files/gif/rabbit_cifar10-deer.gif diff --git a/services/api/tests/files/gif/rabbit_cifar10-deer.gif.thumb.webp b/services/api/tests/files/gif/rabbit_cifar10-deer.gif.thumb.webp new file mode 100644 index 0000000..d523765 Binary files /dev/null and b/services/api/tests/files/gif/rabbit_cifar10-deer.gif.thumb.webp differ diff --git a/clients/python/tests/files/html/RAG and LLM Integration Guide.html b/services/api/tests/files/html/RAG and LLM Integration Guide.html similarity index 100% rename from clients/python/tests/files/html/RAG and LLM Integration Guide.html rename to services/api/tests/files/html/RAG and LLM Integration Guide.html diff --git a/clients/python/tests/files/html/multilingual-code-examples.html b/services/api/tests/files/html/multilingual-code-examples.html similarity index 100% rename from clients/python/tests/files/html/multilingual-code-examples.html rename to services/api/tests/files/html/multilingual-code-examples.html diff --git a/clients/python/tests/files/html/table.html b/services/api/tests/files/html/table.html similarity index 100% rename from clients/python/tests/files/html/table.html rename to services/api/tests/files/html/table.html diff --git a/clients/python/tests/files/jpeg/cifar10-deer.jpg b/services/api/tests/files/jpeg/cifar10-deer.jpg similarity index 100% rename from clients/python/tests/files/jpeg/cifar10-deer.jpg rename to services/api/tests/files/jpeg/cifar10-deer.jpg diff --git a/services/api/tests/files/jpeg/cifar10-deer.jpg.thumb.webp b/services/api/tests/files/jpeg/cifar10-deer.jpg.thumb.webp new file mode 100644 index 0000000..fef7087 Binary files /dev/null and b/services/api/tests/files/jpeg/cifar10-deer.jpg.thumb.webp differ diff --git a/services/api/tests/files/jpeg/doe.jpg b/services/api/tests/files/jpeg/doe.jpg new file mode 100644 index 0000000..1f23ebd Binary files /dev/null and b/services/api/tests/files/jpeg/doe.jpg differ diff --git a/clients/python/tests/files/jpeg/rabbit.jpeg b/services/api/tests/files/jpeg/rabbit.jpeg similarity index 100% rename from clients/python/tests/files/jpeg/rabbit.jpeg rename to services/api/tests/files/jpeg/rabbit.jpeg diff --git a/clients/python/tests/files/json/company-profile.json b/services/api/tests/files/json/company-profile.json similarity index 100% rename from clients/python/tests/files/json/company-profile.json rename to services/api/tests/files/json/company-profile.json diff --git a/clients/python/tests/files/jsonl/ChatMed_TCM-v0.2-5records.jsonl b/services/api/tests/files/jsonl/ChatMed_TCM-v0.2-5records.jsonl similarity index 100% rename from clients/python/tests/files/jsonl/ChatMed_TCM-v0.2-5records.jsonl rename to services/api/tests/files/jsonl/ChatMed_TCM-v0.2-5records.jsonl diff --git a/clients/python/tests/files/jsonl/llm-models.jsonl b/services/api/tests/files/jsonl/llm-models.jsonl similarity index 100% rename from clients/python/tests/files/jsonl/llm-models.jsonl rename to services/api/tests/files/jsonl/llm-models.jsonl diff --git a/clients/python/tests/files/md/creative-story.md b/services/api/tests/files/md/creative-story.md similarity index 100% rename from clients/python/tests/files/md/creative-story.md rename to services/api/tests/files/md/creative-story.md diff --git a/services/api/tests/files/mp3/grand-scheme.mp3 b/services/api/tests/files/mp3/grand-scheme.mp3 new file mode 100644 index 0000000..4bfad52 Binary files /dev/null and b/services/api/tests/files/mp3/grand-scheme.mp3 differ diff --git a/services/api/tests/files/mp3/gutter.mp3 b/services/api/tests/files/mp3/gutter.mp3 new file mode 100644 index 0000000..0c7e5ad Binary files /dev/null and b/services/api/tests/files/mp3/gutter.mp3 differ diff --git a/services/api/tests/files/mp3/gutter.mp3.thumb.mp3 b/services/api/tests/files/mp3/gutter.mp3.thumb.mp3 new file mode 100644 index 0000000..d5c6d69 Binary files /dev/null and b/services/api/tests/files/mp3/gutter.mp3.thumb.mp3 differ diff --git a/services/api/tests/files/mp3/stars.mp3 b/services/api/tests/files/mp3/stars.mp3 new file mode 100644 index 0000000..8e58ed4 Binary files /dev/null and b/services/api/tests/files/mp3/stars.mp3 differ diff --git a/clients/python/tests/files/mp3/turning-a4-size-magazine.mp3 b/services/api/tests/files/mp3/turning-a4-size-magazine.mp3 similarity index 100% rename from clients/python/tests/files/mp3/turning-a4-size-magazine.mp3 rename to services/api/tests/files/mp3/turning-a4-size-magazine.mp3 diff --git a/services/api/tests/files/mp3/turning-a4-size-magazine.mp3.thumb.mp3 b/services/api/tests/files/mp3/turning-a4-size-magazine.mp3.thumb.mp3 new file mode 100644 index 0000000..b27e456 Binary files /dev/null and b/services/api/tests/files/mp3/turning-a4-size-magazine.mp3.thumb.mp3 differ diff --git a/clients/python/tests/files/pdf/1970_PSS_ThAT_mechanism.pdf b/services/api/tests/files/pdf/1970_PSS_ThAT_mechanism.pdf similarity index 100% rename from clients/python/tests/files/pdf/1970_PSS_ThAT_mechanism.pdf rename to services/api/tests/files/pdf/1970_PSS_ThAT_mechanism.pdf diff --git a/services/api/tests/files/pdf/1970_PSS_ThAT_mechanism.pdf.thumb.webp b/services/api/tests/files/pdf/1970_PSS_ThAT_mechanism.pdf.thumb.webp new file mode 100644 index 0000000..81c4ff5 Binary files /dev/null and b/services/api/tests/files/pdf/1970_PSS_ThAT_mechanism.pdf.thumb.webp differ diff --git a/clients/python/tests/files/pdf/1982_PRB_phonon-assisted_tunnel_ionization.pdf b/services/api/tests/files/pdf/1982_PRB_phonon-assisted_tunnel_ionization.pdf similarity index 100% rename from clients/python/tests/files/pdf/1982_PRB_phonon-assisted_tunnel_ionization.pdf rename to services/api/tests/files/pdf/1982_PRB_phonon-assisted_tunnel_ionization.pdf diff --git "a/services/api/tests/files/pdf/GitHub \350\241\250\345\215\225\346\236\266\346\236\204\350\257\255\346\263\225 - GitHub \346\226\207\346\241\243.pdf" "b/services/api/tests/files/pdf/GitHub \350\241\250\345\215\225\346\236\266\346\236\204\350\257\255\346\263\225 - GitHub \346\226\207\346\241\243.pdf" new file mode 100644 index 0000000..bf1459d Binary files /dev/null and "b/services/api/tests/files/pdf/GitHub \350\241\250\345\215\225\346\236\266\346\236\204\350\257\255\346\263\225 - GitHub \346\226\207\346\241\243.pdf" differ diff --git a/clients/python/tests/files/pdf/Large Language Models as Optimizers [DeepMind ; 2023].pdf b/services/api/tests/files/pdf/LLMs as Optimizers [DeepMind ; 2023].pdf similarity index 100% rename from clients/python/tests/files/pdf/Large Language Models as Optimizers [DeepMind ; 2023].pdf rename to services/api/tests/files/pdf/LLMs as Optimizers [DeepMind ; 2023].pdf diff --git a/clients/python/tests/files/pdf/Swire_AR22_e_230406_sample.pdf b/services/api/tests/files/pdf/Swire_AR22_e_230406_sample.pdf similarity index 100% rename from clients/python/tests/files/pdf/Swire_AR22_e_230406_sample.pdf rename to services/api/tests/files/pdf/Swire_AR22_e_230406_sample.pdf diff --git a/clients/python/tests/files/pdf/System Design Blueprint - The Ultimate Guide.pdf b/services/api/tests/files/pdf/System Design Blueprint - The Ultimate Guide.pdf similarity index 100% rename from clients/python/tests/files/pdf/System Design Blueprint - The Ultimate Guide.pdf rename to services/api/tests/files/pdf/System Design Blueprint - The Ultimate Guide.pdf diff --git a/clients/python/tests/files/pdf/Vehicle Detail - MyPUSPAKOM.pdf b/services/api/tests/files/pdf/Vehicle Detail - MyPUSPAKOM.pdf similarity index 100% rename from clients/python/tests/files/pdf/Vehicle Detail - MyPUSPAKOM.pdf rename to services/api/tests/files/pdf/Vehicle Detail - MyPUSPAKOM.pdf diff --git a/clients/python/tests/files/pdf/ag-energy-round-up-2017-02-24.pdf b/services/api/tests/files/pdf/ag-energy-round-up-2017-02-24.pdf similarity index 100% rename from clients/python/tests/files/pdf/ag-energy-round-up-2017-02-24.pdf rename to services/api/tests/files/pdf/ag-energy-round-up-2017-02-24.pdf diff --git a/clients/python/tests/files/pdf/background-checks.pdf b/services/api/tests/files/pdf/background-checks.pdf similarity index 100% rename from clients/python/tests/files/pdf/background-checks.pdf rename to services/api/tests/files/pdf/background-checks.pdf diff --git a/clients/python/tests/files/pdf/ca-warn-report.pdf b/services/api/tests/files/pdf/ca-warn-report.pdf similarity index 100% rename from clients/python/tests/files/pdf/ca-warn-report.pdf rename to services/api/tests/files/pdf/ca-warn-report.pdf diff --git a/clients/python/tests/files/pdf/empty.pdf b/services/api/tests/files/pdf/empty.pdf similarity index 100% rename from clients/python/tests/files/pdf/empty.pdf rename to services/api/tests/files/pdf/empty.pdf diff --git a/clients/python/tests/files/pdf/empty_3pages.pdf b/services/api/tests/files/pdf/empty_3pages.pdf similarity index 100% rename from clients/python/tests/files/pdf/empty_3pages.pdf rename to services/api/tests/files/pdf/empty_3pages.pdf diff --git a/services/api/tests/files/pdf/hello.pdf b/services/api/tests/files/pdf/hello.pdf new file mode 100644 index 0000000..3520fd0 Binary files /dev/null and b/services/api/tests/files/pdf/hello.pdf differ diff --git "a/clients/python/tests/files/pdf/salary \346\200\273\347\273\223.pdf" "b/services/api/tests/files/pdf/salary \346\200\273\347\273\223.pdf" similarity index 100% rename from "clients/python/tests/files/pdf/salary \346\200\273\347\273\223.pdf" rename to "services/api/tests/files/pdf/salary \346\200\273\347\273\223.pdf" diff --git a/clients/python/tests/files/pdf/sample_tables.pdf b/services/api/tests/files/pdf/sample_tables.pdf similarity index 100% rename from clients/python/tests/files/pdf/sample_tables.pdf rename to services/api/tests/files/pdf/sample_tables.pdf diff --git a/clients/python/tests/files/pdf/san-jose-pd-firearm-sample.pdf b/services/api/tests/files/pdf/san-jose-pd-firearm-sample.pdf similarity index 100% rename from clients/python/tests/files/pdf/san-jose-pd-firearm-sample.pdf rename to services/api/tests/files/pdf/san-jose-pd-firearm-sample.pdf diff --git a/clients/python/tests/files/pdf/statement_card.pdf b/services/api/tests/files/pdf/statement_card.pdf similarity index 100% rename from clients/python/tests/files/pdf/statement_card.pdf rename to services/api/tests/files/pdf/statement_card.pdf diff --git a/clients/python/tests/files/pdf/statement_ewallet.pdf b/services/api/tests/files/pdf/statement_ewallet.pdf similarity index 100% rename from clients/python/tests/files/pdf/statement_ewallet.pdf rename to services/api/tests/files/pdf/statement_ewallet.pdf diff --git a/clients/python/tests/files/pdf_mixed/digital_scan_combined.pdf b/services/api/tests/files/pdf_mixed/digital_scan_combined.pdf similarity index 100% rename from clients/python/tests/files/pdf_mixed/digital_scan_combined.pdf rename to services/api/tests/files/pdf_mixed/digital_scan_combined.pdf diff --git a/clients/python/tests/files/pdf_scan/1978_APL_FP_detrapping.PDF b/services/api/tests/files/pdf_scan/1978_APL_FP_detrapping.PDF similarity index 100% rename from clients/python/tests/files/pdf_scan/1978_APL_FP_detrapping.PDF rename to services/api/tests/files/pdf_scan/1978_APL_FP_detrapping.PDF diff --git a/services/api/tests/files/pdf_scan/uuk_bangunan_seragam_pindaan_2017.pdf b/services/api/tests/files/pdf_scan/uuk_bangunan_seragam_pindaan_2017.pdf new file mode 100644 index 0000000..91bedc6 Binary files /dev/null and b/services/api/tests/files/pdf_scan/uuk_bangunan_seragam_pindaan_2017.pdf differ diff --git a/clients/python/tests/files/png/cifar10-deer.png b/services/api/tests/files/png/cifar10-deer.png similarity index 100% rename from clients/python/tests/files/png/cifar10-deer.png rename to services/api/tests/files/png/cifar10-deer.png diff --git a/clients/python/tests/files/png/github-mark-white.png b/services/api/tests/files/png/github-mark-white.png similarity index 100% rename from clients/python/tests/files/png/github-mark-white.png rename to services/api/tests/files/png/github-mark-white.png diff --git a/clients/python/tests/files/png/rabbit.png b/services/api/tests/files/png/rabbit.png similarity index 100% rename from clients/python/tests/files/png/rabbit.png rename to services/api/tests/files/png/rabbit.png diff --git a/services/api/tests/files/png/rabbit.png.thumb.webp b/services/api/tests/files/png/rabbit.png.thumb.webp new file mode 100644 index 0000000..53e61a8 Binary files /dev/null and b/services/api/tests/files/png/rabbit.png.thumb.webp differ diff --git a/clients/python/tests/files/ppt/(2017.06.30) Neural Machine Translation in Linear Time (ByteNet).ppt b/services/api/tests/files/ppt/(2017.06.30) Neural Machine Translation in Linear Time (ByteNet).ppt similarity index 100% rename from clients/python/tests/files/ppt/(2017.06.30) Neural Machine Translation in Linear Time (ByteNet).ppt rename to services/api/tests/files/ppt/(2017.06.30) Neural Machine Translation in Linear Time (ByteNet).ppt diff --git a/clients/python/tests/files/pptx/(2017.06.30) Neural Machine Translation in Linear Time (ByteNet).pptx b/services/api/tests/files/pptx/(2017.06.30) NMT in Linear Time (ByteNet).pptx similarity index 100% rename from clients/python/tests/files/pptx/(2017.06.30) Neural Machine Translation in Linear Time (ByteNet).pptx rename to services/api/tests/files/pptx/(2017.06.30) NMT in Linear Time (ByteNet).pptx diff --git a/clients/python/tests/files/tiff/cifar10-deer.tiff b/services/api/tests/files/tiff/cifar10-deer.tiff similarity index 100% rename from clients/python/tests/files/tiff/cifar10-deer.tiff rename to services/api/tests/files/tiff/cifar10-deer.tiff diff --git a/clients/python/tests/files/tiff/rabbit.tiff b/services/api/tests/files/tiff/rabbit.tiff similarity index 100% rename from clients/python/tests/files/tiff/rabbit.tiff rename to services/api/tests/files/tiff/rabbit.tiff diff --git a/clients/python/tests/files/tsv/weather_observations.tsv b/services/api/tests/files/tsv/weather_observations.tsv similarity index 100% rename from clients/python/tests/files/tsv/weather_observations.tsv rename to services/api/tests/files/tsv/weather_observations.tsv diff --git a/clients/python/tests/files/txt/creative-story.txt b/services/api/tests/files/txt/creative-story.txt similarity index 100% rename from clients/python/tests/files/txt/creative-story.txt rename to services/api/tests/files/txt/creative-story.txt diff --git a/clients/python/tests/files/txt/empty.txt b/services/api/tests/files/txt/empty.txt similarity index 100% rename from clients/python/tests/files/txt/empty.txt rename to services/api/tests/files/txt/empty.txt diff --git a/clients/python/tests/files/txt/weather.txt b/services/api/tests/files/txt/weather.txt similarity index 100% rename from clients/python/tests/files/txt/weather.txt rename to services/api/tests/files/txt/weather.txt diff --git a/services/api/tests/files/wav/gutter.wav b/services/api/tests/files/wav/gutter.wav new file mode 100644 index 0000000..3b79a4f Binary files /dev/null and b/services/api/tests/files/wav/gutter.wav differ diff --git a/services/api/tests/files/wav/gutter.wav.thumb.mp3 b/services/api/tests/files/wav/gutter.wav.thumb.mp3 new file mode 100644 index 0000000..57c95a1 Binary files /dev/null and b/services/api/tests/files/wav/gutter.wav.thumb.mp3 differ diff --git a/clients/python/tests/files/wav/turning-a4-size-magazine.wav b/services/api/tests/files/wav/turning-a4-size-magazine.wav similarity index 100% rename from clients/python/tests/files/wav/turning-a4-size-magazine.wav rename to services/api/tests/files/wav/turning-a4-size-magazine.wav diff --git a/services/api/tests/files/wav/turning-a4-size-magazine.wav.thumb.mp3 b/services/api/tests/files/wav/turning-a4-size-magazine.wav.thumb.mp3 new file mode 100644 index 0000000..a9d5e8a Binary files /dev/null and b/services/api/tests/files/wav/turning-a4-size-magazine.wav.thumb.mp3 differ diff --git a/clients/python/tests/files/webp/rabbit_cifar10-deer.webp b/services/api/tests/files/webp/rabbit_cifar10-deer.webp similarity index 100% rename from clients/python/tests/files/webp/rabbit_cifar10-deer.webp rename to services/api/tests/files/webp/rabbit_cifar10-deer.webp diff --git a/services/api/tests/files/webp/rabbit_cifar10-deer.webp.thumb.webp b/services/api/tests/files/webp/rabbit_cifar10-deer.webp.thumb.webp new file mode 100644 index 0000000..e0ab8db Binary files /dev/null and b/services/api/tests/files/webp/rabbit_cifar10-deer.webp.thumb.webp differ diff --git a/clients/python/tests/files/xls/Claims Form.xls b/services/api/tests/files/xls/Claims Form.xls similarity index 100% rename from clients/python/tests/files/xls/Claims Form.xls rename to services/api/tests/files/xls/Claims Form.xls diff --git a/clients/python/tests/files/xlsx/Claims Form.xlsx b/services/api/tests/files/xlsx/Claims Form.xlsx similarity index 100% rename from clients/python/tests/files/xlsx/Claims Form.xlsx rename to services/api/tests/files/xlsx/Claims Form.xlsx diff --git a/services/api/tests/files/xlsx/Claims Form.xlsx.thumb.gen.webp b/services/api/tests/files/xlsx/Claims Form.xlsx.thumb.gen.webp new file mode 100644 index 0000000..7ec017b Binary files /dev/null and b/services/api/tests/files/xlsx/Claims Form.xlsx.thumb.gen.webp differ diff --git a/services/api/tests/files/xlsx/Claims Form.xlsx.thumb.webp b/services/api/tests/files/xlsx/Claims Form.xlsx.thumb.webp new file mode 100644 index 0000000..517019c Binary files /dev/null and b/services/api/tests/files/xlsx/Claims Form.xlsx.thumb.webp differ diff --git a/clients/python/tests/files/xml/weather-forecast-service.xml b/services/api/tests/files/xml/weather-forecast-service.xml similarity index 100% rename from clients/python/tests/files/xml/weather-forecast-service.xml rename to services/api/tests/files/xml/weather-forecast-service.xml diff --git a/services/api/tests/gen_table/test_empty_db.py b/services/api/tests/gen_table/test_empty_db.py new file mode 100644 index 0000000..0dd88cd --- /dev/null +++ b/services/api/tests/gen_table/test_empty_db.py @@ -0,0 +1,32 @@ +import pytest + +from jamaibase import JamAI +from jamaibase.types import OrganizationCreate, TableType +from owl.utils.exceptions import ResourceNotFoundError +from owl.utils.test import ( + create_organization, + create_project, + create_user, + list_tables, +) + + +def test_get_list_tables_no_schema(): + with ( + create_user() as superuser, + create_organization( + body=OrganizationCreate(name="Clubhouse"), user_id=superuser.id + ) as superorg, + # Create project + create_project( + dict(name="Project"), user_id=superuser.id, organization_id=superorg.id + ) as p0, + ): + super_client = JamAI(user_id=superuser.id, project_id=p0.id) + # No gen table schema + for table_type in TableType: + tables = list_tables(super_client, table_type) + assert len(tables.items) == 0 + assert tables.total == 0 + with pytest.raises(ResourceNotFoundError, match="Table .+ is not found."): + super_client.table.get_table(table_type, "123") diff --git a/services/api/tests/gen_table/test_import_export.py b/services/api/tests/gen_table/test_import_export.py new file mode 100644 index 0000000..12fd1b9 --- /dev/null +++ b/services/api/tests/gen_table/test_import_export.py @@ -0,0 +1,1025 @@ +import builtins +from dataclasses import dataclass +from os.path import dirname, join, realpath +from tempfile import TemporaryDirectory +from types import NoneType +from typing import Any + +import httpx +import pandas as pd +import pytest + +from jamaibase import JamAI +from jamaibase.types import ( + ColumnReorderRequest, + ColumnSchemaCreate, + EmbedGenConfig, + GetURLResponse, + LLMGenConfig, + OkResponse, + OrganizationCreate, + TableImportRequest, + TableMetaResponse, + TableType, +) +from owl.utils.exceptions import ( + BadInputError, +) +from owl.utils.io import csv_to_df, df_to_csv +from owl.utils.test import ( + ELLM_DESCRIBE_CONFIG, + ELLM_DESCRIBE_DEPLOYMENT, + ELLM_EMBEDDING_CONFIG, + ELLM_EMBEDDING_DEPLOYMENT, + STREAM_PARAMS, + TABLE_TYPES, + TEXTS, + RERANK_ENGLISH_v3_SMALL_CONFIG, + RERANK_ENGLISH_v3_SMALL_DEPLOYMENT, + add_table_rows, + assert_is_vector_or_none, + check_rows, + create_deployment, + create_model_config, + create_organization, + create_project, + create_table, + create_user, + get_file_map, + import_table_data, + list_table_rows, + upload_file, +) + +TEST_FILE_DIR = join(dirname(dirname(realpath(__file__))), "files") +FILES = get_file_map(TEST_FILE_DIR) + +FILE_COLUMNS = ["image", "audio", "document", "File ID"] + + +@dataclass(slots=True) +class ServingContext: + superuser_id: str + superorg_id: str + project_id: str + embedding_size: int + image_uri: str + audio_uri: str + document_uri: str + chat_model_id: str + embed_model_id: str + rerank_model_id: str + + +@pytest.fixture(scope="module") +def setup(): + """ + Fixture to set up the necessary organization and projects for file tests. + """ + with ( + create_user() as superuser, + create_organization( + body=OrganizationCreate(name="Superorg"), user_id=superuser.id + ) as superorg, + create_project( + dict(name="Superorg Project"), user_id=superuser.id, organization_id=superorg.id + ) as p0, + ): + assert superorg.id == "0" + # Create models + with ( + create_model_config(ELLM_DESCRIBE_CONFIG) as desc_llm_config, + create_model_config(ELLM_EMBEDDING_CONFIG) as embed_config, + create_model_config(RERANK_ENGLISH_v3_SMALL_CONFIG) as rerank_config, + ): + # Create deployments + with ( + create_deployment(ELLM_DESCRIBE_DEPLOYMENT), + create_deployment(ELLM_EMBEDDING_DEPLOYMENT), + create_deployment(RERANK_ENGLISH_v3_SMALL_DEPLOYMENT), + ): + client = JamAI(user_id=superuser.id, project_id=p0.id) + image_uri = upload_file(client, FILES["rabbit.jpeg"]).uri + audio_uri = upload_file(client, FILES["gutter.mp3"]).uri + document_uri = upload_file( + client, FILES["LLMs as Optimizers [DeepMind ; 2023].pdf"] + ).uri + yield ServingContext( + superuser_id=superuser.id, + superorg_id=superorg.id, + project_id=p0.id, + embedding_size=embed_config.final_embedding_size, + image_uri=image_uri, + audio_uri=audio_uri, + document_uri=document_uri, + chat_model_id=desc_llm_config.id, + embed_model_id=embed_config.id, + rerank_model_id=rerank_config.id, + ) + + +@dataclass(slots=True) +class Data: + data_list: list[dict[str, Any]] + action_data_list: list[dict[str, Any]] + knowledge_data: dict[str, Any] + chat_data: dict[str, Any] + extra_data: dict[str, Any] + + +def _default_data(setup: ServingContext): + action_data_list = [ + { + "ID": str(i), + "Updated at": "1990-05-13T09:01:50.010756+00:00", + "int": 1 if i % 2 == 0 else (1.0 if i % 4 == 1 else None), + "float": -1.25 if i % 2 == 0 else (5 if i % 4 == 1 else None), + "bool": True if i % 2 == 0 else (False if i % 4 == 1 else None), + "str": t, + "image": setup.image_uri if i % 2 == 0 else None, + "audio": setup.audio_uri if i % 2 == 0 else None, + "document": setup.document_uri if i % 2 == 0 else None, + "summary": t if i % 2 == 0 else ("" if i % 4 == 1 else None), + } + for i, t in enumerate(TEXTS.values()) + ] + # Assert integers and floats contain a mix of int, float, None + _ints = [type(d["int"]) for d in action_data_list] + assert int in _ints + assert float in _ints + assert NoneType in _ints + _floats = [type(d["float"]) for d in action_data_list] + assert int in _floats + assert float in _floats + assert NoneType in _floats + # Assert booleans contain a mix of True, False, None + _bools = [d["bool"] for d in action_data_list] + assert True in _bools + assert False in _bools + assert None in _bools + # Assert strings contain a mix of empty string and None + _summaries = [d["summary"] for d in action_data_list] + assert None in _summaries + assert "" in _summaries + knowledge_data = { + "Title": "Dune: Part Two.", + "Text": '"Dune: Part Two" is a film.', + # We use values that can be represented exactly as IEEE floats to ease comparison + "Title Embed": [-1.25] * setup.embedding_size, + "Text Embed": [0.25] * setup.embedding_size, + "File ID": setup.document_uri, + } + chat_data = dict(User=".", AI=".") + extra_data = dict(good=True, words=5) + return Data( + data_list=[ + dict(**d, **knowledge_data, **chat_data, **extra_data) for d in action_data_list + ], + action_data_list=action_data_list, + knowledge_data=knowledge_data, + chat_data=chat_data, + extra_data=extra_data, + ) + + +def _default_dtype( + data: list[dict[str, Any]], + *, + cast_to_string: bool = False, +) -> dict[str, pd.Int64Dtype | pd.Float32Dtype | pd.BooleanDtype | pd.StringDtype]: + cols = set() + for row in data: + cols |= set(row.keys()) + dtype = { + "ID": pd.StringDtype(), + "Updated at": pd.StringDtype(), + "int": pd.Int64Dtype() if not cast_to_string else pd.StringDtype(), + "float": pd.Float32Dtype() if not cast_to_string else pd.StringDtype(), + "bool": pd.BooleanDtype() if not cast_to_string else pd.StringDtype(), + "str": pd.StringDtype(), + "image": pd.StringDtype(), + "audio": pd.StringDtype(), + "document": pd.StringDtype(), + "summary": pd.StringDtype(), + "Title": pd.StringDtype(), + "Text": pd.StringDtype(), + "Title Embed": object, + "Text Embed": object, + "File ID": pd.StringDtype(), + "User": pd.StringDtype(), + "AI": pd.StringDtype(), + "good": pd.BooleanDtype() if not cast_to_string else pd.StringDtype(), + "words": pd.Int64Dtype() if not cast_to_string else pd.StringDtype(), + } + return {k: v for k, v in dtype.items() if k in cols} + + +def _as_df( + data: list[dict[str, Any]], + *, + cast_to_string: bool = False, +) -> pd.DataFrame: + dtype = _default_dtype(data, cast_to_string=cast_to_string) + if cast_to_string: + data = [{k: None if v is None else str(v) for k, v in d.items()} for d in data] + df = pd.DataFrame.from_dict(data).astype(dtype) + return df + + +def _check_rows( + rows: list[dict[str, Any]], + data: list[dict[str, Any]], +): + return check_rows(rows, data, info_cols_equal=False) + + +def _check_knowledge_chat_data( + table_type: TableType, + rows: list[dict[str, Any]], + data: Data, +): + if table_type == TableType.KNOWLEDGE: + _check_rows(rows, [data.knowledge_data] * len(data.data_list)) + elif table_type == TableType.CHAT: + _check_rows(rows, [data.chat_data] * len(data.data_list)) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +@pytest.mark.parametrize("delimiter", [","], ids=["comma"]) +def test_data_import_complete( + setup: ServingContext, + table_type: TableType, + stream: bool, + delimiter: str, +): + """ + Test table data import. + - All column types including vector + - Ensure "ID" and "Updated at" columns are regenerated + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + stream (bool): Stream or not. + delimiter (str): Delimiter. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + with create_table(client, table_type) as table: + with TemporaryDirectory() as tmp_dir: + file_path = join(tmp_dir, "test_data_import_complete.csv") + data = _default_data(setup) + df = _as_df(data.data_list) + df_to_csv(df, file_path, delimiter) + response = import_table_data( + client, + table_type, + table.id, + file_path, + stream=stream, + delimiter=delimiter, + ) + # We currently dont return anything if LLM is not called + assert len(response.rows) == 0 if stream else len(data.data_list) + assert all(len(r.columns) == 0 for r in response.rows) + # Check imported data + rows = list_table_rows(client, table_type, table.id, vec_decimals=2) + assert len(rows.items) == len(data.data_list) + assert rows.total == len(data.data_list) + _check_rows(rows.values, data.action_data_list) + _check_knowledge_chat_data(table_type, rows.values, data) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +@pytest.mark.parametrize("delimiter", [","], ids=["comma"]) +def test_data_import_dtype_coercion( + setup: ServingContext, + table_type: TableType, + stream: bool, + delimiter: str, +): + """ + Test table data import. + - Column dtype coercion (nulls, int <=> float, bool <=> int) + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + stream (bool): Stream or not. + delimiter (str): Delimiter. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + with create_table(client, table_type) as table: + with TemporaryDirectory() as tmp_dir: + file_path = join(tmp_dir, "test_data_import_dtype_coercion.csv") + header = ["int", "float", "bool", "str", "image", "audio", "document", "summary", "AI"] + data = [ + # Base case + [1, 2.0, True, '""', '""', '""', '""', '""', '""'], + # Coercion + [1.0, 2, 1, '""', "", "", "", "", ""], + [-1.0, -2, 0, "", "", "", "", "", ""], + ["", "", "", "", "", "", "", "", ""], + ] + with open(file_path, "w", encoding="utf-8") as f: + f.write(f"{delimiter.join(header)}\n") + f.write("\n".join(delimiter.join(map(str, d)) for d in data)) + response = import_table_data( + client, + table_type, + table.id, + file_path, + stream=stream, + delimiter=delimiter, + ) + # We currently dont return anything if LLM is not called + assert len(response.rows) == 0 if stream else len(data) + assert all(len(r.columns) == 0 for r in response.rows) + # Check imported data + rows = list_table_rows(client, table_type, table.id, vec_decimals=2) + assert len(rows.items) == len(data) + assert rows.total == len(data) + # All strings are null + for col in ["str", "image", "audio", "document", "summary", "AI"]: + assert all(v.get(col, None) is None for v in rows.values) + # Check values + for col in ["int", "float", "bool"]: + for v, d in zip(rows.values, data, strict=True): + if d[header.index(col)] in ["", '""']: + assert v[col] is None + else: + assert v[col] == d[header.index(col)] + assert isinstance(v[col], getattr(builtins, col)) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +@pytest.mark.parametrize("delimiter", [","], ids=["comma"]) +def test_data_import_cast_to_string( + setup: ServingContext, + table_type: TableType, + stream: bool, + delimiter: str, +): + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + dtypes = ["int", "float", "bool", "str", "image", "audio", "document"] + cols = [ColumnSchemaCreate(id=dtype, dtype="str") for dtype in dtypes] + cols += [ + ColumnSchemaCreate( + id="summary", + dtype="str", + gen_config=LLMGenConfig(model="", system_prompt="", prompt=""), + ), + ] + with create_table(client, table_type, cols=cols) as table: + with TemporaryDirectory() as tmp_dir: + file_path = join(tmp_dir, "test_data_import_cast_to_string.csv") + data = _default_data(setup) + df = _as_df(data.data_list) + # Assert some columns are not string type + assert not all(d == pd.StringDtype() for d in df.dtypes.tolist()) + df_to_csv(df, file_path, delimiter) + response = import_table_data( + client, + table_type, + table.id, + file_path, + stream=stream, + delimiter=delimiter, + ) + # We currently dont return anything if LLM is not called + assert len(response.rows) == 0 if stream else len(data.data_list) + assert all(len(r.columns) == 0 for r in response.rows) + # Check imported data + rows = list_table_rows(client, table_type, table.id, vec_decimals=2) + assert len(rows.items) == len(data.data_list) + assert rows.total == len(data.data_list) + action_data_list = [ + { + k: None + if v is None + else str(int(v) if k == "int" else (float(v) if k == "float" else v)) + for k, v in d.items() + } + for d in data.action_data_list + ] + _check_rows(rows.values, action_data_list) + _check_knowledge_chat_data(table_type, rows.values, data) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +@pytest.mark.parametrize("delimiter", [","], ids=["comma"]) +def test_data_import_cast_from_string( + setup: ServingContext, + table_type: TableType, + stream: bool, + delimiter: str, +): + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + with create_table(client, table_type) as table: + with TemporaryDirectory() as tmp_dir: + file_path = join(tmp_dir, "test_data_import_cast_from_string.csv") + data = _default_data(setup) + df = _as_df(data.data_list, cast_to_string=True) + # Assert all columns (except embedding) are string type + assert all( + v == pd.StringDtype() for k, v in df.dtypes.to_dict().items() if "Embed" not in k + ), df.dtypes.to_dict() + df_to_csv(df, file_path, delimiter) + response = import_table_data( + client, + table_type, + table.id, + file_path, + stream=stream, + delimiter=delimiter, + ) + # We currently dont return anything if LLM is not called + assert len(response.rows) == 0 if stream else len(data.data_list) + assert all(len(r.columns) == 0 for r in response.rows) + # Check imported data + rows = list_table_rows(client, table_type, table.id, vec_decimals=2) + assert len(rows.items) == len(data.data_list) + assert rows.total == len(data.data_list) + _check_rows(rows.values, data.action_data_list) + _check_knowledge_chat_data(table_type, rows.values, data) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +@pytest.mark.parametrize("delimiter", [","], ids=["comma"]) +def test_data_import_missing_input_column( + setup: ServingContext, + table_type: TableType, + stream: bool, + delimiter: str, +): + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + with create_table(client, table_type) as table: + with TemporaryDirectory() as tmp_dir: + file_path = join(tmp_dir, "test_data_import_missing_input_column.csv") + data = _default_data(setup) + df = _as_df(data.data_list) + df = df.drop(columns=["int"]) + df_to_csv(df, file_path, delimiter) + response = import_table_data( + client, + table_type, + table.id, + file_path, + stream=stream, + delimiter=delimiter, + ) + # We currently dont return anything if LLM is not called + assert len(response.rows) == 0 if stream else len(data.data_list) + assert all(len(r.columns) == 0 for r in response.rows) + # Check imported data + rows = list_table_rows(client, table_type, table.id, vec_decimals=2) + assert len(rows.items) == len(data.data_list) + assert rows.total == len(data.data_list) + _check_rows( + rows.values, + [{k: v for k, v in d.items() if k != "int"} for d in data.action_data_list], + ) + _check_knowledge_chat_data(table_type, rows.values, data) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +@pytest.mark.parametrize("delimiter", [","], ids=["comma"]) +def test_data_import_with_generation( + setup: ServingContext, + table_type: TableType, + stream: bool, + delimiter: str, +): + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + with create_table(client, table_type) as table: + with TemporaryDirectory() as tmp_dir: + file_path = join(tmp_dir, "test_data_import_with_generation.csv") + data = _default_data(setup) + df = _as_df(data.data_list) + df = df.drop(columns=["summary"]) + df_to_csv(df, file_path, delimiter) + response = import_table_data( + client, + table_type, + table.id, + file_path, + stream=stream, + delimiter=delimiter, + ) + # LLM is called + assert len(response.rows) == len(data.data_list) + assert all(len(r.columns) == 1 for r in response.rows) + assert all("summary" in r.columns for r in response.rows) + # Check imported data + rows = list_table_rows(client, table_type, table.id, vec_decimals=2) + assert len(rows.items) == len(data.data_list) + assert rows.total == len(data.data_list) + _check_rows( + rows.values, + [{k: v for k, v in d.items() if k != "summary"} for d in data.action_data_list], + ) + _check_knowledge_chat_data(table_type, rows.values, data) + # Check LLM generation + summaries = [row["summary"] for row in rows.values] + assert all("There is a text" in s for s in summaries) + assert sum("There is an image with MIME type [image/jpeg]" in s for s in summaries) > 0 + assert sum("There is an audio with MIME type [audio/mpeg]" in s for s in summaries) > 0 + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +@pytest.mark.parametrize("delimiter", [","], ids=["comma"]) +def test_data_import_empty( + setup: ServingContext, + table_type: TableType, + stream: bool, + delimiter: str, +): + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + with create_table(client, table_type) as table: + with TemporaryDirectory() as tmp_dir: + # Empty file + file_path = join(tmp_dir, "empty.csv") + with open(file_path, "w", encoding="utf-8") as f: + f.write("") + with pytest.raises(BadInputError, match="is empty"): + import_table_data( + client, + table_type, + table.id, + file_path, + stream=stream, + delimiter=delimiter, + ) + # No rows + file_path = join(tmp_dir, "no_rows.csv") + with open(file_path, "w", encoding="utf-8") as f: + f.write(delimiter.join(c.id for c in table.cols) + "\n") + with pytest.raises(BadInputError, match="no rows"): + import_table_data( + client, + table_type, + table.id, + file_path, + stream=stream, + delimiter=delimiter, + ) + rows = list_table_rows(client, table_type, table.id, vec_decimals=2) + assert len(rows.items) == 0 + + +def _export_table_rows( + client: JamAI, + table_type: TableType, + table: TableMetaResponse, + *, + data: Data, + delimiter: str, + columns: list[str] | None = None, +) -> tuple[list[dict[str, Any]], pd.DataFrame]: + csv_bytes = client.table.export_table_data( + table_type, + table.id, + delimiter=delimiter, + ) + dtype = _default_dtype(data.data_list, cast_to_string=False) + if columns is None: + columns = [c.id for c in table.cols] + csv_df = csv_to_df( + csv_bytes.decode("utf-8"), + sep=delimiter, + keep_default_na=True, + ).astype({k: v for k, v in dtype.items() if k in columns}) + exported_rows = csv_df.to_dict(orient="records") + assert len(exported_rows) == len(data.data_list) + assert all(isinstance(row, dict) for row in exported_rows) + assert all("ID" in row for row in exported_rows) + assert all("Updated at" in row for row in exported_rows) + return exported_rows, csv_df + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +@pytest.mark.parametrize("delimiter", [","], ids=["comma"]) +def test_data_export( + setup: ServingContext, + table_type: TableType, + stream: bool, + delimiter: str, +): + """ + Test table data export. + - Export all columns (round trip) + - Export subset of columns (round trip) + - Export after column reorder (check column order, round trip) + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + stream (bool): Stream or not. + delimiter (str): Delimiter. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + with create_table(client, table_type) as table: + with TemporaryDirectory() as tmp_dir: + file_path = join(tmp_dir, "test_data_export.csv") + data = _default_data(setup) + df_original = _as_df(data.data_list) + df_to_csv(df_original, file_path, delimiter) + response = import_table_data( + client, + table_type, + table.id, + file_path, + stream=stream, + delimiter=delimiter, + ) + # We currently dont return anything if LLM is not called + assert len(response.rows) == 0 if stream else len(data.data_list) + assert all(len(r.columns) == 0 for r in response.rows) + # Check imported data + rows = list_table_rows(client, table_type, table.id, vec_decimals=2) + assert len(rows.items) == len(data.data_list) + assert rows.total == len(data.data_list) + row_ids = [r["ID"] for r in rows.items] + + ### --- Export all columns, round trip --- ### + exported_rows, _ = _export_table_rows( + client, + table_type, + table, + data=data, + delimiter=delimiter, + ) + # Check row order + exported_row_ids = [r["ID"] for r in exported_rows] + assert row_ids == exported_row_ids + # Check row content + _check_rows(exported_rows, data.action_data_list) + + ### --- Export subset of columns --- ### + columns = [c.id for c in table.cols][:2] + assert len(columns) < len(table.cols) + exported_rows, _ = _export_table_rows( + client, + table_type, + table, + data=data, + delimiter=delimiter, + columns=columns, + ) + assert len(exported_rows) == len(data.data_list) + _check_rows( + exported_rows, + [{k: v for k, v in d.items() if k in columns} for d in data.action_data_list], + ) + + ### --- Export after column reorder --- ### + new_order = ["int", "float", "bool", "str", "image", "audio", "document"][::-1] + new_order += ["summary"] + if table_type == TableType.KNOWLEDGE: + new_order = [ + "Title", + "Title Embed", + "Text", + "Text Embed", + "File ID", + "Page", + ] + new_order + elif table_type == TableType.CHAT: + new_order = ["User", "AI"] + new_order + table = client.table.reorder_columns( + table_type=table_type, + request=ColumnReorderRequest(table_id=table.id, column_names=new_order), + ) + assert isinstance(table, TableMetaResponse) + exported_rows, exported_df = _export_table_rows( + client, + table_type, + table, + data=data, + delimiter=delimiter, + ) + _check_rows(exported_rows, data.action_data_list) + # Check column order + expected_columns = ["ID", "Updated at"] + new_order + assert expected_columns == list(exported_df.columns) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("blocking", [True, False], ids=["blocking", "non-blocking"]) +def test_table_import_export( + setup: ServingContext, + table_type: TableType, + blocking: bool, +): + """ + Test table import and export. + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + """ + stream = False + delimiter = "," + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + with create_table(client, table_type) as table: + # Export empty table + pq_data = client.table.export_table(table_type, table.id) + assert len(pq_data) > 0 + # Add data + with TemporaryDirectory() as tmp_dir: + file_path = join(tmp_dir, "test_table_import_export.csv") + data = _default_data(setup) + df_original = _as_df(data.data_list) + df_to_csv(df_original, file_path, delimiter) + response = import_table_data( + client, + table_type, + table.id, + file_path, + stream=stream, + delimiter=delimiter, + ) + # We currently dont return anything if LLM is not called + assert len(response.rows) == 0 if stream else len(data.data_list) + assert all(len(r.columns) == 0 for r in response.rows) + + ### --- Export table --- ### + table_id_dst = f"{table.id}_import" + try: + with TemporaryDirectory() as tmp_dir: + file_path = join(tmp_dir, f"{table.id}.parquet") + with open(file_path, "wb") as f: + f.write(client.table.export_table(table_type, table.id)) + + ### --- Import table --- ### + # Bad name + with pytest.raises(BadInputError): + client.table.import_table( + table_type, + TableImportRequest( + file_path=file_path, + table_id_dst=f"_{table_id_dst}", + blocking=blocking, + ), + ) + # OK + response = client.table.import_table( + table_type, + TableImportRequest( + file_path=file_path, + table_id_dst=table_id_dst, + blocking=blocking, + ), + ) + if blocking: + table_dst = response + else: + # Poll progress + assert isinstance(response, OkResponse) + assert isinstance(response.progress_key, str) + assert len(response.progress_key) > 0 + prog = client.tasks.poll_progress(response.progress_key, max_wait=30) + assert isinstance(prog, dict) + table_dst = TableMetaResponse.model_validate(prog["data"]["table_meta"]) + assert isinstance(table_dst, TableMetaResponse) + assert table_dst.id == table_id_dst + # Source data + rows = list_table_rows(client, table_type, table.id, vec_decimals=2) + assert len(rows.items) == len(data.data_list) + assert rows.total == len(data.data_list) + # Destination data + rows_dst = list_table_rows(client, table_type, table_dst.id, vec_decimals=2) + # Compare + for row, row_dst in zip(rows.items, rows_dst.items, strict=True): + assert len(row) == len(row_dst) + for col in row: + if col in FILE_COLUMNS: + # File columns should not match due to different S3 URI, unless it is None + value_ori = row[col]["value"] + value_dst = row_dst[col]["value"] + if value_ori is None: + assert value_dst is None + else: + assert value_ori != value_dst + # But content should match + urls = client.file.get_raw_urls([value_ori, value_dst]) + assert isinstance(urls, GetURLResponse) + file_ori = httpx.get(urls.urls[0]).content + file_dst = httpx.get(urls.urls[1]).content + assert file_ori == file_dst + else: + # Regular columns should match exactly (including info columns) + assert row[col] == row_dst[col] + # All "File ID" values should be populated + if table_type == TableType.KNOWLEDGE: + for row_dst in rows_dst.values: + assert isinstance(row_dst["File ID"], str) + assert len(row_dst["File ID"]) > 0 + assert len(set(r["File ID"] for r in rows_dst.values)) == 1 + finally: + client.table.delete_table(table_type, table_id_dst) + + +@pytest.mark.parametrize("delimiter", [","], ids=["comma"]) +def test_table_import_wrong_type( + setup: ServingContext, + delimiter: str, +): + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + with create_table(client, TableType.ACTION) as table: + with TemporaryDirectory() as tmp_dir: + file_path = join(tmp_dir, "test_table_import_wrong_type.csv") + data = _default_data(setup) + df = _as_df(data.data_list) + df_to_csv(df, file_path, delimiter) + response = import_table_data( + client, + TableType.ACTION, + table.id, + file_path, + stream=False, + delimiter=delimiter, + ) + # We currently dont return anything if LLM is not called + assert len(response.rows) == len(data.data_list) + assert all(len(r.columns) == 0 for r in response.rows) + with TemporaryDirectory() as tmp_dir: + file_path = join(tmp_dir, "test_table_import_wrong_type.parquet") + # Export + with open(file_path, "wb") as f: + f.write(client.table.export_table(TableType.ACTION, table.id)) + table_id_dst = f"{table.id}_import" + # Import as knowledge + with pytest.raises(BadInputError): + client.table.import_table( + TableType.KNOWLEDGE, + TableImportRequest( + file_path=file_path, + table_id_dst=table_id_dst, + ), + ) + # Import as chat + with pytest.raises(BadInputError): + client.table.import_table( + TableType.CHAT, + TableImportRequest( + file_path=file_path, + table_id_dst=table_id_dst, + ), + ) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("version", ["v0.4"]) +def test_table_import_parquet( + setup: ServingContext, + table_type: TableType, + version: str, +): + """ + Test table import from an existing Parquet file. + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + ### --- Basic tables --- ### + if table_type == TableType.CHAT: + parquet_filepath = FILES[f"export-{version}-chat-agent.parquet"] + else: + parquet_filepath = FILES[f"export-{version}-{table_type}.parquet"] + # Embedding model cannot be swapped for another model + if table_type == TableType.KNOWLEDGE: + with pytest.raises(BadInputError, match="Embedding model .+ is not found"): + client.table.import_table( + table_type, + TableImportRequest(file_path=parquet_filepath, table_id_dst=None), + ) + # Add the required embedding model + embed_model = "ellm/BAAI/bge-m3" + model = ELLM_EMBEDDING_CONFIG.model_copy(update=dict(id=embed_model, owned_by="ellm")) + deployment = ELLM_EMBEDDING_DEPLOYMENT.model_copy(update=dict(model_id=embed_model)) + with create_model_config(model), create_deployment(deployment): + table = client.table.import_table( + table_type, + TableImportRequest(file_path=parquet_filepath, table_id_dst=None), + ) + try: + assert isinstance(table, TableMetaResponse) + ### Table ID should be derived from the Parquet data + if table_type == TableType.CHAT: + assert table.id == "test-agent" + else: + assert table.id == f"test-{table_type}" + assert table.parent_id is None + col_map = {c.id: c for c in table.cols} + embed_cols = ["Title Embed", "Text Embed"] + + ### Check gen config + if table_type == TableType.ACTION: + gen_config = col_map["answer"].gen_config + assert isinstance(gen_config, LLMGenConfig) + assert gen_config.model == setup.chat_model_id + elif table_type == TableType.KNOWLEDGE: + for c in embed_cols: + gen_config = col_map[c].gen_config + assert isinstance(gen_config, EmbedGenConfig) + assert gen_config.embedding_model == embed_model + else: + gen_config = col_map["AI"].gen_config + assert isinstance(gen_config, LLMGenConfig) + assert gen_config.model == setup.chat_model_id + assert gen_config.multi_turn is True + + ### List rows + rows = list_table_rows(client, table_type, table.id, vec_decimals=2) + assert len(rows.items) == 1 + assert rows.total == 1 + row = rows.values[0] + if table_type == TableType.ACTION: + # Check text content + assert row["question"] == "What is this" + assert row["answer"] == "This is a deer." + assert row["null"] == "" + # Check image content + urls = client.file.get_raw_urls([row["image"]]) + assert isinstance(urls, GetURLResponse) + image = httpx.get(urls.urls[0]).content + with open(FILES["cifar10-deer.jpg"], "rb") as f: + assert image == f.read() + elif table_type == TableType.KNOWLEDGE: + # Check text content + assert row["Title"] == "Gunicorn: A Python WSGI HTTP Server" + assert row["Text"] == "Gunicorn is a Python WSGI HTTP Server." + # Check vector content + for c in embed_cols: + assert_is_vector_or_none(row[c], allow_none=False) + else: + # Check text content + assert row["User"] == "Hi" + assert row["AI"] == ( + "Hello! How can I assist you today? " + "Let me know what you're looking for, and I'll do my best to help. 😊" + ) + + ### Try generation + if table_type == TableType.ACTION: + response = add_table_rows( + client, table_type, table.id, [{"question": "Why"}], stream=False + ) + assert len(response.rows) == 1 + assert "There is a text" in response.rows[0].columns["answer"].content + elif table_type == TableType.KNOWLEDGE: + response = add_table_rows(client, table_type, table.id, [{}], stream=False) + assert len(response.rows) == 1 + else: + response = add_table_rows( + client, table_type, table.id, [{"User": "Hi"}], stream=False + ) + assert len(response.rows) == 1 + assert "There is a text" in response.rows[0].columns["AI"].content + rows = list_table_rows(client, table_type, table.id, vec_decimals=2) + assert len(rows.items) == 2 + assert rows.total == 2 + finally: + client.table.delete_table(table_type, table.id) + + ### --- Chat table (child table) --- ### + if table_type == TableType.CHAT: + table = client.table.import_table( + table_type, + TableImportRequest( + file_path=FILES[f"export-{version}-chat-agent-1.parquet"], table_id_dst=None + ), + ) + try: + assert isinstance(table, TableMetaResponse) + # Table ID should be derived from the Parquet data + assert table.id == "test-agent-1" + # TODO: Perhaps need to handle missing parent and RAG table + assert table.parent_id == "test-agent" + # List rows + rows = list_table_rows(client, table_type, table.id, vec_decimals=2) + assert len(rows.items) == 2 + assert rows.total == 2 + # Check text content + assert rows.values[0]["User"] == "Hi" + assert rows.values[0]["AI"].startswith("Hello! How can I assist you today?") + assert rows.values[1]["User"] == "What is 美洲驼?" + assert rows.values[1]["AI"].startswith( + "**美洲驼** (MÄ›izhÅu tuó) 是以下两ç§å—美洲骆驼科动物的中文统称: \n\n1. **羊驼**" + ) + rows_r = list_table_rows( + client, table_type, table.id, order_ascending=False, vec_decimals=2 + ) + assert all(rr == r for rr, r in zip(rows_r.values[::-1], rows.values, strict=True)) + finally: + client.table.delete_table(table_type, table.id) diff --git a/services/api/tests/gen_table/test_row_ops.py b/services/api/tests/gen_table/test_row_ops.py new file mode 100644 index 0000000..8f13efb --- /dev/null +++ b/services/api/tests/gen_table/test_row_ops.py @@ -0,0 +1,2206 @@ +import re +from contextlib import contextmanager +from dataclasses import dataclass +from decimal import Decimal +from os.path import basename, dirname, join, realpath +from tempfile import TemporaryDirectory +from time import sleep +from typing import Generator + +import httpx +import pandas as pd +import pytest +from flaky import flaky + +from jamaibase import JamAI +from jamaibase.types import ( + ActionTableSchemaCreate, + AddActionColumnSchema, + AddChatColumnSchema, + AddKnowledgeColumnSchema, + CellCompletionResponse, + ChatTableSchemaCreate, + ChatThreadResponse, + CodeInterpreterTool, + ColumnReorderRequest, + ColumnSchema, + ColumnSchemaCreate, + DeploymentCreate, + GenConfigUpdateRequest, + KnowledgeTableSchemaCreate, + MultiRowAddRequest, + MultiRowCompletionResponse, + MultiRowDeleteRequest, + MultiRowRegenRequest, + MultiRowUpdateRequest, + OkResponse, + RowCompletionResponse, + SearchRequest, + TableMetaResponse, + WebSearchTool, +) +from jamaibase.utils.io import df_to_csv +from owl.types import ( + ChatRole, + CloudProvider, + LLMGenConfig, + ModelCapability, + RegenStrategy, + Role, + TableType, +) +from owl.utils.exceptions import ( + BadInputError, + JamaiException, + ResourceNotFoundError, +) +from owl.utils.test import ( + ELLM_EMBEDDING_CONFIG, + ELLM_EMBEDDING_DEPLOYMENT, + GPT_4O_MINI_CONFIG, + GPT_4O_MINI_DEPLOYMENT, + GPT_5_MINI_CONFIG, + GPT_5_MINI_DEPLOYMENT, + OPENAI_O4_MINI_CONFIG, + OPENAI_O4_MINI_DEPLOYMENT, + STREAM_PARAMS, + RERANK_ENGLISH_v3_SMALL_CONFIG, + RERANK_ENGLISH_v3_SMALL_DEPLOYMENT, + add_table_rows, + create_deployment, + create_model_config, + create_organization, + create_project, + create_user, + get_file_map, + list_table_rows, + regen_table_rows, + upload_file, +) + +TEST_FILE_DIR = join(dirname(dirname(realpath(__file__))), "files") +FILES = get_file_map(TEST_FILE_DIR) + +TABLE_TYPES = [TableType.ACTION, TableType.KNOWLEDGE, TableType.CHAT] +TABLE_ID_A = "table_a" +TABLE_ID_B = "table_b" +TABLE_ID_C = "table_c" +TABLE_ID_X = "table_x" +TEXT = '"Arrival" is a 2016 American science fiction drama film directed by Denis Villeneuve and adapted by Eric Heisserer.' +TEXT_CN = ( + '"Arrival" 《é™ä¸´ã€‹æ˜¯ä¸€éƒ¨ 2016 年美国科幻剧情片,由丹尼斯·维伦纽瓦执导,埃里克·海瑟尔改编。' +) +TEXT_JP = '"Arrival" 「Arrivalã€ã¯ã€ãƒ‰ã‚¥ãƒ‹ãƒ»ãƒ´ã‚£ãƒ«ãƒŒãƒ¼ãƒ´ãŒç›£ç£ã—ã€ã‚¨ãƒªãƒƒã‚¯ãƒ»ãƒã‚¤ã‚»ãƒ©ãƒ¼ãŒè„šè‰²ã—ãŸ2016å¹´ã®ã‚¢ãƒ¡ãƒªã‚«ã®SFドラマ映画ã§ã™ã€‚' + +EMBED_WHITE_LIST_EXT = [ + "application/pdf", # pdf + "text/markdown", # md + "text/plain", # txt + "text/html", # html + "text/xml", # xml + "application/xml", # xml + "application/json", # json + "application/jsonl", # jsonl + "application/x-ndjson", # alternative for jsonl + "application/json-lines", # another alternative for jsonl + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", # docx + "application/vnd.openxmlformats-officedocument.presentationml.presentation", # pptx + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", # xlsx + "text/tab-separated-values", # tsv + "text/csv", # csv +] + + +@dataclass(slots=True) +class ServingContext: + superuser_id: str + user_id: str + org_id: str + project_id: str + + +@pytest.fixture(scope="module") +def setup(): + """ + Fixture to set up the necessary organization and projects for file tests. + """ + with ( + # Create superuser + create_user() as superuser, + # Create user + create_user({"email": "testuser@example.com", "name": "Test User"}) as user, + # Create organization + create_organization(user_id=superuser.id) as org, + # Create project + create_project(dict(name="Bucket A"), user_id=superuser.id, organization_id=org.id) as p0, + ): + assert org.id == "0" + client = JamAI(user_id=superuser.id) + # Join organization and project + client.organizations.join_organization( + user_id=user.id, organization_id=org.id, role=Role.ADMIN + ) + client.projects.join_project(user_id=user.id, project_id=p0.id, role=Role.ADMIN) + + # Create models + with ( + create_model_config(GPT_4O_MINI_CONFIG), + create_model_config(GPT_5_MINI_CONFIG), + create_model_config(OPENAI_O4_MINI_CONFIG), + create_model_config( + { + # "id": "openai/Qwen/Qwen-2-Audio-7B", + "id": "openai/gpt-4o-mini-audio-preview", + "type": "llm", + # "name": "ELLM Qwen2 Audio (7B)", + "name": "OpenAI GPT-4o Mini Audio Preview", + "capabilities": ["chat", "audio"], + "context_length": 128000, + "languages": ["en"], + } + ) as llm_config_audio, + create_model_config(ELLM_EMBEDDING_CONFIG), + create_model_config(RERANK_ENGLISH_v3_SMALL_CONFIG), + ): + # Create deployments + with ( + create_deployment(GPT_4O_MINI_DEPLOYMENT), + create_deployment(GPT_5_MINI_DEPLOYMENT), + create_deployment(OPENAI_O4_MINI_DEPLOYMENT), + create_deployment( + DeploymentCreate( + model_id=llm_config_audio.id, + # name="ELLM Qwen2 Audio (7B) Deployment", + name="OpenAI GPT-4o Mini Audio Preview Deployment", + # provider=CloudProvider.ELLM, + provider=CloudProvider.OPENAI, + routing_id=llm_config_audio.id, + # api_base="https://llmci.embeddedllm.com/audio/v1", + api_base="", + ) + ), + create_deployment(ELLM_EMBEDDING_DEPLOYMENT), + create_deployment(RERANK_ENGLISH_v3_SMALL_DEPLOYMENT), + ): + yield ServingContext( + superuser_id=superuser.id, + user_id=user.id, + org_id=org.id, + project_id=p0.id, + ) + + +def _get_chat_model(client: JamAI) -> str: + models = client.model_ids(prefer="openai/gpt-4o-mini", capabilities=["chat"]) + return models[0] + + +def _get_reasoning_model(client: JamAI) -> str: + models = client.model_ids(prefer="openai/gpt-5-mini", capabilities=["reasoning"]) + return models[0] + + +def _get_reranking_model(client: JamAI) -> str: + models = client.model_ids(capabilities=["rerank"]) + return models[0] + + +@contextmanager +def _create_table( + client: JamAI, + table_type: TableType, + table_id: str = TABLE_ID_A, + cols: list[ColumnSchemaCreate] | None = None, + chat_cols: list[ColumnSchemaCreate] | None = None, + embedding_model: str | None = None, +): + try: + if cols is None: + cols = [ + ColumnSchemaCreate(id="good", dtype="bool"), + ColumnSchemaCreate(id="words", dtype="int"), + ColumnSchemaCreate(id="stars", dtype="float"), + ColumnSchemaCreate(id="inputs", dtype="str"), + ColumnSchemaCreate(id="photo", dtype="image"), + ColumnSchemaCreate(id="audio", dtype="audio"), + ColumnSchemaCreate(id="paper", dtype="document"), + ColumnSchemaCreate( + id="summary", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client), + system_prompt="You are a concise assistant.", + # Interpolate string and non-string input columns + prompt="Summarise this in ${words} words:\n\n${inputs}", + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), + ColumnSchemaCreate( + id="captioning", + dtype="str", + gen_config=LLMGenConfig( + model="", + system_prompt="You are a concise assistant.", + # Interpolate file input column + prompt="${photo} \n\nWhat's in the image?", + temperature=0.001, + top_p=0.001, + max_tokens=20, + ), + ), + ColumnSchemaCreate( + id="narration", + dtype="str", + gen_config=LLMGenConfig( + model="", + prompt="${audio} \n\nWhat happened?", + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), + ColumnSchemaCreate( + id="concept", + dtype="str", + gen_config=LLMGenConfig( + model="", + prompt="${paper} \n\nTell the main concept of the paper in 5 words.", + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), + ] + if chat_cols is None: + chat_cols = [ + ColumnSchemaCreate(id="User", dtype="str"), + ColumnSchemaCreate( + id="AI", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client), + system_prompt="You are a wacky assistant.", + temperature=0.001, + top_p=0.001, + max_tokens=5, + ), + ), + ] + + if table_type == TableType.ACTION: + table = client.table.create_action_table( + ActionTableSchemaCreate(id=table_id, cols=cols) + ) + elif table_type == TableType.KNOWLEDGE: + if embedding_model is None: + embedding_model = "" + table = client.table.create_knowledge_table( + KnowledgeTableSchemaCreate(id=table_id, cols=cols, embedding_model=embedding_model) + ) + elif table_type == TableType.CHAT: + table = client.table.create_chat_table( + ChatTableSchemaCreate(id=table_id, cols=chat_cols + cols) + ) + else: + raise ValueError(f"Invalid table type: {table_type}") + assert isinstance(table, TableMetaResponse) + yield table + finally: + client.table.delete_table(table_type, table_id) + + +def _add_row( + client: JamAI, + table_type: TableType, + stream: bool, + table_name: str = TABLE_ID_A, + data: dict | None = None, + knowledge_data: dict | None = None, + chat_data: dict | None = None, +): + if data is None: + data = dict( + good=True, + words=5, + stars=7.9, + inputs=TEXT, + photo=upload_file(client, FILES["rabbit.jpeg"]).uri, + audio=upload_file(client, FILES["turning-a4-size-magazine.mp3"]).uri, + paper=upload_file(client, FILES["LLMs as Optimizers [DeepMind ; 2023].pdf"]).uri, + ) + + if knowledge_data is None: + knowledge_data = dict( + Title="Dune: Part Two.", + Text='"Dune: Part Two" is a 2024 American epic science fiction film.', + ) + if chat_data is None: + chat_data = dict(User="Tell me a joke.") + if table_type == TableType.ACTION: + pass + elif table_type == TableType.KNOWLEDGE: + data.update(knowledge_data) + elif table_type == TableType.CHAT: + data.update(chat_data) + else: + raise ValueError(f"Invalid table type: {table_type}") + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest(table_id=table_name, data=[data], stream=stream), + ) + if stream: + return response + assert isinstance(response, MultiRowCompletionResponse) + assert len(response.rows) == 1 + return response.rows[0] + + +def _collect_reasoning( + responses: MultiRowCompletionResponse | Generator[CellCompletionResponse, None, None], + col: str, +): + if isinstance(responses, MultiRowCompletionResponse): + return "".join(r.columns[col].reasoning_content for r in responses.rows) + return "".join(r.reasoning_content for r in responses if r.output_column_name == col) + + +def _collect_text( + responses: MultiRowCompletionResponse | Generator[CellCompletionResponse, None, None], + col: str, +): + if isinstance(responses, MultiRowCompletionResponse): + return "".join(r.columns[col].content for r in responses.rows) + return "".join(r.content for r in responses if r.output_column_name == col) + + +def _get_exponent(x: float) -> int: + return Decimal(str(x)).as_tuple().exponent + + +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_full_text_search( + setup: ServingContext, + stream: bool, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ColumnSchemaCreate(id="text", dtype="str")] + with _create_table(client, "action", cols=cols) as table: + assert isinstance(table, TableMetaResponse) + # Add data + texts = [ + '"Dune: Part Two" 2024 is Denis\'s science-fiction film.', + '"Dune: Part Two" 2024 is Denis\'s film.', + '"Arrival" 《é™ä¸´ã€‹æ˜¯ä¸€éƒ¨ 2016 年美国科幻剧情片,由丹尼斯·维伦纽瓦执导。', + '"Arrival" 『デューン: パート 2ã€2024 ã¯ãƒ‡ãƒ‹ã‚¹ã®æ˜ ç”»ã§ã™ã€‚', + ] + response = client.table.add_table_rows( + "action", + MultiRowAddRequest( + table_id=table.id, data=[{"text": t} for t in texts], stream=stream + ), + ) + if stream: + # Must wait until stream ends + responses = [r for r in response] + assert all(isinstance(r, CellCompletionResponse) for r in responses) + else: + assert isinstance(response, MultiRowCompletionResponse) + + # Search + def _search(query: str): + return client.table.hybrid_search( + "action", SearchRequest(table_id=table.id, query=query) + ) + + assert len(_search("AND")) == 0 # SQL-like statements should still work + assert len(_search("《")) == 1 + assert len(_search("scien*")) == 1 + assert len(_search("film")) == 2 + assert len(_search("science -fiction")) == 0 # Not supported + assert len(_search("science-fiction")) == 1 + assert len(_search("science -fiction\n2016")) == 1 + assert len(_search("美国")) == 1 + + +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_conversation_starter( + setup: ServingContext, + stream: bool, +): + table_type = TableType.CHAT + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="User", dtype="str"), + ColumnSchemaCreate( + id="AI", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client), + system_prompt="You help remember facts.", + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), + ColumnSchemaCreate(id="words", dtype="int"), + ColumnSchemaCreate( + id="summary", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client), + system_prompt="You are an assistant", + temperature=0.001, + top_p=0.001, + max_tokens=5, + ), + ), + ] + with _create_table(client, table_type, cols=[], chat_cols=cols) as table: + assert isinstance(table, TableMetaResponse) + # Add the starter + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest( + table_id=table.id, data=[dict(AI="Jim has 5 apples.")], stream=stream + ), + ) + if stream: + # Must wait until stream ends + responses = [r for r in response] + assert all(isinstance(r, CellCompletionResponse) for r in responses) + else: + assert isinstance(response.rows[0], RowCompletionResponse) + # Chat with it + response = add_table_rows( + client, + table_type, + table.id, + [dict(User="How many apples does Jim have?")], + stream=stream, + ) + assert len(response.rows) == 1 + row = response.rows[0] + assert "summary" in row.columns + answer = row.columns["AI"].content + assert "5" in answer or "five" in answer.lower() + + +@pytest.mark.timeout(180) +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.parametrize( + "doc", + [ + FILES["salary 总结.pdf"], + # FILES["1978_APL_FP_detrapping.PDF"], + # FILES["digital_scan_combined.pdf"], + FILES["creative-story.md"], + FILES["creative-story.txt"], + FILES["multilingual-code-examples.html"], + FILES["weather-forecast-service.xml"], + FILES["ChatMed_TCM-v0.2-5records.jsonl"], + FILES["Recommendation Letter.docx"], + FILES["(2017.06.30) NMT in Linear Time (ByteNet).pptx"], + FILES["Claims Form.xlsx"], + FILES["weather_observations.tsv"], + FILES["weather_observations_long.csv"], + ], + ids=lambda x: basename(x), +) +def test_add_row_document_dtype( + setup: ServingContext, + table_type: TableType, + stream: bool, + doc: str, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="doc", dtype="document"), + ColumnSchemaCreate( + id="content", + dtype="str", + gen_config=LLMGenConfig( + model="", + prompt="Document: \n${doc} \n\nReply 0 if document received, else -1. Omit any explanation, only answer 0 or -1.", + ), + ), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + + upload_response = upload_file(client, doc) + response = add_table_rows( + client, + table_type, + table.id, + [dict(doc=upload_response.uri)], + stream=stream, + ) + assert len(response.rows) == 1 + row = response.rows[0] + assert "content" in row.columns + rows = list_table_rows(client, table_type, TABLE_ID_A) + assert len(rows.items) == 1 + row = rows.values[0] + assert row["doc"] == upload_response.uri, row["doc"] + assert "0" in row["content"] + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_regen_with_reordered_columns( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="number", dtype="int"), + ColumnSchemaCreate( + id="col1-english", + dtype="str", + gen_config=LLMGenConfig( + model="", + prompt=( + "Number: ${number} \n\nTell the 'Number' in English, " + "only output the answer in uppercase without explanation." + ), + ), + ), + ColumnSchemaCreate( + id="col2-malay", + dtype="str", + gen_config=LLMGenConfig( + model="", + prompt=( + "Number: ${number} \n\nTell the 'Number' in Malay, " + "only output the answer in uppercase without explanation." + ), + ), + ), + ColumnSchemaCreate( + id="col3-mandarin", + dtype="str", + gen_config=LLMGenConfig( + model="", + prompt=( + "Number: ${number} \n\nTell the 'Number' in Mandarin (Chinese Character), " + "only output the answer in uppercase without explanation." + ), + ), + ), + ColumnSchemaCreate( + id="col4-roman", + dtype="str", + gen_config=LLMGenConfig( + model="", + prompt=( + "Number: ${number} \n\nTell the 'Number' in Roman Numerals, " + "only output the answer in uppercase without explanation." + ), + ), + ), + ] + + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + row = _add_row( + client, + table_type, + False, + data=dict(number=1), + ) + assert isinstance(row, RowCompletionResponse) + rows = list_table_rows(client, table_type, table.id) + assert len(rows.items) == 1 + row = rows.values[0] + _id = row["ID"] + assert row["number"] == 1, row["number"] + assert row["col1-english"] == "ONE", row["col1-english"] + assert row["col2-malay"] == "SATU", row["col2-malay"] + assert row["col3-mandarin"] in ("一", "壹"), row["col3-mandarin"] + assert row["col4-roman"] == "I", row["col4-roman"] + + # Update Input + Regen + client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + data={_id: dict(number=2)}, + ), + ) + + response = client.table.regen_table_rows( + table_type, + MultiRowRegenRequest( + table_id=table.id, + row_ids=[_id], + regen_strategy=RegenStrategy.RUN_ALL, + stream=stream, + ), + ) + if stream: + _ = [r for r in response] + + rows = list_table_rows(client, table_type, table.id) + assert len(rows.items) == 1 + row = rows.values[0] + assert row["number"] == 2, row["number"] + assert row["col1-english"] == "TWO", row["col1-english"] + assert row["col2-malay"] == "DUA", row["col2-malay"] + assert row["col3-mandarin"] == "二", row["col3-mandarin"] + assert row["col4-roman"] == "II", row["col4-roman"] + + # Reorder + Update Input + Regen + # [1, 2, 3, 4] -> [3, 1, 4, 2] + new_cols = [ + "ID", + "Updated at", + "number", + "col3-mandarin", + "col1-english", + "col4-roman", + "col2-malay", + ] + if table_type == TableType.KNOWLEDGE: + new_cols += ["Title", "Text", "Title Embed", "Text Embed", "File ID", "Page"] + elif table_type == TableType.CHAT: + new_cols += ["User", "AI"] + client.table.reorder_columns( + table_type=table_type, + request=ColumnReorderRequest( + table_id=TABLE_ID_A, + column_names=new_cols, + ), + ) + # RUN_SELECTED + client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + data={_id: dict(number=5)}, + ), + ) + response = client.table.regen_table_rows( + table_type, + MultiRowRegenRequest( + table_id=TABLE_ID_A, + row_ids=[_id], + regen_strategy=RegenStrategy.RUN_SELECTED, + output_column_id="col1-english", + stream=stream, + ), + ) + if stream: + _ = [r for r in response] + rows = list_table_rows(client, table_type, TABLE_ID_A) + assert len(rows.items) == 1 + row = rows.values[0] + assert row["number"] == 5, row["number"] + assert row["col3-mandarin"] == "二", row["col3-mandarin"] + assert row["col1-english"] == "FIVE", row["col1-english"] + assert row["col4-roman"] == "II", row["col4-roman"] + assert row["col2-malay"] == "DUA", row["col2-malay"] + + # RUN_BEFORE + client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + data={_id: dict(number=6)}, + ), + ) + response = client.table.regen_table_rows( + table_type, + MultiRowRegenRequest( + table_id=TABLE_ID_A, + row_ids=[_id], + regen_strategy=RegenStrategy.RUN_BEFORE, + output_column_id="col4-roman", + stream=stream, + ), + ) + if stream: + _ = [r for r in response] + rows = list_table_rows(client, table_type, TABLE_ID_A) + assert len(rows.items) == 1 + row = rows.values[0] + assert row["number"] == 6, row["number"] + assert row["col3-mandarin"] == "å…­", row["col3-mandarin"] + assert row["col1-english"] == "SIX", row["col1-english"] + assert row["col4-roman"] == "VI", row["col4-roman"] + assert row["col2-malay"] == "DUA", row["col2-malay"] + + # RUN_AFTER + client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + data={_id: dict(number=7)}, + ), + ) + response = client.table.regen_table_rows( + table_type, + MultiRowRegenRequest( + table_id=TABLE_ID_A, + row_ids=[_id], + regen_strategy=RegenStrategy.RUN_AFTER, + output_column_id="col4-roman", + stream=stream, + ), + ) + if stream: + _ = [r for r in response] + rows = list_table_rows(client, table_type, TABLE_ID_A) + assert len(rows.items) == 1 + row = rows.values[0] + assert row["number"] == 7, row["number"] + assert row["col3-mandarin"] == "å…­", row["col3-mandarin"] + assert row["col1-english"] == "SIX", row["col1-english"] + assert row["col4-roman"] == "VII", row["col4-roman"] + assert row["col2-malay"] == "TUJUH", row["col2-malay"] + + +# @pytest.mark.parametrize("table_type", TABLE_TYPES) +# @pytest.mark.parametrize("stream", [True, False]) +# def test_add_row_file_type_output_column( +# setup: ServingContext, +# table_type: TableType, +# stream: bool, +# ): +# client = JamAI(user_id=setup.user_id, project_id=setup.project_id) +# cols = [ +# ColumnSchemaCreate(id="photo", dtype="image"), +# ColumnSchemaCreate(id="question", dtype="str"), +# ColumnSchemaCreate( +# id="captioning", +# dtype="file", +# gen_config=LLMGenConfig(model="", prompt="${photo} What's in the image?"), +# ), +# ColumnSchemaCreate( +# id="answer", +# dtype="file", +# gen_config=LLMGenConfig( +# model="", +# prompt="${photo} ${question}?", +# ), +# ), +# ColumnSchemaCreate( +# id="compare", +# dtype="image", +# gen_config=LLMGenConfig( +# model="", +# prompt="Compare ${captioning} and ${answer}.", +# ), +# ), +# ] +# with _create_table(client, table_type, cols=cols) as table: +# assert isinstance(table, TableMetaResponse) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_add_row_output_column_referred_image_input_with_chat_model( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="photo", dtype="image"), + ColumnSchemaCreate( + id="captioning", + dtype="str", + gen_config=LLMGenConfig(model="", prompt="${photo} What's in the image?"), + ), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + with create_model_config( + { + "id": "openai/Qwen/Qwen2.5-7B-Instruct", + "type": "llm", + "name": "OpenAI GPT-4o Mini", + "capabilities": ["chat"], + "context_length": 32000, + "languages": ["en"], + } + ) as llm_config_chat_only_model: + with create_deployment( + DeploymentCreate( + model_id=llm_config_chat_only_model.id, + name="ELLM Qwen2.5 (7B) Deployment", + provider=CloudProvider.OPENAI, + routing_id=llm_config_chat_only_model.id, + api_base="http://192.168.80.2:9192/v1", + ) + ): + # Add output column that referred to image file, but using chat model + # (Notes: chat model can be set due to default prompt was added afterward) + chat_only_model = llm_config_chat_only_model.id + cols = [ + ColumnSchemaCreate( + id="captioning2", + dtype="str", + gen_config=LLMGenConfig(model=chat_only_model), + ), + ] + with pytest.raises(BadInputError): + if table_type == TableType.ACTION: + client.table.add_action_columns( + AddActionColumnSchema(id=table.id, cols=cols) + ) + elif table_type == TableType.KNOWLEDGE: + client.table.add_knowledge_columns( + AddKnowledgeColumnSchema(id=table.id, cols=cols) + ) + elif table_type == TableType.CHAT: + client.table.add_chat_columns(AddChatColumnSchema(id=table.id, cols=cols)) + else: + raise ValueError(f"Invalid table type: {table_type}") + assert isinstance(table, TableMetaResponse) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", [True, False]) +def test_add_row_sequential_completion_with_error( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="input", dtype="str"), + ColumnSchemaCreate( + id="summary", + dtype="str", + gen_config=LLMGenConfig( + model="", + prompt="Summarise ${input}.", + ), + ), + ColumnSchemaCreate( + id="rephrase", + dtype="str", + gen_config=LLMGenConfig( + model="", + prompt="Rephrase ${summary}", + ), + ), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + response = add_table_rows( + client, + table_type, + table.id, + [dict(input="a" * 10000000)], + stream=stream, + ) + assert len(response.rows) == 1 + row = response.rows[0] + assert "summary" in row.columns + assert "rephrase" in row.columns + rows = list_table_rows(client, table_type, TABLE_ID_A) + assert len(rows.items) == 1 + row = rows.values[0] + assert row["summary"].startswith("[ERROR] ") + second_output = (row["rephrase"]).upper() + if stream: + assert second_output.startswith("[ERROR] ") + else: + assert "WARNING" in second_output or "ERROR" in second_output + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +@pytest.mark.parametrize( + "img_filename", + [ + "s3://image-bucket/bmp/cifar10-deer.bmp", + "s3://image-bucket/tiff/cifar10-deer.tiff", + "file://image-bucket/tiff/rabbit.tiff", + ], +) +def test_add_row_image_file_column_invalid_extension( + setup: ServingContext, + table_type: TableType, + stream: bool, + img_filename: str, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + with pytest.raises( + BadInputError, + match=re.compile( + f"^.*{re.escape('Unsupported file type. Make sure the file belongs to one of the following formats:')}.*" + f"{re.escape('[Image File Types]:')}.*" + f"{re.escape('[Audio File Types]:')}.*" + f"{re.escape('[Document File Types]:')}.*$" + ), + ): + response = _add_row( + client, + table_type, + stream, + data=dict(photo=img_filename), + ) + if stream: + _ = [r for r in response] + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_add_row_wrong_dtype( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + response = add_table_rows( + client, + table_type, + table.id, + [ + dict( + good=True, + words=5, + stars=7.9, + inputs=TEXT, + photo=upload_file(client, FILES["rabbit.jpeg"]).uri, + audio=upload_file(client, FILES["turning-a4-size-magazine.mp3"]).uri, + paper=upload_file( + client, FILES["LLMs as Optimizers [DeepMind ; 2023].pdf"] + ).uri, + ) + ], + stream=stream, + ) + assert len(response.rows) == 1 + row = response.rows[0] + assert "summary" in row.columns + assert "captioning" in row.columns + assert "narration" in row.columns + assert "concept" in row.columns + + # Test adding data with wrong dtype + response = add_table_rows( + client, + table_type, + table.id, + [dict(good="dummy1", words="dummy2", stars="dummy3", inputs=TEXT)], + stream=stream, + ) + rows = list_table_rows(client, table_type, TABLE_ID_A) + assert len(rows.items) == 2 + row = rows.items[-1] + assert row["good"]["value"] is None, row["good"] + assert row["good"]["original"] == "dummy1", row["good"] + assert row["words"]["value"] is None, row["words"] + assert row["words"]["original"] == "dummy2", row["words"] + assert row["stars"]["value"] is None, row["stars"] + assert row["stars"]["original"] == "dummy3", row["stars"] + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_add_row_missing_columns( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + response = add_table_rows( + client, + table_type, + table.id, + [ + dict( + good=True, + words=5, + stars=7.9, + inputs=TEXT, + photo=upload_file(client, FILES["rabbit.jpeg"]).uri, + audio=upload_file(client, FILES["turning-a4-size-magazine.mp3"]).uri, + paper=upload_file( + client, FILES["LLMs as Optimizers [DeepMind ; 2023].pdf"] + ).uri, + ) + ], + stream=stream, + ) + assert len(response.rows) == 1 + row = response.rows[0] + assert "summary" in row.columns + assert "captioning" in row.columns + assert "narration" in row.columns + assert "concept" in row.columns + + # Test adding data with missing column + response = _add_row( + client, + table_type, + stream, + TABLE_ID_A, + data=dict(good="dummy1", inputs=TEXT), + ) + if stream: + responses = [r for r in response] + assert all(isinstance(r, CellCompletionResponse) for r in responses) + else: + assert isinstance(response, RowCompletionResponse) + rows = list_table_rows(client, table_type, TABLE_ID_A) + assert len(rows.items) == 2 + row = rows.items[-1] + assert row["good"]["value"] is None, row["good"] + assert row["good"]["original"] == "dummy1", row["good"] + assert row["words"]["value"] is None, row["words"] + assert "original" not in row["words"], row["words"] + assert row["stars"]["value"] is None, row["stars"] + assert "original" not in row["stars"], row["stars"] + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_add_rows_all_input( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="0", dtype="int"), + ColumnSchemaCreate(id="1", dtype="float"), + ColumnSchemaCreate(id="2", dtype="bool"), + ColumnSchemaCreate(id="3", dtype="str"), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest( + table_id=table.id, + data=[ + {"0": 1, "1": 2.0, "2": False, "3": "days"}, + {"0": 0, "1": 1.0, "2": True, "3": "of"}, + ], + stream=stream, + ), + ) + if stream: + responses = [r for r in response if r.output_column_name != "AI"] + assert len(responses) == 0 + else: + assert isinstance(response, MultiRowCompletionResponse) + assert len(response.rows) == 2 + rows = list_table_rows(client, table_type, table.id) + assert len(rows.items) == 2 + + +@flaky(max_runs=5, min_passes=1) +@pytest.mark.timeout(180) +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +@pytest.mark.parametrize("reasoning_model", ["openai/gpt-5-mini", "openai/o4-mini"][:1]) +def test_reasoning_model_with_reasoning_effort( + setup: ServingContext, + table_type: TableType, + stream: bool, + reasoning_model: str, +): + """ + Tests that different `reasoning.effort` levels produce different outputs + when using a reasoning model with the Responses API. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + + system_prompt = "You are a brilliant logician. Always think twice and give your reasoning before answering!" + prompt = ( + "Solve this riddle: " + "If a plane crashes on the border between the USA and Canada, " + "where do you bury the survivors? " + ) + + cols = [ + ColumnSchemaCreate(id="Riddle", dtype="str"), + ColumnSchemaCreate( + id="LowEffortAnswer", + dtype="str", + gen_config=LLMGenConfig( + model=reasoning_model, + system_prompt=system_prompt, + prompt=prompt, + reasoning_effort="low", + reasoning_summary="auto", + ), + ), + ColumnSchemaCreate( + id="MediumEffortAnswer", + dtype="str", + gen_config=LLMGenConfig( + model=reasoning_model, + system_prompt=system_prompt, + prompt=prompt, + reasoning_effort="medium", + reasoning_summary="auto", + ), + ), + ColumnSchemaCreate( + id="MinimalEffortAnswer", + dtype="str", + gen_config=LLMGenConfig( + model=reasoning_model, + system_prompt=system_prompt, + prompt=prompt, + reasoning_effort="minimal", + reasoning_summary="auto", + ), + ), + ] + + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + + response = add_table_rows( + client, + table_type, + table.id, + [dict(Riddle="Trigger")], + stream=stream, + ) + + low_effort_reasoning = _collect_reasoning(response, "LowEffortAnswer").lower() + medium_effort_reasoning = _collect_reasoning(response, "MediumEffortAnswer").lower() + minimal_effort_reasoning = _collect_reasoning(response, "MinimalEffortAnswer").lower() + assert "survivors" in low_effort_reasoning + assert "survivors" in medium_effort_reasoning + + assert (len(medium_effort_reasoning) > len(low_effort_reasoning)) or ( + len(low_effort_reasoning) > len(minimal_effort_reasoning) + ) + + low_effort_result = _collect_text(response, "LowEffortAnswer").lower() + medium_effort_result = _collect_text(response, "MediumEffortAnswer").lower() + minimal_effort_result = _collect_text(response, "MinimalEffortAnswer").lower() + + assert "bury" in low_effort_result and "survivors" in low_effort_result + assert "bury" in medium_effort_result and "survivors" in medium_effort_result + if reasoning_model == "openai/gpt-5-mini": + assert "bury" in minimal_effort_result and "survivors" in minimal_effort_result + else: + assert "'minimal' is not supported" in minimal_effort_result + + assert response.rows[0].columns["LowEffortAnswer"].usage is not None + assert response.rows[0].columns["MediumEffortAnswer"].usage is not None + assert response.rows[0].columns["MinimalEffortAnswer"].usage is not None + + +@flaky(max_runs=3, min_passes=1) +@pytest.mark.parametrize("table_type", TABLE_TYPES[:1]) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +@pytest.mark.parametrize("capability", [ModelCapability.CHAT, ModelCapability.REASONING]) +def test_agentic_column_with_web_search( + setup: ServingContext, + table_type: TableType, + stream: bool, + capability: ModelCapability, +): + """ + Tests an agentic column that uses web_search to perform a fact-checking task. + Also validates usage metrics. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="Claim", dtype="str"), + ColumnSchemaCreate( + id="FactCheck", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client) + if capability == ModelCapability.CHAT + else _get_reasoning_model(client), + prompt="You are a meticulous fact-checker. Your goal is to verify the following claim: `${Claim}`. " + "Use web search to determine if the claim is true or false and provide a brief explanation.", + tools=[WebSearchTool()], + reasoning_effort="low" if capability == ModelCapability.REASONING else None, + ), + ), + ] + + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + + response = add_table_rows( + client, + table_type, + table.id, + [dict(Claim="The sun revolves around the Earth.")], + stream=stream, + ) + + reasoning = _collect_reasoning(response, "FactCheck") + assert "Searched the web for " in reasoning and "Ran Python code:" not in reasoning + reasoning = reasoning.lower() + assert "earth" in reasoning + assert "sun" in reasoning + assert "revolve" in reasoning or "orbit" in reasoning + + result = _collect_text(response, "FactCheck").lower() + assert result is not None + assert "false" in result + assert "earth" in result + assert "sun" in result + assert "revolve" in result or "orbit" in result + + usage = response.rows[0].columns["FactCheck"].usage + assert usage is not None + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.tool_usage_details is not None + assert usage.tool_usage_details.web_search_calls > 0 + assert usage.tool_usage_details.code_interpreter_calls == 0 + + +@flaky(max_runs=3, min_passes=1) +@pytest.mark.timeout(120) +@pytest.mark.parametrize("table_type", TABLE_TYPES[:1]) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +@pytest.mark.parametrize("capability", [ModelCapability.CHAT, ModelCapability.REASONING]) +def test_agentic_column_with_code_interpreter( + setup: ServingContext, + table_type: TableType, + stream: bool, + capability: ModelCapability, +): + """ + Tests an agentic column that reads numerical data from other columns + and uses the code_interpreter to perform a calculation. Also validates usage metrics. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="Revenue", dtype="int"), + ColumnSchemaCreate(id="Expenses", dtype="int"), + ColumnSchemaCreate( + id="ProfitMargin", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client) + if capability == ModelCapability.CHAT + else _get_reasoning_model(client), + prompt="You are a financial analyst. Check the Revenue: `${Revenue}` and Expenses: `${Expenses}`." + "Then, use the code interpreter to calculate the profit margin percentage. " + "The formula is `(Revenue - Expenses) / Revenue * 100`. " + "Return only the final numerical answer, formatted as a percentage string like '25.0%'.", + tools=[CodeInterpreterTool()], + reasoning_effort="low" if capability == ModelCapability.REASONING else None, + ), + ), + ] + + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + + response = add_table_rows( + client, + table_type, + table.id, + [dict(Revenue=200000, Expenses=50000)], + stream=stream, + ) + + reasoning = _collect_reasoning(response, "ProfitMargin") + assert "Ran Python code:" in reasoning and "Searched the web for " not in reasoning + assert "200000" in reasoning + assert "50000" in reasoning + + result = _collect_text(response, "ProfitMargin") + assert result is not None + assert "75" in result # 150000 / 200000 = 0.75 + assert "%" in result + + usage = response.rows[0].columns["ProfitMargin"].usage + assert usage is not None + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.tool_usage_details is not None + assert usage.tool_usage_details.web_search_calls == 0 + assert usage.tool_usage_details.code_interpreter_calls > 0 + + +@flaky(max_runs=3, min_passes=1) +@pytest.mark.timeout(180) +@pytest.mark.parametrize("table_type", TABLE_TYPES[:1]) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +@pytest.mark.parametrize("capability", [ModelCapability.CHAT, ModelCapability.REASONING]) +def test_agentic_column_with_multiple_tools( + setup: ServingContext, + table_type: TableType, + stream: bool, + capability: ModelCapability, +): + """ + Tests an agentic column that requires chaining multiple tools (web search and code interpreter) + to complete its goal, and validates the usage metrics for both. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="Country", dtype="str"), + ColumnSchemaCreate( + id="PopulationDensityReport", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client) + if capability == ModelCapability.CHAT + else _get_reasoning_model(client), + system_prompt="You are a geography research assistant. Always give short and concise answers.", + prompt="Your task is for the country '${Country}'. " + "1. First, use web search to find its current estimated population. " + "2. Second, use web search to find its total land area in square kilometers. " + "3. Third, use the code interpreter to calculate the population density (population / area). " + "4. Finally, report the result in a single sentence, including the calculated density.", + tools=[WebSearchTool(), CodeInterpreterTool()], + reasoning_effort="low" if capability == ModelCapability.REASONING else None, + ), + ), + ] + + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + + response = add_table_rows( + client, + table_type, + table.id, + [dict(Country="Japan")], + stream=stream, + ) + + reasoning = _collect_reasoning(response, "PopulationDensityReport") + assert "Searched the web for " in reasoning and "Ran Python code:" in reasoning + reasoning = reasoning.lower() + assert "japan" in reasoning + assert "population" in reasoning + assert "density" in reasoning + + result = _collect_text(response, "PopulationDensityReport").lower() + assert result is not None + assert "japan" in result + assert "population" in result + assert "density" in result + # Check for a number, which would be the calculated density + assert any(char.isdigit() for char in result) + + usage = response.rows[0].columns["PopulationDensityReport"].usage + assert usage is not None + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.tool_usage_details is not None + assert usage.tool_usage_details.web_search_calls > 0 + assert usage.tool_usage_details.code_interpreter_calls > 0 + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_regen_rows( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + + image_upload_response = upload_file(client, FILES["rabbit.jpeg"]) + audio_upload_response = upload_file(client, FILES["turning-a4-size-magazine.mp3"]) + response = _add_row( + client, + table_type, + False, + data=dict( + good=True, + words=10, + stars=9.9, + inputs=TEXT, + photo=image_upload_response.uri, + audio=audio_upload_response.uri, + ), + ) + assert isinstance(response, RowCompletionResponse) + rows = list_table_rows(client, table_type, TABLE_ID_A) + assert len(rows.items) == 1 + row = rows.values[0] + _id = row["ID"] + original_ts = row["Updated at"] + assert "arrival" in row["summary"].lower() + # Regen + client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + data={ + _id: dict( + inputs="Dune: Part Two is a 2024 American epic science fiction film directed and produced by Denis Villeneuve" + ) + }, + ), + ) + response = regen_table_rows(client, table_type, table.id, [_id], stream=stream) + row = response.rows[0] + assert "summary" in row.columns + assert "captioning" in row.columns + assert "narration" in row.columns + assert "concept" in row.columns + rows = list_table_rows(client, table_type, TABLE_ID_A) + assert len(rows.items) == 1 + row = rows.values[0] + assert row["good"] is True + assert row["words"] == 10 + assert row["stars"] == 9.9 + assert row["photo"] == image_upload_response.uri + assert row["audio"] == audio_upload_response.uri + assert row["Updated at"] > original_ts + assert "dune" in row["summary"].lower() + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_regen_rows_all_input( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="0", dtype="int"), + ColumnSchemaCreate(id="1", dtype="float"), + ColumnSchemaCreate(id="2", dtype="bool"), + ColumnSchemaCreate(id="3", dtype="str"), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest( + table_id=table.id, + data=[ + {"0": 1, "1": 2.0, "2": False, "3": "days"}, + {"0": 0, "1": 1.0, "2": True, "3": "of"}, + ], + stream=False, + ), + ) + assert isinstance(response, MultiRowCompletionResponse) + assert len(response.rows) == 2 + rows = list_table_rows(client, table_type, table.id) + assert len(rows.items) == 2 + # Regen + response = client.table.regen_table_rows( + table_type, + MultiRowRegenRequest( + table_id=table.id, row_ids=[r["ID"] for r in rows.items], stream=stream + ), + ) + if stream: + responses = [r for r in response if r.output_column_name != "AI"] + assert len(responses) == 0 + else: + assert isinstance(response, MultiRowCompletionResponse) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_delete_rows( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + data = dict(good=True, words=5, stars=9.9, inputs=TEXT, summary="dummy") + _add_row(client, table_type, False, data=data) + _add_row(client, table_type, False, data=data) + _add_row(client, table_type, False, data=data) + _add_row(client, table_type, False, data=data) + _add_row( + client, + table_type, + False, + data=dict(good=True, words=5, stars=7.9, inputs=TEXT_CN), + ) + _add_row( + client, + table_type, + False, + data=dict(good=True, words=5, stars=7.9, inputs=TEXT_JP), + ) + ori_rows = list_table_rows(client, table_type, TABLE_ID_A) + assert len(ori_rows.items) == 6 + delete_id = ori_rows.values[0]["ID"] + + # Delete one row + response = client.table.delete_table_row(table_type, TABLE_ID_A, delete_id) + assert isinstance(response, OkResponse) + rows = list_table_rows(client, table_type, TABLE_ID_A) + assert len(rows.items) == 5 + row_ids = set(r["ID"] for r in rows.values) + assert delete_id not in row_ids + # Delete multiple rows + delete_ids = [r["ID"] for r in ori_rows.values[1:4]] + response = client.table.delete_table_rows( + table_type, + MultiRowDeleteRequest( + table_id=TABLE_ID_A, + row_ids=delete_ids, + ), + ) + assert isinstance(response, OkResponse) + rows = list_table_rows(client, table_type, TABLE_ID_A) + assert len(rows.items) == 2 + row_ids = set(r["ID"] for r in rows.values) + assert len(set(row_ids) & set(delete_ids)) == 0 + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_column_interpolate( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + + cols = [ + ColumnSchemaCreate( + id="output0", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client), + system_prompt="You are a concise assistant.", + prompt='Say "Jan has 5 apples.".', + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), + ColumnSchemaCreate(id="input0", dtype="int"), + ColumnSchemaCreate( + id="output1", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client), + system_prompt="You are a concise assistant.", + prompt=( + "1. ${output0}\n2. Jan has ${input0} apples.\n\n" + "Do the statements agree with each other? Reply Yes or No." + ), + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + + def _add_row_wrapped(stream, data): + return _add_row( + client, + table_type=table_type, + stream=stream, + table_name=table.id, + data=data, + knowledge_data=None, + chat_data=dict(User='Say "Jan has 5 apples.".'), + ) + + # Streaming + response = list(_add_row_wrapped(True, dict(input0=5))) + output0 = _collect_text(response, "output0") + ai = _collect_text(response, "AI") + answer = _collect_text(response, "output1") + assert "yes" in answer.lower(), f'output0="{output0}" ai="{ai}" answer="{answer}"' + response = list(_add_row_wrapped(True, dict(input0=6))) + output0 = _collect_text(response, "output0") + ai = _collect_text(response, "AI") + answer = _collect_text(response, "output1") + assert "no" in answer.lower(), f'output0="{output0}" ai="{ai}" answer="{answer}"' + # Non-streaming + response = _add_row_wrapped(False, dict(input0=5)) + answer = response.columns["output1"].content + assert "yes" in answer.lower(), f'columns={response.columns} answer="{answer}"' + response = _add_row_wrapped(False, dict(input0=6)) + answer = response.columns["output1"].content + assert "no" in answer.lower(), f'columns={response.columns} answer="{answer}"' + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_chat_history_and_sequential_add( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="input", dtype="str"), + ColumnSchemaCreate( + id="output", + dtype="str", + gen_config=LLMGenConfig( + system_prompt="You are a calculator.", + prompt="${input}", + multi_turn=True, + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + # Initialise chat thread and set output format + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest( + table_id=table.id, + data=[ + dict(input="x = 0", output="0"), + dict(input="Add 1", output="1"), + dict(input="Add 1", output="2"), + dict(input="Add 1", output="3"), + dict(input="Add 1", output="4"), + ], + stream=False, + ), + ) + # Test adding one row + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest( + table_id=table.id, + data=[dict(input="Add 1")], + stream=stream, + ), + ) + output = _collect_text(response, "output") + assert "5" in output, output + # Test adding multiple rows + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest( + table_id=table.id, + data=[ + dict(input="Add 1"), + dict(input="Add 2"), + dict(input="Add 1"), + ], + stream=stream, + ), + ) + output = _collect_text(response, "output") + assert "6" in output, output + assert "8" in output, output + assert "9" in output, output + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_chat_history_and_sequential_regen( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="input", dtype="str"), + ColumnSchemaCreate( + id="output", + dtype="str", + gen_config=LLMGenConfig( + system_prompt="You are a calculator.", + prompt="${input}", + multi_turn=True, + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + # Initialise chat thread and set output format + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest( + table_id=table.id, + data=[ + dict(input="x = 0", output="0"), + dict(input="Add 1", output="1"), + dict(input="Add 1", output="2"), + dict(input="Add 2", output="9"), # Wrong answer on purpose + dict(input="Add 1", output="9"), # Wrong answer on purpose + dict(input="Add 3", output="9"), # Wrong answer on purpose + ], + stream=False, + ), + ) + row_ids = sorted([r.row_id for r in response.rows]) + # Test regen one row + response = client.table.regen_table_rows( + table_type, + MultiRowRegenRequest( + table_id=table.id, + row_ids=row_ids[3:4], + stream=stream, + ), + ) + output = _collect_text(response, "output") + assert "4" in output, output + # Test regen multiple rows + # Also test if regen proceeds in correct order from earliest row to latest + response = client.table.regen_table_rows( + table_type, + MultiRowRegenRequest( + table_id=table.id, + row_ids=row_ids[3:][::-1], + stream=stream, + ), + ) + output = _collect_text(response, "output") + assert "4" in output, output + assert "5" in output, output + assert "8" in output, output + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_convert_into_multi_turn( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="input", dtype="str"), + ColumnSchemaCreate( + id="output", + dtype="str", + gen_config=LLMGenConfig( + system_prompt="You are a calculator.", + prompt="${input}", + multi_turn=False, + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + # Initialise chat thread and set output format + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest( + table_id=table.id, + data=[ + dict(input="x = 0", output="0"), + dict(input="x += 1", output="1"), + dict(input="x += 1", output="2"), + dict(input="x += 1", output="3"), + ], + stream=False, + ), + ) + # Test adding one row as single-turn + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest( + table_id=table.id, + data=[dict(input="x += 1")], + stream=stream, + ), + ) + output = _collect_text(response, "output") + assert "4" not in output, output + # Convert into multi-turn + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map=dict( + output=LLMGenConfig( + system_prompt="You are a calculator.", + prompt="${input}", + multi_turn=True, + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), + ), + ) + assert isinstance(table, TableMetaResponse) + # Regen + rows = list_table_rows(client, table_type, table.id) + response = client.table.regen_table_rows( + table_type, + MultiRowRegenRequest( + table_id=table.id, + row_ids=[rows.values[-1]["ID"]], + stream=stream, + ), + ) + output = _collect_text(response, "output") + assert "4" in output, output + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_get_conversation_thread( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="input", dtype="str"), + ColumnSchemaCreate( + id="output", + dtype="str", + gen_config=LLMGenConfig( + system_prompt="You are a calculator.", + prompt="${input}", + multi_turn=True, + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + # Initialise chat thread and set output format + data = [ + dict(input="x = 0", output="0"), + dict(input="Add 1", output="1"), + dict(input="Add 2", output="3"), + dict(input="Add 3", output="6"), + ] + response = client.table.add_table_rows( + table_type, MultiRowAddRequest(table_id=table.id, data=data, stream=False) + ) + row_ids = sorted([r.row_id for r in response.rows]) + + def _check_thread(_chat): + assert isinstance(_chat, ChatThreadResponse) + for i, message in enumerate(_chat.thread): + assert isinstance(message.content, str) + assert len(message.content) > 0 + if i == 0: + assert message.role == ChatRole.SYSTEM + elif i % 2 == 1: + assert message.role == ChatRole.USER + assert message.content == data[(i - 1) // 2]["input"] + else: + assert message.role == ChatRole.ASSISTANT + assert message.content == data[(i // 2) - 1]["output"] + + # --- Fetch complete thread --- # + chat = client.table.get_conversation_threads( + table_type, + table.id, + ["output"], + ).threads["output"] + _check_thread(chat) + assert len(chat.thread) == 9 + assert chat.thread[-1].content == "6" + # --- Row ID filtering --- # + # Filter (include = True) + chat = client.table.get_conversation_threads( + table_type, + table.id, + ["output"], + row_id=row_ids[2], + ).threads["output"] + _check_thread(chat) + assert len(chat.thread) == 7 + assert chat.thread[-1].content == "3" + # Filter (include = False) + chat = client.table.get_conversation_threads( + table_type, + table.id, + ["output"], + row_id=row_ids[2], + include_row=False, + ).threads["output"] + _check_thread(chat) + assert len(chat.thread) == 5 + assert chat.thread[-1].content == "1" + # --- Non-existent column --- # + with pytest.raises( + ResourceNotFoundError, + match="Column .*x.* is not found. Available multi-turn columns:.*output.*", + ): + client.table.get_conversation_threads(table_type, table.id, ["x"]) + # --- Invalid column --- # + with pytest.raises( + ResourceNotFoundError, + match="Column .*input.* is not a multi-turn LLM column. Available multi-turn columns:.*output.*", + ): + client.table.get_conversation_threads(table_type, table.id, ["input"]) + + +def test_hybrid_search( + setup: ServingContext, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + table_type = TableType.KNOWLEDGE + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + data = dict(good=True, words=5, stars=9.9, inputs=TEXT, summary="dummy") + rows = client.table.add_table_rows( + table_type, + MultiRowAddRequest( + table_id=TABLE_ID_A, + data=[dict(Title="Resume 2012", Text="Hi there, I am a farmer.", **data)], + stream=False, + ), + ) + assert isinstance(rows, MultiRowCompletionResponse) + rows = client.table.add_table_rows( + table_type, + MultiRowAddRequest( + table_id=TABLE_ID_A, + data=[dict(Title="Resume 2013", Text="Hi there, I am a carpenter.", **data)], + stream=False, + ), + ) + assert isinstance(rows, MultiRowCompletionResponse) + rows = client.table.add_table_rows( + table_type, + MultiRowAddRequest( + table_id=TABLE_ID_A, + data=[ + dict( + Title="Byte Pair Encoding", + Text="BPE is a subword tokenization method.", + **data, + ) + ], + stream=False, + ), + ) + assert isinstance(rows, MultiRowCompletionResponse) + sleep(1) # Optional, give it some time to index + # Rely on embedding + rows = client.table.hybrid_search( + table_type, + SearchRequest( + table_id=TABLE_ID_A, + query="language", + reranking_model=_get_reranking_model(client), + limit=2, + ), + ) + assert len(rows) == 2 + assert "BPE" in rows[0]["Text"]["value"], rows + # Rely on FTS + rows = client.table.hybrid_search( + table_type, + SearchRequest( + table_id=TABLE_ID_A, + query="candidate 2013", + reranking_model=_get_reranking_model(client), + limit=2, + ), + ) + assert len(rows) == 2 + assert "2013" in rows[0]["Title"]["value"], rows + # hybrid_search without reranker (RRF only) + rows = client.table.hybrid_search( + table_type, + SearchRequest( + table_id=TABLE_ID_A, + query="language", + reranking_model=None, + limit=2, + ), + ) + assert len(rows) == 2 + assert "BPE" in rows[0]["Text"]["value"], rows + + +FILE_PAGES = { + FILES["salary 总结.pdf"]: 1, + FILES["Swire_AR22_e_230406_sample.pdf"]: 5, + FILES["1978_APL_FP_detrapping.PDF"]: 4, + FILES["digital_scan_combined.pdf"]: 15, + FILES["(2017.06.30) NMT in Linear Time (ByteNet).pptx"]: 3, + FILES["Claims Form.xlsx"]: 2, +} + + +@pytest.mark.parametrize( + "file_path", + [ + FILES["salary 总结.pdf"], + FILES["Swire_AR22_e_230406_sample.pdf"], + # FILES["1978_APL_FP_detrapping.PDF"], + # FILES["digital_scan_combined.pdf"], + FILES["creative-story.md"], + FILES["creative-story.txt"], + FILES["RAG and LLM Integration Guide.html"], + FILES["multilingual-code-examples.html"], + FILES["table.html"], + FILES["weather-forecast-service.xml"], + FILES["company-profile.json"], + FILES["llm-models.jsonl"], + FILES["ChatMed_TCM-v0.2-5records.jsonl"], + FILES["Recommendation Letter.docx"], + FILES["(2017.06.30) NMT in Linear Time (ByteNet).pptx"], + FILES["Claims Form.xlsx"], + FILES["weather_observations.tsv"], + FILES["company-profile.csv"], + FILES["weather_observations_long.csv"], + ], + ids=lambda x: basename(x), +) +def test_embed_file( + setup: ServingContext, + file_path: str, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + table_type = TableType.KNOWLEDGE + with _create_table(client, table_type, cols=[]) as table: + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + response = client.table.embed_file(file_path, table.id) + assert isinstance(response, OkResponse) + rows = list_table_rows(client, table_type, table.id) + assert rows.total > 0 + assert rows.offset == 0 + assert rows.limit == 100 + assert len(rows.items) > 0 + for r in rows.values: + assert isinstance(r["Title"], str) + assert len(r["Title"]) > 0 + assert isinstance(r["Text"], str) + assert len(r["Text"]) > 0 + assert r["Page"] > 0 + assert isinstance(r["Title Embed"], list) + assert len(r["Title Embed"]) > 0 + assert all(isinstance(v, float) for v in r["Title Embed"]) + assert isinstance(r["Text Embed"], list) + assert len(r["Text Embed"]) > 0 + assert all(isinstance(v, float) for v in r["Text Embed"]) + if file_path in FILE_PAGES: + assert r["Page"] == FILE_PAGES[file_path] + else: + assert r["Page"] == 1 + + +@pytest.mark.parametrize( + "file_path", + [ + FILES["empty.pdf"], + FILES["empty_3pages.pdf"], + FILES["empty.txt"], + FILES["empty.csv"], + ], + ids=lambda x: basename(x), +) +def test_embed_empty_file( + setup: ServingContext, + file_path: str, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + table_type = TableType.KNOWLEDGE + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + with pytest.raises(BadInputError, match="is empty"): + response = client.table.embed_file(file_path, table.id) + assert isinstance(response, OkResponse) + + +@pytest.mark.parametrize( + "file_path", + [ + FILES["rabbit.jpeg"], + join(dirname(dirname(TEST_FILE_DIR)), "pyproject.toml"), + ], + ids=lambda x: basename(x), +) +def test_embed_file_invalid_file_type( + setup: ServingContext, + file_path: str, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + table_type = TableType.KNOWLEDGE + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + with pytest.raises(JamaiException, match=r"File type .+ is unsupported"): + client.table.embed_file(file_path, table.id) + + +def test_embed_file_options(setup: ServingContext): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + + response = client.table.embed_file_options() + + assert isinstance(response, httpx.Response) + assert response.status_code == 200 + + assert "Allow" in response.headers + assert "POST" in response.headers["Allow"] + assert "OPTIONS" in response.headers["Allow"] + + assert "Accept" in response.headers + for content_type in EMBED_WHITE_LIST_EXT: + assert content_type in response.headers["Accept"] + + assert "Access-Control-Allow-Methods" in response.headers + assert "POST" in response.headers["Access-Control-Allow-Methods"] + assert "OPTIONS" in response.headers["Access-Control-Allow-Methods"] + + assert "Access-Control-Allow-Headers" in response.headers + assert "Content-Type" in response.headers["Access-Control-Allow-Headers"] + + # Ensure the response body is empty + assert response.content == b"" + + +def test_embed_long_file( + setup: ServingContext, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, "knowledge", cols=[]) as table: + assert isinstance(table, TableMetaResponse) + with TemporaryDirectory() as tmp_dir: + # Create a long CSV + data = [ + {"bool": True, "float": 0.0, "int": 0, "str": ""}, + {"bool": False, "float": -1.0, "int": -2, "str": "testing"}, + {"bool": None, "float": None, "int": None, "str": None}, + ] + file_path = join(tmp_dir, "long.csv") + df_to_csv(pd.DataFrame.from_dict(data * 100), file_path) + # Embed the CSV + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + response = client.table.embed_file(file_path, table.id) + assert isinstance(response, OkResponse) + rows = list_table_rows(client, "knowledge", table.id) + assert rows.total == 300 + assert rows.offset == 0 + assert rows.limit == 100 + assert len(rows.items) == 100 + assert all(isinstance(r["Title"], str) for r in rows.values) + assert all(len(r["Title"]) > 0 for r in rows.values) + assert all(isinstance(r["Text"], str) for r in rows.values) + assert all(len(r["Text"]) > 0 for r in rows.values) + assert all(r["Page"] > 0 for r in rows.values) diff --git a/services/api/tests/gen_table/test_row_ops_v2.py b/services/api/tests/gen_table/test_row_ops_v2.py new file mode 100644 index 0000000..6133eb8 --- /dev/null +++ b/services/api/tests/gen_table/test_row_ops_v2.py @@ -0,0 +1,2097 @@ +import re +from collections import defaultdict +from dataclasses import dataclass +from datetime import datetime +from decimal import Decimal +from os.path import basename, dirname, join, realpath +from tempfile import TemporaryDirectory +from types import NoneType +from typing import Any + +import httpx +import pytest +from flaky import flaky + +from jamaibase import JamAI +from jamaibase.types import ( + CITATION_PATTERN, + CellCompletionResponse, + ChatCompletionChunkResponse, + ChatCompletionResponse, + ColumnSchemaCreate, + GenConfigUpdateRequest, + GetURLResponse, + LLMGenConfig, + MultiRowAddRequest, + MultiRowCompletionResponse, + MultiRowUpdateRequest, + OkResponse, + OrganizationCreate, + PythonGenConfig, + RAGParams, + References, + RowCompletionResponse, + S3Content, + TextContent, + WebSearchTool, +) +from owl.configs import ENV_CONFIG +from owl.types import ( + ModelCapability, + ModelType, + RegenStrategy, + TableType, +) +from owl.utils.exceptions import BadInputError, ResourceNotFoundError +from owl.utils.test import ( + ELLM_DESCRIBE_CONFIG, + ELLM_DESCRIBE_DEPLOYMENT, + GPT_5_MINI_CONFIG, + GPT_5_MINI_DEPLOYMENT, + GPT_41_MINI_CONFIG, + GPT_41_MINI_DEPLOYMENT, + STREAM_PARAMS, + TABLE_TYPES, + TEXT_EMBEDDING_3_SMALL_CONFIG, + TEXT_EMBEDDING_3_SMALL_DEPLOYMENT, + TEXTS, + RERANK_ENGLISH_v3_SMALL_CONFIG, + RERANK_ENGLISH_v3_SMALL_DEPLOYMENT, + add_table_rows, + assert_is_vector_or_none, + create_deployment, + create_model_config, + create_organization, + create_project, + create_table, + create_user, + get_file_map, + get_table_row, + list_table_rows, + regen_table_rows, + upload_file, +) + +TEST_FILE_DIR = join(dirname(dirname(realpath(__file__))), "files") +FILES = get_file_map(TEST_FILE_DIR) + +EMBED_WHITE_LIST_EXT = [ + "application/pdf", # pdf + "text/markdown", # md + "text/plain", # txt + "text/html", # html + "text/xml", # xml + "application/xml", # xml + "application/json", # json + "application/jsonl", # jsonl + "application/x-ndjson", # alternative for jsonl + "application/json-lines", # another alternative for jsonl + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", # docx + "application/vnd.openxmlformats-officedocument.presentationml.presentation", # pptx + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", # xlsx + "text/tab-separated-values", # tsv + "text/csv", # csv +] + + +@dataclass(slots=True) +class ServingContext: + superuser_id: str + superorg_id: str + project_id: str + embedding_size: int + image_uri: str + audio_uri: str + document_uri: str + gpt_llm_model_id: str + gpt_llm_reasoning_config_id: str + desc_llm_model_id: str + lorem_llm_model_id: str + short_llm_model_id: str + echo_model_id: str + embed_model_id: str + rerank_model_id: str + + +@pytest.fixture(scope="module") +def setup(): + """ + Fixture to set up the necessary organization and projects for file tests. + """ + with ( + create_user() as superuser, + create_organization( + body=OrganizationCreate(name="Superorg"), user_id=superuser.id + ) as superorg, + create_project( + dict(name="Superorg Project"), user_id=superuser.id, organization_id=superorg.id + ) as p0, + ): + assert superorg.id == "0" + # Create models + with ( + create_model_config(GPT_41_MINI_CONFIG) as gpt_llm_config, + create_model_config(GPT_5_MINI_CONFIG) as gpt_llm_reasoning_config, + create_model_config(ELLM_DESCRIBE_CONFIG) as desc_llm_config, + create_model_config( + dict( + id="ellm/lorem-ttft-20-tpot-10", # TTFT 20 ms, TPOT 10 ms + type=ModelType.LLM, + name="ELLM Lorem Ipsum Generator", + capabilities=[ + ModelCapability.CHAT, + ModelCapability.IMAGE, + ModelCapability.AUDIO, + ], + context_length=128000, + languages=["en"], + owned_by="ellm", + ) + ) as lorem_llm_config, + create_model_config( + dict( + # Max context length = 10 + id="ellm/lorem-context-10", + type=ModelType.LLM, + name="Short-Context Chat Model", + capabilities=[ModelCapability.CHAT], + context_length=5, + languages=["en"], + owned_by="ellm", + ) + ) as short_llm_config, + create_model_config( + dict( + id="ellm/echo-prompt", + type=ModelType.LLM, + name="Echo Prompt Model", + capabilities=[ModelCapability.CHAT], + context_length=1000000, + languages=["en"], + owned_by="ellm", + ) + ) as echo_config, + create_model_config(TEXT_EMBEDDING_3_SMALL_CONFIG) as embed_config, + create_model_config(RERANK_ENGLISH_v3_SMALL_CONFIG) as rerank_config, + ): + # Create deployments + with ( + create_deployment(GPT_41_MINI_DEPLOYMENT), + create_deployment(GPT_5_MINI_DEPLOYMENT), + create_deployment(ELLM_DESCRIBE_DEPLOYMENT), + create_deployment( + dict( + model_id=lorem_llm_config.id, + name=f"{lorem_llm_config.name} Deployment", + provider="custom", + routing_id=lorem_llm_config.id, + api_base=ENV_CONFIG.test_llm_api_base, + ) + ), + create_deployment( + dict( + model_id=short_llm_config.id, + name="Short chat Deployment", + provider="custom", + routing_id=short_llm_config.id, + api_base=ENV_CONFIG.test_llm_api_base, + ) + ), + create_deployment( + dict( + model_id=echo_config.id, + name="Echo Prompt Deployment", + provider="custom", + routing_id=echo_config.id, + api_base=ENV_CONFIG.test_llm_api_base, + ) + ), + create_deployment(TEXT_EMBEDDING_3_SMALL_DEPLOYMENT), + create_deployment(RERANK_ENGLISH_v3_SMALL_DEPLOYMENT), + ): + client = JamAI(user_id=superuser.id, project_id=p0.id) + image_uri = upload_file(client, FILES["rabbit.jpeg"]).uri + audio_uri = upload_file(client, FILES["gutter.mp3"]).uri + document_uri = upload_file( + client, FILES["LLMs as Optimizers [DeepMind ; 2023].pdf"] + ).uri + yield ServingContext( + superuser_id=superuser.id, + superorg_id=superorg.id, + project_id=p0.id, + embedding_size=embed_config.final_embedding_size, + image_uri=image_uri, + audio_uri=audio_uri, + document_uri=document_uri, + gpt_llm_model_id=gpt_llm_config.id, + gpt_llm_reasoning_config_id=gpt_llm_reasoning_config.id, + desc_llm_model_id=desc_llm_config.id, + lorem_llm_model_id=lorem_llm_config.id, + short_llm_model_id=short_llm_config.id, + echo_model_id=echo_config.id, + embed_model_id=embed_config.id, + rerank_model_id=rerank_config.id, + ) + + +@dataclass(slots=True) +class Data: + data_list: list[dict[str, Any]] + action_data_list: list[dict[str, Any]] + knowledge_data: dict[str, Any] + chat_data: dict[str, Any] + extra_data: dict[str, Any] + + +INPUT_COLUMNS = ["int", "float", "bool", "str", "image", "audio", "document"] +FILE_COLUMNS = ["image", "audio", "document"] +OUTPUT_COLUMNS = ["summary (1.0)", "summary (2.0)"] + + +def _default_data(setup: ServingContext): + action_data_list = [ + { + "ID": str(i), + "Updated at": "1990-05-13T09:01:50.010756+00:00", + "int": 1 if i % 2 == 0 else (1.0 if i % 4 == 1 else None), + "float": -1.25 if i % 2 == 0 else (5 if i % 4 == 1 else None), + "bool": True if i % 2 == 0 else (False if i % 4 == 1 else None), + # `str` will sort in opposite order to ID + "str": f"{100 - i:04d}: {t}", + "image": setup.image_uri if i % 2 == 0 else None, + "audio": setup.audio_uri if i % 2 == 0 else None, + "document": setup.document_uri if i % 2 == 0 else None, + } + for i, t in enumerate(list(TEXTS.values()) + ["", None]) + ] + # Assert integers and floats contain a mix of int, float, None + _ints = [type(d["int"]) for d in action_data_list] + assert int in _ints + assert float in _ints + assert NoneType in _ints + _floats = [type(d["float"]) for d in action_data_list] + assert int in _floats + assert float in _floats + assert NoneType in _floats + # Assert booleans contain a mix of True, False, None + _bools = [d["bool"] for d in action_data_list] + assert True in _bools + assert False in _bools + assert None in _bools + # # Assert strings contain a mix of empty string and None + # _summaries = [d["str"] for d in action_data_list] + # assert None in _summaries + # assert "" in _summaries + knowledge_data = { + "Title": "Dune: Part Two.", + "Text": '"Dune: Part Two" is a film.', + # We use values that can be represented exactly as IEEE floats to ease comparison + "Title Embed": [-1.25] * setup.embedding_size, + "Text Embed": [0.25] * setup.embedding_size, + } + chat_data = dict(User=".") + extra_data = dict(good=True, words=5) + return Data( + data_list=[ + dict(**d, **knowledge_data, **chat_data, **extra_data) for d in action_data_list + ], + action_data_list=action_data_list, + knowledge_data=knowledge_data, + chat_data=chat_data, + extra_data=extra_data, + ) + + +def _add_row_default_data( + setup: ServingContext, + client: JamAI, + *, + table_type: TableType, + table_name: str, + stream: bool, +) -> tuple[MultiRowCompletionResponse, Data]: + data = _default_data(setup) + response = add_table_rows(client, table_type, table_name, data.data_list, stream=stream) + # Check returned chunks / response + for row in response.rows: + for col_name, col_value in row.columns.items(): + assert isinstance(col_name, str) + assert isinstance(col_value, (ChatCompletionResponse, ChatCompletionChunkResponse)) + assert isinstance(col_value.content, str) + assert len(col_value.content) > 0 + assert len(response.rows) == len(data.data_list) + # Check expected output columns + expected_columns = set(OUTPUT_COLUMNS) + if table_type == TableType.CHAT: + expected_columns |= {"AI"} + assert all(set(r.columns.keys()) == expected_columns for r in response.rows), ( + f"{response.rows[0].columns.keys()=}" + ) + return response, data + + +def _check_rows( + rows: list[dict[str, Any]], + data: list[dict[str, Any]], +): + assert len(rows) == len(data), f"Row count mismatch: {len(rows)=} != {len(data)=}" + for row, d in zip(rows, data, strict=True): + assert row["image"] is None or row["image"].endswith("/rabbit.jpeg"), row["image"] + assert row["audio"] is None or row["audio"].endswith("/gutter.mp3"), row["audio"] + assert row["document"] is None or row["document"].endswith( + "/LLMs as Optimizers [DeepMind ; 2023].pdf" + ), row["document"] + for col in d: + if col in ["ID", "Updated at"]: + assert row[col] != d[col], f'Column "{col}" is not regenerated: {d[col]=}' + continue + if col in FILE_COLUMNS: + continue + if d[col] not in [None, ""] or col == "str": + assert row[col] == d[col], f'Column "{col}" mismatch: {row[col]=} != {d[col]=}' + else: + assert row[col] is None, f'Column "{col}" mismatch: {row[col]=} != {d[col]=}' + + +def _check_knowledge_chat_data( + table_type: TableType, + rows: list[dict[str, Any]], + data: Data, +): + if table_type == TableType.KNOWLEDGE: + _check_rows(rows, [data.knowledge_data] * len(data.data_list)) + elif table_type == TableType.CHAT: + _check_rows(rows, [data.chat_data] * len(data.data_list)) + + +def _check_columns( + table_type: TableType, + rows: list[dict[str, Any]], +): + expected_cols = set(["ID", "Updated at"] + INPUT_COLUMNS + OUTPUT_COLUMNS) + if table_type == TableType.ACTION: + pass + elif table_type == TableType.KNOWLEDGE: + expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} + elif table_type == TableType.CHAT: + expected_cols |= {"User", "AI"} + else: + raise ValueError(f"Invalid table type: {table_type}") + assert all(isinstance(r, dict) for r in rows) + assert all(set(r.keys()) == expected_cols for r in rows), [list(r.keys()) for r in rows] + + +def _get_exponent(x: float) -> int: + return Decimal(str(x)).as_tuple().exponent + + +def _extract_number(text: str) -> int: + match = re.search(r"\[(\d+)\]", text) + return int(match.group(1)) if match else 0 + + +def _assert_dict_equal(d1: dict[str, Any], d2: dict[str, Any], exclude: list[str] | None = None): + if exclude is None: + exclude = [] + d1 = {k: v for k, v in d1.items() if k not in exclude} + d2 = {k: v for k, v in d2.items() if k not in exclude} + assert d1 == d2 + + +# TODO: Test add row with complete data including output columns + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_multi_image_input( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + """ + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + stream (bool): Stream (SSE) or not. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + image_uris = [ + upload_file(client, FILES["rabbit.jpeg"]).uri, + upload_file(client, FILES["doe.jpg"]).uri, + ] + cols = [ + ColumnSchemaCreate(id="file", dtype="file"), # Test `file` dtype compatibility + ColumnSchemaCreate(id="image", dtype="image"), + ColumnSchemaCreate( + id="o1", + dtype="str", + gen_config=LLMGenConfig(model=setup.desc_llm_model_id), + ), + ColumnSchemaCreate( + id="o2", + dtype="str", + gen_config=LLMGenConfig(model=setup.desc_llm_model_id, prompt="${image} ${o1}"), + ), + ] + with create_table(client, table_type, cols=cols) as table: + # Add rows + data = [ + dict(file=image_uris[0], image=image_uris[1]), + dict(file=image_uris[0], image=image_uris[1], o1="yeah"), + ] + response = add_table_rows(client, table_type, table.id, data, stream=stream) + assert len(response.rows) == len(data) + rows = {r.row_id: {k: v.content for k, v in r.columns.items()} for r in response.rows} + for row in response.rows: + o2 = row.columns["o2"].content + assert "image with MIME type [image/jpeg], shape [(307, 205, 3)]" in o2 + if "o1" in row.columns: + assert "text with [47] tokens" in o2 + o1 = row.columns["o1"].content + assert "image with MIME type [image/jpeg], shape [(1200, 1600, 3)]" in o1 + assert "image with MIME type [image/jpeg], shape [(307, 205, 3)]" in o1 + else: + assert "text with [1] tokens" in o2 + # List rows + _rows = list_table_rows(client, table_type, table.id) + assert len(_rows.items) == 2 + for row in _rows.values: + assert row["file"] == image_uris[0] + assert row["image"] == image_uris[1] + assert row["o1"] == rows[row["ID"]].get("o1", "yeah") + assert row["o2"] == rows[row["ID"]]["o2"] + + +@flaky(max_runs=3, min_passes=1) +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_reasoning_model_and_agentic_tools( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + """ + Tests reasoning and non-reasoning models, with and without web search tool. + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + stream (bool): Stream (SSE) or not. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="Question", dtype="str"), + ColumnSchemaCreate( + id="Reasoning Model", + dtype="str", + gen_config=LLMGenConfig( + model=setup.gpt_llm_reasoning_config_id, + prompt="${Question}", + reasoning_effort="low", + ), + ), + ColumnSchemaCreate( + id="Reasoning Model with Agent Mode", + dtype="str", + gen_config=LLMGenConfig( + model=setup.gpt_llm_reasoning_config_id, + prompt="${Question}", + tools=[WebSearchTool()], + reasoning_effort="low", + ), + ), + ColumnSchemaCreate( + id="Chat Model", + dtype="str", + gen_config=LLMGenConfig( + model=setup.gpt_llm_model_id, + prompt="${Question}", + ), + ), + ColumnSchemaCreate( + id="Chat Model with Agent Mode", + dtype="str", + gen_config=LLMGenConfig( + model=setup.gpt_llm_model_id, + prompt="${Question}", + tools=[WebSearchTool()], + ), + ), + ] + with create_table(client, table_type, cols=cols) as table: + data = [dict(Question="What is the current US interest rate?")] + response = add_table_rows(client, table_type, table.id, data, stream=stream) + assert len(response.rows) == len(data) + for row in response.rows: + reasoning = row.columns["Reasoning Model"].reasoning_content + assert "Searched the web for " not in reasoning + assert len(reasoning) > 0 + answer = row.columns["Reasoning Model"].content.lower() + assert len(answer) > 0 + assert "ERROR" not in answer + + reasoning = row.columns["Reasoning Model with Agent Mode"].reasoning_content + assert "Searched the web for " in reasoning + reasoning = reasoning.lower() + assert len(reasoning) > 0 + answer = row.columns["Reasoning Model with Agent Mode"].content.lower() + assert len(answer) > 0 + assert "ERROR" not in answer + + reasoning = row.columns["Chat Model"].reasoning_content + assert reasoning is None or reasoning == "" + answer = row.columns["Chat Model"].content.lower() + assert len(answer) > 0 + assert "ERROR" not in answer + + reasoning = row.columns["Chat Model with Agent Mode"].reasoning_content + assert "Searched the web for " in reasoning + answer = row.columns["Chat Model with Agent Mode"].content.lower() + assert len(answer) > 0 + assert "ERROR" not in answer + # List rows + _rows = list_table_rows(client, table_type, table.id) + assert len(_rows.items) == 1 + + +@pytest.mark.parametrize("table_type", [TableType.KNOWLEDGE]) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_knowledge_table_embedding( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + """ + Test Knowledge Table embeddings: + - Missing Title, Text, or both + - Embedding vector with invalid length + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + stream (bool): Stream (SSE) or not. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + with create_table(client, table_type, cols=[]) as table: + data = [ + # Complete + dict( + Title="Six-spot burnet", + Text="The six-spot burnet is a moth of the family Zygaenidae.", + ), + # Missing Title + dict( + Text="A neural network is a model inspired by biological neural networks.", + ), + # Missing Text + dict( + Title="A supercomputer has a high level of performance.", + ), + # Missing both + dict(), + ] + response = add_table_rows(client, table_type, table.id, data, stream=stream) + # We currently dont return anything if LLM is not called + assert len(response.rows) == 0 if stream else len(data) + assert all(len(r.columns) == 0 for r in response.rows) + rows = list_table_rows(client, table_type, table.id) + assert rows.total == len(data) + # Check embeddings + for row in rows.values: + assert_is_vector_or_none(row["Title Embed"], allow_none=False) + assert_is_vector_or_none(row["Text Embed"], allow_none=False) + # Check values + row = rows.values[0] + assert row["Title"] == data[0]["Title"], row + assert row["Text"] == data[0]["Text"], row + row = rows.values[1] + assert row["Title"] is None, row + assert row["Text"] == data[1]["Text"], row + row = rows.values[2] + assert row["Title"] == data[2]["Title"], row + assert row["Text"] is None, row + row = rows.values[3] + assert row["Title"] is None, row + assert row["Text"] is None, row + # If embedding with invalid length is added, it will be coerced to None + # Original vector will be saved into state + response = add_table_rows( + client, + table_type, + table.id, + [{"Title": "test", "Title Embed": [1, 2, 3]}], + stream=stream, + ) + # We currently dont return anything if LLM is not called + assert len(response.rows) == 0 if stream else 1 + assert all(len(r.columns) == 0 for r in response.rows) + # Check the vectors + rows = list_table_rows(client, table_type, table.id) + assert rows.total == 5 + row = rows.values[-1] + assert row["Title"] == "test", f"{row['Title']=}" + assert row["Title Embed"] is None, f"{row['Title Embed']=}" + assert row["Text"] is None, f"{row['Title']=}" + assert_is_vector_or_none(row["Text Embed"], allow_none=False) + + +@flaky(max_runs=3, min_passes=1) +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_rag( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + """ + Test RAG: + - Empty Knowledge Table + - Text query + - Single-turn and multi-turn + - Add and regen + - Text + Image query + - Single-turn and multi-turn + - Add and regen + - Chat thread references + - Inline citations + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + stream (bool): Stream (SSE) or not. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + with create_table( + client, TableType.KNOWLEDGE, cols=[ColumnSchemaCreate(id="Species", dtype="str")] + ) as kt: + ### --- Perform RAG --- ### + system_prompt = 'Reply "Unsure" if you don\'t know the answer. Do not guess. Be concise.' + gen_config_kwargs = dict( + model=setup.gpt_llm_model_id, + system_prompt=system_prompt, + prompt="${image}\n\nquestion: ${question}" if table_type == TableType.CHAT else "", + max_tokens=50, + temperature=0.001, + top_p=0.001, + ) + rag_kwargs = dict( + table_id=kt.id, + search_query="", # Generate using LM + k=2, + ) + cols = [ + ColumnSchemaCreate(id="question", dtype="str"), + ColumnSchemaCreate(id="image", dtype="image"), + ColumnSchemaCreate( + id="single", + dtype="str", + gen_config=LLMGenConfig( + multi_turn=False, + rag_params=RAGParams(reranking_model=None, **rag_kwargs), + **gen_config_kwargs, + ), + ), + ColumnSchemaCreate( + id="single-rerank", + dtype="str", + gen_config=LLMGenConfig( + multi_turn=False, + rag_params=RAGParams(reranking_model="", inline_citations=False, **rag_kwargs), + **gen_config_kwargs, + ), + ), + ColumnSchemaCreate( + id="multi", + dtype="str", + gen_config=LLMGenConfig( + multi_turn=True, + rag_params=RAGParams( + reranking_model=None, inline_citations=False, **rag_kwargs + ), + **gen_config_kwargs, + ), + ), + ] + + def _check_references(ref: References | None): + if ref is None: + return + _rows = list_table_rows(client, TableType.KNOWLEDGE, kt.id).values + ref_document_ids = {d["File ID"] for d in _rows[:2]} + document_ids = set(r.document_id for r in ref.chunks) + assert document_ids == ref_document_ids + ref_texts = {d["Text"] for d in _rows[:2]} + texts = set(r.text for r in ref.chunks) + assert len(texts) == min(len(_rows), rag_kwargs["k"]) + assert texts == ref_texts + contexts = [r.context for r in ref.chunks] + assert all("Species" in m for m in contexts) + metas = [r.metadata for r in ref.chunks] + assert all("rrf_score" in m for m in metas) + + def _check_row_references(references: list[dict[str, References]]): + for ref in references: + for r in ref.values(): + _check_references(r) + + def _get_content(row: RowCompletionResponse, col: str) -> str: + ref = row.columns[col].references + assert isinstance(ref, References) + _check_references(ref) + return row.columns[col].content.lower().strip() + + ### --- RAG on empty Knowledge Table --- ### + with create_table(client, table_type, cols=cols) as table: + col_map = {col.id: col.gen_config for col in table.cols} + # Assert that a default reranking model is set + assert col_map["single-rerank"].rag_params.reranking_model == setup.rerank_model_id + assert col_map["single"].rag_params.reranking_model is None + assert col_map["multi"].rag_params.reranking_model is None + # RAG + data = [dict(question="What is the name of the rabbit?")] + response = add_table_rows(client, table_type, table.id, data, stream=stream) + assert len(response.rows) == len(data) + # List rows (should have references) + rows = list_table_rows(client, table_type, table.id) + assert rows.total == len(data) + assert len(rows.references) == len(data) + _check_row_references(rows.references) + + ### --- Add data into Knowledge Table --- ### + data = [ + # Context + { + "Title": "Animal", + "Text": "Its name is Latte.", + "Species": "rabbit", + "File ID": "s3://animal-rabbit.jpeg", + }, + { + "Title": "Animal", + "Text": "Its name is Bambi.", + "Species": "doe", + "File ID": "s3://animal-doe.jpeg", + }, + # Distractor + { + "Title": "Country", + "Text": "Kuala Lumpur is the capital of Malaysia.", + "File ID": "s3://country-kuala-lumpur.pdf", + }, + ] + response = add_table_rows(client, TableType.KNOWLEDGE, kt.id, data, stream=False) + assert len(response.rows) == len(data) + kt_rows = list_table_rows(client, TableType.KNOWLEDGE, kt.id) + assert kt_rows.total == len(data) + + ### Text query + with create_table(client, table_type, cols=cols) as table: + col_map = {col.id: col.gen_config for col in table.cols} + # Assert that a default reranking model is set + assert col_map["single-rerank"].rag_params.reranking_model == setup.rerank_model_id + assert col_map["single"].rag_params.reranking_model is None + assert col_map["multi"].rag_params.reranking_model is None + # RAG + data = [ + dict(question="What is the name of the rabbit?"), # Latte + dict(question="What is its name again?"), # Unsure (single), Latte (multi) + ] + response = add_table_rows(client, table_type, table.id, data, stream=stream) + assert len(response.rows) == len(data) + # List rows (should have references) + rows = list_table_rows(client, table_type, table.id) + assert rows.total == len(data) + assert len(rows.references) == len(data) + _check_row_references(rows.references) + # Check answers + single = _get_content(response.rows[0], "single") + assert "latte" in single + assert len(re.findall(CITATION_PATTERN, single)) > 0 + assert "latte" in _get_content(response.rows[0], "single-rerank") + assert "latte" in _get_content(response.rows[0], "multi") + # "Unsure" tests are fragile + # assert "unsure" in _get_content(response.rows[1], "single") + # assert "unsure" in _get_content(response.rows[1], "single-rerank") + assert len(_get_content(response.rows[1], "single")) > 0 + assert len(_get_content(response.rows[1], "single-rerank")) > 0 + assert "latte" in _get_content(response.rows[1], "multi") + ### Update and regen + # Update question + row_ids = [r["ID"] for r in rows.items] + response = client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + data={row_ids[0]: dict(question="What is the name of the deer?")}, # Bambi + ), + ) + assert isinstance(response, OkResponse) + response = regen_table_rows(client, table_type, table.id, row_ids, stream=stream) + assert len(response.rows) == len(data) + # Check answers + single = _get_content(response.rows[0], "single") + assert "bambi" in single + assert len(re.findall(CITATION_PATTERN, single)) > 0 + assert "bambi" in _get_content(response.rows[0], "single-rerank") + assert "bambi" in _get_content(response.rows[0], "multi") + assert len(_get_content(response.rows[1], "single")) > 0 + assert len(_get_content(response.rows[1], "single-rerank")) > 0 + assert "bambi" in _get_content(response.rows[1], "multi") + + ### Text + Image query + image_uri = upload_file(client, FILES["rabbit.jpeg"]).uri + with create_table(client, table_type, cols=cols) as table: + col_map = {col.id: col.gen_config for col in table.cols} + # Assert that a default reranking model is set + assert col_map["single-rerank"].rag_params.reranking_model == setup.rerank_model_id + assert col_map["single"].rag_params.reranking_model is None + assert col_map["multi"].rag_params.reranking_model is None + # RAG + data = [ + # Latte + dict(question="What is the name of the animal?", image=image_uri, User="lala"), + # Unsure (single), Latte (multi) + dict(question="What is its name again?", User="lala"), + ] + response = add_table_rows(client, table_type, table.id, data, stream=stream) + assert len(response.rows) == len(data) + # List rows (should have references) + rows = list_table_rows(client, table_type, table.id) + assert rows.total == len(data) + assert len(rows.references) == len(data) + _check_row_references(rows.references) + assert "latte" in _get_content(response.rows[0], "single") + assert "latte" in _get_content(response.rows[0], "single-rerank") + assert "latte" in _get_content(response.rows[0], "multi") + # "Unsure" tests are fragile + # assert "unsure" in _get_content(response.rows[1], "single") + # assert "unsure" in _get_content(response.rows[1], "single-rerank") + assert len(_get_content(response.rows[1], "single")) > 0 + assert len(_get_content(response.rows[1], "single-rerank")) > 0 + assert "latte" in _get_content(response.rows[1], "multi") + ### Update and regen + # Update KT + kt_row_ids = [r["ID"] for r in kt_rows.items] + response = client.table.update_table_rows( + TableType.KNOWLEDGE, + MultiRowUpdateRequest( + table_id=kt.id, + data={kt_row_ids[1]: dict(Text="Its name is Daisy")}, + ), + ) + assert isinstance(response, OkResponse) + # Update image + row_ids = [r["ID"] for r in rows.items] + image_uri = upload_file(client, FILES["doe.jpg"]).uri + response = client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + # Daisy + data={row_ids[0]: dict(image=image_uri)}, + ), + ) + assert isinstance(response, OkResponse) + response = regen_table_rows(client, table_type, table.id, row_ids, stream=stream) + assert len(response.rows) == len(data) + # Check answers + assert "daisy" in _get_content(response.rows[0], "single") + assert "daisy" in _get_content(response.rows[0], "single-rerank") + assert "daisy" in _get_content(response.rows[0], "multi") + assert len(_get_content(response.rows[1], "single")) > 0 + assert len(_get_content(response.rows[1], "single-rerank")) > 0 + assert "daisy" in _get_content(response.rows[1], "multi") + + ### Chat thread references + col = "multi" + response = client.table.get_conversation_threads(table_type, table.id) + assert col in response.threads + assert response.table_id == table.id + thread = response.threads[col].thread + assert response.threads[col].column_id == col + for message in thread: + if message.role == "assistant": + assert isinstance(message.references, References) + assert len(message.references.chunks) == rag_kwargs["k"] + _check_references(message.references) + assert isinstance(message.row_id, str) + assert len(message.row_id) > 0 + elif message.role == "user": + assert isinstance(message.row_id, str) + assert len(message.row_id) > 0 + assert message.user_prompt is None + else: + assert isinstance(message.content, str) + assert message.row_id is None + message = thread[1] + assert message.role == "user" + assert isinstance(message.content, list) + assert len(message.content) == 2 + assert isinstance(message.content[0], S3Content) + assert message.content[0].uri == image_uri + assert isinstance(message.content[1], TextContent) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_column_dependency( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + """ + Test column dependency graph. + - Add and regen rows + - No dependency (single-turn, multi-turn) + - Single dependency (single-turn, multi-turn) + - Chain dependency + - Fan-in (with and without chain) and fan-out dependencies + - Multi-single-multi + - Gen config partial update + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + stream (bool): Stream (SSE) or not. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + gen_config_kwargs = dict(model=setup.echo_model_id, system_prompt="^") + cols = [ + ColumnSchemaCreate(id="c0", dtype="str"), + # ["s1", "m1", "s2", "s3", "m2", "s4", "s5", "s6", "s7", "m3", "s8", "m4"] + # Single dependency (single-turn) + ColumnSchemaCreate( + id="s1", + dtype="str", + gen_config=LLMGenConfig(prompt="s1 ${c0}", **gen_config_kwargs), + ), + # Single dependency (multi-turn) + ColumnSchemaCreate( + id="m1", + dtype="str", + gen_config=LLMGenConfig(prompt="m1 ${c0}", multi_turn=True, **gen_config_kwargs), + ), + # Chain dependency + ColumnSchemaCreate( + id="s2", + dtype="str", + gen_config=LLMGenConfig(prompt="s2 ${s1}", **gen_config_kwargs), + ), + # No dependency (single-turn) + ColumnSchemaCreate( + id="s3", + dtype="str", + gen_config=LLMGenConfig(prompt="s3", **gen_config_kwargs), + ), + # No dependency (multi-turn) + ColumnSchemaCreate( + id="m2", + dtype="str", + gen_config=LLMGenConfig(prompt="m2", multi_turn=True, **gen_config_kwargs), + ), + # Fan-out after chain dependency + ColumnSchemaCreate( + id="s4", + dtype="str", + gen_config=LLMGenConfig(prompt="s4 ${s2}", **gen_config_kwargs), + ), + ColumnSchemaCreate( + id="s5", + dtype="str", + gen_config=LLMGenConfig(prompt="s5 ${s2}", **gen_config_kwargs), + ), + ColumnSchemaCreate( + id="s6", + dtype="str", + gen_config=LLMGenConfig(prompt="s6 ${s5}", **gen_config_kwargs), + ), + # Fan-in (single-turn) + ColumnSchemaCreate( + id="s7", + dtype="str", + gen_config=LLMGenConfig(prompt="s7 ${s4} ${s6}", **gen_config_kwargs), + ), + # Fan-in (multi-turn) + ColumnSchemaCreate( + id="m3", + dtype="str", + gen_config=LLMGenConfig(prompt="m3 ${s4} ${s6}", multi_turn=True, **gen_config_kwargs), + ), + # Single dependency (single-turn after multi-turn) + ColumnSchemaCreate( + id="s8", + dtype="str", + gen_config=LLMGenConfig(prompt="s8 ${m3}", **gen_config_kwargs), + ), + # Multi-single-multi + ColumnSchemaCreate( + id="m4", + dtype="str", + gen_config=LLMGenConfig(prompt="m4 ${s8}", multi_turn=True, **gen_config_kwargs), + ), + ] + + def _content(row: RowCompletionResponse, col: str) -> str | None: + return getattr(row.columns.get(col, None), "content", "").strip() + + def _check(rows: list[RowCompletionResponse], base: str, exc: list[str] = None): + if exc is None: + exc = [] + # Check single-turn + for i, row in enumerate(rows): + assert "s1" in exc or _content(row, "s1") == f"^ s1 {base}{i}" + assert "s2" in exc or _content(row, "s2") == f"^ s2 {_content(row, 's1')}" + assert "s3" in exc or _content(row, "s3") == "^ s3" + assert "s4" in exc or _content(row, "s4") == f"^ s4 {_content(row, 's2')}" + assert "s5" in exc or _content(row, "s5") == f"^ s5 {_content(row, 's2')}" + assert "s6" in exc or _content(row, "s6") == f"^ s6 {_content(row, 's5')}" + assert "s7" in exc or _content(row, "s7") == f'^ s7 {_content(row, "s4")} {_content(row, "s6")}' # fmt:off + # Check multi-turn + gt = dict( + m1=[ + f"^ m1 {base}0", + f"^ m1 {base}0 m1 {base}1", + ], + m2=[ + "^ m2", + "^ m2 m2", + ], + m3=[ + f"^ m3 {_content(rows[0], 's4')} {_content(rows[0], 's6')}", + f"^ m3 {_content(rows[0], 's4')} {_content(rows[0], 's6')} m3 {_content(rows[1], 's4')} {_content(rows[1], 's6')}", + ], + s8=[ + f"^ s8 {_content(rows[0], 'm3')}", + f"^ s8 {_content(rows[1], 'm3')}", + ], + m4=[ + f"^ m4 {_content(rows[0], 's8')}", + f"^ m4 {_content(rows[0], 's8')} m4 {_content(rows[1], 's8')}", + ], + ) + for i, row in enumerate(response.rows): + assert "m1" in exc or _content(row, "m1") == gt["m1"][i] + assert "m2" in exc or _content(row, "m2") == gt["m2"][i] + assert "m4" in exc or _content(row, "m3") == gt["m3"][i] + assert "s8" in exc or _content(row, "s8") == gt["s8"][i] + assert "m4" in exc or _content(row, "m4") == gt["m4"][i] + + with create_table(client, table_type, cols=cols) as table: + ### --- Add rows --- ### + data = [dict(c0="r0"), dict(c0="r1")] + response = add_table_rows(client, table_type, table.id, data, stream=stream) + assert len(response.rows) == len(data) + _check(response.rows, "r") + ### --- Regen rows --- ### + row_ids = [r.row_id for r in response.rows] + # Regen all + client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + data={row.row_id: dict(c0=f"z{i}") for i, row in enumerate(response.rows)}, + ), + ) + response = regen_table_rows( + client, + table_type, + table.id, + row_ids, + stream=stream, + regen_strategy=RegenStrategy.RUN_ALL, + ) + assert len(response.rows) == len(data) + _check(response.rows, "z") + # Regen before + client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + data={row.row_id: dict(c0=f"aa{i}") for i, row in enumerate(response.rows)}, + ), + ) + response = regen_table_rows( + client, + table_type, + table.id, + row_ids, + stream=stream, + regen_strategy=RegenStrategy.RUN_BEFORE, + output_column_id="m3", + ) + assert len(response.rows) == len(data) + # _check(response.rows, "z", ["s1", "m1", "s2", "s3", "m2", "s4", "s5", "s6", "s7", "m3"]) + _check(response.rows, "aa", ["s8", "m4"]) + # Regen after + client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + data={row.row_id: dict(c0=f"bb{i}") for i, row in enumerate(response.rows)}, + ), + ) + response = regen_table_rows( + client, + table_type, + table.id, + row_ids, + stream=stream, + regen_strategy=RegenStrategy.RUN_AFTER, + output_column_id="s2", + ) + assert len(response.rows) == len(data) + assert _content(response.rows[0], "s2") == "^ s2 ^ s1 aa0" # Still "aa" + assert _content(response.rows[1], "s2") == "^ s2 ^ s1 aa1" # Still "aa" + _check(response.rows, "aa", ["s1", "m1", "s2"]) # Still "aa" + response = regen_table_rows( + client, + table_type, + table.id, + row_ids, + stream=stream, + regen_strategy=RegenStrategy.RUN_AFTER, + output_column_id="s1", + ) + assert len(response.rows) == len(data) + _check(response.rows, "bb") + # Regen selected + client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + data={row.row_id: dict(c0=f"cc{i}") for i, row in enumerate(response.rows)}, + ), + ) + response = regen_table_rows( + client, + table_type, + table.id, + row_ids, + stream=stream, + regen_strategy=RegenStrategy.RUN_SELECTED, + output_column_id="m1", + ) + assert len(response.rows) == len(data) + # _check(response.rows, "bb", ["m1"]) + assert _content(response.rows[0], "m1") == "^ m1 cc0" + assert _content(response.rows[1], "m1") == "^ m1 cc0 m1 cc1" + # Update gen config and regen + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, column_map=dict(s8=LLMGenConfig(prompt="s8 ${m2}")) + ), + ) + gen_configs = {c.id: c.gen_config for c in table.cols} + assert gen_configs["s8"].system_prompt == "^" + assert gen_configs["s8"].prompt == "s8 ${m2}" + response = regen_table_rows( + client, + table_type, + table.id, + row_ids, + stream=stream, + regen_strategy=RegenStrategy.RUN_AFTER, + output_column_id="s8", + ) + assert _content(response.rows[0], "m4") == "^ m4 ^ s8 ^ m2" + assert _content(response.rows[1], "m4") == "^ m4 ^ s8 ^ m2 m4 ^ s8 ^ m2 m2" + + +@pytest.mark.parametrize( + "python_code", + [ + { + "input": "Hello, World!", + "code": "row['result_column']=row['input']", + "expected": "Hello, World!", + }, + { + "input": "2", + "code": "row['result_column'] = int(row['input']) + int(row['input'])", + "expected": "4", + }, + # Test error handling: + { + "input": "DUMMY", + "code": "row['result_column']=row['undefined']", + "expected": "KeyError: 'undefined'", + }, + ], +) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +async def test_python_fixed_function_str( + setup: ServingContext, + stream: bool, + python_code: dict, +): + table_type = TableType.ACTION + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="input", dtype="str"), + ColumnSchemaCreate( + id="result_column", + dtype="str", + gen_config=PythonGenConfig(python_code=python_code["code"]), + ), + ] + with create_table(client, table_type, cols=cols) as table: + data = [{"input": python_code["input"]}] + # Add rows + response = add_table_rows( + client, table_type, table.id, data, stream=stream, check_usage=False + ) + assert len(response.rows) == len(data) + rows = list_table_rows(client, table_type, table.id) + row_ids = [r.row_id for r in response.rows] + assert rows.total == len(data) + assert rows.values[0]["result_column"] == python_code["expected"] + # Regen rows + response = regen_table_rows( + client, table_type, table.id, row_ids, stream=stream, check_usage=False + ) + assert len(response.rows) == len(data) + rows = list_table_rows(client, table_type, table.id) + assert rows.total == len(data) + assert rows.values[0]["result_column"] == python_code["expected"] + + +def _read_file_content(file_path): + with open(file_path, "rb") as f: + return f.read() + + +@pytest.mark.parametrize( + "image_path", + [ + FILES["cifar10-deer.jpg"], + FILES["rabbit.png"], + FILES["rabbit_cifar10-deer.gif"], + FILES["rabbit_cifar10-deer.webp"], + ], + ids=lambda x: basename(x), +) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +async def test_python_fixed_function_image( + setup: ServingContext, + stream: bool, + image_path: str, +): + table_type = TableType.ACTION + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="source_image", dtype="image"), + ColumnSchemaCreate( + id="result_column", + dtype="image", + gen_config=PythonGenConfig(python_code="row['result_column']=row['source_image']"), + ), + ] + + with create_table(client, table_type, cols=cols) as table: + image_uri = upload_file(client, image_path).uri + data = [{"source_image": image_uri}] + # Add rows + response = add_table_rows( + client, table_type, table.id, data, stream=stream, check_usage=False + ) + assert len(response.rows) == len(data) + rows = list_table_rows(client, table_type, table.id) + row_ids = [r.row_id for r in response.rows] + assert rows.total == len(data) + file_uri = rows.values[0]["result_column"] + assert file_uri.startswith(("file://", "s3://")) + response = client.file.get_raw_urls([file_uri]) + assert isinstance(response, GetURLResponse) + # Compare the contents + downloaded_content = httpx.get(response.urls[0]).content + original_content = _read_file_content(image_path) + assert original_content == downloaded_content, f"Content mismatch for file: {image_path}" + # Regen rows + response = regen_table_rows( + client, table_type, table.id, row_ids, stream=stream, check_usage=False + ) + assert len(response.rows) == len(data) + rows = list_table_rows(client, table_type, table.id) + assert rows.total == len(data) + file_uri = rows.values[0]["result_column"] + assert file_uri.startswith(("file://", "s3://")) + response = client.file.get_raw_urls([file_uri]) + assert isinstance(response, GetURLResponse) + # Compare the contents + downloaded_content = httpx.get(response.urls[0]).content + original_content = _read_file_content(image_path) + assert original_content == downloaded_content, f"Content mismatch for file: {image_path}" + + +def _assert_context_error(content: str) -> None: + assert "maximum context length is 10 tokens" in content + assert content.startswith("[ERROR]") + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_error_cases( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + """ + Test error cases. + - Row add & regen: Downstream columns exceed context length + - Row add & regen: All columns exceed context length + - Error circuit breaker + - Non-existent output column during regen + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + stream (bool): Stream (SSE) or not. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + max_tokens = 8 + num_output_cols = 2 + cols = [ColumnSchemaCreate(id="c0", dtype="str")] + cols += [ + ColumnSchemaCreate( + id=f"c{i + 1}", + dtype="str", + gen_config=LLMGenConfig( + model=setup.short_llm_model_id, + system_prompt=".", + prompt=f"${{c{i}}}", + max_tokens=max_tokens, + ), + ) + for i in range(num_output_cols) + ] + with create_table(client, table_type, cols=cols) as table: + ### --- Context length --- ### + ### Downstream exceed context length + # Row add + data = [dict(c0="0"), dict(c0="1")] + response = add_table_rows(client, table_type, table.id, data, stream=stream) + assert len(response.rows) == len(data) + for row in response.rows: + assert "Lorem ipsum dolor sit amet" in row.columns["c1"].content + _assert_context_error(row.columns["c2"].content) + # Row regen + response = regen_table_rows( + client, + table_type, + table.id, + [r.row_id for r in response.rows], + stream=stream, + regen_strategy=RegenStrategy.RUN_ALL, + ) + for row in response.rows: + assert "Lorem ipsum dolor sit amet" in row.columns["c1"].content + _assert_context_error(row.columns["c2"].content) + ### All exceed context length + # Row add + data = [dict(c0="0 0"), dict(c0="1 1")] + response = add_table_rows(client, table_type, table.id, data, stream=stream) + assert len(response.rows) == len(data) + for row in response.rows: + _assert_context_error(row.columns["c1"].content) + assert "Upstream columns errored out" in row.columns["c2"].content + # Row regen + response = regen_table_rows( + client, + table_type, + table.id, + [r.row_id for r in response.rows], + stream=stream, + regen_strategy=RegenStrategy.RUN_ALL, + ) + for row in response.rows: + _assert_context_error(row.columns["c1"].content) + assert "Upstream columns errored out" in row.columns["c2"].content + + ### --- Regen rows with invalid column --- ### + row_ids = [r.row_id for r in response.rows] + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map={ + f"c{i + 1}": LLMGenConfig(max_tokens=2) for i in range(num_output_cols) + }, + ), + ) + strategies = [ + RegenStrategy.RUN_ALL, + RegenStrategy.RUN_BEFORE, + RegenStrategy.RUN_AFTER, + RegenStrategy.RUN_SELECTED, + ] + for strategy in strategies: + with pytest.raises(ResourceNotFoundError): + regen_table_rows( + client, + table_type, + table.id, + row_ids, + stream=stream, + regen_strategy=strategy, + output_column_id="x", + ) + + +def _assert_consecutive(lst: list) -> bool: + """ + Assert that identical elements occur consecutively in the list. + + Args: + lst: List of strings + + Raises: + AssertionError: If identical elements are not grouped together + """ + if not lst: + raise AssertionError("List is empty") + seen = {lst[0]} + current_element = lst[0] + for element in lst[1:]: + if element != current_element: + # We're starting a new group + if element in seen: + return False + seen.add(element) + current_element = element + return True + + +@flaky(max_runs=5, min_passes=1) +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_concurrency_stream( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + max_tokens = 10 + num_output_cols = 3 + num_rows = 2 + cols = [ColumnSchemaCreate(id="str", dtype="str")] + cols += [ + ColumnSchemaCreate( + id=f"o{i + 1}", + dtype="str", + gen_config=LLMGenConfig( + model=setup.lorem_llm_model_id, + system_prompt="", + prompt="", + max_tokens=max_tokens, + ), + ) + for i in range(num_output_cols) + ] + with create_table(client, table_type, cols=cols) as table: + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest( + table_id=table.id, + data=[dict(str="Lorem ipsum dolor sit amet")] * num_rows, + stream=True, + ), + ) + chunks = [r for r in response if isinstance(r, CellCompletionResponse)] + ### --- Column concurrency --- ### + # Assert that all columns are concurrently generated + rows: dict[str, list[CellCompletionResponse]] = defaultdict(list) + for c in chunks: + rows[c.row_id].append(c) + for row in rows.values(): + chunk_cols = [r.output_column_name for r in row] + assert len(chunk_cols) > num_output_cols * num_rows + _cols = set(chunk_cols[: len(chunk_cols) // 2]) + assert len(_cols) >= 1 + assert not _assert_consecutive(chunk_cols) + ### --- Row concurrency --- ### + row_ids = list(rows.keys()) + chunk_rows = [c.row_id for c in chunks] + # print(f"{[row_ids.index(c.row_id) for c in chunks]=}") + multiturn_cols = [c for c in table.cols if getattr(c.gen_config, "multi_turn", False)] + if len(multiturn_cols) > 0: + # Tables with multi-turn column must have its rows are sequentially generated + for i, row_id in enumerate(row_ids): + chunks_per_row = len(chunk_rows) // len(row_ids) + _chunks = chunk_rows[i * chunks_per_row : (i + 1) * chunks_per_row] + assert row_id in _chunks + assert _assert_consecutive(chunk_rows) + else: + # Tables without must have its rows concurrently generated + _rows = set(chunk_rows[: len(chunk_rows) // num_rows]) + assert len(_rows) == num_rows + for row_id in row_ids: + assert row_id in _rows + assert not _assert_consecutive(chunk_rows) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_multimodal_multiturn( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + """ + Tests multimodal multiturn generation. + - Ensure files are fetched/interpolated from the correct row in a multiturn setting + - Ensure files in history are updated after an earlier row is updated + - Add and regen row + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + stream (bool): Stream (SSE) or not. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="str", dtype="str"), + ColumnSchemaCreate(id="image", dtype="image"), + ColumnSchemaCreate(id="audio", dtype="audio"), + ColumnSchemaCreate(id="document", dtype="document"), + ColumnSchemaCreate( + id="chat", + dtype="str", + gen_config=LLMGenConfig( + model=setup.desc_llm_model_id, + system_prompt="", + prompt="${str} ${image} ${audio} ${document}", + max_tokens=20, + multi_turn=True, + ), + ), + ] + with ( + TemporaryDirectory() as tmp_dir, + create_table(client, table_type, cols=cols) as table, + ): + text_fp = join(tmp_dir, "test.txt") + with open(text_fp, "w") as f: + f.write("Two tokens") + doc_uri = upload_file(client, text_fp).uri + image_uri = upload_file(client, FILES["rabbit.jpeg"]).uri + audio_uri = upload_file(client, FILES["gutter.mp3"]).uri + ### --- Add rows --- ### + response = add_table_rows( + client, + table_type, + table.id, + [ + dict(str="one", image=image_uri, audio=audio_uri, document=doc_uri), + dict(str="one", image=image_uri, audio=audio_uri, document=doc_uri), + ], + stream=stream, + ) + # Check returned chunks / response + for row in response.rows: + chat = row.columns["chat"].content + # print(chat) + chat_contents = chat.split("\n") + assert "System prompt:" in chat_contents[0] + assert _extract_number(chat_contents[0]) > 10 + assert "[image/jpeg], shape [(1200, 1600, 3)]" in chat + assert "[image/jpeg], shape [(32, 32, 3)]" not in chat + assert "[audio/mpeg]" in chat + assert "text with [5] tokens" in chat + assert len(response.rows) == 2 + chat = response.rows[0].columns["chat"].content + chat_contents = chat.split("\n") + assert len(chat.split("\n")) == 4 + chat = response.rows[1].columns["chat"].content + chat_contents = chat.split("\n") + assert len(chat.split("\n")) == 7 + # Update image in first row + image_uri = upload_file(client, FILES["cifar10-deer.jpg"]).uri + client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + data={response.rows[0].row_id: dict(image=image_uri)}, + ), + ) + # Add a row + response = add_table_rows( + client, + table_type, + table.id, + [dict(str="one")], + stream=stream, + ) + assert len(response.rows) == 1 + chat = response.rows[0].columns["chat"].content + # print(chat) + assert "[image/jpeg], shape [(1200, 1600, 3)]" in chat + assert "[image/jpeg], shape [(32, 32, 3)]" in chat # Updated image + assert "[audio/mpeg]" in chat + assert "text with [5] tokens" in chat + assert "text with [1] tokens" in chat + ### --- Regen row --- ### + row = response.rows[0] + response = regen_table_rows(client, table_type, table.id, [row.row_id], stream=stream) + assert len(response.rows) == 1 + chat = response.rows[0].columns["chat"].content + assert "[image/jpeg], shape [(1200, 1600, 3)]" in chat + assert "[image/jpeg], shape [(32, 32, 3)]" in chat # Updated image + assert "[audio/mpeg]" in chat + assert "text with [5] tokens" in chat + assert "text with [1] tokens" in chat + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_add_get_list_rows( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + """ + Test adding a row to a table. + - All column dtypes + - Various languages + + Test get row and list rows from a table. + - offset and limit + - order_by and order_ascending + - where + - search_query and search_columns + - column subset + - float & vector precision + - vector column exclusion + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + stream (bool): Stream (SSE) or not. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + cols = [ColumnSchemaCreate(id=c, dtype=c) for c in INPUT_COLUMNS] + cols += [ + ColumnSchemaCreate( + id=c, + dtype="str", + gen_config=LLMGenConfig( + model="", + system_prompt="", + prompt="", + max_tokens=10, + ), + ) + for c in OUTPUT_COLUMNS + ] + with create_table(client, table_type, cols=cols) as table: + ### --- Add row with all dtypes --- ### + _, data = _add_row_default_data( + setup, + client, + table_type=table_type, + table_name=table.id, + stream=stream, + ) + num_data = len(data.data_list) + return + + ### --- List rows --- ### + rows = list_table_rows(client, table_type, table.id) + # Check row count + assert len(rows.items) == len(data.data_list), ( + f"Row count mismatch: {len(rows.items)=} != {num_data=}" + ) + assert rows.total == len(data.data_list), ( + f"Row count mismatch: {rows.total=} != {num_data=}" + ) + # Check row data + _check_rows(rows.values, data.action_data_list) + _check_knowledge_chat_data(table_type, rows.values, data) + # Check output columns + for row in rows.values: + for c in OUTPUT_COLUMNS: + summary = row[c] + assert "There is a text" in summary, summary + if row["image"]: + assert "There is an image with MIME type [image/jpeg]" in summary, summary + if row["audio"]: + assert "There is an audio with MIME type [audio/mpeg]" in summary, summary + # Check columns + _check_columns(table_type, rows.items) + + ### --- Get row --- ### + for row in rows.items: + _row = get_table_row(client, table_type, table.id, row["ID"]) + assert _row == row, f'Row "{row["ID"]}" mismatch: {_row=} != {row=}' + + ### --- List rows (offset and limit) --- ### + _rows = list_table_rows(client, table_type, table.id, offset=0, limit=1) + assert len(_rows.items) == 1 + assert _rows.total == num_data + assert _rows.items[0]["ID"] == rows.items[0]["ID"], f"{_rows.items=}" + _rows = list_table_rows(client, table_type, table.id, offset=1, limit=1) + assert len(_rows.items) == 1 + assert _rows.total == num_data + assert _rows.items[0]["ID"] == rows.items[1]["ID"], f"{_rows.items=}" + # Offset >= num rows + _rows = list_table_rows(client, table_type, table.id, offset=num_data, limit=1) + assert len(_rows.items) == 0 + assert _rows.total == num_data + _rows = list_table_rows(client, table_type, table.id, offset=num_data + 1, limit=1) + assert len(_rows.items) == 0 + assert _rows.total == num_data + # Invalid offset and limit + with pytest.raises(BadInputError): + list_table_rows(client, table_type, table.id, offset=0, limit=0) + with pytest.raises(BadInputError): + list_table_rows(client, table_type, table.id, offset=-1, limit=1) + + ### --- List rows (order_by and order_ascending) --- ### + _rows = list_table_rows(client, table_type, table.id, order_ascending=False) + assert len(_rows.items) == num_data + assert _rows.total == num_data + assert _rows.items[::-1] == rows.items + _rows = list_table_rows(client, table_type, table.id, order_by="str") + assert len(_rows.items) == num_data + assert _rows.total == num_data + assert _rows.items[::-1] == rows.items + + ### --- List rows (where) --- ### + _rows = list_table_rows(client, table_type, table.id, search_query="Arri") + assert len(_rows.items) == 3 + assert _rows.total == 3 + assert _rows.total != num_data + _id = rows.items[0]["ID"] + _rows = list_table_rows( + client, table_type, table.id, search_query="Arri", where=f""""ID" > '{_id}'""" + ) + assert len(_rows.items) == 2 + assert _rows.total == 2 + _rows = list_table_rows(client, table_type, table.id, where=f""""ID" = '{_id}'""") + assert len(_rows.items) == 1 + assert _rows.total == 1 + + ### --- List rows (search_query and search_columns) --- ### + _rows = list_table_rows(client, table_type, table.id, search_query="Arri") + assert len(_rows.items) == 3 + assert _rows.total == 3 + assert _rows.total != num_data + _rows = list_table_rows(client, table_type, table.id, search_query="Arri", offset=1) + assert len(_rows.items) == 2 + assert _rows.total == 3 + assert _rows.total != num_data + _rows = list_table_rows( + client, table_type, table.id, search_query="Arri", search_columns=["str"] + ) + assert len(_rows.items) == 3 + assert _rows.total == 3 + assert _rows.total != num_data + _rows = list_table_rows( + client, table_type, table.id, search_query="Arri", search_columns=OUTPUT_COLUMNS + ) + assert len(_rows.items) == 0 + assert _rows.total == 0 + + ### --- Get & List rows (column subset) --- ### + _rows = list_table_rows(client, table_type, table.id, limit=2, columns=["str", "bool"]) + expected_columns = {"ID", "Updated at", "str", "bool"} + for row in _rows.items: + cols = set(row.keys()) + assert cols == expected_columns, ( + f"Column order mismatch: {cols=} != {expected_columns=}" + ) + _row = get_table_row(client, table_type, table.id, row["ID"], columns=["str", "bool"]) + assert _row == row, f'Row "{row["ID"]}" mismatch: {_row=} != {row=}' + assert "value" in row["bool"], _row + assert "value" in _row["bool"], _row + + ### --- Get & List rows (float & vector precision) --- ### + # Round to 1 decimal + _rows = list_table_rows( + client, table_type, table.id, limit=2, float_decimals=1, vec_decimals=1 + ) + for row in _rows.items: + exponent = _get_exponent(row["float"]["value"]) + assert exponent >= -1, exponent + if table_type == TableType.KNOWLEDGE: + for col in ["Title Embed", "Text Embed"]: + exponents = [_get_exponent(v) for v in row[col]["value"]] + assert all(e >= -1 for e in exponents), exponents + _row = get_table_row( + client, table_type, table.id, row["ID"], float_decimals=1, vec_decimals=1 + ) + assert _row == row, f'Row "{row["ID"]}" mismatch: {_row=} != {row=}' + # No vector columns + _rows = list_table_rows( + client, table_type, table.id, limit=2, float_decimals=1, vec_decimals=-1 + ) + for row in _rows.items: + exponent = _get_exponent(row["float"]["value"]) + assert exponent >= -1, exponent + assert "Title Embed" not in row + assert "Text Embed" not in row + _row = get_table_row( + client, table_type, table.id, row["ID"], float_decimals=1, vec_decimals=-1 + ) + assert _row == row, f'Row "{row["ID"]}" mismatch: {_row=} != {row=}' + + +def test_list_rows_case_insensitive_sort(setup: ServingContext): + table_type = TableType.ACTION + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + cols = [ColumnSchemaCreate(id="str", dtype="str")] + with create_table(client, table_type, cols=cols) as table: + add_table_rows( + client, + table_type, + table.id, + [dict(str="a"), dict(str="B"), dict(str="C"), dict(str="d")][::-1], + stream=False, + ) + ### --- List rows --- ### + rows = list_table_rows(client, table_type, table.id) + assert [r["str"] for r in rows.values] == ["a", "B", "C", "d"][::-1] + rows = list_table_rows(client, table_type, table.id, order_by="str") + assert [r["str"] for r in rows.values] == ["a", "B", "C", "d"] + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_update_row( + setup: ServingContext, + table_type: TableType, +): + """ + Test row updates. + - All column dtypes + - ID should not be updated even if provided + - Updating data with wrong dtype or vector length should store None + - Updating embedding directly should work + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + with create_table(client, table_type) as table: + ### --- Add row with all dtypes --- ### + data = [ + { + "ID": "0", + "Updated at": "1990-05-13T09:01:50.010756+00:00", + "int": 1, + "float": -1.25, + "bool": True, + "str": "moka", + "image": setup.image_uri, + "audio": setup.audio_uri, + "document": setup.document_uri, + "Title": "Dune: Part Two.", + "Text": '"Dune: Part Two" is a film.', + "Title Embed": [-1.25] * setup.embedding_size, + "Text Embed": [0.25] * setup.embedding_size, + "User": "Hi", + "AI": "Hello", + } + ] + add_table_rows(client, table_type, table.id, data, stream=False) + rows = list_table_rows(client, table_type, table.id) + assert len(rows.items) == 1 + row = rows.values[0] + t0 = datetime.fromisoformat(row["Updated at"]) + + # ID should not be updated, the rest OK + data = dict(ID="2", float=1.0, bool=False) + response = client.table.update_table_rows( + table_type, + MultiRowUpdateRequest(table_id=table.id, data={row["ID"]: data}), + ) + assert isinstance(response, OkResponse) + _rows = list_table_rows(client, table_type, table.id) + assert len(_rows.items) == 1 + _row = _rows.values[0] + t1 = datetime.fromisoformat(_row["Updated at"]) + assert _row["float"] == data["float"] + assert _row["bool"] == data["bool"] + _assert_dict_equal(row, _row, exclude=["Updated at", "float", "bool"]) + assert t1 > t0 + + # Test updating data with wrong dtype + data = dict(ID="2", int="str", float="str", bool="str") + response = client.table.update_table_rows( + table_type, + MultiRowUpdateRequest(table_id=table.id, data={row["ID"]: data}), + ) + assert isinstance(response, OkResponse) + _rows = list_table_rows(client, table_type, table.id) + assert len(_rows.items) == 1 + _row = _rows.values[0] + t2 = datetime.fromisoformat(_row["Updated at"]) + assert _row["int"] is None + assert _row["float"] is None + assert _row["bool"] is None + _assert_dict_equal(row, _row, exclude=["Updated at", "int", "float", "bool"]) + assert t2 > t1 + + if table_type == TableType.KNOWLEDGE: + # Test updating embedding columns directly + response = client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + data={ + row["ID"]: { + "Title Embed": [0] * len(row["Title Embed"]), + "Text Embed": [1] * len(row["Text Embed"]), + } + }, + ), + ) + assert isinstance(response, OkResponse) + _rows = list_table_rows(client, table_type, table.id) + assert len(_rows.items) == 1 + _row = _rows.values[0] + t3 = datetime.fromisoformat(_row["Updated at"]) + assert sum(_row["Title Embed"]) == 0 + assert sum(_row["Text Embed"]) == len(row["Text Embed"]) + assert t3 > t2 + # Test updating embedding columns with wrong length + response = client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + data={row["ID"]: {"Title Embed": [0], "Text Embed": [0]}}, + ), + ) + assert isinstance(response, OkResponse) + _rows = list_table_rows(client, table_type, table.id) + assert len(_rows.items) == 1 + _row = _rows.values[0] + t4 = datetime.fromisoformat(_row["Updated at"]) + assert _row["Title Embed"] is None + assert _row["Text Embed"] is None + assert t4 > t3 + + +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_regen_embedding( + setup: ServingContext, + stream: bool, +): + table_type = TableType.KNOWLEDGE + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + with create_table(client, table_type, cols=[]) as table: + # Add row + data = [{"Title": "Dune: Part Two.", "Text": '"Dune: Part Two" is a film.'}] + add_table_rows(client, table_type, table.id, data, stream=False) + rows = list_table_rows(client, table_type, table.id) + assert len(rows.items) == 1 + r0 = rows.values[0] + t0 = datetime.fromisoformat(r0["Updated at"]) + # Update row + response = client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + data={r0["ID"]: {"Title": "hi", "Text": "papaya"}}, + ), + ) + assert isinstance(response, OkResponse) + rows = list_table_rows(client, table_type, table.id) + assert len(rows.items) == 1 + r1 = rows.values[0] + t1 = datetime.fromisoformat(r1["Updated at"]) + assert t1 > t0 + assert r1["Title"] != r0["Title"] + assert r1["Text"] != r0["Text"] + assert r1["Title Embed"] == r0["Title Embed"] + assert r1["Text Embed"] == r0["Text Embed"] + # Regen row + regen_table_rows(client, table_type, table.id, [r0["ID"]], stream=stream) + rows = list_table_rows(client, table_type, table.id) + assert len(rows.items) == 1 + r2 = rows.values[0] + t2 = datetime.fromisoformat(r2["Updated at"]) + assert t2 > t1 + assert r2["Title"] != r0["Title"] + assert r2["Text"] != r0["Text"] + assert r2["Title Embed"] != r0["Title Embed"] + assert r2["Text Embed"] != r0["Text Embed"] + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_multiturn_regen( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + """ + Tests multiturn row regen. + - Each row correctly sees the regenerated output of the previous row + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + stream (bool): Stream (SSE) or not. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="User", dtype="str"), + ColumnSchemaCreate( + id="AI", + dtype="str", + gen_config=LLMGenConfig( + model=setup.gpt_llm_model_id, + system_prompt="", + prompt="${User}", + max_tokens=20, + multi_turn=True, + ), + ), + ] + if table_type == TableType.CHAT: + chat_cols, cols = cols, [] + else: + chat_cols = None + with create_table(client, table_type, cols=cols, chat_cols=chat_cols) as table: + ### --- Add rows --- ### + response = add_table_rows( + client, + table_type, + table.id, + [ + dict(User="Hi", AI="How are you?"), + dict(User="Repeat your previous response."), + dict(User="Repeat your previous response."), + ], + stream=stream, + ) + # Check returned chunks / response + if stream: + assert len(response.rows) == 2 + else: + assert len(response.rows) == 3 + response.rows = response.rows[1:] + for row in response.rows: + chat = row.columns["AI"].content.strip() + assert chat == "How are you?", f"{row.columns=}" + # Update the second row + client.table.update_table_rows( + table_type, + MultiRowUpdateRequest( + table_id=table.id, + data={response.rows[0].row_id: dict(User="Good. What is 5+5?")}, + ), + ) + ### --- Regen rows --- ### + response = regen_table_rows( + client, + table_type, + table.id, + [response.rows[0].row_id, response.rows[1].row_id], + stream=stream, + ) + assert len(response.rows) == 2 + for row in response.rows: + chat = row.columns["AI"].content.strip() + assert chat != "How are you?", f"{row.columns=}" + assert "10" in chat, f"{row.columns=}" diff --git a/services/api/tests/gen_table/test_table_ops.py b/services/api/tests/gen_table/test_table_ops.py new file mode 100644 index 0000000..9512ff6 --- /dev/null +++ b/services/api/tests/gen_table/test_table_ops.py @@ -0,0 +1,2160 @@ +import re +from contextlib import contextmanager +from dataclasses import dataclass +from os.path import dirname, join, realpath +from typing import Generator + +import pytest + +from jamaibase import JamAI +from jamaibase.types import ( + ActionTableSchemaCreate, + AddActionColumnSchema, + AddChatColumnSchema, + AddKnowledgeColumnSchema, + CellCompletionResponse, + ChatCompletionResponse, # Assuming this might be needed for detailed checks later + ChatTableSchemaCreate, + ColumnDropRequest, + ColumnRenameRequest, + ColumnReorderRequest, + ColumnSchema, + ColumnSchemaCreate, + DeploymentCreate, + GenConfigUpdateRequest, + KnowledgeTableSchemaCreate, + MultiRowAddRequest, + MultiRowCompletionResponse, + OrganizationCreate, + RAGParams, + RowCompletionResponse, + TableMetaResponse, +) +from owl.types import ( + CloudProvider, + LLMGenConfig, + Role, + TableType, +) +from owl.utils.exceptions import ( + BadInputError, + ResourceExistsError, + ResourceNotFoundError, +) +from owl.utils.test import ( + ELLM_EMBEDDING_CONFIG, + ELLM_EMBEDDING_DEPLOYMENT, + GPT_4O_MINI_CONFIG, + GPT_4O_MINI_DEPLOYMENT, + RERANK_ENGLISH_v3_SMALL_CONFIG, + RERANK_ENGLISH_v3_SMALL_DEPLOYMENT, + create_deployment, + create_model_config, + create_organization, + create_project, + create_user, + get_file_map, + list_table_rows, + upload_file, +) + +TEST_FILE_DIR = join(dirname(dirname(realpath(__file__))), "files") +FILES = get_file_map(TEST_FILE_DIR) +EMBEDDING_MODEL = "openai/text-embedding-3-small" +TABLE_TYPES = [TableType.ACTION, TableType.KNOWLEDGE, TableType.CHAT] +REGULAR_COLUMN_DTYPES: list[str] = ["int", "float", "bool", "str"] +SAMPLE_DATA = { + "int": -1, + "float": -0.9, + "bool": True, + "str": '"Arrival" is a 2016 science fiction film. "Arrival" è un film di fantascienza del 2016. 「Arrivalã€ã¯2016å¹´ã®SF映画ã§ã™ã€‚', +} +KT_FIXED_COLUMN_IDS = ["Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"] +CT_FIXED_COLUMN_IDS = ["User"] + +TABLE_ID_A = "table_a" +TABLE_ID_B = "table_b" +TABLE_ID_C = "table_c" +TABLE_ID_X = "table_x" +TEXT = '"Arrival" is a 2016 American science fiction drama film directed by Denis Villeneuve and adapted by Eric Heisserer.' +TEXT_CN = ( + '"Arrival" 《é™ä¸´ã€‹æ˜¯ä¸€éƒ¨ 2016 年美国科幻剧情片,由丹尼斯·维伦纽瓦执导,埃里克·海瑟尔改编。' +) +TEXT_JP = '"Arrival" 「Arrivalã€ã¯ã€ãƒ‰ã‚¥ãƒ‹ãƒ»ãƒ´ã‚£ãƒ«ãƒŒãƒ¼ãƒ´ãŒç›£ç£ã—ã€ã‚¨ãƒªãƒƒã‚¯ãƒ»ãƒã‚¤ã‚»ãƒ©ãƒ¼ãŒè„šè‰²ã—ãŸ2016å¹´ã®ã‚¢ãƒ¡ãƒªã‚«ã®SFドラマ映画ã§ã™ã€‚' + + +@dataclass(slots=True) +class ServingContext: + superuser_id: str + user_id: str + org_id: str + project_id: str + + +@pytest.fixture(scope="module") +def setup(): + """ + Fixture to set up the necessary organization and projects for file tests. + """ + with ( + # Create superuser + create_user() as superuser, + # Create user + create_user({"email": "testuser@example.com", "name": "Test User"}) as user, + # Create organization + create_organization( + body=OrganizationCreate(name="Clubhouse"), user_id=superuser.id + ) as org, + # Create project + create_project(dict(name="Bucket A"), user_id=superuser.id, organization_id=org.id) as p0, + ): + assert superuser.id == "0" + assert org.id == "0" + client = JamAI(user_id=superuser.id) + # Join organization and project + client.organizations.join_organization( + user_id=user.id, organization_id=org.id, role=Role.ADMIN + ) + client.projects.join_project(user_id=user.id, project_id=p0.id, role=Role.ADMIN) + + # Create models + with ( + create_model_config(GPT_4O_MINI_CONFIG), + create_model_config( + { + "id": "openai/Qwen/Qwen-2-Audio-7B", + "type": "llm", + "name": "ELLM Qwen2 Audio (7B)", + "capabilities": ["chat", "audio"], + "context_length": 128000, + "languages": ["en"], + } + ) as llm_config_audio, + create_model_config(ELLM_EMBEDDING_CONFIG), + create_model_config(RERANK_ENGLISH_v3_SMALL_CONFIG), + ): + # Create deployments + with ( + create_deployment(GPT_4O_MINI_DEPLOYMENT), + create_deployment( + DeploymentCreate( + model_id=llm_config_audio.id, + name="ELLM Qwen2 Audio (7B) Deployment", + provider=CloudProvider.ELLM, + routing_id=llm_config_audio.id, + api_base="https://llmci.embeddedllm.com/audio/v1", + ) + ), + create_deployment(ELLM_EMBEDDING_DEPLOYMENT), + create_deployment(RERANK_ENGLISH_v3_SMALL_DEPLOYMENT), + ): + yield ServingContext( + superuser_id=superuser.id, + user_id=user.id, + org_id=org.id, + project_id=p0.id, + ) + + +def _get_chat_model(client: JamAI) -> str: + models = client.model_ids(prefer="openai/gpt-4o-mini", capabilities=["chat"]) + return models[0] + + +def _get_image_models(client: JamAI) -> list[str]: + models = client.model_ids(prefer="openai/gpt-4o-mini", capabilities=["image"]) + return models + + +def _get_chat_only_model(client: JamAI) -> str: + chat_models = client.model_ids(capabilities=["chat"]) + image_models = _get_image_models(client) + chat_only_models = [model for model in chat_models if model not in image_models] + if not chat_only_models: + pytest.skip("No chat-only model available for testing.") + return chat_only_models[0] + + +def _get_reranking_model(client: JamAI) -> str: + models = client.model_ids(capabilities=["rerank"]) + return models[0] + + +@contextmanager +def _create_table( + client: JamAI, + table_type: TableType, + table_id: str = TABLE_ID_A, + cols: list[ColumnSchemaCreate] | None = None, + chat_cols: list[ColumnSchemaCreate] | None = None, + embedding_model: str | None = None, +): + try: + if cols is None: + cols = [ + ColumnSchemaCreate(id="good", dtype="bool"), + ColumnSchemaCreate(id="words", dtype="int"), + ColumnSchemaCreate(id="stars", dtype="float"), + ColumnSchemaCreate(id="inputs", dtype="str"), + ColumnSchemaCreate(id="photo", dtype="image"), + ColumnSchemaCreate( + id="summary", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client), + system_prompt="You are a concise assistant.", + # Interpolate string and non-string input columns + prompt="Summarise this in ${words} words:\n\n${inputs}", + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), + ColumnSchemaCreate( + id="captioning", + dtype="str", + gen_config=LLMGenConfig( + model="", + system_prompt="You are a concise assistant.", + # Interpolate file input column + prompt="${photo} \n\nWhat's in the image?", + temperature=0.001, + top_p=0.001, + max_tokens=300, + ), + ), + ] + if chat_cols is None: + chat_cols = [ + ColumnSchemaCreate(id="User", dtype="str"), + ColumnSchemaCreate( + id="AI", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client), + system_prompt="You are a wacky assistant.", + temperature=0.001, + top_p=0.001, + max_tokens=5, + ), + ), + ] + + if table_type == TableType.ACTION: + table = client.table.create_action_table( + ActionTableSchemaCreate(id=table_id, cols=cols) + ) + elif table_type == TableType.KNOWLEDGE: + if embedding_model is None: + embedding_model = "" + table = client.table.create_knowledge_table( + KnowledgeTableSchemaCreate(id=table_id, cols=cols, embedding_model=embedding_model) + ) + elif table_type == TableType.CHAT: + table = client.table.create_chat_table( + ChatTableSchemaCreate(id=table_id, cols=chat_cols + cols) + ) + else: + raise ValueError(f"Invalid table type: {table_type}") + assert isinstance(table, TableMetaResponse) + yield table + finally: + try: + client.table.delete_table(table_type, table_id) + except ResourceNotFoundError: + pass # Ignore if already deleted + + +@contextmanager +def _create_table_v2( + client: JamAI, + table_type: TableType, + table_id: str = TABLE_ID_A, + cols: list[ColumnSchemaCreate] | None = None, + chat_cols: list[ColumnSchemaCreate] | None = None, + llm_model: str = "", + embedding_model: str = "", + system_prompt: str = "", + prompt: str = "", +) -> Generator[TableMetaResponse, None, None]: + try: + if cols is None: + _input_cols = [ + ColumnSchemaCreate(id=f"in_{dtype}", dtype=dtype) + for dtype in REGULAR_COLUMN_DTYPES + ] + _output_cols = [ + ColumnSchemaCreate( + id=f"out_{dtype}", + dtype=dtype, + gen_config=LLMGenConfig( + model=llm_model, + system_prompt=system_prompt, + prompt=" ".join(f"${{{col.id}}}" for col in _input_cols) + prompt, + max_tokens=10, + ), + ) + for dtype in ["str"] + ] + cols = _input_cols + _output_cols + if chat_cols is None: + chat_cols = [ + ColumnSchemaCreate(id="User", dtype="str"), + ColumnSchemaCreate( + id="AI", + dtype="str", + gen_config=LLMGenConfig( + model=llm_model, + system_prompt=system_prompt, + max_tokens=10, + ), + ), + ] + + expected_cols = {"ID", "Updated at"} + expected_cols |= {c.id for c in cols} + if table_type == TableType.ACTION: + table = client.table.create_action_table( + ActionTableSchemaCreate(id=table_id, cols=cols) + ) + elif table_type == TableType.KNOWLEDGE: + table = client.table.create_knowledge_table( + KnowledgeTableSchemaCreate(id=table_id, cols=cols, embedding_model=embedding_model) + ) + expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} + elif table_type == TableType.CHAT: + table = client.table.create_chat_table( + ChatTableSchemaCreate(id=table_id, cols=chat_cols + cols) + ) + expected_cols |= {c.id for c in chat_cols} + else: + raise ValueError(f"Invalid table type: {table_type}") + assert isinstance(table, TableMetaResponse) + col_ids = set(c.id for c in table.cols) + assert col_ids == expected_cols + yield table + finally: + try: + client.table.delete_table(table_type, table_id) + except Exception: + pass # Ignore if already deleted + + +def _add_row( + client: JamAI, + table_type: TableType, + stream: bool, + table_name: str = TABLE_ID_A, + data: dict | None = None, + knowledge_data: dict | None = None, + chat_data: dict | None = None, +): + if data is None: + # Use a placeholder URI, actual file upload isn't needed for table ops tests + data = dict( + good=True, + words=5, + stars=7.9, + inputs=TEXT, + photo="rabbit.jpeg", + ) + + if knowledge_data is None: + knowledge_data = dict( + Title="Dune: Part Two.", + Text='"Dune: Part Two" is a 2024 American epic science fiction film.', + ) + if chat_data is None: + chat_data = dict(User="Tell me a joke.") + if table_type == TableType.ACTION: + pass + elif table_type == TableType.KNOWLEDGE: + data.update(knowledge_data) + elif table_type == TableType.CHAT: + data.update(chat_data) + else: + raise ValueError(f"Invalid table type: {table_type}") + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest(table_id=table_name, data=[data], stream=stream), + ) + if stream: + # Consume the stream to ensure completion for tests that need data populated + return response + # list(response) + # return None # Streamed responses are handled differently + assert isinstance(response, MultiRowCompletionResponse) + assert len(response.rows) == 1 + return response.rows[0] + + +def _add_row_v2( + client: JamAI, + table_type: TableType, + stream: bool, + table_name: str = TABLE_ID_A, + data: dict | None = None, + knowledge_data: dict | None = None, + chat_data: dict | None = None, + include_output_data: bool = False, +) -> MultiRowCompletionResponse | None: + if data is None: + data = {f"in_{dtype}": SAMPLE_DATA[dtype] for dtype in REGULAR_COLUMN_DTYPES} + if include_output_data: + data.update({f"out_{dtype}": SAMPLE_DATA[dtype] for dtype in ["str"]}) + + if knowledge_data is None: + knowledge_data = dict( + Title="Dune: Part Two.", + Text='"Dune: Part Two" is a 2024 American epic science fiction film.', + ) + if include_output_data: + knowledge_data.update({"Title Embed": None, "Text Embed": None}) + if chat_data is None: + chat_data = dict(User="Tell me a joke.") + if include_output_data: + chat_data.update({"AI": "Nah"}) + if table_type == TableType.ACTION: + pass + elif table_type == TableType.KNOWLEDGE: + data.update(knowledge_data) + elif table_type == TableType.CHAT: + data.update(chat_data) + else: + raise ValueError(f"Invalid table type: {table_type}") + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest(table_id=table_name, data=[data], stream=stream), + ) + if stream: + # Consume the stream + _ = list(response) + # For simplicity in table ops tests, we might not need to reconstruct the full response + return None + assert isinstance(response, MultiRowCompletionResponse) + assert response.object == "gen_table.completion.rows" + assert len(response.rows) == 1 + return response + + +@contextmanager +def _rename_table( + client: JamAI, + table_type: TableType, + table_id_src: str, + table_id_dst: str, +): + try: + table = client.table.rename_table(table_type, table_id_src, table_id_dst) + assert isinstance(table, TableMetaResponse) + yield table + finally: + try: + client.table.delete_table(table_type, table_id_dst) + except ResourceNotFoundError: + pass # Ignore if already deleted + + +@contextmanager +def _duplicate_table( + client: JamAI, + table_type: TableType, + table_id_src: str, + table_id_dst: str, + include_data: bool = True, + create_as_child: bool = False, +): + try: + table = client.table.duplicate_table( + table_type, + table_id_src, + table_id_dst, + include_data=include_data, + create_as_child=create_as_child, + ) + assert isinstance(table, TableMetaResponse) + yield table + finally: + try: + client.table.delete_table(table_type, table_id_dst) + except ResourceNotFoundError: + pass # Ignore if already deleted + + +@contextmanager +def _create_child_table( + client: JamAI, + table_type: TableType, + table_id_src: str, + table_id_dst: str | None, +): + created_id = None + try: + table = client.table.duplicate_table( + table_type, table_id_src, table_id_dst, create_as_child=True + ) + created_id = table.id # Store the actual ID created + assert isinstance(table, TableMetaResponse) + yield table + finally: + if created_id: + try: + client.table.delete_table(table_type, created_id) + except ResourceNotFoundError: + pass # Ignore if already deleted + + +def _collect_text( + responses: MultiRowCompletionResponse | Generator[ChatCompletionResponse, None, None], + col: str, +): + if isinstance(responses, MultiRowCompletionResponse): + # Assuming only one row for simplicity in these tests + if col in responses.rows[0].columns: + return responses.rows[0].columns[col].content + else: + return "" # Column might not exist (e.g., AI in non-chat table) + # Handling stream (simplified for table ops) + content = "" + for r in responses: + if hasattr(r, "output_column_name") and r.output_column_name == col: + content += r.content + return content + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize( + "table_id", ["a", "0", "a.b", "a-b", "a_b", "a-_b", "a-_0b", "a.-_0b", "0_0"] +) +def test_create_table_valid_table_id( + setup: ServingContext, + table_type: TableType, + table_id: str, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type, table_id) as table: + assert isinstance(table, TableMetaResponse) + assert table.id == table_id + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_create_table_valid_column_id( + setup: ServingContext, + table_type: TableType, +): + table_id = TABLE_ID_A + col_ids = ["a", "0", "a b", "a-b", "a_b", "a-_b", "a-_0b", "a -_0b", "0_0"] + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + + # --- Test input column --- # + cols = [ColumnSchemaCreate(id=_id, dtype="str") for _id in col_ids] + with _create_table(client, table_type, table_id, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + created_col_ids = {c.id for c in table.cols if c.id in col_ids} + assert created_col_ids == set(col_ids) + + client.table.delete_table(table_type, table_id) + # --- Test output column --- # + cols = [ + ColumnSchemaCreate( + id=_id, + dtype="str", + gen_config=LLMGenConfig( + model="", + system_prompt="You are a concise assistant.", + prompt="Reply yes", + temperature=0.001, + top_p=0.001, + max_tokens=3, + ), + ) + for _id in col_ids + ] + with _create_table(client, table_type, table_id, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + created_col_ids = {c.id for c in table.cols if c.id in col_ids} + assert created_col_ids == set(col_ids) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize( + "invalid_table_id", ["a_", "_a", "_aa", "aa_", "_a_", "-a", ".a", "a" * 101] +) +def test_create_table_invalid_table_id( + setup: ServingContext, + table_type: TableType, + invalid_table_id: str, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ColumnSchemaCreate(id="valid_col", dtype="str")] + with pytest.raises(BadInputError): + with _create_table(client, table_type, invalid_table_id, cols=cols): + pass + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("column_id", ["a_", "_a", "_aa", "aa_", "_a_", "-a", ".a", "a" * 101]) +def test_create_table_invalid_column_id( + setup: ServingContext, + table_type: TableType, + column_id: str, +): + table_id = TABLE_ID_A + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + + # --- Test input column --- # + cols = [ + ColumnSchemaCreate(id=column_id, dtype="str"), + ] + with pytest.raises(BadInputError): + with _create_table(client, table_type, table_id, cols=cols): + pass + + # --- Test output column --- # + cols = [ + ColumnSchemaCreate( + id=column_id, + dtype="str", + gen_config=LLMGenConfig(), + ), + ] + with pytest.raises(BadInputError): + with _create_table(client, table_type, table_id, cols=cols): + pass + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_create_table_invalid_model( + setup: ServingContext, + table_type: TableType, +): + table_id = TABLE_ID_A + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="input0", dtype="str"), + ColumnSchemaCreate( + id="output0", + dtype="str", + gen_config=LLMGenConfig(model="INVALID_MODEL_ID"), + ), + ] + with pytest.raises(BadInputError): + with _create_table(client, table_type, table_id, cols=cols): + pass + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_create_table_invalid_column_ref( + setup: ServingContext, + table_type: TableType, +): + table_id = TABLE_ID_A + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="input0", dtype="str"), + ColumnSchemaCreate( + id="output0", + dtype="str", + gen_config=LLMGenConfig(prompt="Summarise ${input_non_existent}"), + ), + ] + with pytest.raises(BadInputError): + with _create_table(client, table_type, table_id, cols=cols): + pass + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_create_table_invalid_rag( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + + # Create the knowledge table first + with _create_table(client, TableType.KNOWLEDGE, TABLE_ID_B, cols=[]) as ktable: + # --- Valid knowledge table ID --- # + cols = [ + ColumnSchemaCreate(id="input0", dtype="str"), + ColumnSchemaCreate( + id="output0", + dtype="str", + gen_config=LLMGenConfig( + rag_params=RAGParams(table_id=ktable.id), + ), + ), + ] + # --- Invalid knowledge table ID --- # + cols = [ + ColumnSchemaCreate(id="input0", dtype="str"), + ColumnSchemaCreate( + id="output0", + dtype="str", + gen_config=LLMGenConfig( + rag_params=RAGParams(table_id="INVALID_KT_ID"), + ), + ), + ] + with pytest.raises(BadInputError): + with _create_table(client, table_type, cols=cols): + pass + + # --- Valid reranker --- # + cols = [ + ColumnSchemaCreate(id="input0", dtype="str"), + ColumnSchemaCreate( + id="output0", + dtype="str", + gen_config=LLMGenConfig( + rag_params=RAGParams( + table_id=ktable.id, reranking_model=_get_reranking_model(client) + ), + ), + ), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + + # --- Invalid reranker --- # + cols = [ + ColumnSchemaCreate(id="input0", dtype="str"), + ColumnSchemaCreate( + id="output0", + dtype="str", + gen_config=LLMGenConfig( + rag_params=RAGParams(table_id=ktable.id, reranking_model="INVALID_RERANKER"), + ), + ), + ] + with pytest.raises(BadInputError): + with _create_table(client, table_type, cols=cols): + pass + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_default_llm_model( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="input0", dtype="str"), + ColumnSchemaCreate( + id="output0", + dtype="str", + gen_config=LLMGenConfig(), + ), + ColumnSchemaCreate( + id="output1", + dtype="str", + gen_config=None, + ), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + # Check gen configs + cols_dict = {c.id: c for c in table.cols} + assert isinstance(cols_dict["output0"].gen_config, LLMGenConfig) + assert isinstance(cols_dict["output0"].gen_config.model, str) + assert len(cols_dict["output0"].gen_config.model) > 0 + assert cols_dict["output1"].gen_config is None + if table_type == TableType.CHAT: + assert isinstance(cols_dict["AI"].gen_config, LLMGenConfig) + assert isinstance(cols_dict["AI"].gen_config.model, str) + assert len(cols_dict["AI"].gen_config.model) > 0 + + # --- Update gen config --- # + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=TABLE_ID_A, + column_map=dict( + output0=None, + output1=LLMGenConfig(), + ), + ), + ) + assert isinstance(table, TableMetaResponse) + # Check gen configs + cols_dict = {c.id: c for c in table.cols} + assert cols_dict["output0"].gen_config is None + assert isinstance(cols_dict["output1"].gen_config, LLMGenConfig) + assert isinstance(cols_dict["output1"].gen_config.model, str) + assert len(cols_dict["output1"].gen_config.model) > 0 + if table_type == TableType.CHAT: + assert isinstance(cols_dict["AI"].gen_config, LLMGenConfig) + assert isinstance(cols_dict["AI"].gen_config.model, str) + assert len(cols_dict["AI"].gen_config.model) > 0 + + # --- Add column --- # + add_cols = [ + ColumnSchemaCreate( + id="output2", + dtype="str", + gen_config=None, + ), + ColumnSchemaCreate( + id="output3", + dtype="str", + gen_config=LLMGenConfig(), + ), + ] + if table_type == TableType.ACTION: + table = client.table.add_action_columns( + AddActionColumnSchema(id=table.id, cols=add_cols) + ) + elif table_type == TableType.KNOWLEDGE: + table = client.table.add_knowledge_columns( + AddKnowledgeColumnSchema(id=table.id, cols=add_cols) + ) + elif table_type == TableType.CHAT: + table = client.table.add_chat_columns(AddChatColumnSchema(id=table.id, cols=add_cols)) + else: + raise ValueError(f"Invalid table type: {table_type}") + # Check gen configs + cols_dict = {c.id: c for c in table.cols} + assert cols_dict["output0"].gen_config is None + assert isinstance(cols_dict["output1"].gen_config, LLMGenConfig) + assert isinstance(cols_dict["output1"].gen_config.model, str) + assert len(cols_dict["output1"].gen_config.model) > 0 + assert cols_dict["output2"].gen_config is None + assert isinstance(cols_dict["output3"].gen_config, LLMGenConfig) + assert isinstance(cols_dict["output3"].gen_config.model, str) + assert len(cols_dict["output3"].gen_config.model) > 0 + if table_type == TableType.CHAT: + assert isinstance(cols_dict["AI"].gen_config, LLMGenConfig) + assert isinstance(cols_dict["AI"].gen_config.model, str) + assert len(cols_dict["AI"].gen_config.model) > 0 + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_default_image_model( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + available_image_models = _get_image_models(client) + if not available_image_models: + pytest.skip("No image model available for testing.") + + cols = [ + ColumnSchemaCreate(id="input0", dtype="image"), + ColumnSchemaCreate( + id="output0", + dtype="str", + gen_config=LLMGenConfig(prompt="${input0}"), + ), + ColumnSchemaCreate( + id="output1", + dtype="str", + gen_config=None, + ), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + # Check gen configs + cols_dict = {c.id: c for c in table.cols} + assert isinstance(cols_dict["output0"].gen_config, LLMGenConfig) + assert isinstance(cols_dict["output0"].gen_config.model, str) + assert cols_dict["output0"].gen_config.model in available_image_models + assert cols_dict["output1"].gen_config is None + if table_type == TableType.CHAT: + assert isinstance(cols_dict["AI"].gen_config, LLMGenConfig) + assert isinstance(cols_dict["AI"].gen_config.model, str) + # Default AI model might not be an image model if not needed + # assert cols_dict["AI"].gen_config.model in available_image_models + + # --- Update gen config --- # + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=TABLE_ID_A, + column_map=dict( + output0=None, + output1=LLMGenConfig(prompt="${input0}"), + ), + ), + ) + assert isinstance(table, TableMetaResponse) + # Check gen configs + cols_dict = {c.id: c for c in table.cols} + assert cols_dict["output0"].gen_config is None + assert isinstance(cols_dict["output1"].gen_config, LLMGenConfig) + assert isinstance(cols_dict["output1"].gen_config.model, str) + assert cols_dict["output1"].gen_config.model in available_image_models + + # --- Add column --- # + add_cols_1 = [ + ColumnSchemaCreate( + id="output2", + dtype="str", + gen_config=LLMGenConfig(prompt="${input0}"), + ), + ColumnSchemaCreate(id="file_input1", dtype="image"), + ColumnSchemaCreate( + id="output3", + dtype="str", + gen_config=LLMGenConfig(prompt="${file_input1}"), + ), + ] + if table_type == TableType.ACTION: + table = client.table.add_action_columns( + AddActionColumnSchema(id=table.id, cols=add_cols_1) + ) + elif table_type == TableType.KNOWLEDGE: + table = client.table.add_knowledge_columns( + AddKnowledgeColumnSchema(id=table.id, cols=add_cols_1) + ) + elif table_type == TableType.CHAT: + table = client.table.add_chat_columns( + AddChatColumnSchema(id=table.id, cols=add_cols_1) + ) + else: + raise ValueError(f"Invalid table type: {table_type}") + + # Add a column with default prompt (should pick image model if image inputs exist) + add_cols_2 = [ + ColumnSchemaCreate( + id="output4", + dtype="str", + gen_config=LLMGenConfig(), + ), + ] + if table_type == TableType.ACTION: + table = client.table.add_action_columns( + AddActionColumnSchema(id=table.id, cols=add_cols_2) + ) + elif table_type == TableType.KNOWLEDGE: + table = client.table.add_knowledge_columns( + AddKnowledgeColumnSchema(id=table.id, cols=add_cols_2) + ) + elif table_type == TableType.CHAT: + table = client.table.add_chat_columns( + AddChatColumnSchema(id=table.id, cols=add_cols_2) + ) + else: + raise ValueError(f"Invalid table type: {table_type}") + + # Check gen configs + cols_dict = {c.id: c for c in table.cols} + assert cols_dict["output0"].gen_config is None + for output_column_name in ["output1", "output2", "output3", "output4"]: + assert isinstance(cols_dict[output_column_name].gen_config, LLMGenConfig) + model = cols_dict[output_column_name].gen_config.model + assert isinstance(model, str) + assert model in available_image_models, ( + f'Column {output_column_name} has invalid default model "{model}". Valid: {available_image_models}' + ) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_invalid_image_model( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + available_image_models = _get_image_models(client) + if not available_image_models: + pytest.skip("No image model available for testing.") + try: + chat_only_model = _get_chat_only_model(client) + except IndexError: + pytest.skip("No chat-only model available for testing.") + + cols = [ + ColumnSchemaCreate(id="input0", dtype="image"), + ColumnSchemaCreate( + id="output0", + dtype="str", + gen_config=LLMGenConfig(model=chat_only_model, prompt="${input0}"), + ), + ] + with pytest.raises(BadInputError): + with _create_table(client, table_type, cols=cols): + pass + + cols_valid = [ + ColumnSchemaCreate(id="input0", dtype="image"), + ColumnSchemaCreate( + id="output0", + dtype="str", + gen_config=LLMGenConfig(prompt="${input0}"), + ), + ] + with _create_table(client, table_type, cols=cols_valid) as table: + assert isinstance(table, TableMetaResponse) + # Check gen configs + cols_dict = {c.id: c for c in table.cols} + assert isinstance(cols_dict["output0"].gen_config, LLMGenConfig) + assert isinstance(cols_dict["output0"].gen_config.model, str) + assert cols_dict["output0"].gen_config.model in available_image_models + + # --- Update gen config --- # + with pytest.raises(BadInputError): + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=TABLE_ID_A, + column_map=dict( + output0=LLMGenConfig( + model=chat_only_model, + prompt="${input0}", + ), + ), + ), + ) + # Ensure update with valid model works + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=TABLE_ID_A, + column_map=dict( + output0=LLMGenConfig(prompt="${input0}"), + ), + ), + ) + assert isinstance(table, TableMetaResponse) + # Check gen configs + cols_dict = {c.id: c for c in table.cols} + assert isinstance(cols_dict["output0"].gen_config, LLMGenConfig) + assert isinstance(cols_dict["output0"].gen_config.model, str) + assert cols_dict["output0"].gen_config.model in available_image_models + + # --- Add column --- # + add_cols = [ + ColumnSchemaCreate( + id="output1", + dtype="str", + gen_config=LLMGenConfig(model=chat_only_model, prompt="${input0}"), + ) + ] + with pytest.raises(BadInputError): + if table_type == TableType.ACTION: + table = client.table.add_action_columns( + AddActionColumnSchema(id=table.id, cols=add_cols) + ) + elif table_type == TableType.KNOWLEDGE: + table = client.table.add_knowledge_columns( + AddKnowledgeColumnSchema(id=table.id, cols=add_cols) + ) + elif table_type == TableType.CHAT: + table = client.table.add_chat_columns( + AddChatColumnSchema(id=table.id, cols=add_cols) + ) + else: + raise ValueError(f"Invalid table type: {table_type}") + + +def test_default_embedding_model( + setup: ServingContext, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, TableType.KNOWLEDGE, cols=[], embedding_model="") as table: + assert isinstance(table, TableMetaResponse) + for col in table.cols: + if col.vlen == 0: + continue + assert len(col.gen_config.embedding_model) > 0 + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_default_reranker( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + # Create the knowledge table first + with _create_table(client, TableType.KNOWLEDGE, TABLE_ID_B, cols=[]) as ktable: + cols = [ + ColumnSchemaCreate(id="input0", dtype="str"), + ColumnSchemaCreate( + id="output0", + dtype="str", + gen_config=LLMGenConfig( + rag_params=RAGParams(table_id=ktable.id, reranking_model=""), + ), + ), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + cols_dict = {c.id: c for c in table.cols} + rag_params = cols_dict["output0"].gen_config.rag_params + assert isinstance(rag_params, RAGParams) + reranking_model = rag_params.reranking_model + assert isinstance(reranking_model, str) + assert len(reranking_model) > 0 + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_default_prompts( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="input0", dtype="str"), + ColumnSchemaCreate(id="input1", dtype="str"), + ColumnSchemaCreate( + id="output0", + dtype="str", + gen_config=LLMGenConfig(), # Empty gen_config to trigger defaults + ), + ColumnSchemaCreate( + id="output1", + dtype="str", + gen_config=LLMGenConfig(), # Empty gen_config to trigger defaults + ), + ColumnSchemaCreate( + id="output2", + dtype="str", + gen_config=LLMGenConfig( + system_prompt="You are an assistant.", + prompt="Summarise ${input0}.", + ), + ), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + # Define expected input columns based on table type + input_cols_set = {"input0", "input1"} + if table_type == TableType.KNOWLEDGE: + input_cols_set |= {"Title", "Text", "File ID", "Page"} + elif table_type == TableType.CHAT: + input_cols_set |= {"User"} + + cols_dict = {c.id: c for c in table.cols} + + # Check ["output0", "output1"] for default prompts referencing all inputs + for col_id in ["output0", "output1"]: + gen_config = cols_dict[col_id].gen_config + assert isinstance(gen_config, LLMGenConfig) + assert isinstance(gen_config.prompt, str) + referenced_cols = set(re.findall(r"\$\{(\w+(?:\s\w+)*)\}", gen_config.prompt)) + # Default prompt should reference all non-ID, non-updated_at, non-output, non-vector columns + expected_referenced = { + c.id + for c in table.cols + if c.id not in ("ID", "Updated at") + and c.gen_config is None + and "Embed" not in c.id + } + assert referenced_cols == expected_referenced, ( + f"Col {col_id}: Expected {expected_referenced}, got {referenced_cols}" + ) + + # Check ["output2"] for provided prompts + gen_config_2 = cols_dict["output2"].gen_config + assert isinstance(gen_config_2, LLMGenConfig) + assert gen_config_2.system_prompt == "You are an assistant." + assert gen_config_2.prompt == "Summarise ${input0}." + referenced_cols_2 = set(re.findall(r"\$\{(\w+(?:\s\w+)*)\}", gen_config_2.prompt)) + assert referenced_cols_2 == {"input0"} + + # --- Add column --- # + add_cols = [ + ColumnSchemaCreate( + id="input2", + dtype="int", + ), + ColumnSchemaCreate( + id="output3", + dtype="str", + gen_config=LLMGenConfig(), # Trigger default prompt + ), + ] + if table_type == TableType.ACTION: + table = client.table.add_action_columns( + AddActionColumnSchema(id=table.id, cols=add_cols) + ) + elif table_type == TableType.KNOWLEDGE: + table = client.table.add_knowledge_columns( + AddKnowledgeColumnSchema(id=table.id, cols=add_cols) + ) + elif table_type == TableType.CHAT: + table = client.table.add_chat_columns(AddChatColumnSchema(id=table.id, cols=add_cols)) + else: + raise ValueError(f"Invalid table type: {table_type}") + assert isinstance(table, TableMetaResponse) + + cols_dict = {c.id: c for c in table.cols} + + # Check ["output3"] for default prompt referencing all *current* inputs + gen_config_3 = cols_dict["output3"].gen_config + assert isinstance(gen_config_3, LLMGenConfig) + assert isinstance(gen_config_3.prompt, str) + referenced_cols_3 = set(re.findall(r"\$\{(\w+(?:\s\w+)*)\}", gen_config_3.prompt)) + expected_referenced_3 = { + c.id + for c in table.cols + if c.id not in ("ID", "Updated at") and c.gen_config is None and "Embed" not in c.id + } + assert referenced_cols_3 == expected_referenced_3, ( + f"Col output3: Expected {expected_referenced_3}, got {referenced_cols_3}" + ) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_add_drop_columns( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table_v2(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + _add_row_v2( + client, + table_type, + stream=False, + include_output_data=False, + ) + + # --- COLUMN ADD --- # + _input_cols = [ + ColumnSchemaCreate(id=f"add_in_{dtype}", dtype=dtype) + for dtype in REGULAR_COLUMN_DTYPES + ] + _output_cols = [ + ColumnSchemaCreate( + id=f"add_out_{dtype}", + dtype=dtype, + gen_config=LLMGenConfig( + model="", + system_prompt="", + prompt=" ".join(f"${{{col.id}}}" for col in _input_cols), + max_tokens=10, + ), + ) + for dtype in ["str"] + ] + cols = _input_cols + _output_cols + expected_cols = {"ID", "Updated at"} + expected_cols |= {f"in_{dtype}" for dtype in REGULAR_COLUMN_DTYPES} + expected_cols |= {f"out_{dtype}" for dtype in ["str"]} + expected_cols |= {f"add_in_{dtype}" for dtype in REGULAR_COLUMN_DTYPES} + expected_cols |= {f"add_out_{dtype}" for dtype in ["str"]} + if table_type == TableType.ACTION: + table = client.table.add_action_columns(AddActionColumnSchema(id=table.id, cols=cols)) + elif table_type == TableType.KNOWLEDGE: + table = client.table.add_knowledge_columns( + AddKnowledgeColumnSchema(id=table.id, cols=cols) + ) + expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} + elif table_type == TableType.CHAT: + expected_cols |= {"User", "AI"} + table = client.table.add_chat_columns(AddChatColumnSchema(id=table.id, cols=cols)) + else: + raise ValueError(f"Invalid table type: {table_type}") + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + cols = set(c.id for c in table.cols) + assert cols == expected_cols, cols + # Existing row of new columns should contain None + rows = list_table_rows(client, table_type, table.id) + assert all(set(r.keys()) == expected_cols for r in rows.items) + assert len(rows.items) == 1 + row = rows.values[0] + for col_id, col in row.items(): + if not col_id.startswith("add_"): + continue + assert col is None + # Test adding a new row + data = {} + for dtype in REGULAR_COLUMN_DTYPES: + data[f"in_{dtype}"] = SAMPLE_DATA[dtype] + data[f"out_{dtype}"] = SAMPLE_DATA[dtype] + data[f"add_in_{dtype}"] = SAMPLE_DATA[dtype] + data[f"add_out_{dtype}"] = SAMPLE_DATA[dtype] + _add_row_v2(client, table_type, False, data=data) + rows = list_table_rows(client, table_type, table.id) + assert all(set(r.keys()) == expected_cols for r in rows.items) + assert len(rows.items) == 2 + row = rows.values[-1] + for col_id, col in row.items(): + if not col_id.startswith("add_"): + continue + assert col is not None + + # --- COLUMN DROP --- # + table = client.table.drop_columns( + table_type, + ColumnDropRequest( + table_id=table.id, + column_names=[f"in_{dtype}" for dtype in REGULAR_COLUMN_DTYPES] + + [f"out_{dtype}" for dtype in ["str"]], + ), + ) + expected_cols = {"ID", "Updated at"} + expected_cols |= {f"add_in_{dtype}" for dtype in REGULAR_COLUMN_DTYPES} + expected_cols |= {f"add_out_{dtype}" for dtype in ["str"]} + if table_type == TableType.ACTION: + pass + elif table_type == TableType.KNOWLEDGE: + expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} + elif table_type == TableType.CHAT: + expected_cols |= {"User", "AI"} + else: + raise ValueError(f"Invalid table type: {table_type}") + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + cols = set(c.id for c in table.cols) + assert cols == expected_cols, cols + rows = list_table_rows(client, table_type, table.id) + assert len(rows.items) == 2 + assert all(set(r.keys()) == expected_cols for r in rows.items) + # Test adding a new row + _add_row_v2(client, table_type, False, data=data) + rows = list_table_rows(client, table_type, table.id) + assert len(rows.items) == 3 + assert all(set(r.keys()) == expected_cols for r in rows.items), [ + list(r.keys()) for r in rows.items + ] + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_add_drop_file_column( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table_v2(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + _add_row_v2( + client, + table_type, + stream=False, + include_output_data=False, + ) + + # --- COLUMN ADD --- # + cols = [ + ColumnSchemaCreate(id="add_in_file", dtype="image"), + ColumnSchemaCreate( + id="add_out_str", + dtype="str", + gen_config=LLMGenConfig( + model="", + system_prompt="", + prompt="Describe image ${add_in_file}", + max_tokens=10, + ), + ), + ] + expected_cols = {"ID", "Updated at", "add_in_file", "add_out_str"} + expected_cols |= {f"in_{dtype}" for dtype in REGULAR_COLUMN_DTYPES} + expected_cols |= {f"out_{dtype}" for dtype in ["str"]} + if table_type == TableType.ACTION: + table = client.table.add_action_columns(AddActionColumnSchema(id=table.id, cols=cols)) + elif table_type == TableType.KNOWLEDGE: + table = client.table.add_knowledge_columns( + AddKnowledgeColumnSchema(id=table.id, cols=cols) + ) + expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} + elif table_type == TableType.CHAT: + expected_cols |= {"User", "AI"} + table = client.table.add_chat_columns(AddChatColumnSchema(id=table.id, cols=cols)) + else: + raise ValueError(f"Invalid table type: {table_type}") + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + cols = set(c.id for c in table.cols) + assert cols == expected_cols, cols + # Existing row of new columns should contain None + rows = list_table_rows(client, table_type, table.id) + assert all(set(r.keys()) == expected_cols for r in rows.items) + assert len(rows.items) == 1 + row = rows.values[0] + for col_id, col in row.items(): + if not col_id.startswith("add_"): + continue + assert col is None + # Test adding a new row + upload_response = upload_file(client, FILES["rabbit.jpeg"]) + data = {"add_in_file": upload_response.uri} + for dtype in REGULAR_COLUMN_DTYPES: + data[f"in_{dtype}"] = SAMPLE_DATA[dtype] + response = _add_row_v2(client, table_type, False, data=data) + assert len(response.rows[0].columns["add_out_str"].content) > 0 + rows = list_table_rows(client, table_type, table.id) + assert all(set(r.keys()) == expected_cols for r in rows.items) + assert len(rows.items) == 2 + row = rows.values[-1] + for col_id, col in row.items(): + if not col_id.startswith("add_in_"): + continue + assert col is not None + + # Block file output column + with pytest.raises(BadInputError): + cols = [ + ColumnSchemaCreate( + id="add_out_file", + dtype="image", + gen_config=LLMGenConfig( + model="", + system_prompt="", + prompt="Describe image ${add_in_file}", + max_tokens=10, + ), + ), + ] + if table_type == TableType.ACTION: + client.table.add_action_columns(AddActionColumnSchema(id=table.id, cols=cols)) + elif table_type == TableType.KNOWLEDGE: + client.table.add_knowledge_columns( + AddKnowledgeColumnSchema(id=table.id, cols=cols) + ) + elif table_type == TableType.CHAT: + client.table.add_chat_columns(AddChatColumnSchema(id=table.id, cols=cols)) + else: + raise ValueError(f"Invalid table type: {table_type}") + + # --- COLUMN DROP --- # + table = client.table.drop_columns( + table_type, + ColumnDropRequest( + table_id=table.id, + column_names=[f"in_{dtype}" for dtype in REGULAR_COLUMN_DTYPES] + + [f"out_{dtype}" for dtype in ["str"]], + ), + ) + expected_cols = {"ID", "Updated at", "add_in_file", "add_out_str"} + if table_type == TableType.ACTION: + pass + elif table_type == TableType.KNOWLEDGE: + expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} + elif table_type == TableType.CHAT: + expected_cols |= {"User", "AI"} + else: + raise ValueError(f"Invalid table type: {table_type}") + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + cols = set(c.id for c in table.cols) + assert cols == expected_cols, cols + rows = list_table_rows(client, table_type, table.id) + assert len(rows.items) == 2 + assert all(set(r.keys()) == expected_cols for r in rows.items) + # Test adding a new row + _add_row_v2(client, table_type, False, data={"add_in_file": upload_response.uri}) + rows = list_table_rows(client, table_type, table.id) + assert len(rows.items) == 3 + assert all(set(r.keys()) == expected_cols for r in rows.items), [ + list(r.keys()) for r in rows.items + ] + + +def test_kt_drop_invalid_columns(setup: ServingContext): + table_type = "knowledge" + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + for col in KT_FIXED_COLUMN_IDS: + with pytest.raises(BadInputError): + client.table.drop_columns( + table_type, + ColumnDropRequest(table_id=table.id, column_names=[col]), + ) + + +def test_ct_drop_invalid_columns(setup: ServingContext): + table_type = "chat" + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + for col in CT_FIXED_COLUMN_IDS: + with pytest.raises(BadInputError): + client.table.drop_columns( + table_type, + ColumnDropRequest(table_id=table.id, column_names=[col]), + ) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_rename_columns( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="x", dtype="str"), + ColumnSchemaCreate( + id="y", + dtype="str", + gen_config=LLMGenConfig(prompt=r"Summarise ${x}, \${x}"), + ), + ] + with _create_table(client, table_type, cols=cols) as table: + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + # Test rename on empty table + table = client.table.rename_columns( + table_type, + ColumnRenameRequest(table_id=table.id, column_map=dict(y="z")), + ) + assert isinstance(table, TableMetaResponse) + expected_cols = {"ID", "Updated at", "x", "z"} + if table_type == TableType.ACTION: + pass + elif table_type == TableType.KNOWLEDGE: + expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} + elif table_type == TableType.CHAT: + expected_cols |= {"User", "AI"} + else: + raise ValueError(f"Invalid table type: {table_type}") + cols = set(c.id for c in table.cols) + assert cols == expected_cols + + table = client.table.get_table(table_type, table.id) + assert isinstance(table, TableMetaResponse) + cols = set(c.id for c in table.cols) + assert cols == expected_cols + # Test adding data with new column names + _add_row(client, table_type, False, data=dict(x="True", z="")) + # Test rename table with data + # Test also auto gen config reference update + table = client.table.rename_columns( + table_type, + ColumnRenameRequest(table_id=table.id, column_map=dict(x="a")), + ) + assert isinstance(table, TableMetaResponse) + expected_cols = {"ID", "Updated at", "a", "z"} + if table_type == TableType.ACTION: + pass + elif table_type == TableType.KNOWLEDGE: + expected_cols |= {"Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"} + elif table_type == TableType.CHAT: + expected_cols |= {"User", "AI"} + else: + raise ValueError(f"Invalid table type: {table_type}") + cols = set(c.id for c in table.cols) + assert cols == expected_cols + table = client.table.get_table(table_type, table.id) + assert isinstance(table, TableMetaResponse) + cols = set(c.id for c in table.cols) + assert cols == expected_cols + # Test auto gen config reference update + cols = {c.id: c for c in table.cols} + prompt = cols["z"].gen_config.prompt + assert "${a}" in prompt + assert "\\${x}" in prompt # Escaped reference syntax + + # Repeated new column names + with pytest.raises(ResourceExistsError): + client.table.rename_columns( + table_type, + ColumnRenameRequest(table_id=table.id, column_map=dict(a="b", z="b")), + ) + # Rename to existing column name + with pytest.raises(ResourceExistsError): + client.table.rename_columns( + table_type, + ColumnRenameRequest(table_id=table.id, column_map=dict(z="a")), + ) + # Overlapping new and old column names is OK depending on rename order + client.table.rename_columns( + table_type, + ColumnRenameRequest(table_id=table.id, column_map=dict(a="b", z="a")), + ) + table = client.table.get_table(table_type, table.id) + assert isinstance(table, TableMetaResponse) + cols = set(c.id for c in table.cols) + assert len({"ID", "Updated at", "b", "a"} - cols) == 0 + + +def test_kt_rename_invalid_columns(setup: ServingContext): + table_type = "knowledge" + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + for col in KT_FIXED_COLUMN_IDS: + with pytest.raises(BadInputError): + client.table.rename_columns( + table_type, + ColumnRenameRequest(table_id=table.id, column_map={col: col}), + ) + + +def test_ct_rename_invalid_columns(setup: ServingContext): + table_type = "chat" + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + for col in CT_FIXED_COLUMN_IDS: + with pytest.raises(BadInputError): + client.table.rename_columns( + table_type, + ColumnRenameRequest(table_id=table.id, column_map={col: col}), + ) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_reorder_columns( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + table = client.table.get_table(table_type, TABLE_ID_A) + assert isinstance(table, TableMetaResponse) + + column_names = [ + "ID", + "Updated at", + "inputs", + "good", + "words", + "stars", + "photo", + "summary", + "captioning", + ] + expected_order = [ + "ID", + "Updated at", + "good", + "words", + "stars", + "inputs", + "photo", + "summary", + "captioning", + ] + if table_type == TableType.ACTION: + pass + elif table_type == TableType.KNOWLEDGE: + column_names += ["Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"] + expected_order = ( + expected_order[:2] + + ["Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"] + + expected_order[2:] + ) + elif table_type == TableType.CHAT: + column_names += ["User", "AI"] + expected_order = expected_order[:2] + ["User", "AI"] + expected_order[2:] + else: + raise ValueError(f"Invalid table type: {table_type}") + cols = [c.id for c in table.cols] + assert cols == expected_order, cols + # Test reorder empty table + table = client.table.reorder_columns( + table_type, + ColumnReorderRequest(table_id=TABLE_ID_A, column_names=column_names), + ) + expected_order = [ + "ID", + "Updated at", + "inputs", + "good", + "words", + "stars", + "photo", + "summary", + "captioning", + ] + if table_type == TableType.ACTION: + pass + elif table_type == TableType.KNOWLEDGE: + expected_order += ["Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"] + elif table_type == TableType.CHAT: + expected_order += ["User", "AI"] + else: + raise ValueError(f"Invalid table type: {table_type}") + cols = [c.id for c in table.cols] + assert cols == expected_order, cols + table = client.table.get_table(table_type, TABLE_ID_A) + assert isinstance(table, TableMetaResponse) + cols = [c.id for c in table.cols] + assert cols == expected_order, cols + # Test add row + response = _add_row( + client, + table_type, + True, + data=dict(good=True, words=5, stars=9.9, inputs=TEXT), + ) + summary = _collect_text(list(response), "summary") + assert len(summary) > 0 + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_reorder_columns_invalid( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + assert all(isinstance(c, ColumnSchema) for c in table.cols) + table = client.table.get_table(table_type, TABLE_ID_A) + assert isinstance(table, TableMetaResponse) + + column_names = [ + "ID", + "Updated at", + "inputs", + "good", + "words", + "stars", + "photo", + "summary", + "captioning", + ] + expected_order = [ + "ID", + "Updated at", + "good", + "words", + "stars", + "inputs", + "photo", + "summary", + "captioning", + ] + if table_type == TableType.ACTION: + pass + elif table_type == TableType.KNOWLEDGE: + column_names += ["Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"] + expected_order = ( + expected_order[:2] + + ["Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"] + + expected_order[2:] + ) + elif table_type == TableType.CHAT: + column_names += ["User", "AI"] + expected_order = expected_order[:2] + ["User", "AI"] + expected_order[2:] + else: + raise ValueError(f"Invalid table type: {table_type}") + cols = [c.id for c in table.cols] + assert cols == expected_order, cols + + # --- Test validation by putting "summary" on the left of "words" --- # + column_names = [ + "ID", + "Updated at", + "inputs", + "good", + "stars", + "summary", + "words", + "photo", + "captioning", + ] + if table_type == TableType.ACTION: + pass + elif table_type == TableType.KNOWLEDGE: + column_names += ["Title", "Title Embed", "Text", "Text Embed", "File ID", "Page"] + elif table_type == TableType.CHAT: + column_names += ["User", "AI"] + else: + raise ValueError(f"Invalid table type: {table_type}") + with pytest.raises(BadInputError): + client.table.reorder_columns( + table_type, + ColumnReorderRequest(table_id=TABLE_ID_A, column_names=column_names), + ) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) +def test_null_gen_config( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest(table_id=table.id, column_map=dict(summary=None)), + ) + response = _add_row( + client, table_type, stream, data=dict(good=True, words=5, stars=9.9, inputs=TEXT) + ) + if stream: + # Must wait until stream ends + responses = [r for r in response] + assert all(isinstance(r, CellCompletionResponse) for r in responses) + else: + assert isinstance(response, RowCompletionResponse) + rows = list_table_rows(client, table_type, table.id) + assert len(rows.items) == 1 + row = rows.values[0] + assert row["summary"] is None + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_invalid_referenced_column( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + # --- Non-existent column --- # + cols = [ + ColumnSchemaCreate(id="words", dtype="int"), + ColumnSchemaCreate( + id="summary", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client), + system_prompt="You are a concise assistant.", + prompt="Summarise ${inputs}", + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), + ] + with pytest.raises(BadInputError): + with _create_table(client, table_type, cols=cols): + pass + + # --- Vector column --- # + cols = [ + ColumnSchemaCreate(id="words", dtype="int"), + ColumnSchemaCreate( + id="summary", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client), + system_prompt="You are a concise assistant.", + prompt="Summarise ${Text Embed}", + temperature=0.001, + top_p=0.001, + max_tokens=10, + ).model_dump(), + ), + ] + with pytest.raises(BadInputError): + with _create_table(client, table_type, cols=cols): + pass + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("stream", [True, False], ids=["stream", "non-stream"]) +def test_gen_config_empty_prompts( + setup: ServingContext, + table_type: TableType, + stream: bool, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="words", dtype="int"), + ColumnSchemaCreate( + id="summary", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client), + temperature=0.001, + top_p=0.001, + max_tokens=10, + ), + ), + ] + chat_cols = [ + ColumnSchemaCreate(id="User", dtype="str"), + ColumnSchemaCreate( + id="AI", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client), + temperature=0.001, + top_p=0.001, + max_tokens=5, + ), + ), + ] + with _create_table(client, table_type, cols=cols, chat_cols=chat_cols) as table: + assert isinstance(table, TableMetaResponse) + data = dict(words=5) + if table_type == TableType.KNOWLEDGE: + data["Title"] = "Dune: Part Two." + data["Text"] = "Dune: Part Two is a 2024 American epic science fiction film." + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest(table_id=table.id, data=[data], stream=stream), + ) + if stream: + # Must wait until stream ends + responses = [r for r in response] + assert all(isinstance(r, CellCompletionResponse) for r in responses) + summary = "".join(r.content for r in responses if r.output_column_name == "summary") + assert len(summary) > 0 + if table_type == TableType.CHAT: + ai = "".join(r.content for r in responses if r.output_column_name == "AI") + assert len(ai) > 0 + else: + assert isinstance(response.rows[0], RowCompletionResponse) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_table_search_and_parent_id( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + # _delete_tables(client) + with ( + _create_table(client, table_type, "beast") as table, + _create_table(client, table_type, "feast"), + _create_table(client, table_type, "bear"), + _create_table(client, table_type, "fear"), + ): + assert isinstance(table, TableMetaResponse) + with ( + _create_child_table(client, table_type, "beast", "least"), + _create_child_table(client, table_type, "beast", "lease"), + _create_child_table(client, table_type, "beast", "yeast"), + ): + # Regular list + tables = client.table.list_tables(table_type, limit=3) + assert isinstance(tables.items, list) + assert tables.total == 7 + assert tables.offset == 0 + assert tables.limit == 3 + assert len(tables.items) == 3 + assert all(isinstance(r, TableMetaResponse) for r in tables.items) + # Search + tables = client.table.list_tables(table_type, search_query="be", limit=3) + assert isinstance(tables.items, list) + assert tables.total == 2 + assert tables.offset == 0 + assert tables.limit == 3 + assert len(tables.items) == 2 + assert all(isinstance(r, TableMetaResponse) for r in tables.items) + # Search + tables = client.table.list_tables(table_type, search_query="ast", limit=3) + assert isinstance(tables.items, list) + assert tables.total == 4 + assert tables.offset == 0 + assert tables.limit == 3 + assert len(tables.items) == 3 + assert all(isinstance(r, TableMetaResponse) for r in tables.items) + # Search with parent ID + tables = client.table.list_tables(table_type, search_query="ast", parent_id="beast") + assert isinstance(tables.items, list) + assert tables.total == 2 + assert tables.offset == 0 + assert tables.limit == 100 + assert len(tables.items) == 2 + assert all(isinstance(r, TableMetaResponse) for r in tables.items) + # Search with parent ID + tables = client.table.list_tables(table_type, search_query="as", parent_id="beast") + assert isinstance(tables.items, list) + assert tables.total == 3 + assert tables.offset == 0 + assert tables.limit == 100 + assert len(tables.items) == 3 + assert all(isinstance(r, TableMetaResponse) for r in tables.items) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_duplicate_table( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + _add_row( + client, + table_type, + False, + data=dict(good=True, words=5, stars=9.9, inputs=TEXT, summary=""), + ) + + # Duplicate with data + with _duplicate_table(client, table_type, TABLE_ID_A, TABLE_ID_B) as table: + # Add another to table A + _add_row( + client, + table_type, + False, + table_name=TABLE_ID_A, + data=dict(good=True, words=5, stars=9.9, inputs=TEXT, summary=""), + ) + assert table.id == TABLE_ID_B + rows = list_table_rows(client, table_type, TABLE_ID_B) + assert len(rows.items) == 1 + + # Duplicate without data + with _duplicate_table( + client, table_type, TABLE_ID_A, TABLE_ID_C, include_data=False + ) as table: + assert table.id == TABLE_ID_C + rows = list_table_rows(client, table_type, TABLE_ID_C) + assert len(rows.items) == 0 + + # # Deploy with data + # with _duplicate_table(client, table_type, TABLE_ID_A, TABLE_ID_B, deploy=True) as table: + # assert table.id == TABLE_ID_B + # assert table.parent_id == TABLE_ID_A + # rows = list_table_rows(client,table_type, TABLE_ID_B) + # assert len(rows.items) == 2 + + # # Deploy will always include data + # with _duplicate_table( + # client, table_type, TABLE_ID_A, TABLE_ID_C, deploy=True, include_data=False + # ) as table: + # assert table.id == TABLE_ID_C + # assert table.parent_id == TABLE_ID_A + # rows = list_table_rows(client,table_type, TABLE_ID_C) + # assert len(rows.items) == 2 + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize("table_id_dst", ["Y", None]) +@pytest.mark.parametrize("include_data", [True, False]) +@pytest.mark.parametrize("create_as_child", [True, False]) +def test_duplicate_table_nonexistent( + setup: ServingContext, + table_type: TableType, + table_id_dst: str | None, + include_data: bool, + create_as_child: bool, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with pytest.raises(ResourceNotFoundError): + client.table.duplicate_table( + table_type, + "X", + table_id_dst, + include_data=include_data, + create_as_child=create_as_child, + ) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +@pytest.mark.parametrize( + "table_id_dst", + ["a_", "_a", "_aa", "aa_", "_a_", "-a", ".a", "a" * 101], +) +def test_duplicate_table_invalid_name( + setup: ServingContext, + table_type: TableType, + table_id_dst: str, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + _add_row( + client, + table_type, + False, + data=dict(good=True, words=5, stars=9.9, inputs=TEXT, summary=""), + ) + + with pytest.raises(BadInputError): + with _duplicate_table(client, table_type, TABLE_ID_A, table_id_dst): + pass + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_create_child_table( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type) as table_a: + assert isinstance(table_a, TableMetaResponse) + _add_row( + client, + table_type, + False, + data=dict(good=True, words=5, stars=9.9, inputs=TEXT, summary=""), + ) + # Duplicate with data + with _create_child_table(client, table_type, TABLE_ID_A, TABLE_ID_B) as table_b: + assert isinstance(table_b, TableMetaResponse) + # Add another to table A + _add_row( + client, + table_type, + False, + data=dict(good=True, words=5, stars=9.9, inputs=TEXT, summary=""), + ) + assert table_b.id == TABLE_ID_B + # Ensure the the parent id meta data has been correctly set. + assert table_b.parent_id == TABLE_ID_A + rows = list_table_rows(client, table_type, TABLE_ID_B) + assert len(rows.items) == 1 + + # Create child table with no dst id + with _create_child_table(client, table_type, TABLE_ID_A, None) as table_c: + assert isinstance(table_c.id, str) + assert table_c.id.startswith(TABLE_ID_A) + assert table_c.id != TABLE_ID_A + # Ensure the the parent id meta data has been correctly set. + assert table_c.parent_id == TABLE_ID_A + rows = list_table_rows(client, table_type, table_c.id) + assert len(rows.items) == 2 + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_rename_table( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + with _create_table(client, table_type, TABLE_ID_A) as table: + assert isinstance(table, TableMetaResponse) + _add_row( + client, + table_type, + False, + data=dict(good=True, words=5, stars=9.9, inputs=TEXT, summary=""), + ) + # Create child table + with _create_child_table(client, table_type, TABLE_ID_A, TABLE_ID_B) as child: + assert isinstance(child, TableMetaResponse) + # Rename + with _rename_table(client, table_type, TABLE_ID_A, TABLE_ID_C) as table: + rows = list_table_rows(client, table_type, TABLE_ID_C) + assert len(rows.items) == 1 + # Assert the old table is gone + with pytest.raises(ResourceNotFoundError): + list_table_rows(client, table_type, TABLE_ID_A) + # Assert the child table parent ID is updated + assert client.table.get_table(table_type, child.id).parent_id == TABLE_ID_C + # Add rows to both tables + _add_row( + client, + table_type, + False, + TABLE_ID_B, + data=dict(good=True, words=5, stars=9.9, inputs=TEXT, summary=""), + ) + _add_row( + client, + table_type, + False, + TABLE_ID_C, + data=dict(good=True, words=5, stars=9.9, inputs=TEXT, summary=""), + ) + + +def test_chat_table_gen_config( + setup: ServingContext, +): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="User", dtype="str"), + ColumnSchemaCreate( + id="AI", + dtype="str", + gen_config=LLMGenConfig( + model=_get_chat_model(client), + system_prompt="You are a concise assistant.", + multi_turn=False, + temperature=0.001, + top_p=0.001, + max_tokens=20, + ), + ), + ] + with _create_table(client, "chat", cols=[], chat_cols=cols) as table: + cfg_map = {c.id: c.gen_config for c in table.cols} + # AI column gen config will be multi turn regardless of input params + assert cfg_map["AI"].multi_turn is True diff --git a/services/api/tests/gen_table/test_table_ops_v2.py b/services/api/tests/gen_table/test_table_ops_v2.py new file mode 100644 index 0000000..6b27b36 --- /dev/null +++ b/services/api/tests/gen_table/test_table_ops_v2.py @@ -0,0 +1,824 @@ +from copy import deepcopy +from dataclasses import dataclass +from os.path import dirname, join, realpath +from tempfile import TemporaryDirectory + +import pytest +from sqlmodel import text + +from jamaibase import JamAI +from jamaibase.types import ( + AddActionColumnSchema, + AddChatColumnSchema, + AddKnowledgeColumnSchema, + ColumnSchemaCreate, + GenConfigUpdateRequest, + OkResponse, + OrganizationCreate, + OrgMemberRead, + ProjectMemberRead, + RAGParams, + TableImportRequest, + TableMetaResponse, +) +from owl.db import sync_session +from owl.types import ( + LLMGenConfig, + Role, + TableType, +) +from owl.utils.exceptions import BadInputError, ResourceNotFoundError +from owl.utils.test import ( + ELLM_DESCRIBE_CONFIG, + ELLM_DESCRIBE_DEPLOYMENT, + ELLM_EMBEDDING_CONFIG, + ELLM_EMBEDDING_DEPLOYMENT, + GPT_41_NANO_CONFIG, + GPT_41_NANO_DEPLOYMENT, + RERANK_ENGLISH_v3_SMALL_CONFIG, + RERANK_ENGLISH_v3_SMALL_DEPLOYMENT, + add_table_rows, + create_deployment, + create_model_config, + create_organization, + create_project, + create_table, + create_user, + list_table_rows, + list_tables, +) + +TEST_DIR = dirname(dirname(realpath(__file__))) +TABLE_TYPES = [TableType.ACTION, TableType.KNOWLEDGE, TableType.CHAT] + + +@dataclass(slots=True) +class ServingContext: + superuser_id: str + user_id: str + superorg_id: str + project_id: str + llm_model_id: str + desc_llm_model_id: str + rerank_model_id: str + + +@pytest.fixture(scope="module") +def setup(): + """ + Fixture to set up the necessary organization and projects for file tests. + """ + with ( + # Create superuser and user + create_user() as superuser, + create_user(dict(email="user@up.com", name="User")) as user, + # Create organization + create_organization( + body=OrganizationCreate(name="Clubhouse"), user_id=superuser.id + ) as superorg, + # Create project + create_project( + dict(name="Project"), user_id=superuser.id, organization_id=superorg.id + ) as p0, + ): + assert superuser.id == "0" + assert superorg.id == "0" + # Join organization and project as member + client = JamAI(user_id=superuser.id) + membership = client.organizations.join_organization( + user.id, + organization_id=superorg.id, + role=Role.MEMBER, + ) + assert isinstance(membership, OrgMemberRead) + membership = client.projects.join_project( + user.id, + project_id=p0.id, + role=Role.MEMBER, + ) + assert isinstance(membership, ProjectMemberRead) + + # Create models + gpt_config = deepcopy(GPT_41_NANO_CONFIG) + gpt_config.name = "A OpenAI GPT-4.1 nano" + with ( + # Purposely include a model name that starts with A to test default model sorting + create_model_config(gpt_config) as llm_config, + # Default model should still prefer ELLM model + create_model_config(ELLM_DESCRIBE_CONFIG) as desc_llm_config, + create_model_config(ELLM_EMBEDDING_CONFIG), + create_model_config(RERANK_ENGLISH_v3_SMALL_CONFIG) as rerank_config, + ): + # Create deployments + with ( + create_deployment(GPT_41_NANO_DEPLOYMENT), + create_deployment(ELLM_DESCRIBE_DEPLOYMENT), + create_deployment(ELLM_EMBEDDING_DEPLOYMENT), + create_deployment(RERANK_ENGLISH_v3_SMALL_DEPLOYMENT), + ): + yield ServingContext( + superuser_id=superuser.id, + user_id=user.id, + superorg_id=superorg.id, + project_id=p0.id, + llm_model_id=llm_config.id, + desc_llm_model_id=desc_llm_config.id, + rerank_model_id=rerank_config.id, + ) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_default_model_default_prompts( + setup: ServingContext, + table_type: TableType, +): + """ + Test default model and prompts: + - Default model + - Default prompts + - Table creation (should set default prompts) + - Multi-turn column + - Column add (should set default prompts) + - Gen config update (should NOT set default prompts) + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="str", dtype="str"), + # Default for system prompt and prompt + ColumnSchemaCreate( + id="o1", + dtype="str", + gen_config=LLMGenConfig( + model="", + system_prompt="", + prompt="", + ), + ), + ColumnSchemaCreate(id="float", dtype="float"), + # Default for system prompt + ColumnSchemaCreate( + id="o2", + dtype="str", + gen_config=LLMGenConfig( + model="", + system_prompt="", + prompt="What is love?", + ), + ), + # Default for prompt + ColumnSchemaCreate( + id="o3", + dtype="str", + gen_config=LLMGenConfig( + model="", + system_prompt="Baby don't hurt me", + prompt="", + ), + ), + # Default for system prompt and prompt (multi-turn) + ColumnSchemaCreate( + id="o4", + dtype="str", + gen_config=LLMGenConfig( + model="", + system_prompt="", + prompt="", + multi_turn=True, + ), + ), + ] + with create_table(client, table_type, cols=cols) as table: + table = client.table.get_table(table_type, table.id) + assert isinstance(table, TableMetaResponse) + ### --- Default model --- ### + for col in ["o1", "o2", "o3", "o4"]: + gen_config = table.cfg_map[col] + assert isinstance(gen_config, LLMGenConfig) + assert gen_config.model == setup.desc_llm_model_id + ### --- Default prompts --- ### + default_sys_phrase = ( + "You are a versatile data generator. " + "Your task is to process information from input data and generate appropriate responses based on the specified column name and input data." + ) + # Table creation + assert default_sys_phrase in table.cfg_map["o1"].system_prompt + assert default_sys_phrase in table.cfg_map["o2"].system_prompt + assert table.cfg_map["o3"].system_prompt == "Baby don't hurt me" + assert "You are an agent named" in table.cfg_map["o4"].system_prompt + + def _check_prompt(prompt: str): + assert "${str}" in prompt + assert "${ID}" not in prompt # Info columns + assert "${Updated at}" not in prompt # Info columns + if table_type == TableType.KNOWLEDGE: + assert "${Title}" in prompt + assert "${Text}" in prompt + assert "${File ID}" in prompt + assert "${Page}" in prompt + assert "${Title Embed}" not in prompt # Vector columns + assert "${Text Embed}" not in prompt # Vector columns + elif table_type == TableType.CHAT: + assert "${User}" in prompt + + gen_config = table.cfg_map["o1"] + assert "${float}" not in gen_config.prompt # Columns on its right + _check_prompt(gen_config.prompt) + assert table.cfg_map["o2"].prompt == "What is love?" + gen_config = table.cfg_map["o3"] + assert "${float}" in gen_config.prompt + _check_prompt(gen_config.prompt) + gen_config = table.cfg_map["o4"] + assert "${float}" in gen_config.prompt + _check_prompt(gen_config.prompt) + # Column add + cols = [ + ColumnSchemaCreate( + id="o5", + dtype="str", + gen_config=LLMGenConfig( + model="", + system_prompt="", + prompt="", + ), + ), + ] + if table_type == TableType.ACTION: + client.table.add_action_columns(AddActionColumnSchema(id=table.id, cols=cols)) + elif table_type == TableType.KNOWLEDGE: + client.table.add_knowledge_columns(AddKnowledgeColumnSchema(id=table.id, cols=cols)) + else: + client.table.add_chat_columns(AddChatColumnSchema(id=table.id, cols=cols)) + table = client.table.get_table(table_type, table.id) + gen_config = table.cfg_map["o5"] + assert default_sys_phrase in gen_config.system_prompt + assert "${float}" in gen_config.prompt + _check_prompt(gen_config.prompt) + # Update gen config to empty prompt + client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map={ + "o5": LLMGenConfig( + model="", + system_prompt="", + prompt="", + ) + }, + ), + ) + table = client.table.get_table(table_type, table.id) + gen_config = table.cfg_map["o5"] + assert isinstance(gen_config, LLMGenConfig) + assert gen_config.model == setup.desc_llm_model_id # Default model + assert gen_config.system_prompt == "" + assert gen_config.prompt == "" + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_create_delete_table( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + with create_table(client, table_type) as table: + assert isinstance(table, TableMetaResponse) + # Delete + response = client.table.delete_table(table_type, table.id) + assert isinstance(response, OkResponse) + # After deleting + with pytest.raises(ResourceNotFoundError, match="is not found."): + client.table.get_table(table_type, table.id) + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_get_list_tables( + setup: ServingContext, + table_type: TableType, +): + """ + Test get table and list tables. + - offset and limit + - order_by and order_ascending + - created_by + - parent_id (list project with agents, chat agent, chat, all tables) + - search_query + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + """ + super_client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ColumnSchemaCreate(id="int", dtype="int")] + + ### --- Test get and list on DB without schemas --- ### + with sync_session() as session: + for table_type in TableType: + session.exec(text(f'DROP SCHEMA IF EXISTS "{setup.project_id}_{table_type}" CASCADE')) + session.commit() + tables = list_tables(client, table_type) + assert len(tables.items) == 0 + assert tables.total == 0 + with pytest.raises(ResourceNotFoundError, match="Table .+ is not found."): + client.table.get_table(table_type, "123") + + ### --- Create tables --- ### + with ( + create_table(super_client, table_type, "Table 2", cols=cols) as t0, + create_table(super_client, table_type, "table 1", cols=cols) as t1, + create_table(client, table_type, "Table 0", cols=cols) as t2, + ): + assert isinstance(t0, TableMetaResponse) + assert isinstance(t1, TableMetaResponse) + assert isinstance(t2, TableMetaResponse) + num_tables = 3 + ### --- List tables --- ### + tables = list_tables(client, table_type) + assert len(tables.items) == num_tables + assert tables.total == num_tables + assert [t.id for t in tables.items] == [t0.id, t1.id, t2.id] + + ### --- Get table --- ### + for table in tables.items: + _table = client.table.get_table(table_type, table.id) + assert isinstance(_table, TableMetaResponse) + assert _table.model_dump(exclude={"num_rows"}) == table.model_dump( + exclude={"num_rows"} + ) + + ### --- List tables (case-insensitive sort) --- ### + _tables = list_tables(client, table_type, order_by="id") + assert _tables.total == num_tables + assert [t.id for t in _tables.items] == [t2.id, t1.id, t0.id] + + ### --- List tables (offset and limit) --- ### + _tables = list_tables(client, table_type, offset=0, limit=1) + assert len(_tables.items) == 1 + assert _tables.total == num_tables + assert _tables.items[0].id == tables.items[0].id, f"{_tables.items=}" + _tables = list_tables(client, table_type, offset=1, limit=1) + assert len(_tables.items) == 1 + assert _tables.total == num_tables + assert _tables.items[0].id == tables.items[1].id, f"{_tables.items=}" + # Offset >= num tables + _tables = list_tables(client, table_type, offset=num_tables, limit=1) + assert len(_tables.items) == 0 + assert _tables.total == num_tables + _tables = list_tables(client, table_type, offset=num_tables + 1, limit=1) + assert len(_tables.items) == 0 + assert _tables.total == num_tables + # Invalid offset and limit + with pytest.raises(BadInputError): + list_tables(client, table_type, offset=0, limit=0) + with pytest.raises(BadInputError): + list_tables(client, table_type, offset=-1, limit=1) + + ### --- List tables (order_by and order_ascending) --- ### + _tables = list_tables(client, table_type, order_ascending=False) + assert len(tables.items) == num_tables + assert _tables.total == num_tables + assert [t.id for t in _tables.items[::-1]] == [t.id for t in tables.items] + _tables = list_tables(client, table_type, order_by="id") + assert len(tables.items) == num_tables + assert _tables.total == num_tables + assert [t.id for t in _tables.items[::-1]] == [t.id for t in tables.items] + + ### --- List tables (created_by) --- ### + _tables = list_tables(client, table_type, created_by=setup.superuser_id) + assert len(_tables.items) == 2 + assert _tables.total == 2 + assert _tables.total != num_tables + _tables = list_tables(client, table_type, created_by=setup.user_id) + assert len(_tables.items) == 1 + assert _tables.total == 1 + assert _tables.total != num_tables + + ### --- List tables (parent_id) --- ### + if table_type == TableType.CHAT: + # Create a child table + _table = client.table.duplicate_table(table_type, t0.id, None, create_as_child=True) + try: + assert isinstance(_table, TableMetaResponse) + # List projects with chat agent list + projects = client.projects.list_projects(setup.superorg_id, list_chat_agents=True) + assert len(projects.items) == 1 + assert projects.total == 1 + _project = projects.items[0] + assert len(_project.chat_agents) == num_tables + # List all chat agents + _tables = list_tables(client, table_type, parent_id="_agent_") + assert len(_tables.items) == num_tables + assert _tables.total == num_tables + assert {t.id for t in _tables.items} == {t.id for t in _project.chat_agents} + _tables = list_tables(client, table_type, parent_id="_agent_", offset=1) + assert len(_tables.items) == num_tables - 1 + assert _tables.total == num_tables + # List all chats + _tables = list_tables(client, table_type, parent_id="_chat_") + assert len(_tables.items) == 1 + assert _tables.total == 1 + # List all tables + _tables = list_tables(client, table_type, parent_id=None) + assert len(_tables.items) == num_tables + 1 + assert _tables.total == num_tables + 1 + finally: + client.table.delete_table(table_type, _table.id) + + ### --- List tables (search_query) --- ### + _tables = list_tables(client, table_type, search_query="1") + assert len(_tables.items) == 1 + assert _tables.total == 1 + assert _tables.total != num_tables + assert _tables.items[0].id == t1.id + _tables = list_tables(client, table_type, search_query="1", offset=1) + assert len(_tables.items) == 0 + assert _tables.total == 1 + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_update_gen_config( + setup: ServingContext, + table_type: TableType, +): + """ + Test updating table generation config: + - Partial update + - Switch to/from None + - Chat AI column must always have gen config + - Chat AI column multi-turn must always be True + - Invalid column reference + - Invalid LLM model + - Invalid knowledge table ID + - Invalid reranker model + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="i0", dtype="str"), + ColumnSchemaCreate(id="o0", dtype="str", gen_config=LLMGenConfig()), + ColumnSchemaCreate(id="o1", dtype="str", gen_config=None), + ] + with ( + create_table(client, TableType.KNOWLEDGE) as kt, + create_table(client, table_type, cols=cols) as table, + ): + assert isinstance(table.cfg_map["o0"], LLMGenConfig) + assert len(table.cfg_map["o0"].system_prompt) > 0 + assert len(table.cfg_map["o0"].prompt) > 0 + assert table.cfg_map["o1"] is None + if table_type == TableType.CHAT: + assert isinstance(table.cfg_map["AI"], LLMGenConfig) + + # --- Partial update --- # + old_cfg = table.cfg_map["o0"].model_dump() + # Update prompt + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map=dict(o0=LLMGenConfig(prompt="test")), + ), + ) + assert isinstance(table, TableMetaResponse) + assert isinstance(table.cfg_map["o0"], LLMGenConfig) + assert len(table.cfg_map["o0"].system_prompt) > 0 + assert table.cfg_map["o0"].prompt == "test" + new_cfg = table.cfg_map["o0"].model_dump() + assert old_cfg != new_cfg + old_cfg["prompt"] = "test" + assert old_cfg == new_cfg + + # --- Switch to/from None --- # + # Flip configs + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map=dict(o0=None, o1=LLMGenConfig()), + ), + ) + assert isinstance(table, TableMetaResponse) + assert table.cfg_map["o0"] is None + assert isinstance(table.cfg_map["o1"], LLMGenConfig) + assert len(table.cfg_map["o1"].system_prompt) == 0 + assert len(table.cfg_map["o1"].prompt) == 0 + if table_type == TableType.CHAT: + assert isinstance(table.cfg_map["AI"], LLMGenConfig) + + # --- Chat AI column must always have gen config --- # + # --- Chat AI column multi-turn must always be True --- # + if table_type == TableType.CHAT: + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map=dict(AI=None), + ), + ) + assert isinstance(table, TableMetaResponse) + assert isinstance(table.cfg_map["AI"], LLMGenConfig) + table.cfg_map["AI"].multi_turn = False + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map=dict(AI=table.cfg_map["AI"]), + ), + ) + assert isinstance(table, TableMetaResponse) + assert isinstance(table.cfg_map["AI"], LLMGenConfig) + assert table.cfg_map["AI"].multi_turn is True + + # --- Invalid column reference --- # + with pytest.raises(BadInputError, match="invalid source columns"): + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map=dict(o1=LLMGenConfig(prompt="${o2}")), + ), + ) + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map=dict(o1=LLMGenConfig(prompt="${o0}")), + ), + ) + assert table.cfg_map["o1"].prompt == "${o0}" + + # --- Invalid LLM model --- # + with pytest.raises(BadInputError, match="LLM model .+ is not found"): + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map=dict(o0=LLMGenConfig(model="INVALID")), + ), + ) + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map=dict(o0=LLMGenConfig(model=setup.llm_model_id)), + ), + ) + assert table.cfg_map["o0"].model == setup.llm_model_id + + # --- Invalid knowledge table ID --- # + with pytest.raises(BadInputError, match="Knowledge Table .+ does not exist"): + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map=dict( + o0=LLMGenConfig(rag_params=RAGParams(table_id="INVALID")), + ), + ), + ) + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map=dict( + o0=LLMGenConfig(rag_params=RAGParams(table_id=kt.id)), + ), + ), + ) + assert isinstance(table.cfg_map["o0"].rag_params, RAGParams) + assert table.cfg_map["o0"].rag_params.table_id == kt.id + + # --- Invalid reranker model --- # + with pytest.raises(BadInputError, match="Reranking model .+ is not found"): + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map=dict( + o0=LLMGenConfig(rag_params=RAGParams(reranking_model="INVALID")), + ), + ), + ) + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map=dict( + o0=LLMGenConfig( + rag_params=RAGParams(reranking_model=setup.rerank_model_id), + ), + ), + ), + ) + assert isinstance(table.cfg_map["o0"].rag_params, RAGParams) + assert table.cfg_map["o0"].rag_params.reranking_model == setup.rerank_model_id + assert table.cfg_map["o0"].rag_params.table_id == kt.id + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_long_table_column_ids( + setup: ServingContext, + table_type: TableType, +): + """ + Test various table and row operations on a table with long table and column IDs (100 characters). + - Check default prompts + - Update gen config + - Rename table and column + - Add row before and after: + - Table and column renames + - Column add and drop + - List rows + - Hybrid search + - RAG + - Import and export + + Args: + setup (ServingContext): Setup. + table_type (TableType): Table type. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + # 100 characters + kt_id = "one two three four five six seven eight nine ten eleven twelve thirteen fourteen fifteen sixteen (0)" + table_id = "one two three four five six seven eight nine ten eleven twelve thirteen fourteen fifteen sixteen (1)" + col_ids = [table_id, table_id.replace("one", "111"), table_id.replace("one", "112")] + cols = [ + ColumnSchemaCreate(id=col_ids[0], dtype="str").model_dump(), + ColumnSchemaCreate( + id=col_ids[1], dtype="str", gen_config=LLMGenConfig(model=setup.desc_llm_model_id) + ).model_dump(), + ColumnSchemaCreate( + id=col_ids[2], + dtype="str", + gen_config=LLMGenConfig( + model=setup.desc_llm_model_id, rag_params=RAGParams(table_id=kt_id) + ), + ), + ] + with ( + create_table(client, TableType.KNOWLEDGE, table_id=kt_id, cols=[]) as kt, + create_table(client, table_type, table_id=table_id, cols=cols) as table, + ): + assert kt.id == kt_id + assert table.id == table_id + col_map = {c.id: c for c in table.cols} + # Add knowledge data + add_table_rows(client, TableType.KNOWLEDGE, kt.id, [dict(), dict()], stream=False) + rows = list_table_rows(client, TableType.KNOWLEDGE, kt.id) + assert len(rows.values) == 2 + assert rows.total == 2 + # Check default prompts + gen_cfg = col_map[col_ids[1]].gen_config + assert isinstance(gen_cfg, LLMGenConfig) + assert isinstance(gen_cfg.system_prompt, str) + assert len(gen_cfg.system_prompt) > 1 + assert isinstance(gen_cfg.prompt, str) + assert len(gen_cfg.prompt) > 1 + assert f'Table name: "{table.id}"' in gen_cfg.prompt + assert f"{col_ids[0]}: ${{{col_ids[0]}}}" in gen_cfg.prompt + assert f'column "{col_ids[1]}"' in gen_cfg.prompt + # Update prompt and multi-turn + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest( + table_id=table.id, + column_map={ + col_ids[1]: LLMGenConfig(prompt=f"${{{col_ids[0]}}}", multi_turn=True) + }, + ), + ) + assert isinstance(table, TableMetaResponse) + # Add row + row_data = {"Title": "", "Text": "", "User": "Hi", "AI": "Hello"} + response = add_table_rows( + client, table_type, table.id, [{col_ids[0]: "one", **row_data}], stream=False + ) + content = response.rows[0].columns[col_ids[1]].content + assert "System prompt: There is a text with [40] tokens." in content + assert "There is a text with [1] tokens." in content + # Rename table + table_id_dst = table.id.replace("one", "two") + table = client.table.rename_table(table_type, table.id, table_id_dst) + assert isinstance(table, TableMetaResponse) + assert table.id == table_id_dst + # Rename column + col_id_dst = col_ids[1].replace("111", "222") + table = client.table.rename_columns( + table_type, + dict(table_id=table.id, column_map={col_ids[1]: col_id_dst}), + ) + assert isinstance(table, TableMetaResponse) + col_ids[1] = col_id_dst + col_map = {c.id: c for c in table.cols} + assert col_id_dst in col_map + # Add row + response = add_table_rows( + client, table_type, table.id, [{col_ids[0]: "one two", **row_data}], stream=True + ) + content = response.rows[0].columns[col_ids[1]].content + assert "System prompt: There is a text with [40] tokens." in content + assert "There is a text with [1] tokens." in content + assert "There is a text with [2] tokens." in content + # Add column + new_col_id = col_ids[1].replace("222", "333") + new_cols = [ + ColumnSchemaCreate( + id=new_col_id, dtype="str", gen_config=LLMGenConfig(model=setup.desc_llm_model_id) + ).model_dump(), + ] + if table_type == TableType.ACTION: + table = client.table.add_action_columns(dict(id=table.id, cols=new_cols)) + elif table_type == TableType.KNOWLEDGE: + table = client.table.add_knowledge_columns(dict(id=table.id, cols=new_cols)) + elif table_type == TableType.CHAT: + table = client.table.add_chat_columns(dict(id=table.id, cols=new_cols)) + else: + raise ValueError(f"Unknown table type: {table_type}") + col_map = {c.id: c for c in table.cols} + assert new_col_id in col_map + # Check default prompts + gen_cfg = col_map[new_col_id].gen_config + assert isinstance(gen_cfg, LLMGenConfig) + assert isinstance(gen_cfg.system_prompt, str) + assert len(gen_cfg.system_prompt) > 1 + assert isinstance(gen_cfg.prompt, str) + assert len(gen_cfg.prompt) > 1 + assert f'Table name: "{table.id}"' in gen_cfg.prompt + assert f"{col_ids[0]}: ${{{col_ids[0]}}}" in gen_cfg.prompt + assert f'column "{new_col_id}"' in gen_cfg.prompt + # Add row + response = add_table_rows( + client, table_type, table.id, [{col_ids[0]: "a b c", **row_data}], stream=True + ) + content = response.rows[0].columns[col_ids[1]].content + assert "System prompt: There is a text with [40] tokens." in content + assert "There is a text with [1] tokens." in content + assert "There is a text with [2] tokens." in content + assert "There is a text with [3] tokens." in content + content = response.rows[0].columns[new_col_id].content + assert "There is a text with" in content + # Drop column + table = client.table.drop_columns( + table_type, dict(table_id=table.id, column_names=[new_col_id]) + ) + assert isinstance(table, TableMetaResponse) + col_map = {c.id: c for c in table.cols} + assert new_col_id not in col_map + # Add row + response = add_table_rows( + client, table_type, table.id, [{col_ids[0]: "a b c d", **row_data}], stream=True + ) + content = response.rows[0].columns[col_ids[1]].content + assert "System prompt: There is a text with [40] tokens." in content + assert "There is a text with [1] tokens." in content + assert "There is a text with [2] tokens." in content + assert "There is a text with [3] tokens." in content + assert "There is a text with [4] tokens." in content + assert len(response.rows[0].columns) == 2 + # List rows + rows = list_table_rows(client, table_type, table.id) + assert len(rows.values) == 4 + assert rows.total == 4 + for r in rows.references: + assert len(r[col_ids[2]].chunks) == 2 + rows = list_table_rows(client, table_type, table.id, where=f""""{col_ids[1]}" ~* '3'""") + assert len(rows.values) == 2 + assert rows.total == 2 + with pytest.raises(BadInputError): + list_table_rows(client, table_type, table.id, where=f""""{col_ids[1]}" ~* 3""") + # Hybrid search + results = client.table.hybrid_search( + table_type, dict(table_id=table.id, query="token", limit=2) + ) + assert isinstance(results, list) + assert len(results) == 2 + for r in results: + assert "rrf_score" in r + for c in col_ids: + assert c in r + # Export table + with TemporaryDirectory() as tmp_dir: + file_path = join(tmp_dir, f"{table.id}.parquet") + with open(file_path, "wb") as f: + f.write(client.table.export_table(table_type, table.id)) + # Import table + import_table_id = table.id.replace("(1)", "(2)") + response = client.table.import_table( + table_type, + TableImportRequest( + file_path=file_path, table_id_dst=import_table_id, blocking=True + ), + ) + rows = list_table_rows(client, table_type, import_table_id) + assert len(rows.values) == 4 + assert rows.total == 4 + for r in rows.references: + assert len(r[col_ids[2]].chunks) == 2 diff --git a/services/api/tests/gen_table/test_v1.py b/services/api/tests/gen_table/test_v1.py new file mode 100644 index 0000000..1f1cf48 --- /dev/null +++ b/services/api/tests/gen_table/test_v1.py @@ -0,0 +1,366 @@ +from dataclasses import dataclass +from os.path import dirname, join, realpath +from tempfile import TemporaryDirectory + +import pytest + +from jamaibase import JamAI +from jamaibase.types import ( + ActionTableSchemaCreate, + AddActionColumnSchema, + AddChatColumnSchema, + AddKnowledgeColumnSchema, + ChatTableSchemaCreate, + ChatThreadResponse, + ColumnDropRequest, + ColumnRenameRequest, + ColumnReorderRequest, + ColumnSchemaCreate, + GenConfigUpdateRequest, + KnowledgeTableSchemaCreate, + MultiRowAddRequest, + MultiRowCompletionResponse, + MultiRowDeleteRequest, + MultiRowRegenRequest, + OkResponse, + OrganizationCreate, + Page, + RowUpdateRequest, + SearchRequest, + TableDataImportRequest, + TableImportRequest, + TableMetaResponse, +) +from owl.types import ( + LLMGenConfig, + TableType, +) +from owl.utils.crypt import generate_key +from owl.utils.test import ( + ELLM_DESCRIBE_CONFIG, + ELLM_DESCRIBE_DEPLOYMENT, + ELLM_EMBEDDING_CONFIG, + ELLM_EMBEDDING_DEPLOYMENT, + GPT_41_NANO_CONFIG, + GPT_41_NANO_DEPLOYMENT, + RERANK_ENGLISH_v3_SMALL_CONFIG, + RERANK_ENGLISH_v3_SMALL_DEPLOYMENT, + create_deployment, + create_model_config, + create_organization, + create_project, + create_user, +) + +TEST_DIR = dirname(dirname(realpath(__file__))) +TABLE_TYPES = [TableType.ACTION, TableType.KNOWLEDGE, TableType.CHAT] + + +@dataclass(slots=True) +class ServingContext: + superuser_id: str + superorg_id: str + project_id: str + llm_model_id: str + desc_llm_model_id: str + + +@pytest.fixture(scope="module") +def setup(): + """ + Fixture to set up the necessary organization and projects for file tests. + """ + with ( + # Create superuser + create_user() as superuser, + # Create organization + create_organization( + body=OrganizationCreate(name="Clubhouse"), user_id=superuser.id + ) as superorg, + # Create project + create_project( + dict(name="Project"), user_id=superuser.id, organization_id=superorg.id + ) as p0, + ): + assert superuser.id == "0" + assert superorg.id == "0" + + # Create models + with ( + create_model_config(GPT_41_NANO_CONFIG) as llm_config, + create_model_config(ELLM_DESCRIBE_CONFIG) as desc_llm_config, + create_model_config(ELLM_EMBEDDING_CONFIG), + create_model_config(RERANK_ENGLISH_v3_SMALL_CONFIG), + ): + # Create deployments + with ( + create_deployment(GPT_41_NANO_DEPLOYMENT), + create_deployment(ELLM_DESCRIBE_DEPLOYMENT), + create_deployment(ELLM_EMBEDDING_DEPLOYMENT), + create_deployment(RERANK_ENGLISH_v3_SMALL_DEPLOYMENT), + ): + yield ServingContext( + superuser_id=superuser.id, + superorg_id=superorg.id, + project_id=p0.id, + llm_model_id=llm_config.id, + desc_llm_model_id=desc_llm_config.id, + ) + + +def _gen_id() -> str: + return generate_key(8, "table-") + + +@pytest.mark.parametrize("table_type", TABLE_TYPES) +def test_gen_table_v1( + setup: ServingContext, + table_type: TableType, +): + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + cols = [ + ColumnSchemaCreate(id="int", dtype="int"), + ColumnSchemaCreate( + id="summary", + dtype="str", + gen_config=LLMGenConfig( + model="", + system_prompt="", + prompt="", + max_tokens=10, + ), + ), + ] + # Create table + if table_type == TableType.ACTION: + table = client.table.create_action_table( + ActionTableSchemaCreate(id=_gen_id(), cols=cols), v1=True + ) + elif table_type == TableType.KNOWLEDGE: + table = client.table.create_knowledge_table( + KnowledgeTableSchemaCreate(id=_gen_id(), cols=cols, embedding_model=""), v1=True + ) + elif table_type == TableType.CHAT: + cols = [ + ColumnSchemaCreate(id="User", dtype="str"), + ColumnSchemaCreate( + id="AI", + dtype="str", + gen_config=LLMGenConfig( + model="", + system_prompt="You are a wacky assistant.", + max_tokens=5, + ), + ), + ] + cols + table = client.table.create_chat_table( + ChatTableSchemaCreate(id=_gen_id(), cols=cols), v1=True + ) + else: + raise ValueError(f"Unknown table type: {table_type}") + assert isinstance(table, TableMetaResponse) + cols = {c.id: c for c in table.cols} + assert "int" in cols + + # Duplicate table + new_table = client.table.duplicate_table(table_type, table_id_src=table.id, v1=True) + assert isinstance(new_table, TableMetaResponse) + assert new_table.id != table.id + + # Get table + _table = client.table.get_table(table_type, table_id=table.id, v1=True) + assert isinstance(_table, TableMetaResponse) + assert _table.id == table.id + + # List tables + tables = client.table.list_tables(table_type, v1=True) + assert isinstance(tables, Page) + assert len(tables.items) == 2 + assert tables.total == 2 + + # Rename table + table_id_dst = _gen_id() + _table = client.table.rename_table( + table_type, new_table.id, table_id_dst=table_id_dst, v1=True + ) + assert isinstance(_table, TableMetaResponse) + assert _table.id != new_table.id + assert _table.id == table_id_dst + new_table = _table + + # Delete table + response = client.table.delete_table(table_type, table_id=new_table.id, v1=True) + assert isinstance(response, OkResponse) + + # Add columns + cols = [ColumnSchemaCreate(id="str", dtype="str")] + if table_type == TableType.ACTION: + table = client.table.add_action_columns( + AddActionColumnSchema(id=table.id, cols=cols), v1=True + ) + elif table_type == TableType.KNOWLEDGE: + table = client.table.add_knowledge_columns( + AddKnowledgeColumnSchema(id=table.id, cols=cols), v1=True + ) + elif table_type == TableType.CHAT: + table = client.table.add_chat_columns(AddChatColumnSchema(id=table.id, cols=cols), v1=True) + else: + raise ValueError(f"Unknown table type: {table_type}") + assert isinstance(table, TableMetaResponse) + cols = {c.id: c for c in table.cols} + assert "int" in cols + assert "str" in cols + + # Rename columns + table = client.table.rename_columns( + table_type, + ColumnRenameRequest(table_id=table.id, column_map={"int": "integer"}), + v1=True, + ) + assert isinstance(table, TableMetaResponse) + cols = {c.id: c for c in table.cols} + assert "int" not in cols + assert "integer" in cols + + # Update gen config + table = client.table.update_gen_config( + table_type, + GenConfigUpdateRequest(table_id=table.id, column_map={"summary": None}), + v1=True, + ) + assert isinstance(table, TableMetaResponse) + cols = {c.id: c for c in table.cols} + assert cols["summary"].gen_config is None + + # Reorder columns + if table_type == TableType.ACTION: + table = client.table.reorder_columns( + table_type, + ColumnReorderRequest(table_id=table.id, column_names=["str", "integer", "summary"]), + v1=True, + ) + assert isinstance(table, TableMetaResponse) + assert [c.id for c in table.cols][-3:] == ["str", "integer", "summary"] + + # Drop columns + table = client.table.drop_columns( + table_type, + ColumnDropRequest(table_id=table.id, column_names=["integer"]), + v1=True, + ) + assert isinstance(table, TableMetaResponse) + cols = {c.id: c for c in table.cols} + assert "integer" not in cols + + # Add rows + response = client.table.add_table_rows( + table_type, + MultiRowAddRequest( + table_id=table.id, data=[{"str": "foo", "summary": "bar"}] * 3, stream=False + ), + v1=True, + ) + assert isinstance(response, MultiRowCompletionResponse) + assert len(response.rows) == 3 + + # List rows + rows = client.table.list_table_rows(table_type, table.id, v1=True) + assert isinstance(rows, Page) + assert len(rows.items) == 3 + assert rows.total == 3 + for row in rows.items: + assert "value" in row["str"] + + # List rows (V1 value bug) + rows = client.table.list_table_rows(table_type, table.id, columns=["str"], v1=True) + assert isinstance(rows, Page) + assert len(rows.items) == 3 + assert rows.total == 3 + for row in rows.items: + assert "value" not in row["str"] + + # Get row + row_id = rows.items[0]["ID"] + row = client.table.get_table_row(table_type, table.id, row_id, v1=True) + assert isinstance(row, dict) + + # Get conversation thread + if table_type == TableType.CHAT: + thread = client.table.get_conversation_thread(table_type, table.id, "AI") + assert isinstance(thread, ChatThreadResponse) + + # Hybrid search + response = client.table.hybrid_search( + table_type, + SearchRequest(table_id=table.id, query="foo"), + v1=True, + ) + assert isinstance(response, list) + assert len(response) == 3 + + # Regen rows + response = client.table.regen_table_rows( + table_type, + MultiRowRegenRequest(table_id=table.id, row_ids=[row_id], stream=False), + v1=True, + ) + assert isinstance(response, MultiRowCompletionResponse) + assert len(response.rows) == 1 + + # Update row + response = client.table.update_table_row( + table_type, + RowUpdateRequest(table_id=table.id, row_id=row_id, data={"str": "baz"}), + ) + assert isinstance(response, OkResponse) + + # Delete rows + response = client.table.delete_table_rows( + table_type, + MultiRowDeleteRequest(table_id=table.id, row_ids=[row_id]), + v1=True, + ) + assert isinstance(response, OkResponse) + + # Delete row + response = client.table.delete_table_row(table_type, table.id, rows.items[1]["ID"]) + assert isinstance(response, OkResponse) + + # Data import export + csv_bytes = client.table.export_table_data(table_type, table.id, v1=True) + assert len(csv_bytes) > 0 + with TemporaryDirectory() as tmp_dir: + fp = join(tmp_dir, "test.csv") + with open(fp, "wb") as f: + f.write(csv_bytes) + response = client.table.import_table_data( + table_type, + TableDataImportRequest(file_path=fp, table_id=table.id, stream=False), + v1=True, + ) + assert isinstance(response, MultiRowCompletionResponse) + assert len(response.rows) == 1 + + # Table import export + parquet_bytes = client.table.export_table(table_type, table.id, v1=True) + assert len(parquet_bytes) > 0 + with TemporaryDirectory() as tmp_dir: + fp = join(tmp_dir, "test.parquet") + with open(fp, "wb") as f: + f.write(parquet_bytes) + _table = client.table.import_table( + table_type, + TableImportRequest(file_path=fp, table_id_dst=_gen_id()), + v1=True, + ) + assert isinstance(_table, TableMetaResponse) + assert _table.id != table.id + + # Embed file + if table_type == TableType.KNOWLEDGE: + with TemporaryDirectory() as tmp_dir: + fp = join(tmp_dir, "test.txt") + with open(fp, "w") as f: + f.write("Lorem ipsum") + response = client.table.embed_file(fp, table.id, v1=True) + assert isinstance(response, OkResponse) diff --git a/services/api/tests/gen_table_core/test_gen_table_core.py b/services/api/tests/gen_table_core/test_gen_table_core.py new file mode 100644 index 0000000..1a5fcb5 --- /dev/null +++ b/services/api/tests/gen_table_core/test_gen_table_core.py @@ -0,0 +1,1909 @@ +import asyncio +import csv +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from jamaibase.types import ProjectRead +from owl.db.gen_table import ( + GENTABLE_ENGINE, + ColumnDtype, + ColumnMetadata, + GenerativeTableCore, + TableMetadata, +) +from owl.types import LLMGenConfig, TableType +from owl.utils.exceptions import BadInputError, ResourceNotFoundError +from owl.utils.test import ( + GPT_41_NANO_CONFIG, + GPT_41_NANO_DEPLOYMENT, + create_deployment, + create_model_config, + create_project, + setup_organizations, +) + +VECTOR_LEN = 2 + + +@dataclass(slots=True) +class Session: + projects: list[ProjectRead] + chat_model_id: str + + +@dataclass(slots=True) +class Setup: + projects: list[ProjectRead] + chat_model_id: str + table_type: str + table_id: str + schema_id: str + table: GenerativeTableCore + + +@pytest.fixture(autouse=True, scope="module") +def session(): + with setup_organizations() as ctx: + with ( + create_project(dict(name="Mickey 17"), user_id=ctx.superuser.id) as p0, + create_project(dict(name="Mickey 18"), user_id=ctx.superuser.id) as p1, + create_model_config(GPT_41_NANO_CONFIG) as llm_config, + create_deployment(GPT_41_NANO_DEPLOYMENT), + ): + yield Session( + projects=[p0, p1], + chat_model_id=llm_config.id, + ) + + +@pytest.fixture(autouse=True, scope="function") +async def setup(session: Session): + """Fixture to set up and tear down test environment""" + table_type = TableType.ACTION + table_id = "Table (test)" + project_id = session.projects[0].id + schema_id = f"{project_id}_{table_type}" + # Drop schema + await GenerativeTableCore.drop_schema(project_id=project_id, table_type=table_type) + + # Create table + table = await GenerativeTableCore.create_table( + project_id=project_id, + table_type=table_type, + table_metadata=TableMetadata( + table_id=table_id, + title="Test Table", + parent_id=None, + version="1", + versioning_enabled=True, + meta={}, + ), + column_metadata_list=[ + ColumnMetadata( + column_id="col (1)", + table_id=table_id, + dtype=ColumnDtype.STR, + vlen=0, + gen_config=None, + column_order=1, + meta={}, + ), + ColumnMetadata( + column_id="col (2)", + table_id=table_id, + dtype=ColumnDtype.INT, + vlen=0, + gen_config=None, + column_order=2, + meta={}, + ), + ColumnMetadata( + column_id="vector_col", + table_id=table_id, + dtype=ColumnDtype.FLOAT, + vlen=VECTOR_LEN, + gen_config=None, + column_order=3, + meta={}, + ), + ], + ) + yield Setup( + projects=session.projects, + chat_model_id=session.chat_model_id, + table_type=table_type, + table_id=table_id, + schema_id=schema_id, + table=table, + ) + # Clean up table + async with GENTABLE_ENGINE.transaction() as conn: + await conn.execute(f""" + DROP SCHEMA IF EXISTS "{schema_id}" CASCADE + """) + # https://github.com/MagicStack/asyncpg/issues/293#issuecomment-395069799 + # Need to close the connection, such that the next test will create pool on the new event loop + await GENTABLE_ENGINE.close() + + +@contextmanager +def assert_updated_time(table: GenerativeTableCore): + """Assert that table "updated_at" has been updated""" + start_time = table.table_metadata.updated_at + try: + yield + finally: + assert table.table_metadata.updated_at > start_time + + +class TestImportExportOperations: + async def test_export_empty_table(self, setup: Setup, tmp_path): + """Test exporting and importing an empty table preserves schema""" + table = setup.table + + # Export empty table + export_path = tmp_path / "empty_export.parquet" + await table.export_table(export_path) + assert export_path.exists() + + # Import empty table + imported_table = await GenerativeTableCore.import_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + source=export_path, + table_id_dst="imported_empty_table", + ) + + # Verify schema preserved + assert len((await imported_table.list_rows()).items) == 0 + assert len(imported_table.column_metadata) == len(table.column_metadata) + for orig_col, imp_col in zip( + table.column_metadata, imported_table.column_metadata, strict=True + ): + assert orig_col.column_id == imp_col.column_id + assert orig_col.dtype == imp_col.dtype + assert orig_col.vlen == imp_col.vlen + + async def test_import_table_to_new_project(self, setup: Setup, tmp_path): + """Test exporting and importing an empty table preserves schema""" + table = setup.table + new_project_id = setup.projects[1].id + # cleanup before test + await GenerativeTableCore.drop_schema( + project_id=new_project_id, table_type=setup.table_type + ) + + # Export empty table + export_path = tmp_path / "empty_export.parquet" + await table.export_table(export_path) + assert export_path.exists() + + # Import empty table + imported_table = await GenerativeTableCore.import_table( + project_id=new_project_id, + table_type=setup.table_type, + source=export_path, + table_id_dst="imported_empty_table", + ) + + # Verify schema preserved + assert len((await imported_table.list_rows()).items) == 0 + assert len(imported_table.column_metadata) == len(table.column_metadata) + for orig_col, imp_col in zip( + table.column_metadata, imported_table.column_metadata, strict=True + ): + assert orig_col.column_id == imp_col.column_id + assert orig_col.dtype == imp_col.dtype + assert orig_col.vlen == imp_col.vlen + + async def test_state_column_preservation(self, setup: Setup, tmp_path): + """Test state columns are preserved during export/import""" + table = setup.table + + # Add row with state values + new_row = { + "col (1)": "test", + "col (2)": 123, + "vector_col": np.random.rand(VECTOR_LEN), + } + await table.add_rows([new_row]) + + # Export table + export_path = tmp_path / "state_export.parquet" + await table.export_table(export_path) + + # Import table + imported_table = await GenerativeTableCore.import_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + source=export_path, + table_id_dst="imported_state_table", + ) + + # Verify state columns preserved + rows = (await imported_table.list_rows(remove_state_cols=False)).items + assert len(rows) == 1 + for state_col, _ in new_row.items(): + assert rows[0].get(f"{state_col}_") is not None + + async def test_export_table_basic(self, setup: Setup, tmp_path): + """Test basic table export functionality""" + # Create table + table = setup.table + + # Add test data + await table.add_rows( + [{"col (1)": "test1", "col (2)": 123, "vector_col": np.random.rand(VECTOR_LEN)}] + ) + + # Export table + export_path = tmp_path / "exported_table.parquet" + await table.export_table(export_path) + + # Verify file exists + assert export_path.exists() + assert export_path.stat().st_size > 0 + + async def test_export_table_error_cases(self, setup: Setup, tmp_path): + """Test error cases for table export""" + # Create table + table = setup.table + + # Test invalid path + invalid_path = Path("/invalid/path/export.parquet") + with pytest.raises(ResourceNotFoundError): + await table.export_table(invalid_path) + + async def test_import_table_basic(self, setup: Setup, tmp_path): + """Test basic table import functionality""" + # Create and export test table + table = setup.table + await table.add_rows( + [{"col (1)": "test1", "col (2)": 123, "vector_col": np.random.rand(VECTOR_LEN)}] + ) + export_path = tmp_path / "exported_table.parquet" + await table.export_table(export_path) + + # Import table with new name + imported_table = await GenerativeTableCore.import_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + source=export_path, + table_id_dst="imported_table", + ) + + # Verify imported data + rows = (await imported_table.list_rows()).items + assert len(rows) == 1 + assert rows[0]["col (1)"] == "test1" + assert rows[0]["col (2)"] == 123 + assert len(rows[0]["vector_col"]) == VECTOR_LEN + + async def test_import_table_error_cases(self, setup: Setup, tmp_path): + """Test error cases for table import""" + # Test invalid path + invalid_path = Path("/invalid/path/import.parquet") + with pytest.raises(ResourceNotFoundError): + await GenerativeTableCore.import_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + source=invalid_path, + table_id_dst="imported_table", + ) + + async def test_export_import_parquet_basic(self, setup: Setup, tmp_path): + """Test basic Parquet export/import functionality with detailed verification""" + table = setup.table + + # Add test data with different types + test_data = [ + { + "col (1)": "test1", + "col (2)": 123, + "vector_col": np.random.rand(VECTOR_LEN), + } + ] + await table.add_rows(test_data) + + # Get original metadata and columns + original_metadata = table.table_metadata + original_columns = table.column_metadata + + # Export to Parquet + export_path = tmp_path / "exported_table.parquet" + await table.export_table(export_path) + + # Verify file exists + assert export_path.exists() + assert export_path.stat().st_size > 0 + + # Import with new name + imported_table = await GenerativeTableCore.import_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + source=export_path, + table_id_dst="imported_table_parquet", + ) + + # Verify imported data + rows = (await imported_table.list_rows()).items + assert len(rows) == 1 + + # Detailed data comparison + original_row = (await table.list_rows()).items[0] + imported_row = rows[0] + + # Compare all fields except internal IDs + for data in test_data: + for key in data.keys(): + if isinstance(data[key], np.ndarray): + np.testing.assert_array_equal(original_row[key], imported_row[key]) + else: + assert original_row[key] == imported_row[key] + + # Verify metadata preservation + imported_metadata = imported_table.table_metadata + assert imported_metadata.title == original_metadata.title + assert imported_metadata.meta == original_metadata.meta + # assert imported_metadata.version == original_metadata.version + + # Verify column preservation + imported_columns = imported_table.column_metadata + assert len(imported_columns) == len(original_columns) + + for orig_col, imp_col in zip(original_columns, imported_columns, strict=True): + assert orig_col.column_id == imp_col.column_id + assert orig_col.dtype == imp_col.dtype + assert orig_col.vlen == imp_col.vlen + assert orig_col.column_order == imp_col.column_order + + async def test_import_recreates_indexes(self, setup: Setup, tmp_path): + """Verify imported tables have all indexes recreated""" + # Export original table + export_path = tmp_path / "export.parquet" + await setup.table.export_table(export_path) + + # Import to new table + imported = await GenerativeTableCore.import_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + source=export_path, + table_id_dst="imported_with_indexes", + ) + + # Verify indexes exist + async with GENTABLE_ENGINE.transaction() as conn: + # Check FTS index + fts_index = await conn.fetchval( + """ + SELECT COUNT(*) FROM pg_indexes + WHERE schemaname = $1 AND tablename = $2 AND indexname LIKE '%_fts_idx' + """, + imported.schema_id, + imported.table_id, + ) + assert fts_index > 0 + + # Check vector indexes + vec_indexes = await conn.fetchval( + """ + SELECT COUNT(*) FROM pg_indexes + WHERE schemaname = $1 AND tablename = $2 AND indexname LIKE '%_vec_idx' + """, + imported.schema_id, + imported.table_id, + ) + assert vec_indexes == len(imported.vector_column_names) + + async def test_export_import_parquet_large_data(self, setup: Setup, tmp_path): + """Test Parquet export/import with large dataset and detailed verification""" + table = setup.table + + # Add 1000 rows of test data + test_data = [ + { + "col (1)": f"test{i}", + "col (2)": i, + "vector_col": np.random.rand(VECTOR_LEN), + } + for i in range(1000) + ] + await table.add_rows(test_data) + + # Get original metadata and columns + original_metadata = table.table_metadata + original_columns = table.column_metadata + + # Export to Parquet + export_path = tmp_path / "large_export.parquet" + await table.export_table(export_path) + + # Import with new name + imported_table = await GenerativeTableCore.import_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + source=export_path, + table_id_dst="large_import_parquet", + ) + + # Verify all data was imported + rows = (await imported_table.list_rows(limit=1000)).items + assert len(rows) == 1000 + + # Get original rows for comparison + original_rows = (await table.list_rows(limit=1000)).items + + # Detailed data comparison + for orig_row, imp_row in zip(original_rows, rows, strict=True): + assert orig_row["col (1)"] == imp_row["col (1)"] + assert orig_row["col (2)"] == imp_row["col (2)"] + np.testing.assert_array_equal(orig_row["vector_col"], imp_row["vector_col"]) + + # Verify metadata preservation + imported_metadata = imported_table.table_metadata + assert imported_metadata.title == original_metadata.title + assert imported_metadata.meta == original_metadata.meta + # assert imported_metadata.version == original_metadata.version + + # Verify column preservation + imported_columns = imported_table.column_metadata + assert len(imported_columns) == len(original_columns) + + for orig_col, imp_col in zip(original_columns, imported_columns, strict=True): + assert orig_col.column_id == imp_col.column_id + assert orig_col.dtype == imp_col.dtype + assert orig_col.vlen == imp_col.vlen + assert orig_col.column_order == imp_col.column_order + + async def test_export_parquet_error_cases(self, setup: Setup, tmp_path): + """Test Parquet export error cases""" + table = setup.table + + # Test invalid path + invalid_path = Path("/invalid/path/export.parquet") + with pytest.raises(ResourceNotFoundError): + await table.export_table(invalid_path) + + # Test invalid format + with pytest.raises(BadInputError): + await table.export_table(tmp_path / "test.csv") + + async def test_import_parquet_invalid_path_cases(self, setup: Setup, tmp_path): + """Test Parquet import invalid case cases""" + # Test invalid path + invalid_path = Path("/invalid/path/import.parquet") + with pytest.raises(ResourceNotFoundError): + await GenerativeTableCore.import_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + source=invalid_path, + table_id_dst="imported_table", + ) + + async def test_import_corrupt_files(self, setup: Setup, tmp_path): + """Test handling of corrupted import files.""" + # Setup + corrupt_path = tmp_path / "corrupt.parquet" + + # Test malformed file + corrupt_path.write_bytes(b"PAR1\x00\x00INVALID\x00PAR1") + with pytest.raises(BadInputError, match="contains bad data"): + await GenerativeTableCore.import_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + source=corrupt_path, + table_id_dst="corrupt_test_table", + ) + + # Test partial file (truncated) + with open(corrupt_path, "wb") as f: + f.write(b"PAR1") # Only magic bytes + with pytest.raises(BadInputError, match="contains bad data"): + await GenerativeTableCore.import_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + source=corrupt_path, + table_id_dst="corrupt_test_table", + ) + + # Test invalid metadata + df = pd.DataFrame({"col (1)": [1, 2, 3]}) + df.to_parquet(corrupt_path) + # Corrupt the metadata by overwriting footer + with open(corrupt_path, "r+b") as f: + f.seek(-100, 2) + f.write(b"X" * 100) + with pytest.raises(BadInputError, match="contains bad data"): + await GenerativeTableCore.import_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + source=corrupt_path, + table_id_dst="corrupt_test_table", + ) + + async def test_export_data_basic(self, setup: Setup, tmp_path): + """Test basic data export to CSV""" + # Create table + table = setup.table + + # Add test data + await table.add_rows( + [{"col (1)": "test1", "col (2)": 123, "vector_col": np.random.rand(VECTOR_LEN)}] + ) + + # Export data + export_path = tmp_path / "exported_data.csv" + await table.export_data(export_path) + + # Verify CSV content + with open(export_path, "r") as f: + reader = csv.DictReader(f) + rows = list(reader) + assert len(rows) == 1 + assert rows[0]["col (1)"] == "test1" + assert rows[0]["col (2)"] == "123" + assert len(rows[0]["vector_col"].split(",")) == VECTOR_LEN + + async def test_export_data_error_cases(self, setup: Setup, tmp_path): + """Test error cases for data export""" + # Create table + table = setup.table + + # Test invalid path + invalid_path = Path("/invalid/path/export.csv") + with pytest.raises(BadInputError): + await table.export_data(invalid_path) + + async def test_import_data(self, setup: Setup, tmp_path): + """Test importing data from CSV""" + # Create table + table = setup.table + + # Create test CSV + csv_path = tmp_path / "import.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["col (1)", "col (2)", "vector_col"]) + writer.writeheader() + writer.writerow( + { + "col (1)": "import1", + "col (2)": "1", + "vector_col": np.random.rand(VECTOR_LEN).tolist(), + } + ) + writer.writerow( + { + "col (1)": "import2", + "col (2)": "2", + "vector_col": np.random.rand(VECTOR_LEN).tolist(), + } + ) + + # Import data + await table.import_data(csv_path) + + # Verify imported data + rows = (await table.list_rows()).items + assert len(rows) == 2 + assert rows[0]["col (1)"] == "import1" + assert rows[0]["col (2)"] == 1 + assert len(rows[0]["vector_col"]) == VECTOR_LEN + + async def test_import_with_column_mapping(self, setup: Setup, tmp_path): + """Test importing data with column mapping""" + # Create table + table = setup.table + + # Create test CSV with different column names + csv_path = tmp_path / "import.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["csv_col1", "csv_col2", "csv_vector"]) + writer.writeheader() + writer.writerow( + { + "csv_col1": "mapped1", + "csv_col2": "1", + "csv_vector": np.random.rand(VECTOR_LEN).tolist(), + } + ) + + # Import with column mapping + column_id_mapping = { + "csv_col1": "col (1)", + "csv_col2": "col (2)", + "csv_vector": "vector_col", + } + await table.import_data(csv_path, column_id_mapping=column_id_mapping) + + # Verify imported data + rows = (await table.list_rows()).items + assert len(rows) == 1 + assert rows[0]["col (1)"] == "mapped1" + assert rows[0]["col (2)"] == 1 + assert len(rows[0]["vector_col"]) == VECTOR_LEN + + async def test_import_error_handling(self, setup: Setup, tmp_path): + """Test error handling during import""" + # Create table + table = setup.table + + # Test missing file + with pytest.raises(ResourceNotFoundError): + await table.import_data(Path("/nonexistent/file.csv")) + + # Test invalid column mapping + csv_path = tmp_path / "import.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["invalid_col"]) + writer.writeheader() + writer.writerow({"invalid_col": "value"}) + await table.import_data(csv_path) + assert len((await table.list_rows()).items) == 0 + + # Test invalid vector data + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["vector_col"]) + writer.writeheader() + writer.writerow({"vector_col": "invalid,vector,data"}) + with pytest.raises(BadInputError): + await table.import_data(csv_path) + + +# Include existing test classes here... +class TestTableOperations: + async def test_table_creation(self, setup: Setup): + """Test creating a new data table with metadata""" + # Verify table exists + async with GENTABLE_ENGINE.transaction() as conn: + exists = await conn.fetchval( + f'SELECT EXISTS (SELECT 1 FROM "{setup.schema_id}"."TableMetadata" WHERE table_id = $1)', + "Table (test)", + ) + assert exists + + async def test_table_creation_concurrent(self, session: Session): + """Test creating a new data table with metadata concurrently""" + table_type = TableType.ACTION + project_id = session.projects[0].id + # Drop schema + await GenerativeTableCore.drop_schemas(project_id) + # Create table + num_tables = 3 + await asyncio.gather( + *[ + GenerativeTableCore.create_table( + project_id=project_id, + table_type=table_type, + table_metadata=TableMetadata( + table_id=f"Table {i}", + title="Test Table", + parent_id=None, + version="1", + versioning_enabled=True, + meta={}, + ), + column_metadata_list=[ + ColumnMetadata( + column_id="col", + table_id=f"Table {i}", + dtype=ColumnDtype.STR, + vlen=0, + gen_config=None, + column_order=1, + meta={}, + ), + ], + ) + for i in range(num_tables) + ] + ) + # Verify table exists + async with GENTABLE_ENGINE.transaction() as conn: + for i in range(num_tables): + exists = await conn.fetchval( + f'SELECT EXISTS (SELECT 1 FROM "{project_id}_{table_type}"."TableMetadata" WHERE table_id = $1)', + f"Table {i}", + ) + assert exists + + async def test_table_duplication(self, setup: Setup): + """Test duplicating a table with data""" + # Create original table + table = setup.table + + # Insert test data + test_data = [ + { + "col (1)": "value1", + "col (2)": 1, + }, + { + "col (1)": "value2", + "col (2)": 2, + }, + { + "col (1)": None, + "col (2)": 3, + }, # Test null handling + ] + await table.add_rows(test_data) + + # Duplicate table + new_table = await GenerativeTableCore.duplicate_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + table_id_src=setup.table_id, + table_id_dst="test_table_copy", + ) + + # Verify new table exists + async with GENTABLE_ENGINE.transaction() as conn: + exists = await conn.fetchval( + f'SELECT EXISTS (SELECT 1 FROM "{setup.schema_id}"."TableMetadata" WHERE table_id = $1)', + "test_table_copy", + ) + assert exists + + # Verify data was copied correctly + original_rows = (await table.list_rows()).items + new_rows = (await new_table.list_rows()).items + + # Verify row count matches + assert len(new_rows) == len(original_rows) + + # Verify specific data values + for row, test_row in zip(new_rows, test_data, strict=True): + assert row["col (1)"] == test_row["col (1)"] + assert row["col (2)"] == test_row["col (2)"] + + async def test_duplicate_recreates_indexes(self, setup: Setup): + """Verify duplicated tables have all indexes recreated""" + original = setup.table + duplicated = await GenerativeTableCore.duplicate_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + table_id_src=original.table_id, + table_id_dst="duplicated_with_indexes", + ) + + # Verify indexes exist + async with GENTABLE_ENGINE.transaction() as conn: + # Check FTS index + fts_index = await conn.fetchval( + """ + SELECT COUNT(*) FROM pg_indexes + WHERE schemaname = $1 AND tablename = $2 AND indexname LIKE '%_fts_idx' + """, + duplicated.schema_id, + duplicated.table_id, + ) + assert fts_index > 0 + + # Check vector indexes match original count + original_vec_count = len(original.vector_column_names) + duplicated_vec_count = await conn.fetchval( + """ + SELECT COUNT(*) FROM pg_indexes + WHERE schemaname = $1 AND tablename = $2 AND indexname LIKE '%_vec_idx' + """, + duplicated.schema_id, + duplicated.table_id, + ) + assert duplicated_vec_count == original_vec_count + + async def test_rename_table(self, setup: Setup): + """Verify renaming table works properly by checking it can be opened and the associated ColumnMetadata and TableMetadata exists""" + table = setup.table + new_name = "renamed_table" + with assert_updated_time(table): + # Rename table + await table.rename_table(new_name) + # Verify table was renamed by opening it. + new_table = await GenerativeTableCore.open_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + table_id=new_name, + ) + # verify associated ColumnMetadata and TableMetadata exist + assert all([col.table_id == new_name for col in new_table.column_metadata]) + assert new_table.table_metadata.table_id == new_name + + async def test_rename_table_has_indexes(self, setup: Setup): + """Verify renaming a table updates all associated indexes""" + table = setup.table + new_name = "renamed_table" + + # Get original index names + async with GENTABLE_ENGINE.transaction() as conn: + original_indexes = await conn.fetch( + """ + SELECT indexname FROM pg_indexes + WHERE schemaname = $1 AND tablename = $2 + """, + table.schema_id, + table.table_id, + ) + + # Rename table + await table.rename_table(new_name) + + # Verify indexes were renamed + async with GENTABLE_ENGINE.transaction() as conn: + new_indexes = await conn.fetch( + """ + SELECT indexname FROM pg_indexes + WHERE schemaname = $1 AND tablename = $2 + """, + table.schema_id, + new_name, + ) + + # Check all indexes were renamed properly + assert len(new_indexes) == len(original_indexes) + for new_idx in new_indexes: + assert new_name in new_idx["indexname"] + + async def test_table_drop(self, setup: Setup): + """Test dropping a table""" + table = setup.table + + # Verify table exists + async with GENTABLE_ENGINE.transaction() as conn: + exists = await conn.fetchval( + f'SELECT EXISTS (SELECT 1 FROM "{setup.schema_id}"."TableMetadata" WHERE table_id = $1)', + setup.table_id, + ) + assert exists + + # Drop table + await table.drop_table() + + # Verify table does not exists + async with GENTABLE_ENGINE.transaction() as conn: + # check TableMetadata + exists = await conn.fetchval( + f'SELECT EXISTS (SELECT 1 FROM "{setup.schema_id}"."TableMetadata" WHERE table_id = $1)', + setup.table_id, + ) + assert not exists + # check columnmetadata + exists = await conn.fetchval( + f'SELECT EXISTS (SELECT 1 FROM "{setup.schema_id}"."ColumnMetadata" WHERE table_id = $1)', + setup.table_id, + ) + assert not exists + + # check table not in schema + ret = await conn.fetch( + """ + SELECT table_name + FROM information_schema.tables + WHERE table_schema = $1 + """, + setup.schema_id, + ) + assert setup.table_id not in [r["table_name"] for r in ret] + + +class TestColumnOperations: + async def test_add_column(self, setup: Setup): + """Test adding a new column""" + table = setup.table + with assert_updated_time(table): + # Add new column + new_column = ColumnMetadata( + column_id="new_col", + table_id=setup.table_id, + dtype=ColumnDtype.FLOAT, + vlen=0, + gen_config=None, + column_order=4, + ) + await table.add_column(new_column) + # Verify new column exists + async with GENTABLE_ENGINE.transaction() as conn: + exists = await conn.fetchval( + f'SELECT EXISTS (SELECT 1 FROM "{setup.schema_id}"."ColumnMetadata" WHERE column_id = $1)', + "new_col", + ) + assert exists + + # Verify column was added to the actual table + columns = await conn.fetch( + "SELECT column_name FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2", + setup.schema_id, + setup.table_id, + ) + assert "new_col" in [c["column_name"] for c in columns] + + async def test_drop_columns(self, setup: Setup): + """Test removing a column""" + table = setup.table + with assert_updated_time(table): + # Remove column + await table.drop_columns(["col (1)"]) + # Verify column removed + async with GENTABLE_ENGINE.transaction() as conn: + exists = await conn.fetchval( + f'SELECT EXISTS (SELECT 1 FROM "{setup.schema_id}"."ColumnMetadata" WHERE column_id = $1)', + "col (1)", + ) + assert not exists + + async def test_column_dtype_storage(self, setup: Setup): + """Verify ColumnMetadata stores original dtype not PostgreSQL type""" + table = setup.table + + # Check existing columns + for col in table.column_metadata: + assert isinstance(col.dtype, ColumnDtype) # Should be enum value + assert not col.dtype.isnumeric() # Shouldn't be PostgreSQL type string + + # Add new column and verify + new_col = ColumnMetadata( + column_id="test_dtype", + table_id=table.table_id, + dtype=ColumnDtype.FLOAT, + vlen=VECTOR_LEN, + ) + table = await table.add_column(new_col) + + # verify dtype + test_col = next(c for c in table.column_metadata if c.column_id == "test_dtype") + assert test_col.dtype == ColumnDtype.FLOAT + + async def test_column_ordering(self, setup: Setup): + """Test column ordering""" + table = setup.table + with assert_updated_time(table): + # Reorder columns + await table.reorder_columns(["ID", "Updated at", "col (2)", "vector_col", "col (1)"]) + # Verify new order + async with GENTABLE_ENGINE.transaction() as conn: + columns = await conn.fetch( + f'SELECT column_id FROM "{setup.schema_id}"."ColumnMetadata" ORDER BY column_order' + ) + columns = [c["column_id"] for c in columns if not c["column_id"].endswith("_")] + assert columns == ["ID", "Updated at", "col (2)", "vector_col", "col (1)"] + + async def test_update_column_gen_config_to_null(self, setup: Setup): + """Test updating column gen_config to NULL""" + table = setup.table + with assert_updated_time(table): + # Add column with proper LLMGenConfig instance + new_column = ColumnMetadata( + column_id="output_col", + table_id=setup.table_id, + dtype=ColumnDtype.STR, + vlen=0, + gen_config=LLMGenConfig( + model=setup.chat_model_id, + temperature=0.7, + system_prompt="Test system", + prompt="Test prompt", + multi_turn=False, + ), + column_order=4, + ) + table = await table.add_column(new_column) + assert {c.column_id: c for c in table.column_metadata}[ + "output_col" + ].gen_config is not None + # Update gen_config to NULL + table = await table.update_gen_config(update_mapping={"output_col": None}) + assert {c.column_id: c for c in table.column_metadata}["output_col"].gen_config is None + + async def test_update_gen_config_basic(self, setup: Setup): + """Test basic gen_config updates""" + table = setup.table + with assert_updated_time(table): + # Add column with NULL config + new_col = ColumnMetadata( + column_id="output_col", + table_id=table.table_id, + dtype=ColumnDtype.STR, + gen_config=None, + ) + table = await table.add_column(new_col) + # Update to valid config + new_config = LLMGenConfig( + model=setup.chat_model_id, + temperature=0.7, + system_prompt="Test", + prompt="Test prompt", + ) + updated = await table.update_gen_config(update_mapping={"output_col": new_config}) + # Verify update + col = next(c for c in updated.column_metadata if c.column_id == "output_col") + assert col.gen_config == new_config + assert col.is_output_column + + async def test_update_gen_config_change_existing(self, setup: Setup): + """Test updating from one gen_config to another""" + table = setup.table + with assert_updated_time(table): + # Initial config + initial_config = LLMGenConfig( + model=setup.chat_model_id, + temperature=0.5, + system_prompt="Initial", + prompt="Initial prompt", + ) + + # Add column with initial config + new_col = ColumnMetadata( + column_id="output_col", + table_id=table.table_id, + dtype=ColumnDtype.STR, + gen_config=initial_config, + ) + table = await table.add_column(new_col) + + # New config with different values + updated_config = LLMGenConfig( + model=setup.chat_model_id, + temperature=0.7, + system_prompt="Updated", + prompt="Updated prompt", + ) + + # Update config + updated = await table.update_gen_config(update_mapping={"output_col": updated_config}) + + # Verify all fields changed + col = next(c for c in updated.column_metadata if c.column_id == "output_col") + assert col.gen_config.model == setup.chat_model_id + assert col.gen_config.temperature == 0.7 + assert col.gen_config.system_prompt == "Updated" + assert col.gen_config.prompt == "Updated prompt" + + # Verify persistence after reload + table = await GenerativeTableCore.open_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + table_id=setup.table_id, + ) + reloaded_col = next(c for c in table.column_metadata if c.column_id == "output_col") + assert reloaded_col.gen_config == updated_config + + async def test_update_gen_config_invalid_column(self, setup: Setup): + """Test updating non-existent column""" + table = setup.table + config = LLMGenConfig( + model=setup.chat_model_id, + temperature=0.7, + system_prompt="Test", + prompt="Test prompt", + ) + with pytest.raises(ResourceNotFoundError): + await table.update_gen_config(update_mapping={"nonexistent_col": config}) + + async def test_update_gen_config_partial_changes(self, setup: Setup): + """Test updating only some config fields""" + table = setup.table + with assert_updated_time(table): + initial_config = LLMGenConfig( + model=setup.chat_model_id, + temperature=0.5, + system_prompt="Initial", + prompt="Initial prompt", + ) + + new_col = ColumnMetadata( + column_id="output_col", + table_id=table.table_id, + dtype=ColumnDtype.STR, + gen_config=initial_config, + ) + table = await table.add_column(new_col) + + # Only update temperature + updated_config = LLMGenConfig( + model=setup.chat_model_id, # Same model + temperature=0.8, # Updated + system_prompt="Initial", # Same + prompt="Initial prompt", # Same + ) + + updated = await table.update_gen_config(update_mapping={"output_col": updated_config}) + col = next(c for c in updated.column_metadata if c.column_id == "output_col") + + assert col.gen_config.model == setup.chat_model_id + assert col.gen_config.temperature == 0.8 + assert col.gen_config.system_prompt == "Initial" + assert col.gen_config.prompt == "Initial prompt" + + +class TestSearchOperations: + @pytest.fixture + def test_vectors(self): + return { + "valid_vector": np.random.rand(VECTOR_LEN), + "empty_vector": np.array([]), + "wrong_dim_vector": np.random.rand(VECTOR_LEN * 2), + "list_vector": np.random.rand(VECTOR_LEN).tolist(), + } + + async def test_vector_search_basic(self, setup: Setup, test_vectors): + """Test basic vector search functionality""" + table = setup.table + # Insert test vectors + test_data = [ + {"col (1)": "foo", "col (2)": 1, "vector_col": np.random.rand(VECTOR_LEN)}, + {"col (1)": "bar", "col (2)": 2, "vector_col": np.random.rand(VECTOR_LEN)}, + {"col (1)": "baz", "col (2)": 3, "vector_col": np.random.rand(VECTOR_LEN)}, + ] + await table.add_rows(test_data) + + # Test search with numpy array + results = await table.vector_search( + "dummy_query", + embedding_fn=lambda _, __: test_vectors["valid_vector"], + vector_column_names=["vector_col"], + ) + assert len(results) == 3 + assert "score" in results[0] + # Scores are distances (lower is better) + assert results[0]["score"] <= results[1]["score"] <= results[2]["score"] + + # Test search with list input + list_results = await table.vector_search( + "dummy_query", + embedding_fn=lambda _, __: test_vectors["list_vector"], + vector_column_names=["vector_col"], + ) + assert len(list_results) == 3 + + @pytest.mark.asyncio + async def test_multi_column_vector_search(self, setup: Setup, test_vectors): + """Test basic vector search functionality on multiple columns""" + table = setup.table + # Add vector column + table = await table.add_column( + ColumnMetadata( + column_id="vector_col2", + table_id="Table (test)", + dtype=ColumnDtype.FLOAT, + vlen=VECTOR_LEN, + gen_config=None, + column_order=4, + meta={}, + ) + ) + # Insert test vectors + test_data = [ + { + "col (1)": "foo", + "col (2)": 1, + "vector_col": test_vectors["valid_vector"], + "vector_col2": test_vectors["valid_vector"], + }, + { + "col (1)": "bar", + "col (2)": 2, + "vector_col": test_vectors["valid_vector"], + "vector_col2": test_vectors["valid_vector"], + }, + { + "col (1)": "baz", + "col (2)": 3, + "vector_col": np.random.rand(VECTOR_LEN), + "vector_col2": np.random.rand(VECTOR_LEN), + }, + ] + await table.add_rows(test_data) + + # Test search with numpy array + results = await table.vector_search( + "dummy_query", + embedding_fn=lambda _, __: test_vectors["valid_vector"], + vector_column_names=["vector_col"], + ) + # Scores are distances (lower is better) + assert results[0]["score"] <= results[1]["score"] <= results[2]["score"] + # Ensure not matching row is last + assert results[2]["col (1)"] == "baz" + + async def test_vector_search_errors(self, setup: Setup, test_vectors): + """Test vector search error handling""" + + table = setup.table + + # Test invalid input types + with pytest.raises(BadInputError): + await table.vector_search( + "dummy_query", + embedding_fn=lambda _, __: "invalid_type", + vector_column_names=["vector_col"], + ) + + # Test empty vector + with pytest.raises(BadInputError): + await table.vector_search( + "dummy_query", + embedding_fn=lambda _, __: test_vectors["empty_vector"], + vector_column_names=["vector_col"], + ) + + # Test dimension mismatch + with pytest.raises(BadInputError): + await table.vector_search( + "dummy_query", + embedding_fn=lambda _, __: test_vectors["wrong_dim_vector"], + vector_column_names=["vector_col"], + ) + + async def test_fts_search_basic(self, setup: Setup): + """Test basic full text search functionality""" + new_column = ColumnMetadata( + column_id="search_col", + table_id="Table (test)", + dtype=ColumnDtype.STR, + vlen=0, + gen_config=None, + column_order=4, + meta={}, + ) + + table = setup.table + table = await table.add_column(new_column) + + # Insert test data with searchable content + test_data = [ + {"col (1)": "foo", "col (2)": 1, "search_col": "quick brown fox"}, + {"col (1)": "bar", "col (2)": 2, "search_col": "lazy dog"}, + {"col (1)": "baz", "col (2)": 3, "search_col": "quick dog"}, + ] + await table.add_rows(test_data) + + # Test basic search + results = await table.fts_search("quick") + assert len(results) == 2 + assert {r["search_col"] for r in results} == {"quick brown fox", "quick dog"} + + async def test_fts_uses_index(self, setup: Setup): + """Verify FTS queries actually use the index""" + table = setup.table + rows_to_add = [{"col (1)": "test search term"}] + await table.add_rows(rows_to_add) + + # Test basic search + results = await table.fts_search("quick", explain=True) + + # Verify index scan is used + if not any(["Index Scan" in res["QUERY PLAN"] for res in results]): + # add more rows to force index scan + await table.add_rows(rows_to_add * 1000) + results = await table.fts_search("quick", force_use_index=True, explain=True) + assert any(["Index Scan" in res["QUERY PLAN"] for res in results]) + else: + assert True + + async def test_fts_search_pagination(self, setup: Setup): + """Test search with pagination""" + # Create table with text column + new_column = ColumnMetadata( + column_id="search_col", + table_id="Table (test)", + dtype=ColumnDtype.STR, + vlen=0, + gen_config=None, + column_order=4, + meta={}, + ) + + table = setup.table + table = await table.add_column(new_column) + + # Insert test data with searchable content + test_data = [ + {"col (1)": "foo", "col (2)": 1, "search_col": "quick brown fox"}, + {"col (1)": "bar", "col (2)": 2, "search_col": "lazy dog"}, + {"col (1)": "baz", "col (2)": 3, "search_col": "quick dog"}, + ] + await table.add_rows(test_data) + + # Test limit/offset + page1 = await table.fts_search("dog", limit=1) + assert len(page1) == 1 + + page2 = await table.fts_search("dog", limit=1, offset=1) + assert len(page2) == 1 + assert page1[0]["ID"] != page2[0]["ID"] + + async def test_fts_search_state_inclusion(self, setup: Setup): + """Test that state columns are included in search results""" + # Create table with text and state columns + new_column = ColumnMetadata( + column_id="search_col", + table_id="Table (test)", + dtype=ColumnDtype.STR, + vlen=0, + gen_config=None, + column_order=4, + ) + + table = setup.table + table = await table.add_column(new_column) + + # Insert test data + await table.add_rows( + [ + { + "col (1)": "foo", + "col (2)": 1, + "search_col": "test value", + } + ] + ) + + # Verify that state columns appear in search results + results = await table.fts_search("value", remove_state_cols=False) + assert "search_col_" in results[0].keys() + + @pytest.mark.asyncio + async def test_multi_column_fts_search(self, setup: Setup): + """Test searches across multiple columns""" + table = setup.table + + # Add multiple text columns + table = await table.add_column( + ColumnMetadata( + column_id="text1", + table_id="Table (test)", + dtype=ColumnDtype.STR, + vlen=0, + gen_config=None, + column_order=4, + ) + ) + table = await table.add_column( + ColumnMetadata( + column_id="text2", + table_id="Table (test)", + dtype=ColumnDtype.STR, + vlen=0, + gen_config=None, + column_order=5, + ) + ) + + # Insert test data + await table.add_rows( + [ + {"text1": "first column text", "text2": "unrelated"}, + {"text1": "unrelated", "text2": "second column text"}, + ] + ) + + # Test FTS across multiple columns + results = await table.fts_search("text") + assert len(results) == 2 + assert {r["text1"] for r in results} == {"first column text", "unrelated"} + assert {r["text2"] for r in results} == {"unrelated", "second column text"} + + @pytest.mark.parametrize( + "text", + [ + "中文测试", # Chinese + "日本語テスト", # Japanese + "한국어 테스트", # Korean + ], + ) + @pytest.mark.asyncio + async def test_cjk_search(self, setup: Setup, text: str): + """Test CJK language support in FTS""" + table = setup.table + table = await table.add_column( + ColumnMetadata( + column_id="cjk_text", + table_id="Table (test)", + dtype=ColumnDtype.STR, + vlen=0, + gen_config=None, + column_order=4, + ) + ) + + # Insert CJK text + await table.add_rows( + [ + {"cjk_text": text}, + ] + ) + + # Search for the text + results = await table.fts_search(text) + assert len(results) == 1 + assert results[0]["cjk_text"] == text + + @pytest.mark.asyncio + async def test_multi_term_fts_search(self, setup: Setup): + """Test FTS with multiple search terms""" + table = setup.table + table = await table.add_column( + ColumnMetadata( + column_id="multi_text", + table_id="Table (test)", + dtype=ColumnDtype.STR, + vlen=0, + gen_config=None, + column_order=4, + ) + ) + + # Insert test data + await table.add_rows( + [ + {"multi_text": "quick brown fox"}, + {"multi_text": "lazy dog"}, + {"multi_text": "quick dog"}, + ] + ) + + # Test AND semantics + results = await table.fts_search("dog Quick") + assert len(results) == 1 + assert results[0]["multi_text"] == "quick dog" + + # Test OR semantics + results = await table.fts_search("quick OR lazy") + assert len(results) == 3 + + async def test_hybrid_search_basic(self, setup: Setup): + """Test basic hybrid search functionality""" + table = setup.table + + # Add text column for FTS + text_col = ColumnMetadata( + column_id="text_col", + table_id=table.table_id, + dtype=ColumnDtype.STR, + vlen=0, + gen_config=None, + column_order=4, + ) + table = await table.add_column(text_col) + + # Insert test data with both text and vector content + same_vector = np.random.rand(VECTOR_LEN) + test_data = [ + {"col (1)": "foo", "text_col": "quick brown fox", "vector_col": same_vector}, + {"col (1)": "bar", "text_col": "lazy dog", "vector_col": np.random.rand(VECTOR_LEN)}, + {"col (1)": "bay", "text_col": "slow hound", "vector_col": np.random.rand(VECTOR_LEN)}, + {"col (1)": "baz", "text_col": "quick dog", "vector_col": same_vector}, + {"col (1)": "bax", "text_col": "tardy fox", "vector_col": np.random.rand(VECTOR_LEN)}, + ] + await table.add_rows(test_data) + + # Mock embedding function + async def mock_embed_fn(model: str, text: str): + return same_vector + + # Test hybrid search + results = await table.hybrid_search( + fts_query="quick", + vs_query="quick", + embedding_fn=mock_embed_fn, + vector_column_names=["vector_col"], + use_bm25_ranking=False, + ) + + assert len(results) > 0 + assert all("rrf_score" in r for r in results) # Check for rrf_score key + results = sorted(results, key=lambda x: x["rrf_score"], reverse=True) + # Verify the top results contain 'quick' from FTS and the matching vector + # The exact ranking depends on RRF scoring, but the relevant items should be highly ranked. + assert all(["quick" in r["text_col"] for r in results[:2]]) + + async def test_hybrid_search_with_bm25(self, setup: Setup): + """Test hybrid search with bm25 functionality""" + table = setup.table + + # Add text column for FTS + text_col = ColumnMetadata( + column_id="text_col", + table_id=table.table_id, + dtype=ColumnDtype.STR, + vlen=0, + gen_config=None, + column_order=4, + ) + table = await table.add_column(text_col) + + # Insert test data with both text and vector content + same_vector = np.random.rand(VECTOR_LEN) + test_data = [ + {"col (1)": "foo", "text_col": "quick brown fox", "vector_col": same_vector}, + {"col (1)": "bar", "text_col": "lazy dog", "vector_col": np.random.rand(VECTOR_LEN)}, + {"col (1)": "bay", "text_col": "slow hound", "vector_col": np.random.rand(VECTOR_LEN)}, + {"col (1)": "baz", "text_col": "quick dog", "vector_col": same_vector}, + {"col (1)": "bax", "text_col": "tardy fox", "vector_col": np.random.rand(VECTOR_LEN)}, + ] + await table.add_rows(test_data) + + # Mock embedding function + async def mock_embed_fn(model: str, text: str): + return same_vector + + # Test hybrid search + results = await table.hybrid_search( + fts_query="quick", + vs_query="quick", + embedding_fn=mock_embed_fn, + vector_column_names=["vector_col"], + ) + + assert len(results) > 0 + assert all("rrf_score" in r for r in results) # Check for rrf_score key + results = sorted(results, key=lambda x: x["rrf_score"], reverse=True) + # VS scores should be the same, the difference should comes from FTS + # longer document reduced BM25 scores + assert results[0]["text_col"] == "quick dog" + assert results[1]["text_col"] == "quick brown fox" + + +class TestRowOperations: + async def test_add_rows(self, setup: Setup): + table = setup.table + with assert_updated_time(table): + # Insert data + row_data = [{"col (1)": "test value", "col (2)": 123, "version": "1"}] + await table.add_rows(row_data) + # Verify data inserted + async with GENTABLE_ENGINE.transaction() as conn: + result = await conn.fetchrow( + f'SELECT * FROM "{setup.schema_id}"."{setup.table_id}"' + ) + assert result["col (1)"] == "test value" + assert result["col (2)"] == 123 + + async def test_add_rows_batch(self, setup: Setup): + table = setup.table + with assert_updated_time(table): + # Insert data + row_data = [ + {"col (1)": "test value", "col (2)": 1, "version": "1"}, + {"col (1)": "test value 2", "col (2)": 2, "version": "1"}, + {"col (1)": "test value 3", "col (2)": 3, "version": "1"}, + ] + await table.add_rows(row_data) + # Verify data inserted + async with GENTABLE_ENGINE.transaction() as conn: + result = await conn.fetch(f'SELECT * FROM "{setup.schema_id}"."{setup.table_id}"') + assert result[0]["col (1)"] == "test value" + assert result[0]["col (2)"] == 1 + assert result[-1]["col (1)"] == "test value 3" + assert result[-1]["col (2)"] == 3 + + async def test_list_rows(self, setup: Setup): + table = setup.table + # Insert data + row_data = [ + {"col (1)": "llama", "col (2)": 1}, + {"col (1)": "lama", "col (2)": 2}, + {"col (1)": "DROP TABLE", "col (2)": 3}, + ] + await table.add_rows(row_data) + # List data + rows = (await table.list_rows()).items + rows_reversed = (await table.list_rows(order_ascending=False)).items + assert all(rr == r for rr, r in zip(rows_reversed[::-1], rows, strict=True)) + # Verify data inserted + assert rows[0]["col (1)"] == "llama" + assert rows[0]["col (2)"] == 1 + assert rows[-1]["col (1)"] == "DROP TABLE" + assert rows[-1]["col (2)"] == 3 + + async def test_list_rows_search_query(self, setup: Setup): + # Create table + table = setup.table + # Insert data + row_data = [ + {"col (1)": "llama", "col (2)": 1}, + {"col (1)": "lama", "col (2)": 2}, + {"col (1)": "1", "col (2)": 3}, + ] + await table.add_rows(row_data) + # Search + rows = (await table.list_rows(search_query="lama")).items + assert len(rows) == 2 + rows = (await table.list_rows(search_query="^lama")).items + assert len(rows) == 1 + assert rows[0]["col (1)"] == "lama" + rows = ( + await table.list_rows(search_query="1", search_columns=["col (1)", "col (2)"]) + ).items + assert len(rows) == 2 + rows = (await table.list_rows(search_query="1", search_columns=["col (2)"])).items + assert len(rows) == 1 + assert rows[0]["col (1)"] == "llama" + assert rows[0]["col (2)"] == 1 + + async def test_count_rows(self, setup: Setup): + """Verify count_rows() returns correct counts""" + table = setup.table + # Empty table + assert await table.count_rows() == 0 + # After insert + await table.add_rows([{"col (1)": "test"}]) + assert await table.count_rows() == 1 + # After delete + rows = (await table.list_rows()).items + await table.delete_rows(row_ids=[rows[0]["ID"]]) + assert await table.count_rows() == 0 + + async def test_update_rows(self, setup: Setup, tmp_path): + """Test updating rows including NULL values""" + table = setup.table + with assert_updated_time(table): + # Insert initial data + row_data = [{"col (1)": "initial value", "col (2)": 123}] + row_added = (await (await table.add_rows(row_data)).list_rows()).items + + # Update data + update_data = { + "col (1)": "updated value", + "col (2)": 456, + } + await table.update_rows({row_added[0]["ID"]: update_data}) + + # Verify data updated + retrieved_row = await table.get_row(row_added[0]["ID"]) + assert retrieved_row["col (1)"] == update_data["col (1)"] + assert retrieved_row["col (2)"] == update_data["col (2)"] + + # Test NULL value updates + # Case 1: Set existing value to NULL + await table.update_rows({row_added[0]["ID"]: {"col (1)": None}}) + updated = await table.get_row(row_added[0]["ID"]) + assert updated["col (1)"] is None + assert updated["col (2)"] == 456 + + # Case 2: Verify NULLs persist through export/import + export_path = tmp_path / "null_test.parquet" + await table.export_table(export_path) + new_table = await GenerativeTableCore.import_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + source=export_path, + table_id_dst="null_test_table", + ) + imported = await new_table.get_row(row_added[0]["ID"]) + assert imported["col (1)"] is None + + async def test_delete_rows_with_id(self, setup: Setup): + table = setup.table + with assert_updated_time(table): + # Insert data + row_data = [ + {"col (1)": "test value", "col (2)": 1}, + {"col (1)": "test value", "col (2)": 2}, + {"col (1)": "test value", "col (2)": 3}, + ] + new_rows = (await (await table.add_rows(row_data)).list_rows()).items + + # Delete data + await table.delete_rows(row_ids=[new_rows[0]["ID"], new_rows[2]["ID"]]) + + # Verify data deleted + with pytest.raises(ResourceNotFoundError, match="Row .+ not found in table"): + await table.get_row(new_rows[0]["ID"]) + with pytest.raises(ResourceNotFoundError, match="Row .+ not found in table"): + await table.get_row(new_rows[2]["ID"]) + + async def test_delete_rows_with_where(self, setup: Setup): + table = setup.table + with assert_updated_time(table): + # Insert data + row_data = [ + {"col (1)": "test value", "col (2)": 1}, + {"col (1)": "test value", "col (2)": 2}, + {"col (1)": "test value", "col (2)": 3}, + ] + new_rows = (await (await table.add_rows(row_data)).list_rows()).items + # Delete data + await table.delete_rows(where='"col (2)" > 1') + # Verify data deleted + with pytest.raises(ResourceNotFoundError, match="Row .+ not found in table"): + await table.get_row(new_rows[1]["ID"]) + with pytest.raises(ResourceNotFoundError, match="Row .+ not found in table"): + await table.get_row(new_rows[2]["ID"]) + + async def test_delete_rows_with_id_where(self, setup: Setup): + table = setup.table + with assert_updated_time(table): + # Insert data + row_data = [ + {"col (1)": "test value", "col (2)": 1}, + {"col (1)": "test value", "col (2)": 2}, + {"col (1)": "test value", "col (2)": 3}, + ] + new_rows = (await (await table.add_rows(row_data)).list_rows()).items + # Delete data + await table.delete_rows( + row_ids=[new_rows[1]["ID"], new_rows[2]["ID"]], where='"col (2)" > 2' + ) + # Verify data deleted + response = await table.get_row(new_rows[0]["ID"]) + assert isinstance(response, dict) + response = await table.get_row(new_rows[1]["ID"]) + assert isinstance(response, dict) + with pytest.raises(ResourceNotFoundError, match="Row .+ not found in table"): + await table.get_row(new_rows[2]["ID"]) + + +# --- Fixtures and Tests for Stateful Operations --- + + +async def setup_table_newly_created(table: GenerativeTableCore): + """Provides an async setup function for the newly created table.""" + # No op needed here, just return the table + return table + + +async def setup_table_with_added_column(table: GenerativeTableCore): + """Provides an async setup function for a table with an added column.""" + new_col = ColumnMetadata( + column_id="added_col_state_test", table_id=table.table_id, dtype=ColumnDtype.BOOL + ) + return await table.add_column(new_col) + + +async def setup_table_with_dropped_column(table: GenerativeTableCore): + """Provides an async setup function for a table with 'col (1)' dropped.""" + return await table.drop_columns(["col (1)"]) + + +async def setup_table_renamed(table: GenerativeTableCore): + """Provides an async setup function for a renamed table.""" + new_name = "renamed_state_test" + return await table.rename_table(new_name) + + +async def setup_table_duplicated(table: GenerativeTableCore, setup: Setup): + """Provides an async setup function for a duplicated table.""" + # Add data just before duplicating + await table.add_rows( + [{"col (1)": "data_for_dup", "col (2)": 111, "vector_col": np.random.rand(VECTOR_LEN)}] + ) + return await GenerativeTableCore.duplicate_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + table_id_src=table.table_id, + table_id_dst="duplicated_state_test", + ) + + +async def setup_table_imported(table: GenerativeTableCore, setup: Setup, tmp_path): + """Provides an async setup function for an imported table.""" + export_path = tmp_path / "state_test_export.parquet" + # Add data just before exporting + await table.add_rows( + [{"col (1)": "data_for_import", "col (2)": 222, "vector_col": np.random.rand(VECTOR_LEN)}] + ) + await table.export_table(export_path) + return await GenerativeTableCore.import_table( + project_id=setup.projects[0].id, + table_type=setup.table_type, + source=export_path, + table_id_dst="imported_state_test", + ) + + +class TestStatefulOperations: + # List of setup fixture names to parametrize over + SETUP_TABLE_STATE_FIXTURES = [ + "setup_table_newly_created", + "setup_table_with_added_column", + "setup_table_with_dropped_column", + "setup_table_renamed", + "setup_table_duplicated", + "setup_table_imported", # Use the new fixture name + ] + + def parametrized_setup(self, setup_fixture_name, setup: Setup, tmp_path): + if setup_fixture_name == "setup_table_newly_created": + return setup_table_newly_created(setup.table) + elif setup_fixture_name == "setup_table_with_added_column": + return setup_table_with_added_column(setup.table) + elif setup_fixture_name == "setup_table_with_dropped_column": + return setup_table_with_dropped_column(setup.table) + elif setup_fixture_name == "setup_table_renamed": + return setup_table_renamed(setup.table) + elif setup_fixture_name == "setup_table_duplicated": + return setup_table_duplicated(setup.table, setup) + elif setup_fixture_name == "setup_table_imported": + return setup_table_imported(setup.table, setup, tmp_path) + + @pytest.mark.parametrize("setup_fixture_name", SETUP_TABLE_STATE_FIXTURES) + @pytest.mark.asyncio + async def test_core_ops_on_various_tables(self, setup_fixture_name, setup: Setup, tmp_path): + """ + Tests core operations (add row, add col, drop col, search) + on tables in various states (new, col added, col dropped, renamed, duplicated, imported). + """ + # --- Setup --- + # Setup the table as per the parameter + table = await self.parametrized_setup(setup_fixture_name, setup, tmp_path) + + assert isinstance(table, GenerativeTableCore) # Ensure we have a table object + + # Store fixture name for easier debugging in asserts/fails + current_state_name = setup_fixture_name.replace("setup_", "") + + initial_row_count = await table.count_rows() + initial_col_count = len(table.column_metadata) + + # --- Test Add Row --- + row_data = {"col (2)": 999} # Use a column likely to exist across states + if "col (1)" in table.data_table_model.get_column_ids(exclude_state=True): + row_data["col (1)"] = f"state_test_{current_state_name}" + if "vector_col" in table.data_table_model.get_column_ids(exclude_state=True): + row_data["vector_col"] = np.random.rand(VECTOR_LEN) # Use correct dimension + if "added_col_state_test" in table.data_table_model.get_column_ids(exclude_state=True): + row_data["added_col_state_test"] = True + await table.add_rows([row_data]) + assert await table.count_rows() == initial_row_count + 1 + + # --- Test Add Column --- + temp_col_id = "temp_col_in_test" + temp_col_meta = ColumnMetadata( + column_id=temp_col_id, table_id=table.table_id, dtype=ColumnDtype.FLOAT + ) + table = await table.add_column( + temp_col_meta + ) # Reassign table as add_column returns updated instance + assert any(col.column_id == temp_col_id for col in table.column_metadata) + assert ( + len(table.column_metadata) == initial_col_count + 2 + ) # +1 for data col, +1 for state col + + # --- Test Drop Column --- + table = await table.drop_columns( + [temp_col_id] + ) # Reassign table as drop_columns returns updated instance + assert not any(col.column_id == temp_col_id for col in table.column_metadata) + assert len(table.column_metadata) == initial_col_count # Should be back to original count + + # --- Test Search (Index Check) --- + # FTS Search (ensure index exists and query runs) + # Use a column that exists in most states or adapt search term + search_term = ( + "state_test" + if "col (1)" in table.data_table_model.get_column_ids(exclude_state=True) + else "a" + ) # Generic search if col (1) dropped + try: + await table.fts_search(search_term, limit=1, explain=False) + except Exception as e: + pytest.fail(f"FTS search failed on {current_state_name} state: {e}") + + # Vector Search (ensure index exists and query runs) + if "vector_col" in table.vector_column_names: + + async def mock_embed_fn(model: str, text: str): + # Return a vector of the correct dimension for the column + vlen = next(c.vlen for c in table.column_metadata if c.column_id == "vector_col") + return np.random.rand(vlen) + + try: + await table.vector_search( + query="dummy", + embedding_fn=mock_embed_fn, + vector_column_names=["vector_col"], + limit=1, + ) + except Exception as e: + pytest.fail(f"Vector search failed on {current_state_name} state: {e}") + elif current_state_name not in [ + "table_with_dropped_column" + ]: # Expect vector col unless explicitly dropped + pytest.fail( + f"Vector column 'vector_col' missing unexpectedly in {current_state_name} state" + ) diff --git a/services/api/tests/gen_table_core/test_manipulation.py b/services/api/tests/gen_table_core/test_manipulation.py new file mode 100644 index 0000000..e69de29 diff --git a/services/api/tests/routers/test_conversation.py b/services/api/tests/routers/test_conversation.py new file mode 100644 index 0000000..fe1c11b --- /dev/null +++ b/services/api/tests/routers/test_conversation.py @@ -0,0 +1,816 @@ +from dataclasses import dataclass +from os.path import dirname, join, realpath +from typing import Generator + +import pytest + +from jamaibase import JamAI +from jamaibase.types import ( + AgentMetaResponse, + CellCompletionResponse, + ChatTableSchemaCreate, + ColumnSchemaCreate, + ConversationCreateRequest, + ConversationMetaResponse, + LLMGenConfig, + MessageAddRequest, + MessagesRegenRequest, + MessageUpdateRequest, + OkResponse, + OrganizationCreate, + OrgMemberRead, + Page, + ProjectMemberRead, + Role, + TableType, +) +from owl.utils.exceptions import ResourceNotFoundError +from owl.utils.test import ( + ELLM_DESCRIBE_CONFIG, + ELLM_DESCRIBE_DEPLOYMENT, + GPT_4O_MINI_CONFIG, + GPT_4O_MINI_DEPLOYMENT, + create_deployment, + create_model_config, + create_organization, + create_project, + create_user, + get_file_map, + upload_file, +) + +TEST_FILE_DIR = join(dirname(dirname(realpath(__file__))), "files") +FILES = get_file_map(TEST_FILE_DIR) + + +@dataclass(slots=True) +class ConversationContext: + """Dataclass to hold context for conversation tests.""" + + superuser_id: str + user_id: str + org_id: str + project_id: str + template_table_id: str + real_template_table_id: str + multimodal_template_table_id: str + + +@pytest.fixture(scope="module") +def setup() -> Generator[ConversationContext, None, None]: + """ + Fixture to set up the necessary environment for conversation tests. + """ + with ( + create_user() as superuser, + create_user({"email": "testuser@example.com", "name": "Test User"}) as user, + create_organization( + body=OrganizationCreate(name="Convo Org"), + user_id=superuser.id, + ) as superorg, + create_project( + dict(name="Convo Project"), user_id=superuser.id, organization_id=superorg.id + ) as project, + ): + assert superuser.id == "0" + assert superorg.id == "0" + client = JamAI(user_id=superuser.id) + membership = client.organizations.join_organization( + user_id=user.id, organization_id=superorg.id, role=Role.MEMBER + ) + assert isinstance(membership, OrgMemberRead) + membership = client.projects.join_project( + user_id=user.id, project_id=project.id, role=Role.MEMBER + ) + assert isinstance(membership, ProjectMemberRead) + client = JamAI(user_id=superuser.id, project_id=project.id) + + with ( + create_model_config(GPT_4O_MINI_CONFIG) as llm_config, + create_model_config(ELLM_DESCRIBE_CONFIG) as llm_describe_config, + create_deployment(GPT_4O_MINI_DEPLOYMENT), + create_deployment(ELLM_DESCRIBE_DEPLOYMENT), + ): + # TODO: Don't call these templates since we have actual templates + # Standard Template + template_id = "chat-template-v2" + template_cols = [ + ColumnSchemaCreate(id="User", dtype="str"), + ColumnSchemaCreate( + id="AI", + dtype="str", + gen_config=LLMGenConfig(model=llm_describe_config.id, multi_turn=True), + ), + ] + client.table.create_chat_table( + ChatTableSchemaCreate(id=template_id, cols=template_cols) + ) + + # Real Template - for regeneration tests + real_template_id = "real-chat-template-v2" + real_template_cols = [ + ColumnSchemaCreate(id="User", dtype="str"), + ColumnSchemaCreate( + id="AI", + dtype="str", + gen_config=LLMGenConfig(model=llm_config.id, multi_turn=True, temperature=1.0), + ), + ] + client.table.create_chat_table( + ChatTableSchemaCreate(id=real_template_id, cols=real_template_cols) + ) + + # Multimodal Template + multimodal_template_id = "multimodal-chat-template-v2" + multimodal_cols = [ + ColumnSchemaCreate(id="User", dtype="str"), + ColumnSchemaCreate(id="Photo", dtype="image"), + ColumnSchemaCreate(id="Audio", dtype="audio"), + ColumnSchemaCreate(id="Doc", dtype="document"), + ColumnSchemaCreate( + id="AI", + dtype="str", + gen_config=LLMGenConfig( + model=llm_describe_config.id, + multi_turn=True, + prompt="Photo: ${Photo} \nAudio: ${Audio} \nDocument: ${Doc} \n\n${User}", + ), + ), + ] + client.table.create_chat_table( + ChatTableSchemaCreate(id=multimodal_template_id, cols=multimodal_cols) + ) + try: + yield ConversationContext( + superuser_id=superuser.id, + user_id=user.id, + org_id=superorg.id, + project_id=project.id, + template_table_id=template_id, + real_template_table_id=real_template_id, + multimodal_template_table_id=multimodal_template_id, + ) + finally: + client.table.delete_table(TableType.CHAT, template_id, missing_ok=True) + client.table.delete_table(TableType.CHAT, real_template_id, missing_ok=True) + client.table.delete_table(TableType.CHAT, multimodal_template_id, missing_ok=True) + + +def _create_conversation_and_get_id( + client: JamAI, + setup_context: ConversationContext, + initial_data: dict | None = None, + check_regen: bool = False, + multimodal: bool = False, +) -> str: + """Helper to create a conversation and extract its ID from the streamed metadata.""" + # TODO: This function should just take in table ID instead of the booleans + if check_regen: + template_id = setup_context.real_template_table_id + elif multimodal: + template_id = setup_context.multimodal_template_table_id + else: + template_id = setup_context.template_table_id + if initial_data is None: + initial_data = {"User": "First message"} + + create_req = ConversationCreateRequest(agent_id=template_id, data=initial_data) + response_stream = client.conversations.create_conversation(create_req) + responses = [r for r in response_stream] + + metadata = responses[0] + assert isinstance(metadata, ConversationMetaResponse), "Stream did not yield metadata first" + conv_id = metadata.conversation_id + assert conv_id is not None, "Metadata event did not contain conversation_id" + return conv_id + + +def test_create_conversation(setup: ConversationContext): + """ + Tests creating a new conversation and that a title is automatically generated. + - Creates a conversation with a specific user prompt. + - Verifies the first message is saved correctly. + - Verifies an AI-generated title is set on the conversation metadata. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + # 1. Create the conversation + user_prompt = "What is the theory of relativity?" + conv_id = _create_conversation_and_get_id(client, setup, initial_data={"User": user_prompt}) + assert isinstance(conv_id, str) + + # 2. Verify the conversation was created with the correct message + conv_details = client.conversations.list_messages(conv_id) + assert conv_details.total == 1 + assert conv_details.items[0]["User"] == user_prompt + + # 3. Verify that the title was auto-generated and saved + meta_after_creation = client.conversations.get_conversation(conv_id) + assert isinstance(meta_after_creation.title, str) + assert len(meta_after_creation.title) > 0, ( + "Title should have been auto-generated but is empty." + ) + assert "There is a text with" in meta_after_creation.title + + +def test_create_conversation_with_provided_title(setup: ConversationContext): + """ + Tests that providing a title during creation skips automatic generation. + - Creates a conversation and passes a custom `title` parameter. + - Verifies the conversation is created successfully. + - Asserts that the final conversation title matches the one provided. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + # 1. Create the conversation with provided title + provided_title = "Custom Test Title" + create_req = ConversationCreateRequest( + agent_id=setup.template_table_id, + title=provided_title, + data={"User": "This message should not be used for a title"}, + ) + response_stream = client.conversations.create_conversation(create_req) + responses = [r for r in response_stream] + metadata = responses[0] + conv_id = metadata.conversation_id + assert conv_id is not None + + # 2. Verify the conversation was created + conv_details = client.conversations.list_messages(conv_id) + assert conv_details.total == 1 + + # 3. Verify that the provided title was used + meta_after_creation = client.conversations.get_conversation(conv_id) + assert meta_after_creation.title == provided_title + + +def test_list_conversations(setup: ConversationContext): + """ + Tests listing conversations, ensuring only child chats are returned. + - Creates a new conversation. + - Calls the list endpoint. + - Verifies the new conversation is in the list. + - Verifies that parent agents/templates are not in the list. + - Verifies conversation title search. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + conv_id = _create_conversation_and_get_id(client, setup) + convos_page = client.conversations.list_conversations() + assert convos_page.total >= 1 + assert any(c.conversation_id == conv_id for c in convos_page.items) + # Verify that parent templates are NOT in the list + assert not any(c.conversation_id == setup.template_table_id for c in convos_page.items) + assert not any( + c.conversation_id == setup.multimodal_template_table_id for c in convos_page.items + ) + conv_id = _create_conversation_and_get_id(client, setup) + client.conversations.rename_conversation_title(conv_id, "text with [3600] tokens") + convos_page = client.conversations.list_conversations() + assert convos_page.total >= 2 + # Verify literal search + convos_page_search = client.conversations.list_conversations(search_query="[3600] tokens") + assert convos_page_search.total == 1 + # Verify regex search + convos_page_search = client.conversations.list_conversations(search_query="[0-9]{4}") + assert convos_page_search.total == 1 + convos_page_search = client.conversations.list_conversations(search_query="text with") + assert convos_page_search.total >= 2 + + +def test_list_agents(setup: ConversationContext): + """ + Tests listing agents, ensuring only parent templates are returned. + - Creates a new child conversation. + - Calls the list_agents endpoint. + - Verifies parent templates are in the list. + - Verifies the new child conversation is NOT in the list. + - Verifies agent id search. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + conv_id = _create_conversation_and_get_id(client, setup) + agents_page = client.conversations.list_agents() + assert agents_page.total == 3 + # Verify that parent templates ARE in the list + assert any(c.conversation_id == setup.template_table_id for c in agents_page.items) + assert any(c.conversation_id == setup.multimodal_template_table_id for c in agents_page.items) + # Verify that child conversations are NOT in the list + assert not any(c.conversation_id == conv_id for c in agents_page.items) + agents_page_search = client.conversations.list_agents(search_query="multimodal-") + assert agents_page_search.total == 1 + agents_page_search = client.conversations.list_agents(search_query="chat-template-v2") + assert agents_page_search.total == 3 + + +def test_get_conversation(setup: ConversationContext): + """ + Tests fetching the metadata for a single, specific conversation. + - Creates a conversation. + - Fetches it by its ID. + - Verifies the returned metadata matches the created conversation. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + conv_id = _create_conversation_and_get_id(client, setup) + convo_meta = client.conversations.get_conversation(conv_id) + assert isinstance(convo_meta, ConversationMetaResponse) + assert convo_meta.conversation_id == conv_id + assert convo_meta.parent_id == setup.template_table_id + assert convo_meta.created_by == setup.user_id + + +def test_get_agent(setup: ConversationContext): + """ + Tests fetching the metadata for a single, specific agent/template. + - Fetches a known agent by its ID. + - Verifies the returned metadata is correct. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + agent_meta = client.conversations.get_agent(setup.template_table_id) + assert isinstance(agent_meta, AgentMetaResponse) + assert agent_meta.agent_id == setup.template_table_id + assert agent_meta.created_by == setup.superuser_id + + +def test_generate_conversation_title(setup: ConversationContext): + """ + Tests explicitly generating a title for an existing conversation. + - Creates a conversation (which gets an auto-generated title). + - Calls the dedicated `generate_title` endpoint. + - Verifies the conversation's title is updated to the newly generated one. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + conv_id = _create_conversation_and_get_id(client, setup) + response = client.conversations.generate_title(conversation_id=conv_id) + assert isinstance(response, ConversationMetaResponse) + assert isinstance(response.title, str) + assert len(response.title) > 0 + + updated_table_meta = client.conversations.get_conversation(conv_id) + assert updated_table_meta.title == response.title + + +def test_rename_conversation_title(setup: ConversationContext): + """ + Tests renaming the title of an existing conversation. + - Creates a conversation. + - Calls the rename endpoint with a new title. + - Verifies the conversation metadata reflects the new title. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + conv_id = _create_conversation_and_get_id(client, setup) + new_title = "renamed-conversation-title" + table_meta = client.conversations.get_conversation(conv_id) + assert table_meta.title != new_title, "Title should not match the new title initially" + + rename_response = client.conversations.rename_conversation_title(conv_id, new_title) + assert isinstance(rename_response, ConversationMetaResponse) + assert rename_response.title == new_title + + updated_table_meta = client.conversations.get_conversation(conv_id) + assert updated_table_meta.title == new_title + + +def test_delete_conversation(setup: ConversationContext): + """ + Tests the permanent deletion of a conversation. + - Creates a conversation. + - Deletes it. + - Verifies that fetching the conversation by its ID now raises a ResourceNotFoundError. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + conv_id = _create_conversation_and_get_id(client, setup) + delete_response = client.conversations.delete_conversation(conv_id) + assert isinstance(delete_response, OkResponse) + with pytest.raises(ResourceNotFoundError): + client.conversations.list_messages(conv_id) + + +def test_send_message(setup: ConversationContext): + """ + Tests sending a follow-up message to an existing conversation. + - Creates a conversation with one message. + - Sends a second message to the same conversation. + - Verifies the conversation now contains two messages. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + conv_table_id = _create_conversation_and_get_id(client, setup) + second_prompt = "And what is the capital of Germany?" + send_req = MessageAddRequest( + conversation_id=conv_table_id, + data={"User": second_prompt}, + ) + stream_gen = client.conversations.send_message(send_req) + ai_response_chunks = [c for c in stream_gen] + assert len(ai_response_chunks) > 0, "Send message stream was empty" + + conv_details = client.conversations.list_messages(conv_table_id) + assert conv_details.total == 2 + assert conv_details.items[1]["User"] == second_prompt + assert "text with [8] tokens" in conv_details.items[1]["AI"] + + +def test_list_messages(setup: ConversationContext): + """ + Tests fetching the full message history of a conversation. + - Creates a conversation with an initial message. + - Fetches the message list. + - Verifies the content of the first message. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + conv_id = _create_conversation_and_get_id(client, setup) + convo_details = client.conversations.list_messages(conv_id) + assert isinstance(convo_details, Page) + assert convo_details.total == 1 + first_turn = convo_details.items[0] + assert first_turn["User"] == "First message" + assert "text with [3] tokens" in first_turn["AI"] + # Threads + # TODO: Move this to its own test + response = client.conversations.get_threads(conv_id) + thread = response.threads["AI"].thread + assert len(thread) > 2 + assert thread[0].role == "system" + assert thread[1].role == "user" + assert thread[1].user_prompt == "First message" + assert thread[2].role == "assistant" + assert "text with [3] tokens" in thread[2].content + + +def test_regen_message(setup: ConversationContext): + """ + Tests regenerating the last AI response in a conversation. + - Creates a conversation. + - Stores the original AI response. + - Calls the regeneration endpoint. + - Verifies the new AI response is different from the original. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + conv_id = _create_conversation_and_get_id( + client, setup, initial_data={"User": "Suggest a movie"}, check_regen=True + ) + # 1. Get the initial message + convo_details = client.conversations.list_messages(conv_id) + assert convo_details.total == 1 + message_row = convo_details.items[0] + row_id = message_row["ID"] + original_ai_content = message_row["AI"] + assert original_ai_content is not None + + # 2. Update the message (Optional) + new_content = "Suggest a movie before 1950." + update_req = MessageUpdateRequest( + conversation_id=conv_id, + row_id=row_id, + data={"User": new_content}, + ) + update_response = client.conversations.update_message(update_req) + assert isinstance(update_response, OkResponse) + + # 3. Regenerate the AI response + regen_req = MessagesRegenRequest( + conversation_id=conv_id, + row_id=row_id, + ) + stream_gen = client.conversations.regen_message(regen_req) + responses = list(stream_gen) + assert len(responses) > 0 + assert all(isinstance(r, CellCompletionResponse) for r in responses) + + # 3. Verify the regeneration + updated_details = client.conversations.list_messages(conv_id) + assert updated_details.total == 1 + updated_message_row = updated_details.items[0] + assert updated_message_row["AI"] != original_ai_content + + +def test_regen_messages(setup: ConversationContext): + """ + Tests regenerating from an earlier point in a multi-message conversation. + - Creates a conversation with three messages. + - Calls regenerate starting from the first message's ID. + - Verifies that all three AI responses have changed. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + conv_id = _create_conversation_and_get_id( + client, setup, initial_data={"User": "Suggest a movie"}, check_regen=True + ) + # 1. Send messages + send_req = MessageAddRequest( + conversation_id=conv_id, + data={"User": "Suggest second movie"}, + ) + list(client.conversations.send_message(send_req)) # consume stream + send_req = MessageAddRequest( + conversation_id=conv_id, + data={"User": "Describe the movies"}, + ) + list(client.conversations.send_message(send_req)) # consume stream + + # 2. Get the conversation details + convo_details = client.conversations.list_messages(conv_id) + assert convo_details.total == 3 + first_row = convo_details.items[0] + second_row = convo_details.items[1] + third_row = convo_details.items[2] + assert first_row["User"] == "Suggest a movie" + assert second_row["User"] == "Suggest second movie" + assert third_row["User"] == "Describe the movies" + + # 3. Update the message (Optional) + new_content = "Suggest a movie before 1950." + update_req = MessageUpdateRequest( + conversation_id=conv_id, + row_id=first_row["ID"], + data={"User": new_content}, + ) + update_response = client.conversations.update_message(update_req) + assert isinstance(update_response, OkResponse) + + # 4. Regenerate both messages + regen_req = MessagesRegenRequest( + conversation_id=conv_id, + row_id=first_row["ID"], + ) + stream_gen = client.conversations.regen_message(regen_req) + responses = list(stream_gen) + assert len(responses) > 0 + assert all(isinstance(r, CellCompletionResponse) for r in responses) + + # 5. Verify the regeneration + updated_details = client.conversations.list_messages(conv_id) + assert updated_details.total == 3 + updated_first_row = updated_details.items[0] + updated_second_row = updated_details.items[1] + updated_third_row = updated_details.items[2] + assert updated_first_row["User"] != first_row["User"] + assert updated_second_row["User"] == second_row["User"] + assert updated_third_row["User"] == third_row["User"] + assert updated_first_row["AI"] != first_row["AI"] + assert updated_second_row["AI"] != second_row["AI"] + assert updated_third_row["AI"] != third_row["AI"] + + +def test_update_message(setup: ConversationContext): + """ + Tests editing the content of a specific message. + - Creates a conversation. + - Updates the 'User' content of the first message. + - Verifies the change while ensuring the 'AI' content is untouched. + - Updates the 'AI' content and verifies the change. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + conv_id = _create_conversation_and_get_id(client, setup) + # 1. Get the initial message to find its row_id + convo_details = client.conversations.list_messages(conv_id) + assert convo_details.total == 1 + message_row = convo_details.items[0] + row_id = message_row["ID"] + assert message_row["User"] == "First message" + + # 2. Update the message + new_content = "This is the edited first message." + update_req = MessageUpdateRequest( + conversation_id=conv_id, + row_id=row_id, + data={"User": new_content}, + ) + update_response = client.conversations.update_message(update_req) + assert isinstance(update_response, OkResponse) + + # 3. Verify the update + updated_details = client.conversations.list_messages(conv_id) + assert updated_details.total == 1 + updated_message_row = updated_details.items[0] + assert updated_message_row["User"] == new_content + assert updated_message_row["AI"] == message_row["AI"] # AI part should be unchanged + + # 2. Update the message + new_ai_content = "AI Response" + update_req = MessageUpdateRequest( + conversation_id=conv_id, + row_id=row_id, + data={"AI": new_ai_content}, + ) + update_response = client.conversations.update_message(update_req) + assert isinstance(update_response, OkResponse) + + # 3. Verify the update + updated_details = client.conversations.list_messages(conv_id) + assert updated_details.total == 1 + updated_message_row = updated_details.items[0] + assert updated_message_row["User"] == new_content + assert updated_message_row["AI"] == new_ai_content + + +def test_conversation_with_image(setup: ConversationContext): + """ + Tests starting a conversation with a multimodal (image) input. + - Uploads an image to get a file URI. + - Creates a conversation using a multimodal agent, passing the image URI. + - Verifies the AI response correctly identifies the image content. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + photo_uri = upload_file(client, FILES["rabbit.jpeg"]).uri + initial_data = {"User": "What animal is in this image?", "Photo": photo_uri} + conv_id = _create_conversation_and_get_id( + client, setup, initial_data=initial_data, multimodal=True + ) + messages = client.conversations.list_messages(conv_id) + assert messages.total == 1 + assert "[image/jpeg], shape [(1200, 1600, 3)]" in messages.items[0]["AI"].lower() + + +def test_conversation_with_audio(setup: ConversationContext): + """ + Tests starting a conversation with a multimodal (audio) input. + - Uploads an audio file to get a file URI. + - Creates a conversation using a multimodal agent, passing the audio URI. + - Verifies the AI response indicates successful processing of the audio. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + audio_uri = upload_file(client, FILES["turning-a4-size-magazine.mp3"]).uri + initial_data = {"User": "What does the audio say?", "Audio": audio_uri} + conv_id = _create_conversation_and_get_id( + client, setup, initial_data=initial_data, multimodal=True + ) + messages = client.conversations.list_messages(conv_id) + assert messages.total == 1 + assert "[audio/mpeg]" in messages.items[0]["AI"].lower() + + +def test_conversation_with_document(setup: ConversationContext): + """ + Tests starting a conversation with a multimodal (document) input. + - Uploads a document to get a file URI. + - Creates a conversation using a multimodal agent, passing the document URI. + - Verifies the AI response correctly processes the document content. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + doc_uri = upload_file(client, FILES["creative-story.md"]).uri + initial_data = {"User": "Summarize this document in one sentence.", "Doc": doc_uri} + conv_id = _create_conversation_and_get_id( + client, setup, initial_data=initial_data, multimodal=True + ) + messages = client.conversations.list_messages(conv_id) + assert messages.total == 1 + assert "text with [398] tokens" in messages.items[0]["AI"].lower() + + +def test_full_lifecycle(setup: ConversationContext): + """ + Tests the complete sequence of user actions from creation to deletion. + - Creates a conversation, which auto-generates a title. + - Sends a follow-up message. + - Regenerates the last message. + - Renames the conversation title. + - Deletes the conversation and verifies it's gone. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + # 1. Create + conv_id = _create_conversation_and_get_id( + client, setup, initial_data={"User": "Suggest a movie"}, check_regen=True + ) + assert client.conversations.list_messages(conv_id).total == 1 + meta = client.conversations.get_conversation(conv_id) + assert len(meta.title) > 0 + + # 2. Send Message + send_req = MessageAddRequest( + conversation_id=conv_id, + data={"User": "Suggest second movie"}, + ) + list(client.conversations.send_message(send_req)) # consume stream + messages_after_send = client.conversations.list_messages(conv_id) + assert messages_after_send.total == 2 + + # 3. Regenerate Message + last_message = messages_after_send.items[-1] + last_row_id = last_message["ID"] + original_ai_content = last_message["AI"] + regen_req = MessagesRegenRequest(conversation_id=conv_id, row_id=last_row_id) + list(client.conversations.regen_message(regen_req)) # consume stream + messages_after_regen = client.conversations.list_messages(conv_id) + assert messages_after_regen.total == 2 + regenerated_message = messages_after_regen.items[-1] + assert regenerated_message["User"] == "Suggest second movie" + assert regenerated_message["AI"] != original_ai_content + + # 4. Rename + new_title = "Best Movie Agent" + client.conversations.rename_conversation_title(conv_id, new_title) + updated_table_meta = client.conversations.get_conversation(conv_id) + assert updated_table_meta.title == new_title + + # 5. Delete + client.conversations.delete_conversation(conv_id) + with pytest.raises(ResourceNotFoundError): + client.conversations.list_messages(conv_id) + + +def test_full_lifecycle_with_files(setup: ConversationContext): + """ + Tests a complete lifecycle using multimodal inputs. + - Creates a conversation with an image. + - Sends a follow-up with audio. + - Updates and regenerates the first (image) message. + - Sends a final follow-up with a document. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + # 1. Create with image + photo_uri = upload_file(client, FILES["rabbit.jpeg"]).uri + initial_data = {"User": "What is this animal?", "Photo": photo_uri} + conv_id = _create_conversation_and_get_id(client, setup, initial_data, multimodal=True) + messages = client.conversations.list_messages(conv_id) + assert messages.total == 1 + response_text = messages.items[0]["AI"].lower() + assert ( + "[image/jpeg], shape [(1200, 1600, 3)]" in response_text + and "text with [10] tokens" in response_text + ) + first_row_id = messages.items[0]["ID"] + + # 2. Send a follow-up with audio + audio_uri = upload_file(client, FILES["turning-a4-size-magazine.mp3"]).uri + send_req = MessageAddRequest( + conversation_id=conv_id, + data={"User": "What sound is this?", "Audio": audio_uri}, + ) + list(client.conversations.send_message(send_req)) + messages = client.conversations.list_messages(conv_id) + assert messages.total == 2 + assert "[audio/mpeg]" in messages.items[1]["AI"].lower() + + # 3. Update and Regenerate the first message (the one with the image) + new_content = "What is this animal? Why is it so popular?" + update_req = MessageUpdateRequest( + conversation_id=conv_id, + row_id=first_row_id, + data={"User": new_content}, + ) + update_response = client.conversations.update_message(update_req) + assert isinstance(update_response, OkResponse) + regen_req = MessagesRegenRequest(conversation_id=conv_id, row_id=first_row_id) + list(client.conversations.regen_message(regen_req)) + messages_after_regen = client.conversations.list_messages(conv_id) + assert messages_after_regen.total == 2 + response_text = messages_after_regen.items[0]["AI"].lower() + assert ( + "[image/jpeg], shape [(1200, 1600, 3)]" in response_text + and "text with [15] tokens" in response_text + ) + + # 4. Send a follow-up with a document + doc_uri = upload_file(client, FILES["creative-story.md"]).uri + send_req = MessageAddRequest( + conversation_id=conv_id, + data={"User": "Summarize this document in one sentence.", "Doc": doc_uri}, + ) + list(client.conversations.send_message(send_req)) + messages = client.conversations.list_messages(conv_id) + assert messages.total == 3 + assert "text with [398] tokens" in messages.items[2]["AI"].lower() + + +def test_conversation_permissions(setup: ConversationContext): + """ + Tests that users cannot access conversations they do not own. + - User1 creates a conversation. + - User2 (in the same project) tries to access User1's conversation. + - Asserts that all access attempts by User2 fail with ResourceNotFoundError. + """ + client1 = JamAI(user_id=setup.user_id, project_id=setup.project_id) + conv_id1 = _create_conversation_and_get_id(client1, setup) + + with create_user({"email": "user2@example.com", "name": "user2"}) as user2: + su_client = JamAI(user_id=setup.superuser_id) + su_client.organizations.join_organization( + user_id=user2.id, organization_id=setup.org_id, role=Role.GUEST + ) + su_client.projects.join_project( + user_id=user2.id, project_id=setup.project_id, role=Role.GUEST + ) + client2 = JamAI(user_id=user2.id, project_id=setup.project_id) + + assert client2.conversations.list_conversations().total == 0 + + with pytest.raises(ResourceNotFoundError): + client2.conversations.list_messages(conv_id1) + + +def test_invalid_operations(setup: ConversationContext): + """ + Tests various invalid API calls to ensure they fail with the correct errors. + - Tries to get/rename a non-existent conversation. + - Tries to create a conversation from a non-existent agent template. + """ + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + non_existent_id = "non-existent-conversation-id" + + with pytest.raises(ResourceNotFoundError): + client.conversations.list_messages(non_existent_id) + with pytest.raises(ResourceNotFoundError): + client.conversations.rename_conversation_title(non_existent_id, "new-title") + + with pytest.raises(ResourceNotFoundError): + create_req = ConversationCreateRequest( + agent_id="non-existent-template", + data={"User": "test"}, + ) + list(client.conversations.create_conversation(create_req)) diff --git a/services/api/tests/routers/test_models.py b/services/api/tests/routers/test_models.py new file mode 100644 index 0000000..0a20e13 --- /dev/null +++ b/services/api/tests/routers/test_models.py @@ -0,0 +1,204 @@ +import pytest + +from jamaibase import JamAI +from jamaibase.types import ( + DeploymentRead, + DeploymentUpdate, + ModelConfigRead, +) +from owl.types import CloudProvider, DeploymentCreate, OkResponse, Page +from owl.utils.exceptions import ( + BadInputError, + ForbiddenError, + ResourceExistsError, + ResourceNotFoundError, +) +from owl.utils.test import ( + GPT_4O_MINI_CONFIG, + GPT_4O_MINI_DEPLOYMENT, + SMOL_LM2_CONFIG, + create_deployment, + create_model_config, + setup_organizations, +) + + +def test_create_model_config(): + with setup_organizations(): + with create_model_config(SMOL_LM2_CONFIG) as model: + assert isinstance(model, ModelConfigRead) + assert model.id == SMOL_LM2_CONFIG.id + assert model.type == SMOL_LM2_CONFIG.type + assert model.name == SMOL_LM2_CONFIG.name + assert model.context_length == SMOL_LM2_CONFIG.context_length + assert model.capabilities == SMOL_LM2_CONFIG.capabilities + + +def test_create_existing_model_config(): + with setup_organizations(): + with create_model_config(SMOL_LM2_CONFIG) as model: + assert model.id == SMOL_LM2_CONFIG.id + with pytest.raises(ResourceExistsError): + with create_model_config(SMOL_LM2_CONFIG): + pass + + +def test_list_system_model_configs(): + with setup_organizations() as ctx: + with create_model_config(SMOL_LM2_CONFIG): + # OK + models = JamAI(user_id=ctx.superuser.id).models.list_model_configs() + assert isinstance(models, Page) + assert len(models.items) == 1 + assert models.total == 1 + + +@pytest.mark.cloud +def test_list_system_model_configs_permission(): + with setup_organizations() as ctx: + with create_model_config(SMOL_LM2_CONFIG): + # No permission + with pytest.raises(ForbiddenError): + JamAI(user_id=ctx.user.id).models.list_model_configs() + + +def test_get_model_config(): + with setup_organizations() as ctx: + with create_model_config(SMOL_LM2_CONFIG) as model: + client = JamAI(user_id=ctx.superuser.id) + # Fetch + response = client.models.get_model_config(model.id) + assert isinstance(response, ModelConfigRead) + assert response.model_dump() == model.model_dump() + + +def test_get_nonexistent_model_config(): + with setup_organizations() as ctx: + client = JamAI(user_id=ctx.superuser.id) + with pytest.raises(ResourceNotFoundError): + client.models.get_model_config("nonexistent-model") + + +def test_update_model_config(): + """ + Test updating a model config. + - Update name + - Update ID and ensure foreign keys of deployments are updated + - `owned_by` and `id` must match for ELLM models + """ + with setup_organizations() as ctx: + with ( + create_model_config(GPT_4O_MINI_CONFIG) as model, + create_deployment(GPT_4O_MINI_DEPLOYMENT) as deployment, + ): + assert isinstance(model, ModelConfigRead) + client = JamAI(user_id=ctx.superuser.id) + # Update name + new_name = "NEW MODEL" + model = client.models.update_model_config(model.id, dict(name=new_name)) + assert isinstance(model, ModelConfigRead) + assert model.id == model.id + assert model.name == new_name + # Update meta + meta = dict(icon="openai") + model = client.models.update_model_config(model.id, dict(meta=meta)) + assert isinstance(model, ModelConfigRead) + assert model.id == model.id + assert model.meta == meta + # `owned_by` and `id` must match for ELLM models + new_owned_by = "ellm" + new_id = "ellm/biglm2:135m" + with pytest.raises(BadInputError, match="ELLM models must have `owned_by"): + client.models.update_model_config(model.id, dict(owned_by=new_owned_by)) + with pytest.raises(BadInputError, match="ELLM models must have `owned_by"): + client.models.update_model_config(model.id, dict(id=new_id)) + # Update ID and `owned_by` + model = client.models.update_model_config( + model.id, dict(id=new_id, owned_by=new_owned_by) + ) + assert isinstance(model, ModelConfigRead) + assert model.id == new_id + assert model.name == new_name + assert model.meta == meta + assert model.owned_by == new_owned_by + # Fetch again + model = client.models.get_model_config(model.id) + assert isinstance(model, ModelConfigRead) + assert model.id == new_id + assert model.name == new_name + assert model.meta == meta + assert model.owned_by == new_owned_by + # Fetch deployment to ensure foreign key is updated + response = client.models.get_deployment(deployment.id) + assert isinstance(response, DeploymentRead) + assert response.model.id == new_id + + +def test_delete_model_config(): + with setup_organizations() as ctx: + with create_model_config(SMOL_LM2_CONFIG) as model: + client = JamAI(user_id=ctx.superuser.id) + response = client.models.delete_model_config(model.id) + assert isinstance(response, OkResponse) + with pytest.raises(ResourceNotFoundError): + client.models.get_model_config(model.id) + + +def test_create_cloud_deployment(): + with setup_organizations() as ctx: + with ( + create_model_config(GPT_4O_MINI_CONFIG) as model, + create_deployment(GPT_4O_MINI_DEPLOYMENT) as deployment, + ): + assert deployment.model_id == model.id + assert deployment.name == GPT_4O_MINI_DEPLOYMENT.name + assert deployment.provider == CloudProvider.OPENAI + assert deployment.routing_id == GPT_4O_MINI_DEPLOYMENT.routing_id + + model = JamAI(user_id=ctx.superuser.id).models.get_model_config(model.id) + assert isinstance(model, ModelConfigRead) + + +def test_get_deployment(): + with setup_organizations() as ctx: + with ( + create_model_config(GPT_4O_MINI_CONFIG) as model, + create_deployment( + DeploymentCreate( + model_id=model.id, + name="Test Deployment", + provider=CloudProvider.OPENAI, + routing_id="openai/gpt-4o-mini", + ) + ) as deployment, + ): + client = JamAI(user_id=ctx.superuser.id) + # Fetch + response = client.models.get_deployment(deployment.id) + assert isinstance(response, DeploymentRead) + assert response.model_dump() == deployment.model_dump() + + +def test_update_deployment(): + with setup_organizations() as ctx: + with ( + create_model_config(GPT_4O_MINI_CONFIG), + create_deployment(GPT_4O_MINI_DEPLOYMENT) as deployment, + ): + assert deployment.name == GPT_4O_MINI_DEPLOYMENT.name + client = JamAI(user_id=ctx.superuser.id) + # Update + new_name = "NEW DEPLOYMENT" + deployment = client.models.update_deployment( + deployment.id, DeploymentUpdate(name=new_name) + ) + assert isinstance(deployment, DeploymentRead) + assert deployment.name == new_name + # Fetch again + deployment = client.models.get_deployment(deployment.id) + assert isinstance(deployment, DeploymentRead) + assert deployment.name == new_name + + +if __name__ == "__main__": + test_create_model_config() diff --git a/services/api/tests/routers/test_organizations.py b/services/api/tests/routers/test_organizations.py new file mode 100644 index 0000000..ac1989e --- /dev/null +++ b/services/api/tests/routers/test_organizations.py @@ -0,0 +1,380 @@ +import pytest + +from jamaibase import JamAI +from jamaibase.types import ( + OrganizationRead, + OrganizationUpdate, + OrgMemberRead, + Page, +) +from owl.configs import ENV_CONFIG +from owl.db import TEMPLATE_ORG_ID, sync_session +from owl.db.models import Organization +from owl.types import ChatCompletionResponse, ChatRequest, Role, StripePaymentInfo +from owl.utils.exceptions import ( + BadInputError, + ForbiddenError, + ResourceNotFoundError, + UpgradeTierError, +) +from owl.utils.test import ( + BASE_PLAN_ID, + ELLM_DESCRIBE_CONFIG, + ELLM_DESCRIBE_DEPLOYMENT, + GPT_4O_MINI_CONFIG, + GPT_4O_MINI_DEPLOYMENT, + GPT_41_NANO_CONFIG, + GPT_41_NANO_DEPLOYMENT, + create_deployment, + create_model_config, + create_organization, + create_user, + setup_organizations, + setup_projects, +) + + +def test_create_superorg(): + """ + Test creating organizations + - Assert that superorg and template org are created + - Assert that external keys are persisted correctly + """ + with ( + create_user() as superuser, + create_organization( + dict(name="Clubhouse", external_keys={"key": "value", "openai": "openai"}), + user_id=superuser.id, + ) as superorg, + ): + with create_organization( + dict(name="Clubhouse", external_keys={"key": "value", "openai": "openai"}), + user_id=superuser.id, + ) as org: + assert superorg.id == "0" + assert superorg.name == "Clubhouse" + assert superorg.external_keys == {"key": "value", "openai": "openai"} + assert org.id != "0" + assert org.name == "Clubhouse" + assert org.external_keys == {"key": "value", "openai": "openai"} + # Check org memberships + user = JamAI(user_id=superuser.id).users.get_user(superuser.id) + assert len(user.org_memberships) == 3 # Superorg + Template + Org + org_memberships = {m.organization_id: m for m in user.org_memberships} + assert org_memberships["0"].role == Role.ADMIN + assert org_memberships[TEMPLATE_ORG_ID].role == Role.ADMIN + assert org_memberships[org.id].role == Role.ADMIN + + # Assert template org and sys org still exist + with sync_session() as session: + assert session.get(Organization, TEMPLATE_ORG_ID) is not None + assert session.get(Organization, "0") is not None + # Assert template org and sys org are deleted + with sync_session() as session: + assert session.get(Organization, TEMPLATE_ORG_ID) is None + assert session.get(Organization, "0") is None + + +# @pytest.mark.cloud +# def test_create_superorg_permission(): +# with create_user(), create_user(dict(email="russell@up.com", name="Russell")) as user: +# with pytest.raises(ForbiddenError): +# with create_organization(user_id=user.id): +# pass + + +@pytest.mark.cloud +def test_create_organization_base_tier_limit(): + """ + A user can only have one organization with a base tier plan. + """ + with ( + create_user() as superuser, + create_user(dict(name="user", email="russell@up.com")) as user, + # Internal org "0" is not counted against the limit + create_organization(dict(name="Admin org"), user_id=superuser.id) as superorg, + # First base tier org + create_organization(dict(name="Org 1"), user_id=superuser.id) as o1, + ): + assert superorg.id == "0" + assert o1.id != "0" + # Auto-subscribed to base tier plan + assert o1.price_plan_id == BASE_PLAN_ID + # Second base tier org + with pytest.raises( + (ForbiddenError, UpgradeTierError), + match="can have only one organization with Free Plan", + ): + with create_organization(dict(name="Org 2"), user_id=superuser.id): + pass + # Create another org with a different plan + super_client = JamAI(user_id=superuser.id) + client = JamAI(user_id=user.id) + with ( + create_organization( + dict(name="Org 2"), user_id=superuser.id, subscribe_plan=False + ) as o2, + create_organization(dict(name="Org X"), user_id=user.id, subscribe_plan=False) as ox, + ): + assert o2.price_plan_id is None + assert o2.active is False + plans = super_client.prices.list_price_plans().items + plan = next((p for p in plans if p.id != BASE_PLAN_ID), None) + assert plan is not None + invoice = super_client.organizations.subscribe_plan( + organization_id=o2.id, price_plan_id=plan.id + ) + assert isinstance(invoice, StripePaymentInfo) + assert invoice.amount_due == 0 # Stripe not enabled + # Second base tier org + with pytest.raises( + (ForbiddenError, UpgradeTierError), + match="can have only one organization with Free Plan", + ): + with create_organization(dict(name="Org 3"), user_id=superuser.id): + pass + # Cannot subscribe to base tier plan + with pytest.raises( + (ForbiddenError, UpgradeTierError), + match="can have only one organization with Free Plan", + ): + super_client.organizations.subscribe_plan( + organization_id=o2.id, price_plan_id=BASE_PLAN_ID + ) + # Auto-subscribed to base tier plan + assert ox.price_plan_id == BASE_PLAN_ID + with pytest.raises(BadInputError, match="already subscribed to .+ plan"): + client.organizations.subscribe_plan( + organization_id=ox.id, price_plan_id=BASE_PLAN_ID + ) + + +def test_list_organizations(): + with setup_organizations() as ctx: + orgs = JamAI(user_id=ctx.superuser.id).organizations.list_organizations() + assert isinstance(orgs, Page) + assert len(orgs.items) == 3 # 2 orgs + template + assert orgs.total == 3 + + +@pytest.mark.cloud +def test_list_organizations_permission(): + with setup_organizations() as ctx: + with pytest.raises(ForbiddenError): + JamAI(user_id=ctx.user.id).organizations.list_organizations() + + +def test_get_org(): + """ + Test fetch organization. + - Admin can view API keys + - Member cannot view API keys + - System user can fetch org but not API keys + """ + with setup_organizations() as ctx: + super_client = JamAI(user_id=ctx.superuser.id) + client = JamAI(user_id=ctx.user.id) + # Add API key + super_client.organizations.update_organization( + ctx.superorg.id, OrganizationUpdate(external_keys=dict(x="x")) + ) + client.organizations.update_organization( + ctx.org.id, OrganizationUpdate(external_keys=dict(x="x")) + ) + # Join organization as member + membership = super_client.organizations.join_organization( + ctx.user.id, + organization_id=ctx.superorg.id, + role=Role.MEMBER, + ) + assert isinstance(membership, OrgMemberRead) + # Admin can view API keys + org = super_client.organizations.get_organization(ctx.superorg.id) + assert isinstance(org, OrganizationRead) + assert org.id == ctx.superorg.id + assert org.external_keys["x"] == "x" + # Member cannot view API keys (cloud only) + org = client.organizations.get_organization(ctx.superorg.id) + assert isinstance(org, OrganizationRead) + assert org.id == ctx.superorg.id + assert org.external_keys["x"] == "x" if ENV_CONFIG.is_oss else "***" + # System user can fetch org but not API keys (cloud only) + user = super_client.users.get_user() + assert ctx.org.id not in {m.organization_id for m in user.org_memberships} + org = super_client.organizations.get_organization(ctx.org.id) + assert isinstance(org, OrganizationRead) + assert org.id == ctx.org.id + assert org.external_keys["x"] == "x" if ENV_CONFIG.is_oss else "***" + + +def test_update_org(): + """ + Test update organization. + - Partial update org + - Partial update external keys + """ + with setup_organizations() as ctx: + client = JamAI(user_id=ctx.user.id) + # Partial update org + org = client.organizations.update_organization( + ctx.org.id, OrganizationUpdate(name="Updated Name") + ) + assert isinstance(org, OrganizationRead) + assert org.name == "Updated Name" + assert org.timezone is None + org = client.organizations.update_organization( + ctx.org.id, OrganizationUpdate(timezone="Asia/Kuala_Lumpur") + ) + assert isinstance(org, OrganizationRead) + assert org.name == "Updated Name" + assert org.timezone == "Asia/Kuala_Lumpur" + with pytest.raises(BadInputError, match="currency"): + # Only USD is accepted for now + client.organizations.update_organization(ctx.org.id, dict(currency="EUR")) + with pytest.raises(BadInputError, match="timezone"): + # Strict timezone validation + client.organizations.update_organization( + ctx.org.id, dict(timezone="asia/kuala_lumpur") + ) + # Update external keys + org = client.organizations.update_organization( + ctx.org.id, OrganizationUpdate(external_keys=dict(x="x")) + ) + assert isinstance(org, OrganizationRead) + assert org.external_keys == dict(x="x") + org = client.organizations.update_organization( + ctx.org.id, OrganizationUpdate(external_keys=dict(y="y")) + ) + assert isinstance(org, OrganizationRead) + assert org.external_keys == dict(y="y") + + +@pytest.mark.cloud +def test_update_org_permission(): + """ + Test update organization. + - Only admin can update org + """ + with setup_organizations() as ctx: + super_client = JamAI(user_id=ctx.superuser.id) + client = JamAI(user_id=ctx.user.id) + # Test update permission + membership = client.organizations.join_organization( + ctx.superuser.id, + organization_id=ctx.org.id, + role=Role.MEMBER, + ) + assert isinstance(membership, OrgMemberRead) + # Member fail + with pytest.raises(ForbiddenError): + super_client.organizations.update_organization( + ctx.org.id, OrganizationUpdate(name="New Name") + ) + # Admin OK + org = client.organizations.update_organization( + ctx.org.id, OrganizationUpdate(name="New Name") + ) + assert isinstance(org, OrganizationRead) + assert org.name == "New Name" + + +def test_delete_org(): + with setup_organizations() as ctx: + ok_response = JamAI(user_id=ctx.user.id).organizations.delete_organization( + ctx.org.id, missing_ok=False + ) + assert ok_response.ok is True + client = JamAI(user_id=ctx.superuser.id) + with pytest.raises(ResourceNotFoundError): + client.organizations.get_organization(ctx.org.id) + # Assert users are not deleted + users = client.users.list_users() + assert isinstance(users, Page) + assert len(users.items) == 2 + + +@pytest.mark.cloud +def test_delete_org_permission(): + with setup_organizations() as ctx: + client = JamAI(user_id=ctx.user.id) + with pytest.raises(ForbiddenError): + client.organizations.delete_organization(ctx.superorg.id, missing_ok=False) + + +def test_organisation_model_catalogue(): + """ + Test listing model configs: + - System level + - Organization level + - Private models via allow list and block list + - Run chat completion + """ + with setup_projects() as ctx: + with ( + # Common models + create_model_config(GPT_4O_MINI_CONFIG) as m0, + # Private models (allow list) + create_model_config( + dict( + **ELLM_DESCRIBE_CONFIG.model_dump(exclude_unset=True), + allowed_orgs=[ctx.org.id], + ) + ) as m1, + # Private models (block list) + create_model_config( + dict( + **GPT_41_NANO_CONFIG.model_dump(exclude_unset=True), + allowed_orgs=[ctx.org.id, ctx.superorg.id], + blocked_orgs=[ctx.org.id], + ) + ) as m2, + create_deployment(GPT_4O_MINI_DEPLOYMENT), + create_deployment(ELLM_DESCRIBE_DEPLOYMENT), + create_deployment(GPT_41_NANO_DEPLOYMENT), + ): + assert m0.is_private is False + assert m1.is_private is True + assert m2.is_private is True + super_client = JamAI(user_id=ctx.superuser.id, project_id=ctx.projects[0].id) + client = JamAI(user_id=ctx.user.id, project_id=ctx.projects[1].id) + # System-level + models = super_client.models.list_model_configs() + assert isinstance(models, Page) + assert len(models.items) == 3 + assert models.total == 3 + # Organisation-level + models = super_client.organizations.model_catalogue(organization_id=ctx.superorg.id) + assert isinstance(models, Page) + assert len(models.items) == 2 + assert models.total == 2 + model_ids = {m.id for m in models.items} + assert GPT_4O_MINI_CONFIG.id in model_ids + assert ELLM_DESCRIBE_CONFIG.id not in model_ids + assert GPT_41_NANO_CONFIG.id in model_ids + # Organisation-level + models = client.organizations.model_catalogue(organization_id=ctx.org.id) + assert isinstance(models, Page) + assert len(models.items) == 2 + assert models.total == 2 + model_ids = {m.id for m in models.items} + assert GPT_4O_MINI_CONFIG.id in model_ids + assert ELLM_DESCRIBE_CONFIG.id in model_ids + assert GPT_41_NANO_CONFIG.id not in model_ids + # Run chat completion + req = ChatRequest( + model=ELLM_DESCRIBE_CONFIG.id, + messages=[{"role": "user", "content": "Hi there"}], + max_tokens=4, + stream=False, + ) + response = client.generate_chat_completions(req) + assert isinstance(response, ChatCompletionResponse) + assert len(response.content) > 0 + assert response.prompt_tokens == 2 + assert response.completion_tokens > 0 + with pytest.raises(ResourceNotFoundError): + super_client.generate_chat_completions(req) + + +if __name__ == "__main__": + test_list_organizations() diff --git a/services/api/tests/routers/test_projects.py b/services/api/tests/routers/test_projects.py new file mode 100644 index 0000000..77fab85 --- /dev/null +++ b/services/api/tests/routers/test_projects.py @@ -0,0 +1,577 @@ +import re +from dataclasses import dataclass +from os.path import dirname, join, realpath +from tempfile import TemporaryDirectory + +import httpx +import pytest + +from jamaibase import JamAI +from jamaibase.types import ( + GetURLResponse, + LLMGenConfig, + OrganizationCreate, + OrgMemberRead, + Page, + ProjectCreate, + ProjectMemberRead, + ProjectRead, + ProjectUpdate, + RAGParams, + TableImportRequest, +) +from jamaibase.utils.exceptions import ( + AuthorizationError, + ForbiddenError, + ResourceExistsError, + ResourceNotFoundError, +) +from owl.db import TEMPLATE_ORG_ID +from owl.types import GEN_CONFIG_VAR_PATTERN, ColumnDtype, Role, TableType +from owl.utils.test import ( + ELLM_DESCRIBE_CONFIG, + ELLM_DESCRIBE_DEPLOYMENT, + TEXT_EMBEDDING_3_SMALL_CONFIG, + TEXT_EMBEDDING_3_SMALL_DEPLOYMENT, + RERANK_ENGLISH_v3_SMALL_CONFIG, + RERANK_ENGLISH_v3_SMALL_DEPLOYMENT, + create_deployment, + create_model_config, + create_organization, + create_project, + create_user, + get_file_map, + list_table_rows, + setup_organizations, + upload_file, +) + +TEST_FILE_DIR = join(dirname(dirname(realpath(__file__))), "files") +FILES = get_file_map(TEST_FILE_DIR) + +FILE_COLUMNS = ["image", "audio", "document"] + + +def test_create_project(): + with setup_organizations() as ctx: + # Standard creation + with ( + create_project(user_id=ctx.superuser.id), + create_project( + dict(name="Mickey 17"), user_id=ctx.user.id, organization_id=ctx.org.id + ), + ): + pass + + +@pytest.mark.cloud +def test_create_project_auth(): + from owl.db import sync_session + from owl.db.models.cloud import APIKey + + with ( + setup_organizations() as ctx, + create_project(user_id=ctx.user.id, organization_id=ctx.org.id) as p0, + ): + ### --- Test Project Key auth --- ### + # Project-linked PAT + pat = JamAI(user_id=ctx.user.id).users.create_pat(dict(name="pat", project_id=p0.id)) + assert pat.id.startswith("jamai_pat_") + client = JamAI(user_id=ctx.user.id, token=pat.id) + name = "Mickey 18" + p1 = client.projects.create_project(dict(name=name, organization_id=ctx.org.id)) + assert p1.name == name + with pytest.raises(AuthorizationError, match="invalid authorization token"): + JamAI(user_id=ctx.user.id, token=f"{pat.id}xx").projects.create_project( + dict(name=name, organization_id=ctx.org.id) + ) + # No project link + pat = JamAI(user_id=ctx.user.id).users.create_pat(dict(name="pat", project_id=None)) + assert pat.id.startswith("jamai_pat_") + client = JamAI(user_id=ctx.user.id, token=pat.id) + name = "Mickey 19" + p1 = client.projects.create_project(dict(name=name, organization_id=ctx.org.id)) + assert p1.name == name + with pytest.raises(AuthorizationError, match="invalid authorization token"): + JamAI(user_id=ctx.user.id, token=f"{pat.id}xx").projects.create_project( + dict(name=name, organization_id=ctx.org.id) + ) + + ### --- Test Legacy Organization Key auth --- ### + with sync_session() as session: + key = APIKey(id="jamai_sk_legacy", organization_id=ctx.org.id) + session.add(key) + session.commit() + session.refresh(key) + client = JamAI(user_id=ctx.user.id, token=key.id) + name = "Mickey 20" + p1 = client.projects.create_project(dict(name=name, organization_id=ctx.org.id)) + assert p1.name == name + with pytest.raises(AuthorizationError, match="invalid authorization token"): + JamAI(user_id=ctx.user.id, token=f"{key.id}xx").projects.create_project( + dict(name=name, organization_id=ctx.org.id) + ) + + # List projects + projects = client.projects.list_projects(ctx.org.id) + assert isinstance(projects, Page) + assert len(projects.items) == 4 + assert projects.total == 4 + + +@pytest.mark.cloud +def test_create_project_permission(): + with setup_organizations() as ctx: + assert ctx.user.id != "0" + with pytest.raises(ForbiddenError): + with create_project( + dict(name="My First Project", organization_id=ctx.superorg.id), + user_id=ctx.user.id, + ): + pass + + +# def test_create_existing_project(): +# with setup_organizations() as ctx: +# with create_project(user_id=ctx.superuser.id) as project: +# with pytest.raises(ResourceExistsError): +# with create_project( +# dict(id=project.id, name="Mickey 1"), user_id=ctx.superuser.id +# ): +# pass + + +def test_create_project_duplicate_name(): + with setup_organizations() as ctx, create_project(user_id=ctx.superuser.id) as p0: + with ( + create_project(dict(name=p0.name), user_id=ctx.superuser.id) as p1, + create_project(dict(name=p0.name), user_id=ctx.superuser.id) as p2, + ): + assert isinstance(p1, ProjectRead) + assert p1.name == f"{p0.name} (1)" + assert isinstance(p2, ProjectRead) + assert p2.name == f"{p0.name} (2)" + assert len({p0.id, p1.id, p2.id}) == 3 + + +def test_create_project_missing_org(): + with setup_organizations() as ctx: + with pytest.raises((ForbiddenError, ResourceNotFoundError)): + with create_project( + dict(name="My First Project"), + user_id=ctx.superuser.id, + organization_id="nonexistent", + ): + pass + + +def test_list_projects(): + with setup_organizations() as ctx: + with ( + create_project(user_id=ctx.superuser.id), + create_project(dict(name="Mickey 1"), user_id=ctx.superuser.id), + ): + projects = JamAI(user_id=ctx.superuser.id).projects.list_projects(ctx.superorg.id) + assert isinstance(projects, Page) + assert len(projects.items) == 2 + + +@pytest.mark.cloud +def test_list_projects_permission(): + """ + Test project list permission. + - ADMIN and MEMBER can list all projects. + - Non-members cannot list projects at all. + - GUEST can only list projects that they are a member of. + """ + with ( + setup_organizations() as ctx, + create_project(user_id=ctx.superuser.id), + create_project(user_id=ctx.superuser.id) as p1, + create_project(user_id=ctx.user.id, organization_id=ctx.org.id), + ): + super_client = JamAI(user_id=ctx.superuser.id) + client = JamAI(user_id=ctx.user.id) + ### --- Admin can list all projects --- ### + projects = super_client.projects.list_projects(ctx.superorg.id) + assert isinstance(projects, Page) + assert len(projects.items) == 2 + ### --- Non-member fail --- ### + with pytest.raises(ForbiddenError): + client.projects.list_projects(ctx.superorg.id) + ### --- Guest can list projects that they are a member of --- ### + # Join organization as guest and project + membership = super_client.organizations.join_organization( + ctx.user.id, + organization_id=ctx.superorg.id, + role=Role.GUEST, + ) + assert isinstance(membership, OrgMemberRead) + membership = super_client.projects.join_project( + ctx.user.id, + project_id=p1.id, + role=Role.MEMBER, + ) + assert isinstance(membership, ProjectMemberRead) + projects = client.projects.list_projects(ctx.superorg.id) + assert isinstance(projects, Page) + assert len(projects.items) == 1 + # Project role doesn't matter + membership = super_client.projects.update_member_role( + user_id=ctx.user.id, + project_id=p1.id, + role=Role.GUEST, + ) + assert isinstance(membership, ProjectMemberRead) + assert membership.role == Role.GUEST + projects = client.projects.list_projects(ctx.superorg.id) + assert isinstance(projects, Page) + assert len(projects.items) == 1 + ### --- Member can list all projects --- ### + # Update org role to MEMBER + membership = super_client.organizations.update_member_role( + user_id=ctx.user.id, + organization_id=ctx.superorg.id, + role=Role.MEMBER, + ) + assert isinstance(membership, OrgMemberRead) + assert membership.role == Role.MEMBER + projects = client.projects.list_projects(ctx.superorg.id) + assert isinstance(projects, Page) + assert len(projects.items) == 2 + + +@pytest.mark.cloud +def test_update_project_permission(): + with ( + setup_organizations() as ctx, + create_project(user_id=ctx.user.id, organization_id=ctx.org.id) as project, + ): + client = JamAI(user_id=ctx.user.id) + # Join organization and project as member + membership = client.organizations.join_organization( + ctx.superuser.id, + organization_id=ctx.org.id, + role=Role.MEMBER, + ) + assert isinstance(membership, OrgMemberRead) + membership = client.projects.join_project( + ctx.superuser.id, + project_id=project.id, + role=Role.MEMBER, + ) + assert isinstance(membership, ProjectMemberRead) + # Admin OK + updated_proj = client.projects.update_project(project.id, ProjectUpdate(name="New Name")) + assert isinstance(updated_proj, ProjectRead) + # Member fail + with pytest.raises(ForbiddenError): + JamAI(user_id=ctx.superuser.id).projects.update_project( + project.id, ProjectUpdate(name="Another Name") + ) + + +@dataclass(slots=True) +class ServingContext: + superuser_id: str + superorg_id: str + project_id: str + embedding_size: int + image_uri: str + audio_uri: str + document_uri: str + chat_model_id: str + embed_model_id: str + rerank_model_id: str + + +@pytest.fixture(scope="module") +def setup(): + """ + Fixture to set up the necessary organization and projects for file tests. + """ + with ( + create_user() as superuser, + create_organization( + body=OrganizationCreate(name="Superorg"), user_id=superuser.id + ) as superorg, + create_project( + dict(name="Superorg Project"), user_id=superuser.id, organization_id=superorg.id + ) as p0, + ): + assert superuser.id == "0" + assert superorg.id == "0" + + bge = "ellm/BAAI/bge-m3" + with ( + # Create models + create_model_config(ELLM_DESCRIBE_CONFIG) as desc_llm_config, + create_model_config(TEXT_EMBEDDING_3_SMALL_CONFIG) as embed_config, + create_model_config(RERANK_ENGLISH_v3_SMALL_CONFIG) as rerank_config, + create_model_config( + TEXT_EMBEDDING_3_SMALL_CONFIG.model_copy(update=dict(id=bge, owned_by="ellm")) + ), + # Create deployments + create_deployment(ELLM_DESCRIBE_DEPLOYMENT), + create_deployment(TEXT_EMBEDDING_3_SMALL_DEPLOYMENT), + create_deployment(RERANK_ENGLISH_v3_SMALL_DEPLOYMENT), + create_deployment( + TEXT_EMBEDDING_3_SMALL_DEPLOYMENT.model_copy(update=dict(model_id=bge)) + ), + ): + client = JamAI(user_id=superuser.id, project_id=p0.id) + image_uri = upload_file(client, FILES["rabbit.jpeg"]).uri + audio_uri = upload_file(client, FILES["gutter.mp3"]).uri + document_uri = upload_file( + client, FILES["LLMs as Optimizers [DeepMind ; 2023].pdf"] + ).uri + yield ServingContext( + superuser_id=superuser.id, + superorg_id=superorg.id, + project_id=p0.id, + embedding_size=embed_config.final_embedding_size, + image_uri=image_uri, + audio_uri=audio_uri, + document_uri=document_uri, + chat_model_id=desc_llm_config.id, + embed_model_id=embed_config.id, + rerank_model_id=rerank_config.id, + ) + + +def _check_tables(user_id: str, project_id: str): + client = JamAI(user_id=user_id, project_id=project_id) + for table_type in TableType: + tables = client.table.list_tables(table_type, parent_id="_agent_") + assert tables.total == 1 + rows = list_table_rows(client, table_type, tables.items[0].id) + assert rows.total == 1 + if table_type == TableType.ACTION: + # Check image content + urls = client.file.get_raw_urls([rows.values[0]["image"]]) + assert isinstance(urls, GetURLResponse) + image = httpx.get(urls.urls[0]).content + with open(FILES["cifar10-deer.jpg"], "rb") as f: + assert image == f.read() + + +def test_project_import_export( + setup: ServingContext, +): + """ + Test project import and export. + + Args: + setup (ServingContext): Setup. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + tables = [] + try: + # Create the tables + for table_type in TableType: + if table_type == TableType.CHAT: + parquet_filepath = FILES["export-v0.4-chat-agent.parquet"] + else: + parquet_filepath = FILES[f"export-v0.4-{table_type}.parquet"] + table = client.table.import_table( + table_type, + TableImportRequest(file_path=parquet_filepath, table_id_dst=None), + ) + tables.append(table) + # Export + with TemporaryDirectory() as tmp_dir: + file_path = join(tmp_dir, f"{setup.project_id}.parquet") + with open(file_path, "wb") as f: + f.write(client.projects.export_project(setup.project_id)) + # Import as new project + imported_project = client.projects.import_project( + file_path, + project_id="", + organization_id=setup.superorg_id, + ) + assert isinstance(imported_project, ProjectRead) + assert imported_project.id != setup.project_id + _check_tables(setup.superuser_id, imported_project.id) + # Import into existing project + with create_project( + dict(name="Superorg Project 1"), + user_id=setup.superuser_id, + organization_id=setup.superorg_id, + ) as p1: + imported_project = client.projects.import_project( + file_path, + project_id=p1.id, + organization_id="", + ) + assert isinstance(imported_project, ProjectRead) + assert imported_project.id == p1.id + _check_tables(setup.superuser_id, imported_project.id) + # Should not change existing metadata + project = client.projects.get_project(p1.id) + assert project.name == "Superorg Project 1" + # Should fail if tables already exist + with pytest.raises(ResourceExistsError): + client.projects.import_project( + file_path, + project_id=p1.id, + organization_id="", + ) + finally: + for table in tables: + client.table.delete_table(table_type, table.id) + + +@pytest.mark.parametrize("version", ["v0.4"]) +def test_project_import_parquet( + setup: ServingContext, + version: str, +): + """ + Test project import from an existing Parquet file. + - Import as new project from v0.4 file + - Import into existing parquet from v0.4 file + - Import v0.4 file with table and column names that are too long (test truncation) + + Args: + setup (ServingContext): Setup. + """ + client = JamAI(user_id=setup.superuser_id, project_id=setup.project_id) + ### --- Import as new project --- ### + imported_project = client.projects.import_project( + FILES[f"export-{version}-project.parquet"], + project_id="", + organization_id=setup.superorg_id, + ) + assert imported_project.id != setup.project_id + assert imported_project.name == "Test Project æ–°ã—ã„" + _check_tables(setup.superuser_id, imported_project.id) + ### --- Import into existing project --- ### + with create_project( + dict(name="Superorg Project 2"), + user_id=setup.superuser_id, + organization_id=setup.superorg_id, + ) as p1: + imported_project = client.projects.import_project( + FILES[f"export-{version}-project.parquet"], + project_id=p1.id, + organization_id="", + ) + assert imported_project.id == p1.id + assert imported_project.name == p1.name + assert imported_project.name != "Test Project æ–°ã—ã„" + _check_tables(setup.superuser_id, imported_project.id) + ### --- Import table and column names that are too long --- ### + imported_project = client.projects.import_project( + FILES[f"export-{version}-project-long-name.parquet"], + project_id="", + organization_id=setup.superorg_id, + ) + assert imported_project.id != setup.project_id + client = JamAI(user_id=setup.superuser_id, project_id=imported_project.id) + # Check tables + tables = client.table.list_tables(TableType.KNOWLEDGE) + assert len(tables.items) == 1 + assert tables.total == 1 + kt = tables.items[0] + assert len(kt.id) == 100 + tables = client.table.list_tables(TableType.ACTION) + assert len(tables.items) == 1 + assert tables.total == 1 + at = tables.items[0] + assert len(at.id) == 100 + assert len(at.cols) == 4 + for col in at.cols[2:]: + assert len(col.id) == 100 + assert at.cols[2].dtype == ColumnDtype.IMAGE + assert at.cols[3].dtype == ColumnDtype.STR + cfg = at.cols[3].gen_config + assert isinstance(cfg, LLMGenConfig) + ref_ids = re.findall(GEN_CONFIG_VAR_PATTERN, cfg.prompt) + assert len(ref_ids) == 1 + assert ref_ids[0] == at.cols[2].id + assert isinstance(cfg.rag_params, RAGParams) + assert cfg.rag_params.table_id == kt.id + tables = client.table.list_tables(TableType.CHAT) + assert len(tables.items) == 2 + assert tables.total == 2 + tables = client.table.list_tables(TableType.CHAT, parent_id="_agent_") + assert len(tables.items) == 1 + assert tables.total == 1 + agent = tables.items[0] + assert len(agent.id) == 100 + tables = client.table.list_tables(TableType.CHAT, parent_id="_chat_") + assert len(tables.items) == 1 + assert tables.total == 1 + ct = tables.items[0] + assert len(ct.id) == 100 + assert agent.parent_id is None + assert ct.parent_id == agent.id + + +def test_template_import_export( + setup: ServingContext, +): + """ + Test template import. + + Args: + setup (ServingContext): Setup. + """ + # Create template + template = JamAI(user_id=setup.superuser_id).projects.create_project( + ProjectCreate(organization_id=TEMPLATE_ORG_ID, name="Template") + ) + client = JamAI(user_id=setup.superuser_id, project_id=template.id) + tables = [] + try: + # Create the tables + for table_type in TableType: + if table_type == TableType.CHAT: + parquet_filepath = FILES["export-v0.4-chat-agent.parquet"] + else: + parquet_filepath = FILES[f"export-v0.4-{table_type}.parquet"] + table = client.table.import_table( + table_type, + TableImportRequest(file_path=parquet_filepath, table_id_dst=None), + ) + tables.append(table) + # Import as new project + imported_project = client.projects.import_template( + template.id, + project_id="", + organization_id=setup.superorg_id, + ) + assert isinstance(imported_project, ProjectRead) + assert imported_project.id != setup.project_id + _check_tables(setup.superuser_id, imported_project.id) + # Import into existing project + with create_project( + dict(name="Superorg Project 2"), + user_id=setup.superuser_id, + organization_id=setup.superorg_id, + ) as p1: + imported_project = client.projects.import_template( + template.id, + project_id=p1.id, + organization_id="", + ) + assert isinstance(imported_project, ProjectRead) + assert imported_project.id == p1.id + _check_tables(setup.superuser_id, imported_project.id) + # Should not change existing metadata + project = client.projects.get_project(p1.id) + assert project.name == "Superorg Project 2" + # Should fail if tables already exist + with pytest.raises(ResourceExistsError): + client.projects.import_template( + template.id, + project_id=p1.id, + organization_id="", + ) + finally: + for table in tables: + client.table.delete_table(table_type, table.id) + + +if __name__ == "__main__": + test_list_projects() diff --git a/services/api/tests/routers/test_serving.py b/services/api/tests/routers/test_serving.py new file mode 100644 index 0000000..729081c --- /dev/null +++ b/services/api/tests/routers/test_serving.py @@ -0,0 +1,1201 @@ +import base64 +from copy import deepcopy +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from time import sleep +from typing import Generator + +import numpy as np +import pytest +from flaky import flaky + +from jamaibase import JamAI, JamAIAsync +from jamaibase.types import ( + ChatCompletionChoice, + ChatCompletionChunkResponse, + ChatCompletionMessage, + ChatCompletionResponse, + ChatCompletionUsage, + ChatEntry, + ChatRequest, + DeploymentCreate, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingUsage, + ModelInfoListResponse, + OkResponse, + OrganizationCreate, + RAGParams, + References, + RerankingRequest, + StripePaymentInfo, + TextContent, +) +from jamaibase.utils.exceptions import BadInputError, ForbiddenError, ResourceNotFoundError +from owl.configs import ENV_CONFIG +from owl.types import ( + CloudProvider, + ModelCapability, + ModelType, + Role, + TableType, +) +from owl.utils import uuid7_str +from owl.utils.test import ( + DS_PARAMS, + ELLM_EMBEDDING_DEPLOYMENT, + GPT_41_NANO_CONFIG, + GPT_41_NANO_DEPLOYMENT, + STREAM_PARAMS, + TEXT_EMBEDDING_3_SMALL_CONFIG, + RERANK_ENGLISH_v3_SMALL_CONFIG, + RERANK_ENGLISH_v3_SMALL_DEPLOYMENT, + add_table_rows, + create_deployment, + create_model_config, + create_organization, + create_project, + create_table, + create_user, +) + +METER_RETRY = 50 +METER_RETRY_DELAY = 1 +# Together AI sometimes take a long time +CHAT_TIMEOUT = 30 +RERANK_TIMEOUT = 60 +EMBED_TIMEOUT = 30 + + +@dataclass(slots=True) +class ServingContext: + superuser_id: str + user_id: str + superorg_id: str + org_id: str + project_ids: list[str] + chat_model_id: str + short_chat_model_id: str + embedding_model_id: str + rerank_model_id: str + chat_deployment_id: str + embedding_deployment_id: str + rerank_deployment_id: str + llm_input_costs: float + llm_output_costs: float + embed_costs: float + rerank_costs: float + chat_request: ChatRequest + chat_request_text_array: ChatRequest + chat_request_short: ChatRequest + embedding_request: EmbeddingRequest + reranking_request: RerankingRequest + + +def _metrics_match_llm_token_counts(metrics_data, serving_info): + count_true = 0 + for entry in metrics_data.get("data", []): + if entry["groupBy"].get("model", "") == serving_info["model"]: + if ( + entry["groupBy"]["type"] == "input" + and entry["value"] == serving_info["prompt_tokens"] + ): + count_true += 1 + if ( + entry["groupBy"]["type"] == "output" + and entry["value"] == serving_info["completion_tokens"] + ): + count_true += 1 + return count_true == 2 + + +def _metrics_match_llm_spent(metrics_data, serving_info): + count_true = 0 + for entry in metrics_data["data"]: + if ( + entry["groupBy"].get("model", "") == serving_info["model"] + and entry["groupBy"].get("category", "") == "llm_tokens" + ): + if ( + entry["groupBy"]["type"] == "input" + and round(entry["value"], 8) == serving_info["prompt_costs"] + ): + count_true += 1 + if ( + entry["groupBy"]["type"] == "output" + and round(entry["value"], 8) == serving_info["completion_costs"] + ): + count_true += 1 + return count_true == 2 + + +def _metrics_match_embed_token_counts(metrics_data, serving_info): + count_true = 0 + for entry in metrics_data["data"]: + if entry["groupBy"].get("model", "") == serving_info["model"]: + if entry["value"] == serving_info["tokens"]: + count_true += 1 + return count_true == 1 + + +def _metrics_match_embed_spent(metrics_data, serving_info): + count_true = 0 + for entry in metrics_data["data"]: + if ( + entry["groupBy"].get("model", "") == serving_info["model"] + and entry["groupBy"].get("category", "") == "embedding_tokens" + ): + if round(entry["value"], 8) == serving_info["costs"]: + count_true += 1 + return count_true == 1 + + +def _metrics_match_rerank_search_counts(metrics_data, serving_info): + count_true = 0 + for entry in metrics_data["data"]: + if entry["groupBy"].get("model", "") == serving_info["model"]: + if entry["value"] == serving_info["documents"]: + count_true += 1 + return count_true == 1 + + +def _metrics_match_rerank_spent(metrics_data, serving_info): + count_true = 0 + for entry in metrics_data["data"]: + if ( + entry["groupBy"].get("model", "") == serving_info["model"] + and entry["groupBy"].get("category", "") == "reranker_searches" + ): + if round(entry["value"], 8) == serving_info["costs"]: + count_true += 1 + return count_true == 1 + + +@pytest.fixture(scope="module") +def setup(): + """ + Fixture to set up the necessary organization, models, deployments, and projects for serving tests. + """ + with ( + # Create superuser + create_user() as superuser, + # Create user + create_user({"email": "testuser@example.com", "name": "Test User"}) as user, + # Create organization + create_organization(body=OrganizationCreate(name="TSP"), user_id=superuser.id) as superorg, + create_organization(body=OrganizationCreate(name="Org"), user_id=user.id) as org, + # Create project + create_project(dict(name="P0"), user_id=superuser.id, organization_id=superorg.id) as p0, + create_project(dict(name="P1"), user_id=superuser.id, organization_id=superorg.id) as p1, + create_project(dict(name="P2"), user_id=user.id, organization_id=org.id) as p2, + ): + assert superuser.id == "0" + assert superorg.id == "0" + projects = [p0, p1, p2] + client = JamAI(user_id=superuser.id) + # Join organization and project + client.organizations.join_organization( + user_id=user.id, organization_id=superorg.id, role=Role.ADMIN + ) + client.projects.join_project(user_id=user.id, project_id=p0.id, role=Role.ADMIN) + client.projects.join_project(user_id=user.id, project_id=p1.id, role=Role.ADMIN) + # Create models + with ( + create_model_config(GPT_41_NANO_CONFIG) as llm_config, + create_model_config( + dict( + # Max context length = 5 + id=f"ellm/lorem-context-5/{uuid7_str()}", + type=ModelType.LLM, + name="Short-Context Chat Model", + capabilities=[ModelCapability.CHAT], + context_length=5, + languages=["en"], + owned_by="ellm", + ) + ) as short_llm_config, + create_model_config(TEXT_EMBEDDING_3_SMALL_CONFIG) as embed_config, + create_model_config(RERANK_ENGLISH_v3_SMALL_CONFIG) as rerank_config, + ): + # Create deployments + with ( + create_deployment(GPT_41_NANO_DEPLOYMENT) as chat_deployment, + create_deployment( + DeploymentCreate( + model_id=short_llm_config.id, + name="Short chat Deployment", + provider="custom", + routing_id=short_llm_config.id, + api_base=ENV_CONFIG.test_llm_api_base, + ) + ), + create_deployment( + ELLM_EMBEDDING_DEPLOYMENT.model_copy(update=dict(model_id=embed_config.id)) + ) as embedding_deployment, + create_deployment(RERANK_ENGLISH_v3_SMALL_DEPLOYMENT) as rerank_deployment, + ): + # Yield the setup data for use in tests + yield ServingContext( + superuser_id=superuser.id, + user_id=user.id, + superorg_id=superorg.id, + org_id=org.id, + project_ids=[project.id for project in projects], + chat_model_id=llm_config.id, + short_chat_model_id=short_llm_config.id, + embedding_model_id=embed_config.id, + rerank_model_id=rerank_config.id, + chat_deployment_id=chat_deployment.id, + embedding_deployment_id=embedding_deployment.id, + rerank_deployment_id=rerank_deployment.id, + llm_input_costs=llm_config.llm_input_cost_per_mtoken, + llm_output_costs=llm_config.llm_output_cost_per_mtoken, + embed_costs=embed_config.embedding_cost_per_mtoken, + rerank_costs=rerank_config.reranking_cost_per_ksearch, + chat_request=ChatRequest( + model=llm_config.id, + # Test malformed input + messages=[ChatEntry.system(content=""), ChatEntry.user(content="Hi")], + max_tokens=3, + stream=False, + ), + # TODO: Test image and audio input + chat_request_text_array=ChatRequest( + model=llm_config.id, + messages=[ + ChatEntry.user( + content=[ + TextContent(text="Hi "), + TextContent(text="there"), + ] + ) + ], + max_tokens=3, + stream=False, + ), + chat_request_short=ChatRequest( + model=short_llm_config.id, + messages=[{"role": "user", "content": "Hi there how is your day going?"}], + max_tokens=4, + stream=False, + ), + embedding_request=EmbeddingRequest( + model=embed_config.id, + input="This is a test input.", + # encoding_format="base64", + ), + reranking_request=RerankingRequest( + model=rerank_config.id, + query="What is the capital of France?", + documents=["London", "Berlin", "Paris"], + ), + ) + + +@pytest.mark.cloud +def test_model_prices(setup: ServingContext): + del setup + client = JamAI() + prices = client.prices.list_model_prices() + assert len(prices.llm_models) == 2 + assert len(prices.embed_models) == 1 + assert len(prices.rerank_models) == 1 + + +def test_model_info(setup: ServingContext): + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + chat_model_id = setup.chat_model_id + + response = client.model_info() + assert isinstance(response, ModelInfoListResponse) + assert len(response.data) == 4 + + response = client.model_info(model=chat_model_id) + assert len(response.data) == 1 + assert response.data[0].id == chat_model_id + assert response.data[0].capabilities == ["chat", "image", "tool"] + + response = client.model_info(capabilities=["chat"]) + assert len(response.data) > 1 + assert all("chat" in m.capabilities for m in response.data) + + response = client.model_info(model="non-existent-model") + assert len(response.data) == 0 # Ensure no data is returned for a non-existent model + + +def test_model_ids(setup: ServingContext): + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + embedding_model_id = setup.embedding_model_id + + response = client.model_ids() + assert isinstance(response, list) + assert len(response) == 4 + + response = client.model_ids(prefer=embedding_model_id) + assert isinstance(response, list) + assert len(response) == 4 + assert embedding_model_id == response[0] + + +@pytest.mark.cloud +def test_chat_completion_without_credit(setup: ServingContext): + # Only Cloud enforces quota and credits + super_client = JamAI(user_id=setup.superuser_id) + # Set zero credit + response = super_client.organizations.set_credit_grant(setup.org_id, amount=0) + assert isinstance(response, OkResponse) + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[2]) + with pytest.raises(ForbiddenError, match="Insufficient .+ credits"): + client.generate_chat_completions(setup.chat_request) + + +def _test_chat_completion_stream( + setup: ServingContext, request: ChatRequest +) -> list[ChatCompletionChunkResponse | References]: + request.stream = True + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + _responses = client.generate_chat_completions(request) + responses: list[ChatCompletionChunkResponse | References] = [item for item in _responses] + assert len(responses) > 0 + assert all(isinstance(r, (ChatCompletionChunkResponse, References)) for r in responses) + _chat_chunks = [r for r in responses if isinstance(r, ChatCompletionChunkResponse)] + assert all(isinstance(r.content, str) for r in _chat_chunks) + assert len("".join(r.content for r in _chat_chunks)) > 1 + response = responses[-1] + assert isinstance(response.usage, ChatCompletionUsage) + assert isinstance(response.usage.prompt_tokens, int) + assert isinstance(response.usage.completion_tokens, int) + assert isinstance(response.usage.total_tokens, int) + assert response.prompt_tokens > 0 + assert response.completion_tokens > 0 + assert response.total_tokens == response.prompt_tokens + response.completion_tokens + return responses + + +def _compile_and_check_responses( + response: (Generator[ChatCompletionChunkResponse, None, None] | ChatCompletionResponse), + stream: bool, +): + if stream: + responses: list[ChatCompletionChunkResponse] = [item for item in response] + for r in responses: + assert isinstance(r, ChatCompletionChunkResponse) + assert r.object == "chat.completion.chunk" + assert r.usage is None or isinstance(r.usage, ChatCompletionUsage) + content = "".join(getattr(r.choices[0].delta, "content", "") or "" for r in responses) + reasoning_content = "".join( + getattr(r.choices[0].delta, "reasoning_content", "") or "" for r in responses + ) + usage = responses[-1].usage + assert isinstance(usage, ChatCompletionUsage) + + choice = responses[0].choices[0] + assert isinstance(choice, ChatCompletionChoice) + + message = ChatCompletionMessage(content=content) + assert isinstance(message, ChatCompletionMessage) + + if reasoning_content: + message.reasoning_content = reasoning_content + assert isinstance(message.reasoning_content, str) + assert len(message.reasoning_content) > 0 + + choice.delta = None + choice.message = message + + response = ChatCompletionResponse( + id=responses[0].id, + object="chat.completion", + created=responses[0].created, + model=responses[0].model, + choices=[choice], + usage=usage, + service_tier=responses[0].service_tier, + system_fingerprint=responses[0].system_fingerprint, + ) + + assert isinstance(response, ChatCompletionResponse) + assert isinstance(response.id, str) + assert response.object == "chat.completion" + assert isinstance(response.created, int) + assert isinstance(response.model, str) + assert isinstance(response.choices[0], ChatCompletionChoice) + assert isinstance(response.choices[0].message, ChatCompletionMessage) + assert isinstance(response.choices[0].message.content, str) + assert len(response.choices[0].message.content) > 1 + assert isinstance(response.usage, ChatCompletionUsage) + assert isinstance(response.prompt_tokens, int) + assert isinstance(response.completion_tokens, int) + assert response.prompt_tokens > 0 + assert response.completion_tokens > 0 + assert response.usage.total_tokens == response.prompt_tokens + response.completion_tokens + + return response + + +@pytest.mark.cloud +def test_serving_credit(setup: ServingContext): + setup = deepcopy(setup) + super_client = JamAI(user_id=setup.superuser_id, project_id=setup.project_ids[0]) + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[2]) + # Assert credit grant is consumed first + response = super_client.organizations.set_credit_grant(setup.org_id, amount=0.01) + assert isinstance(response, OkResponse) + client.generate_chat_completions(setup.chat_request, timeout=CHAT_TIMEOUT) + sleep(1.0) + org = client.organizations.get_organization(setup.org_id) + assert org.credit == 0 + assert org.credit_grant < 0.01 + # Set credit to zero + super_client.organizations.set_credit_grant(setup.org_id, amount=0) + # Chat completion + for stream in [True, False]: + setup.chat_request.stream = stream + with pytest.raises(ForbiddenError, match="Insufficient quota or credits"): + client.generate_chat_completions(setup.chat_request, timeout=CHAT_TIMEOUT) + # Embedding + with pytest.raises(ForbiddenError, match="Insufficient quota or credits"): + client.generate_embeddings(setup.embedding_request, timeout=EMBED_TIMEOUT) + # Reranking + with pytest.raises(ForbiddenError, match="Insufficient quota or credits"): + client.rerank(setup.reranking_request, timeout=RERANK_TIMEOUT) + # Assert credit is consumed if there is no credit grant + response = client.organizations.purchase_credits(setup.org_id, amount=1) + assert isinstance(response, StripePaymentInfo) + super_client.organizations.set_credit_grant(setup.org_id, amount=0) + client.generate_chat_completions(setup.chat_request, timeout=CHAT_TIMEOUT) + sleep(1.0) + org = client.organizations.get_organization(setup.org_id) + assert org.credit < 1 + assert org.credit_grant == 0 + + +def _test_chat_completion( + client: JamAI, + request: ChatRequest, + stream: bool, + timeout: int = 60, +): + request.stream = stream + response = client.generate_chat_completions(request, timeout=timeout) + return _compile_and_check_responses(response, stream) + + +def _test_chat_completion_no_stream( + setup: ServingContext, request: ChatRequest +) -> ChatCompletionResponse: + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + response = client.generate_chat_completions(request) + assert isinstance(response, ChatCompletionResponse) + assert isinstance(response.content, str) + assert len(response.content) > 1 + assert isinstance(response.usage, ChatCompletionUsage) + assert isinstance(response.prompt_tokens, int) + assert isinstance(response.completion_tokens, int) + assert response.prompt_tokens > 0 + assert response.completion_tokens > 0 + assert response.usage.total_tokens == response.prompt_tokens + response.completion_tokens + return response + + +def test_chat_completion_auto_model(setup: ServingContext): + setup = deepcopy(setup) + setup.chat_request = ChatRequest( + **setup.chat_request.model_dump( + exclude={"model"}, exclude_unset=True, exclude_defaults=True + ) + ) + _test_chat_completion_no_stream(setup, setup.chat_request) + + +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_chat_completion(setup: ServingContext, stream: bool): + setup = deepcopy(setup) + if stream: + _test_chat_completion_stream(setup, setup.chat_request) + else: + _test_chat_completion_no_stream(setup, setup.chat_request) + + +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_chat_completion_text_array(setup: ServingContext, stream: bool): + setup = deepcopy(setup) + if stream: + _test_chat_completion_stream(setup, setup.chat_request_text_array) + else: + _test_chat_completion_no_stream(setup, setup.chat_request_text_array) + + +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_chat_completion_rag(setup: ServingContext, stream: bool): + """ + Chat completion with RAG. + - RAG on empty table: stream and non-stream + - RAG on non-empty table: stream and non-stream + + Args: + setup (ServingContext): Setup. + stream (bool): Stream (SSE) or not. + """ + setup = deepcopy(setup) + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + with create_table(client, TableType.KNOWLEDGE, cols=[]) as kt: + setup.chat_request.rag_params = RAGParams( + reranking_model=None, + table_id=kt.id, + search_query="", + k=2, + ) + ### --- RAG on empty table --- ### + if stream: + responses = _test_chat_completion_stream(setup, setup.chat_request) + assert isinstance(responses[0], References) + else: + response = _test_chat_completion_no_stream(setup, setup.chat_request) + assert isinstance(response.references, References) + ### --- Add data into Knowledge Table --- ### + data = [dict(Title="Pet", Text="My pet's name is Latte.")] + response = add_table_rows(client, TableType.KNOWLEDGE, kt.id, data, stream=False) + assert len(response.rows) == len(data) + if stream: + responses = _test_chat_completion_stream(setup, setup.chat_request) + assert isinstance(responses[0], References) + else: + response = _test_chat_completion_no_stream(setup, setup.chat_request) + assert isinstance(response.references, References) + + +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +async def test_chat_completion_error_cases(setup: ServingContext, stream: bool): + """ + Test chat completion error cases. + - Sync and async + - Exceed context length + - Model not found + + Args: + setup (ServingContext): Setup. + stream (bool): Stream (SSE) or not. + """ + setup = deepcopy(setup) + model_id = setup.chat_request_short.model + setup.chat_request_short.stream = stream + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + aclient = JamAIAsync(user_id=setup.user_id, project_id=setup.project_ids[0]) + # Prompt too long, max tokens too large + with pytest.raises(BadInputError, match="maximum context length"): + client.generate_chat_completions(setup.chat_request_short) + with pytest.raises(BadInputError, match="maximum context length"): + await aclient.generate_chat_completions(setup.chat_request_short) + # Max tokens is too large + setup.chat_request_short.messages[0].content = "Hi there" + with pytest.raises(BadInputError, match="maximum context length"): + client.generate_chat_completions(setup.chat_request_short) + with pytest.raises(BadInputError, match="maximum context length"): + await aclient.generate_chat_completions(setup.chat_request_short) + # Unknown model + setup.chat_request_short.model = "unknown" + with pytest.raises(ResourceNotFoundError, match="Model .+ is not found"): + client.generate_chat_completions(setup.chat_request_short) + with pytest.raises(ResourceNotFoundError, match="Model .+ is not found"): + await aclient.generate_chat_completions(setup.chat_request_short) + # OK + setup.chat_request_short.model = model_id + setup.chat_request_short.max_tokens = 1 + if stream: + responses = list(client.generate_chat_completions(setup.chat_request_short)) + assert len(responses) > 0 + assert all(isinstance(r, ChatCompletionChunkResponse) for r in responses) + assert all(isinstance(r.content, str) for r in responses) + assert len("".join(r.content for r in responses)) > 1 + response = responses[-1] + else: + response = client.generate_chat_completions(setup.chat_request_short) + assert isinstance(response.usage, ChatCompletionUsage) + assert isinstance(response.usage.prompt_tokens, int) + assert isinstance(response.usage.completion_tokens, int) + assert isinstance(response.usage.total_tokens, int) + assert response.prompt_tokens > 0 + assert response.completion_tokens > 0 + assert response.total_tokens == response.prompt_tokens + response.completion_tokens + + +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +@pytest.mark.parametrize("data_source", **DS_PARAMS) +def test_get_llm_usage_metrics(setup: ServingContext, stream: bool, data_source: str): + setup = deepcopy(setup) + setup.chat_request.stream = stream + start_dt = datetime.now(tz=timezone.utc) + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + if stream: + responses = list(client.generate_chat_completions(setup.chat_request)) + response = responses[-1] + else: + response = client.generate_chat_completions(setup.chat_request) + serving_info = { + "model": setup.chat_model_id, + "prompt_tokens": response.prompt_tokens, + "completion_tokens": response.completion_tokens, + } + response_match = False + for _ in range(METER_RETRY): + response = client.meters.get_usage_metrics( + type="llm", + from_=start_dt, + to=start_dt + timedelta(minutes=2), + window_size="10s", + proj_ids=[setup.project_ids[0]], + group_by=["type", "model"], + data_source=data_source, + ) + if _metrics_match_llm_token_counts(response.model_dump(), serving_info): + response_match = True + break + sleep(METER_RETRY_DELAY) + assert response_match + + response = client.organizations.get_organization_metrics( + metric_id="llm", + from_=start_dt, + to=start_dt + timedelta(minutes=2), + window_size="10s", + org_id=setup.superorg_id, + proj_ids=[setup.project_ids[0]], + group_by=["type", "model"], + data_source=data_source, + ) + assert _metrics_match_llm_token_counts(response.model_dump(), serving_info) + + +@pytest.mark.cloud +@flaky(max_runs=5, min_passes=1) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +@pytest.mark.parametrize("data_source", **DS_PARAMS) +def test_get_llm_billing_metrics(setup: ServingContext, stream: bool, data_source: str): + setup = deepcopy(setup) + start_dt = datetime.now(tz=timezone.utc) + setup.chat_request.stream = stream + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + if stream: + responses = list(client.generate_chat_completions(setup.chat_request)) + response = responses[-1] + else: + response = client.generate_chat_completions(setup.chat_request) + serving_info = { + "model": setup.chat_model_id, + "prompt_costs": round(response.prompt_tokens * 1e-6 * setup.llm_input_costs, 8), + "completion_costs": round(response.completion_tokens * 1e-6 * setup.llm_output_costs, 8), + } + response_match = False + for _ in range(METER_RETRY): + response = client.meters.get_billing_metrics( + from_=start_dt, + to=start_dt + timedelta(minutes=2), + window_size="10s", + proj_ids=[setup.project_ids[0]], + group_by=["type", "model", "category"], + data_source=data_source, + ) + if _metrics_match_llm_spent(response.model_dump(), serving_info): + response_match = True + break + sleep(METER_RETRY_DELAY) + assert response_match + + response = client.organizations.get_organization_metrics( + metric_id="spent", + from_=start_dt, + to=start_dt + timedelta(minutes=2), + window_size="10s", + org_id=setup.superorg_id, + proj_ids=[setup.project_ids[0]], + group_by=["type", "model", "category"], + data_source=data_source, + ) + assert _metrics_match_llm_spent(response.model_dump(), serving_info) + + +def _test_chat_reasoning_cloud( + setup: ServingContext, + provider: CloudProvider, + routing_id: str, + stream: bool, + max_tokens: int, + timeout: int = 60, + prompt: str = "How many R is in Red?", + reasoning_effort: str | None = None, + thinking_budget: int | None = None, +): + model_id = setup.chat_model_id + super_client = JamAI(user_id=setup.superuser_id) + super_client.models.update_deployment( + setup.chat_deployment_id, + dict(provider=provider, routing_id=routing_id), + ) + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + chat_request = ChatRequest( + model=model_id, + messages=[ChatEntry.user(content=prompt)], + max_tokens=max_tokens, + stream=stream, + reasoning_effort=reasoning_effort, + thinking_budget=thinking_budget, + temperature=0, + top_p=0.6, + ) + + response = _test_chat_completion(client, chat_request, stream, timeout) + assert response.model == model_id + return response + + +@flaky(max_runs=3, min_passes=1) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_chat_reasoning_openai(setup: ServingContext, stream: bool): + kwargs = dict( + setup=setup, + provider=CloudProvider.OPENAI, + stream=stream, + max_tokens=1000, + ) + # Test default params + response = _test_chat_reasoning_cloud( + routing_id="gpt-5-mini", + **kwargs, + ) + assert len(response.content) > 0 + # Test disabling reasoning + response = _test_chat_reasoning_cloud( + routing_id="gpt-5-mini", + reasoning_effort="disable", + **kwargs, + ) + assert len(response.content) > 0 + assert response.reasoning_tokens < 300 + # Test reasoning effort + med_response = _test_chat_reasoning_cloud( + routing_id="gpt-5-mini", + thinking_budget=512, + **kwargs, + ) + assert len(med_response.content) > 0 + assert med_response.usage.reasoning_tokens > 0 + + +@flaky(max_runs=3, min_passes=1) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_chat_reasoning_anthropic(setup: ServingContext, stream: bool): + kwargs = dict( + setup=setup, + provider=CloudProvider.ANTHROPIC, + routing_id="claude-sonnet-4-0", + stream=stream, + max_tokens=2200, + ) + # Test default params + response = _test_chat_reasoning_cloud(**kwargs) + assert len(response.content) > 0 + kwargs["routing_id"] = "claude-3-7-sonnet-latest" + response = _test_chat_reasoning_cloud(**kwargs) + assert len(response.content) > 0 + # Test disabling reasoning + response = _test_chat_reasoning_cloud( + reasoning_effort="disable", + **kwargs, + ) + # Test reasoning effort + kwargs["max_tokens"] = 5000 + med_response = _test_chat_reasoning_cloud( + reasoning_effort="medium", + **kwargs, + ) + assert len(med_response.content) > 0 + assert med_response.usage.reasoning_tokens > 0 + + +# @flaky(max_runs=3, min_passes=1) +# @pytest.mark.parametrize("stream", **STREAM_PARAMS) +# def test_chat_reasoning_gemini(setup: ServingContext, stream: bool): +# kwargs = dict( +# setup=setup, +# provider=CloudProvider.GEMINI, +# stream=stream, +# max_tokens=1024, +# ) +# # Test default params +# response = _test_chat_reasoning_cloud( +# routing_id="gemini-2.5-flash-lite", +# **kwargs, +# ) +# assert len(response.content) > 0 +# # Test disabling reasoning +# response = _test_chat_reasoning_cloud( +# reasoning_effort="disable", +# routing_id="gemini-2.5-pro", +# **kwargs, +# ) +# assert len(response.content) > 0 +# # Test reasoning effort +# high_response = _test_chat_reasoning_cloud( +# thinking_budget=512, +# routing_id="gemini-2.5-flash", +# **kwargs, +# ) +# assert len(high_response.content) > 0 +# assert high_response.reasoning_tokens > 0 + + +@flaky(max_runs=5, min_passes=1) +def test_generate_embeddings_auto_model(setup: ServingContext): + setup = deepcopy(setup) + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + setup.embedding_request.model = "" + response = client.generate_embeddings(setup.embedding_request, timeout=EMBED_TIMEOUT) + assert len(response.data) > 0 + embedding = response.data[0].embedding + assert isinstance(embedding, list) + assert len(embedding) > 1 + assert all(isinstance(x, float) for x in embedding) + + +@flaky(max_runs=5, min_passes=1) +@pytest.mark.parametrize( + "texts", + ["What is a llama?", ["What is a llama?", "What is an alpaca?"]], + ids=["str", "list[str]"], +) +def test_generate_embeddings(setup: ServingContext, texts: str | list[str]): + setup = deepcopy(setup) + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + setup.embedding_request.input = texts + # Float embeddings + response = client.generate_embeddings(setup.embedding_request, timeout=EMBED_TIMEOUT) + assert isinstance(response, EmbeddingResponse) + assert isinstance(response.model, str) + assert isinstance(response.usage, EmbeddingUsage) + assert isinstance(response.data, list) + if isinstance(texts, str): + assert len(response.data) == 1 + else: + assert len(response.data) == len(texts) + for d in response.data: + assert isinstance(d.embedding, list) + assert len(d.embedding) > 1 + assert all(isinstance(x, float) for x in d.embedding) + embed_float = np.asarray(response.data[0].embedding, dtype=np.float32) + + # Base64 embeddings + setup.embedding_request.encoding_format = "base64" + response = client.generate_embeddings(setup.embedding_request, timeout=EMBED_TIMEOUT) + assert isinstance(response, EmbeddingResponse) + assert isinstance(response.model, str) + assert isinstance(response.usage, EmbeddingUsage) + assert isinstance(response.data, list) + if isinstance(texts, str): + assert len(response.data) == 1 + else: + assert len(response.data) == len(texts) + for d in response.data: + assert isinstance(d.embedding, str) + assert len(d.embedding) > 1 + embed_base64 = np.frombuffer(base64.b64decode(response.data[0].embedding), dtype=np.float32) + assert len(embed_float) == len(embed_base64) + assert np.allclose(embed_float, embed_base64, atol=0.01, rtol=0.05) + + +@flaky(max_runs=5, min_passes=1) +@pytest.mark.parametrize("data_source", **DS_PARAMS) +def test_get_embed_usage_metrics(setup: ServingContext, data_source: str): + start_dt = datetime.now(tz=timezone.utc) + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + response = client.generate_embeddings(setup.embedding_request, timeout=EMBED_TIMEOUT) + serving_info = { + "model": setup.embedding_model_id, + "tokens": response.usage.total_tokens, + } + response_match = False + for _ in range(METER_RETRY): + response = client.meters.get_usage_metrics( + type="embedding", + from_=start_dt, + to=start_dt + timedelta(minutes=2), + window_size="10s", + group_by=["model"], + data_source=data_source, + ) + if _metrics_match_embed_token_counts(response.model_dump(), serving_info): + response_match = True + break + sleep(METER_RETRY_DELAY) + + assert response_match + + response = client.organizations.get_organization_metrics( + metric_id="embedding", + from_=start_dt, + to=start_dt + timedelta(minutes=2), + window_size="10s", + org_id=setup.superorg_id, + group_by=["model"], + data_source=data_source, + ) + + assert _metrics_match_embed_token_counts(response.model_dump(), serving_info) + + +# response = client.projects.get_usage_metrics( +# type="embedding", +# from_=start_dt, +# to=start_dt + timedelta(minutes=2), +# window_size="10s", +# proj_id=setup.project_ids[0], +# group_by=["model"], +# ) +# assert _metrics_match_embed_token_counts(response.json(), serving_info) + +# response = client.projects.get_usage_metrics( +# type="embedding", +# from_=start_dt, +# to=start_dt + timedelta(minutes=2), +# window_size="10s", +# proj_id=setup.project_ids[1], +# group_by=["model"], +# ) +# assert not _metrics_match_embed_token_counts(response.json(), serving_info) + + +@pytest.mark.cloud +@flaky(max_runs=5, min_passes=1) +@pytest.mark.parametrize("data_source", **DS_PARAMS) +def test_get_embed_billing_metrics(setup: ServingContext, data_source: str): + start_dt = datetime.now(tz=timezone.utc) + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + response = client.generate_embeddings(setup.embedding_request, timeout=EMBED_TIMEOUT) + serving_info = { + "model": setup.embedding_model_id, + "costs": round(response.usage.total_tokens * 1e-6 * setup.embed_costs, 8), + } + response_match = False + for _ in range(METER_RETRY): + response = client.meters.get_billing_metrics( + from_=start_dt, + to=start_dt + timedelta(minutes=2), + window_size="10s", + group_by=["model", "category"], + data_source=data_source, + ) + if _metrics_match_embed_spent(response.model_dump(), serving_info): + response_match = True + break + sleep(METER_RETRY_DELAY) + + assert response_match + + response = client.organizations.get_organization_metrics( + metric_id="spent", + from_=start_dt, + to=start_dt + timedelta(minutes=2), + window_size="10s", + org_id=setup.superorg_id, + group_by=["model", "category"], + data_source=data_source, + ) + assert _metrics_match_embed_spent(response.model_dump(), serving_info) + + +# response = client.projects.get_billing_metrics( +# from_=start_dt, +# to=start_dt + timedelta(minutes=2), +# window_size="10s", +# proj_id=setup.project_ids[0], +# group_by=["model", "category"], +# ) +# assert _metrics_match_embed_spent(response.json(), serving_info) + +# response = client.projects.get_billing_metrics( +# from_=start_dt, +# to=start_dt + timedelta(minutes=2), +# window_size="10s", +# proj_id=setup.project_ids[1], +# group_by=["model", "category"], +# ) +# assert not _metrics_match_embed_spent(response.json(), serving_info) + + +@flaky(max_runs=5, min_passes=1) +def test_rerank_auto_model(setup: ServingContext): + setup = deepcopy(setup) + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + setup.reranking_request.model = "" + response = client.rerank(setup.reranking_request, timeout=RERANK_TIMEOUT) + assert response.results[0].index == 2, f"Reranking results are unsorted: {response.results}" + relevance_scores = [x.relevance_score for x in response.results] + assert len(relevance_scores) == 3 + assert relevance_scores[0] > relevance_scores[1] + + +@flaky(max_runs=5, min_passes=1) +def test_rerank(setup: ServingContext): + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + response = client.rerank(setup.reranking_request, timeout=RERANK_TIMEOUT) + assert response.results[0].index == 2, f"Reranking results are unsorted: {response.results}" + relevance_scores = [x.relevance_score for x in response.results] + assert len(relevance_scores) == 3 + assert relevance_scores[0] > relevance_scores[1] + + +@flaky(max_runs=5, min_passes=1) +@pytest.mark.parametrize("data_source", **DS_PARAMS) +def test_get_rerank_usage_metrics(setup: ServingContext, data_source: str): + start_dt = datetime.now(tz=timezone.utc) + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + response = client.rerank(setup.reranking_request, timeout=RERANK_TIMEOUT) + serving_info = { + "model": setup.rerank_model_id, + "documents": len(response.results), + } + response_match = False + for _ in range(METER_RETRY): + response = client.meters.get_usage_metrics( + type="reranking", + from_=start_dt, + to=start_dt + timedelta(minutes=2), + window_size="10s", + group_by=["model"], + data_source=data_source, + ) + if _metrics_match_rerank_search_counts(response.model_dump(), serving_info): + response_match = True + break + sleep(METER_RETRY_DELAY) + + assert response_match + + response = client.organizations.get_organization_metrics( + metric_id="reranking", + from_=start_dt, + to=start_dt + timedelta(minutes=2), + window_size="10s", + org_id=setup.superorg_id, + group_by=["model"], + data_source=data_source, + ) + + assert _metrics_match_rerank_search_counts(response.model_dump(), serving_info) + + +# response = client.projects.get_usage_metrics( +# type="reranking", +# from_=start_dt, +# to=start_dt + timedelta(minutes=2), +# window_size="10s", +# proj_id=setup.project_ids[0], +# group_by=["model"], +# ) +# assert _metrics_match_rerank_search_counts(response.json(), serving_info) + +# response = client.projects.get_usage_metrics( +# type="reranking", +# from_=start_dt, +# to=start_dt + timedelta(minutes=2), +# window_size="10s", +# proj_id=setup.project_ids[1], +# group_by=["model"], +# ) +# assert not _metrics_match_rerank_search_counts(response.json(), serving_info) + + +@pytest.mark.cloud +@flaky(max_runs=5, min_passes=1) +@pytest.mark.parametrize("data_source", **DS_PARAMS) +def test_get_rerank_billing_metrics(setup: ServingContext, data_source: str): + start_dt = datetime.now(tz=timezone.utc) + client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) + response = client.rerank(setup.reranking_request, timeout=RERANK_TIMEOUT) + serving_info = { + "model": setup.rerank_model_id, + "costs": round(len(response.results) * 1e-3 * setup.rerank_costs, 8), + } + response_match = False + for _ in range(METER_RETRY): + response = client.meters.get_billing_metrics( + from_=start_dt, + to=start_dt + timedelta(minutes=2), + window_size="10s", + group_by=["model", "category"], + data_source=data_source, + ) + if _metrics_match_rerank_spent(response.model_dump(), serving_info): + response_match = True + break + sleep(METER_RETRY_DELAY) + + assert response_match + + response = client.organizations.get_organization_metrics( + metric_id="spent", + from_=start_dt, + to=start_dt + timedelta(minutes=2), + window_size="10s", + org_id=setup.superorg_id, + group_by=["model", "category"], + data_source=data_source, + ) + assert _metrics_match_rerank_spent(response.model_dump(), serving_info) + + +# response = client.projects.get_billing_metrics( +# from_=start_dt, +# to=start_dt + timedelta(minutes=2), +# window_size="10s", +# proj_id=setup.project_ids[0], +# group_by=["model", "category"], +# ) +# assert _metrics_match_rerank_spent(response.json(), serving_info) + +# response = client.projects.get_billing_metrics( +# from_=start_dt, +# to=start_dt + timedelta(minutes=2), +# window_size="10s", +# proj_id=setup.project_ids[1], +# group_by=["model", "category"], +# ) +# assert not _metrics_match_rerank_spent(response.json(), serving_info) + + +# @flaky(max_runs=5, min_passes=1) +# def test_chat_arbitrary_provider(setup: ServingContext): +# setup = deepcopy(setup) +# client = JamAI(user_id=setup.superuser_id) +# model_id = uuid7_str("llm-model/") +# with create_model_config( +# { +# "id": model_id, +# "type": "llm", +# "name": "Chat Model", +# "capabilities": ["chat"], +# "context_length": 1024, +# "languages": ["en"], +# } +# ): +# with create_deployment( +# DeploymentCreate( +# model_id=model_id, +# name="Chat Deployment", +# provider="abc", +# routing_id="openai/gpt-4o-mini", +# api_base="", +# ) +# ): +# client.organizations.update_organization( +# OrganizationUpdate( +# id=setup.org_id, +# external_keys=dict(abc=ENV_CONFIG.openai_api_key_plain), +# ) +# ) +# client = JamAI(user_id=setup.user_id, project_id=setup.project_ids[0]) +# setup.chat_request.model = model_id +# response = client.generate_chat_completions(setup.chat_request, timeout=CHAT_TIMEOUT) +# assert response.model == model_id +# assert isinstance(response.content, str) +# assert len(response.content) > 1 diff --git a/services/api/tests/routers/test_templates.py b/services/api/tests/routers/test_templates.py new file mode 100644 index 0000000..f2e5e1e --- /dev/null +++ b/services/api/tests/routers/test_templates.py @@ -0,0 +1,253 @@ +from dataclasses import dataclass +from os.path import dirname, join, realpath + +import pytest + +from jamaibase import JamAI +from jamaibase.types import ( + OrganizationCreate, + Page, + ProjectCreate, + ProjectRead, + TableImportRequest, + TableMetaResponse, +) +from owl.db import TEMPLATE_ORG_ID +from owl.types import Role, TableType +from owl.utils.test import ( + ELLM_DESCRIBE_CONFIG, + ELLM_DESCRIBE_DEPLOYMENT, + TEXT_EMBEDDING_3_SMALL_CONFIG, + TEXT_EMBEDDING_3_SMALL_DEPLOYMENT, + RERANK_ENGLISH_v3_SMALL_CONFIG, + RERANK_ENGLISH_v3_SMALL_DEPLOYMENT, + add_table_rows, + create_deployment, + create_model_config, + create_organization, + create_project, + create_user, + get_file_map, + upload_file, +) + +TEST_FILE_DIR = join(dirname(dirname(realpath(__file__))), "files") +FILES = get_file_map(TEST_FILE_DIR) + + +def _create_template(client: JamAI, name: str = "Template") -> ProjectRead: + return client.projects.create_project( + ProjectCreate(organization_id=TEMPLATE_ORG_ID, name=name) + ) + + +# We test template creation just as sanity check +# Template creation, update, deletion operations are the same as projects +def test_create_template(): + with create_user() as superuser, create_organization(user_id=superuser.id) as superorg: + assert superorg.id == "0" + client = JamAI(user_id=superuser.id) + template = _create_template(client, "Template 1") + try: + assert isinstance(template, ProjectRead) + assert template.created_by == superuser.id, f"{template.created_by=}, {superuser.id=}" + assert template.name == "Template 1" + assert template.organization_id == TEMPLATE_ORG_ID + # Check memberships + user = client.users.get_user(superuser.id) + assert len(user.org_memberships) == 2 # Superorg + Template + org_memberships = {m.organization_id: m for m in user.org_memberships} + assert "0" in org_memberships + assert org_memberships["0"].role == Role.ADMIN + assert TEMPLATE_ORG_ID in org_memberships + assert org_memberships[TEMPLATE_ORG_ID].role == Role.ADMIN + proj_memberships = {m.project_id: m for m in user.proj_memberships} + assert proj_memberships[template.id].role == Role.ADMIN + finally: + client.projects.delete_project(template.id) + + +@dataclass(slots=True) +class ServingContext: + superuser_id: str + superorg_id: str + project_id: str + embedding_size: int + image_uri: str + audio_uri: str + document_uri: str + chat_model_id: str + embed_model_id: str + rerank_model_id: str + + +@pytest.fixture(scope="module") +def setup(): + """ + Fixture to set up the necessary organization and projects for file tests. + """ + with ( + create_user(dict(email="admin@up.com", name="System Admin")) as superuser, + create_organization( + body=OrganizationCreate(name="Superorg"), user_id=superuser.id + ) as superorg, + create_project( + dict(name="Superorg Project"), user_id=superuser.id, organization_id=superorg.id + ) as p0, + ): + assert superuser.id == "0" + assert superorg.id == "0" + + bge = "ellm/BAAI/bge-m3" + with ( + # Create models + create_model_config(ELLM_DESCRIBE_CONFIG) as desc_llm_config, + create_model_config(TEXT_EMBEDDING_3_SMALL_CONFIG) as embed_config, + create_model_config(RERANK_ENGLISH_v3_SMALL_CONFIG) as rerank_config, + create_model_config( + TEXT_EMBEDDING_3_SMALL_CONFIG.model_copy(update=dict(id=bge, owned_by="ellm")) + ), + # Create deployments + create_deployment(ELLM_DESCRIBE_DEPLOYMENT), + create_deployment(TEXT_EMBEDDING_3_SMALL_DEPLOYMENT), + create_deployment(RERANK_ENGLISH_v3_SMALL_DEPLOYMENT), + create_deployment( + TEXT_EMBEDDING_3_SMALL_DEPLOYMENT.model_copy(update=dict(model_id=bge)) + ), + ): + client = JamAI(user_id=superuser.id, project_id=p0.id) + image_uri = upload_file(client, FILES["rabbit.jpeg"]).uri + audio_uri = upload_file(client, FILES["gutter.mp3"]).uri + document_uri = upload_file( + client, FILES["LLMs as Optimizers [DeepMind ; 2023].pdf"] + ).uri + yield ServingContext( + superuser_id=superuser.id, + superorg_id=superorg.id, + project_id=p0.id, + embedding_size=embed_config.final_embedding_size, + image_uri=image_uri, + audio_uri=audio_uri, + document_uri=document_uri, + chat_model_id=desc_llm_config.id, + embed_model_id=embed_config.id, + rerank_model_id=rerank_config.id, + ) + + +def test_get_list_templates(setup: ServingContext): + super_client = JamAI(user_id=setup.superuser_id) + public_client = JamAI() + # List projects + response = super_client.projects.list_projects(organization_id=setup.superorg_id) + assert isinstance(response, Page) + assert len(response.items) == 1 + assert response.total == 1 + # List templates + response = super_client.templates.list_templates() + assert isinstance(response, Page) + assert len(response.items) == 0 + assert response.total == 0 + assert public_client.templates.list_templates().total == 0 + # Create templates + templates = [] + try: + templates = [_create_template(super_client) for _ in range(2)] + # There are now two templates + response = super_client.templates.list_templates() + assert len(response.items) == 2 + assert response.total == 2 + assert all(t.name.startswith("Template") for t in templates) + assert public_client.templates.list_templates().total == 2 + # There is still just one project + assert super_client.projects.list_projects(organization_id=setup.superorg_id).total == 1 + # Get a template + template = super_client.templates.get_template(templates[0].id) + assert template.id == templates[0].id + assert template.name == templates[0].name + finally: + for template in templates: + super_client.projects.delete_project(template.id) + + +def test_get_list_template_tables_rows(setup: ServingContext): + # Create template + template = _create_template(JamAI(user_id=setup.superuser_id)) + super_client = JamAI(user_id=setup.superuser_id, project_id=template.id) + public_client = JamAI() + tables: list[TableMetaResponse] = [] + try: + # Create the tables + for table_type in TableType: + if table_type == TableType.CHAT: + parquet_filepath = FILES["export-v0.4-chat-agent.parquet"] + else: + parquet_filepath = FILES[f"export-v0.4-{table_type}.parquet"] + table = super_client.table.import_table( + table_type, + TableImportRequest(file_path=parquet_filepath, table_id_dst=None), + ) + assert isinstance(table, TableMetaResponse) + tables.append(table) + # Get and list tables + # Get and list table rows + for i, table_type in enumerate(TableType): + table_id = tables[i].id + # List tables + response = super_client.templates.list_tables(template.id, table_type) + assert isinstance(response, Page) + assert all(isinstance(r, TableMetaResponse) for r in response.items) + assert len(response.items) == 1 + assert response.total == 1 + assert public_client.templates.list_tables(template.id, table_type).total == 1 + # Get table + table = super_client.templates.get_table(template.id, table_type, table_id) + assert isinstance(table, TableMetaResponse) + assert table.id == response.items[0].id + table = public_client.templates.get_table(template.id, table_type, table_id) + assert table.id == response.items[0].id + # List rows + rows = super_client.templates.list_table_rows(template.id, table_type, table_id) + assert isinstance(rows, Page) + assert all(isinstance(r, dict) for r in rows.items) + assert len(rows.items) == 1 + assert rows.total == 1 + rows = public_client.templates.list_table_rows(template.id, table_type, table_id) + assert rows.total == 1 + # Get row + row = super_client.templates.get_table_row( + template.id, table_type, table_id, rows.items[0]["ID"] + ) + assert isinstance(row, dict) + assert row["ID"] == rows.items[0]["ID"] + row = public_client.templates.get_table_row( + template.id, table_type, table_id, rows.items[0]["ID"] + ) + assert row["ID"] == rows.items[0]["ID"] + # Try generation + if table_type == TableType.ACTION: + response = add_table_rows( + super_client, table_type, table_id, [{"question": "Why"}], stream=False + ) + assert len(response.rows) == 1 + assert "There is a text" in response.rows[0].columns["answer"].content + elif table_type == TableType.KNOWLEDGE: + response = add_table_rows(super_client, table_type, table_id, [{}], stream=False) + assert len(response.rows) == 1 + else: + response = add_table_rows( + super_client, table_type, table_id, [{"User": "Hi"}], stream=False + ) + assert len(response.rows) == 1 + assert "There is a text" in response.rows[0].columns["AI"].content + # List rows again + rows = super_client.templates.list_table_rows(template.id, table_type, table_id) + assert isinstance(rows, Page) + assert all(isinstance(r, dict) for r in rows.items) + assert len(rows.items) == 2 + assert rows.total == 2 + rows = public_client.templates.list_table_rows(template.id, table_type, table_id) + assert rows.total == 2 + finally: + for table in tables: + super_client.table.delete_table(table_type, table.id) diff --git a/services/api/tests/routers/test_users.py b/services/api/tests/routers/test_users.py new file mode 100644 index 0000000..21a66e0 --- /dev/null +++ b/services/api/tests/routers/test_users.py @@ -0,0 +1,341 @@ +import httpx +import pytest +from pwdlib import PasswordHash + +from jamaibase import JamAI +from jamaibase.types import ( + OkResponse, + Page, + PasswordChangeRequest, + PasswordLoginRequest, + UserRead, +) +from jamaibase.utils.exceptions import ( + AuthorizationError, + BadInputError, + ForbiddenError, + ResourceExistsError, + ResourceNotFoundError, +) +from owl.utils.test import ( + EMAIL, + create_organization, + create_user, + register_password, + setup_organizations, + setup_projects, +) + +# --- Auth --- # + +PASSWORD = "test_password" + + +def test_register_password(): + with register_password(dict(email=EMAIL, name="Carl", password=PASSWORD)): + pass + + +def test_login_password(): + with register_password(dict(email=EMAIL, name="Carl", password=PASSWORD)): + user = JamAI().auth.login_password(PasswordLoginRequest(email=EMAIL, password=PASSWORD)) + assert isinstance(user, UserRead) + + +def test_login_password_wrong_pw(): + with register_password(dict(email=EMAIL, name="Carl", password=PASSWORD)): + user = JamAI().auth.login_password(PasswordLoginRequest(email=EMAIL, password=PASSWORD)) + assert isinstance(user, UserRead) + # Wrong password should fail + with pytest.raises(AuthorizationError): + JamAI().auth.login_password(PasswordLoginRequest(email=EMAIL, password="PASSWORD")) + + +def test_login_password_hash(): + with register_password(dict(email=EMAIL, name="Carl", password=PASSWORD)): + user = JamAI().auth.login_password(PasswordLoginRequest(email=EMAIL, password=PASSWORD)) + assert isinstance(user, UserRead) + # Password hash should fail + hasher = PasswordHash.recommended() + password_hash = hasher.hash(PASSWORD) + with pytest.raises((AuthorizationError, BadInputError)): + JamAI().auth.login_password(dict(email=EMAIL, password=password_hash)) + + +def test_change_password(): + with register_password(dict(email=EMAIL, name="Carl", password=PASSWORD)): + # Existing password OK + user = JamAI().auth.login_password(PasswordLoginRequest(email=EMAIL, password=PASSWORD)) + assert isinstance(user, UserRead) + # Change password + user = JamAI(user_id=user.id).auth.change_password( + PasswordChangeRequest(email=EMAIL, password=PASSWORD, new_password=PASSWORD * 2) + ) + assert isinstance(user, UserRead) + # Old password should fail + with pytest.raises(AuthorizationError): + JamAI().auth.login_password(PasswordLoginRequest(email=EMAIL, password=PASSWORD)) + # New password OK + user = JamAI().auth.login_password( + PasswordLoginRequest(email=EMAIL, password=PASSWORD * 2) + ) + assert isinstance(user, UserRead) + + +@pytest.mark.cloud +def test_change_password_wrong_user(): + with ( + register_password(dict(email=EMAIL, name="Carl", password=PASSWORD)) as u0, + register_password(dict(email="russell@up.com", name="Russell", password="test")) as u1, + ): + # Existing password OK + user = JamAI().auth.login_password(PasswordLoginRequest(email=EMAIL, password=PASSWORD)) + assert user.id == u0.id + # Wrong user should fail + with pytest.raises(ForbiddenError): + JamAI(user_id=u1.id).auth.change_password( + PasswordChangeRequest(email=EMAIL, password="PASSWORD", new_password=PASSWORD * 2) + ) + # Old password OK + user = JamAI().auth.login_password(PasswordLoginRequest(email=EMAIL, password=PASSWORD)) + assert user.id == u0.id + # New password should fail + with pytest.raises(AuthorizationError): + JamAI().auth.login_password(PasswordLoginRequest(email=EMAIL, password=PASSWORD * 2)) + + +def test_change_password_wrong_old_pw(): + with register_password(dict(email=EMAIL, name="Carl", password=PASSWORD)): + # Existing password OK + user = JamAI().auth.login_password(PasswordLoginRequest(email=EMAIL, password=PASSWORD)) + assert isinstance(user, UserRead) + # Wrong old password should fail + with pytest.raises(AuthorizationError): + JamAI(user_id=user.id).auth.change_password( + PasswordChangeRequest(email=EMAIL, password="PASSWORD", new_password=PASSWORD * 2) + ) + # Old password OK + user = JamAI().auth.login_password(PasswordLoginRequest(email=EMAIL, password=PASSWORD)) + assert isinstance(user, UserRead) + # New password should fail + with pytest.raises(AuthorizationError): + JamAI().auth.login_password(PasswordLoginRequest(email=EMAIL, password=PASSWORD * 2)) + + +# --- Users --- # + + +def test_create_superuser(): + with create_user(dict(email=EMAIL, name="Carl", password="test_password")) as user: + assert user.id == "0" + assert user.email == EMAIL + assert isinstance(user.password_hash, str) + assert user.password_hash == "***" + + +def test_create_user(): + with ( + create_user(), + create_user( + dict(email="russell@up.com", name="Russell", password="test_password") + ) as user, + ): + assert user.id != "0" + assert isinstance(user.password_hash, str) + assert user.password_hash == "***" + + +def test_create_user_existing_id(): + with create_user(), create_user(dict(email="russell@up.com", name="Russell")) as user: + with pytest.raises(ResourceExistsError): + with create_user(dict(id=user.id, email="random@up.com", name="Random")): + pass + + +def test_create_user_existing_email(): + with ( + create_user(dict(email=EMAIL, name="Carl", password=PASSWORD)) as user, + ): + with pytest.raises(ResourceExistsError, match="email"): + with create_user(dict(email=user.email, name="Random")): + pass + + +def test_get_list_users(): + relations = {"org_memberships", "proj_memberships", "organizations", "projects"} + dump_kwargs = dict(warnings="error", exclude=relations) + with ( + # Create users with name ordering opposite of creation order + # Also test case sensitivity + create_user(dict(email="russell@up.com", name="Russell")) as superuser, + create_user( + dict(email="carl@up.com", name="carl", google_id="1234", github_id="22") + ) as u1, + create_user(dict(email="aaron@up.com", name="Aaron")) as u2, + create_organization(user_id=superuser.id), + ): + super_client = JamAI(user_id=superuser.id) + ### --- List users --- ### + num_users = 3 + users = super_client.users.list_users() + assert isinstance(users, Page) + assert len(users.items) == num_users + assert users.total == num_users + assert all(isinstance(m, UserRead) for m in users.items) + assert users.items[0].id == superuser.id + assert users.items[1].id == u1.id + + ### --- Get user --- ### + for u in users.items: + _user = super_client.users.get_user(u.id) + assert isinstance(_user, UserRead) + u = u.model_dump(**dump_kwargs) + _user = _user.model_dump(**dump_kwargs) + assert _user == u, f"Data mismatch: {_user=}, {u=}" + # Fetch using Google ID + _user = super_client.users.get_user(f"google-oauth2|{u1.google_id}") + assert isinstance(_user, UserRead) + u = u1.model_dump(**dump_kwargs) + _user = _user.model_dump(**dump_kwargs) + assert _user == u, f"Data mismatch: {_user=}, {u=}" + # Fetch using GitHub ID + _user = super_client.users.get_user(f"github|{u1.github_id}") + assert isinstance(_user, UserRead) + u = u1.model_dump(**dump_kwargs) + _user = _user.model_dump(**dump_kwargs) + assert _user == u, f"Data mismatch: {_user=}, {u=}" + + ### --- List users (offset and limit) --- ### + _users = super_client.users.list_users(offset=0, limit=1) + assert len(_users.items) == 1 + assert _users.total == num_users + assert _users.items[0].id == users.items[0].id, f"{_users.items=}" + _users = super_client.users.list_users(offset=1, limit=1) + assert len(_users.items) == 1 + assert _users.total == num_users + assert _users.items[0].id == users.items[1].id, f"{_users.items=}" + # Offset >= num rows + _users = super_client.users.list_users(offset=num_users, limit=1) + assert len(_users.items) == 0 + assert _users.total == num_users + _users = super_client.users.list_users(offset=num_users + 1, limit=1) + assert len(_users.items) == 0 + assert _users.total == num_users + # Invalid offset and limit + with pytest.raises(BadInputError): + super_client.users.list_users(offset=0, limit=0) + with pytest.raises(BadInputError): + super_client.users.list_users(offset=-1, limit=1) + + ### --- List users (order_by and order_ascending) --- ### + _users = super_client.users.list_users(order_ascending=False) + assert len(users.items) == num_users + assert _users.total == num_users + assert [t.id for t in _users.items[::-1]] == [t.id for t in users.items] + _users = super_client.users.list_users(order_by="name") + assert len(users.items) == num_users + assert _users.total == num_users + assert [t.id for t in _users.items[::-1]] == [t.id for t in users.items] + assert [t.name for t in _users.items] == [u2.name, u1.name, superuser.name] + _users = super_client.users.list_users(order_by="name", order_ascending=False) + assert len(users.items) == num_users + assert _users.total == num_users + assert [t.id for t in _users.items] == [t.id for t in users.items] + + ### --- List users (search_query and search_columns) --- ### + _users = super_client.users.list_users(search_query="rus") + assert len(_users.items) == 1 + assert _users.total == 1 + assert _users.total != num_users + assert _users.items[0].id == superuser.id + _users = super_client.users.list_users(search_query="rus", offset=1) + assert len(_users.items) == 0 + assert _users.total == 1 + + +@pytest.mark.cloud +def test_list_users_permission(): + with create_user(), create_user(dict(email="russell@up.com", name="Russell")) as user: + with pytest.raises(ForbiddenError): + JamAI(user_id=user.id).users.list_users() + + +def test_get_nonexistent_user(): + with setup_organizations() as ctx: + client = JamAI(user_id=ctx.superuser.id) + response = client.users.get_user(ctx.user.id) + assert isinstance(response, UserRead) + with pytest.raises(ResourceNotFoundError): + client.users.get_user("fake") + + +def test_update_user(): + with create_user() as user: + client = JamAI(user_id=user.id) + new_name = f"{user.name} {user.name}" + response = client.users.update_user(dict(name=new_name)) + assert isinstance(response, UserRead) + assert response.name == new_name + assert response.model_dump( + exclude={"updated_at", "name", "preferred_name"} + ) == user.model_dump(exclude={"updated_at", "name", "preferred_name"}) + assert response.updated_at > user.updated_at + + +def test_delete_user(): + with ( + create_user() as superuser, + create_user(dict(email="russell@up.com", name="Russell")) as user, + create_organization(user_id=superuser.id), + ): + client = JamAI(user_id=superuser.id) + # Fetch + response = client.users.get_user(user.id) + assert isinstance(response, UserRead) + # Delete + response = JamAI(user_id=user.id).users.delete_user(missing_ok=False) + assert isinstance(response, OkResponse) + assert response.ok is True + # Fetch again + with pytest.raises(ResourceNotFoundError): + client.users.get_user(user.id) + + +def test_cors(): + def _assert_cors(_response: httpx.Response): + assert "Access-Control-Allow-Origin" in _response.headers, _response.headers + assert "Access-Control-Allow-Methods" in _response.headers, _response.headers + assert "Access-Control-Allow-Headers" in _response.headers, _response.headers + assert "Access-Control-Allow-Credentials" in _response.headers, _response.headers + assert _response.headers["Access-Control-Allow-Credentials"].lower() == "true" + + with setup_projects() as ctx: + client = JamAI(user_id=ctx.superuser.id) + + headers = { + "Origin": "http://example.com", + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": "Content-Type", + } + + # Preflight + response = httpx.options(client.api_base, headers=headers) + _assert_cors(response) + print(response.headers) + + endpoint = f"{client.api_base}/v1/models" + # Assert preflight no auth + response = httpx.options(endpoint, headers=headers) + _assert_cors(response) + # Assert CORS headers in methods with auth + response = httpx.get( + endpoint, + headers={ + "Authorization": "Bearer PAT_KEY", + **headers, + }, + ) + assert response.status_code == 401 + assert "Access-Control-Allow-Origin" in response.headers, response.headers + assert "Access-Control-Allow-Credentials" in response.headers, response.headers + assert response.headers["Access-Control-Allow-Credentials"].lower() == "true" diff --git a/services/api/tests/test_db.py b/services/api/tests/test_db.py new file mode 100644 index 0000000..6275c52 --- /dev/null +++ b/services/api/tests/test_db.py @@ -0,0 +1,24 @@ +from sqlmodel import select + +from owl.db import async_session, sync_session +from owl.db.models import User +from owl.types import UserAuth +from owl.utils.test import create_user + + +async def test_async_session(): + with create_user() as user: + assert user.id == "0" + async with async_session() as session: + users = (await session.exec(select(User))).all() + users = [UserAuth.model_validate(user) for user in users] + assert len(users) == 1 + + +async def test_sync_session(): + with create_user() as user: + assert user.id == "0" + with sync_session() as session: + users = (session.exec(select(User))).all() + users = [UserAuth.model_validate(user) for user in users] + assert len(users) == 1 diff --git a/services/api/tests/test_docparse.py b/services/api/tests/test_docparse.py new file mode 100644 index 0000000..01fbb57 --- /dev/null +++ b/services/api/tests/test_docparse.py @@ -0,0 +1,84 @@ +import hashlib +import json +import os +from os.path import basename, dirname, join, realpath + +import pytest + +from owl.docparse import DoclingLoader +from owl.utils.test import get_file_map + +TEST_FILE_DIR = join(dirname(realpath(__file__)), "files") +FILES = get_file_map(TEST_FILE_DIR) + +GT_FILE_DIR = join(dirname(realpath(__file__)), "docling_ground_truth") +GT_FILES = get_file_map(GT_FILE_DIR) + + +def get_canonical_json_hash(data: dict) -> str: + """ + Calculates a SHA256 hash of a dictionary after canonical JSON serialization. + Ensures keys are sorted and spacing is compact for consistent hashing. + """ + if not isinstance(data, dict): + # If data is not a dict (e.g., an error string or None from .get("document", {})), + # we still need a consistent way to hash it. Converting to string is a simple way. + # However, for this test, 'document' should ideally always be a dict or an empty dict. + # If it can be None or other types, this part might need more specific handling + # based on expected behavior. + stable_representation = str(data) + else: + # sort_keys=True: Essential for canonical form. + # separators=(',', ':'): Creates the most compact JSON, removing unnecessary whitespace. + stable_representation = json.dumps(data, sort_keys=True, separators=(",", ":")) + + json_bytes = stable_representation.encode("utf-8") + return hashlib.sha256(json_bytes).hexdigest() + + +@pytest.mark.timeout(180) +@pytest.mark.parametrize( + "doc_path", + [ + FILES["Swire_AR22_e_230406_sample.pdf"], + FILES["GitHub è¡¨å•æž¶æž„语法 - GitHub 文档.pdf"], + ], + ids=lambda x: basename(x), +) +async def test_convert_pdf_document_to_markdown(doc_path: str): + """ + Test the conversion of various document types to markdown. + """ + loader = DoclingLoader( + request_id="test_request", + docling_serve_url="http://localhost:5001", + ) + with open(doc_path, "rb") as f: + doc_content_bytes = f.read() + + api_response_data = await loader.retrieve_document_content( + basename(doc_path), doc_content_bytes + ) + + api_document_content = api_response_data.get("document", {}) + + # Sanity check on md_content from the API response + md_content_from_api = api_document_content.get("md_content", "") + assert isinstance(md_content_from_api, str) + + # --- Ground Truth Comparison --- + gt_file_path = GT_FILES[f"{os.path.splitext(basename(doc_path))[0]}.json"] + + with open(gt_file_path, "r", encoding="utf-8") as f_gt: + expected_document_content = json.load(f_gt).get("document", {}) + + api_content_hash = get_canonical_json_hash(api_document_content) + gt_content_hash = get_canonical_json_hash(expected_document_content) + + assert api_content_hash == gt_content_hash, ( + f"Hash mismatch for the 'document' content of '{basename(doc_path)}'.\n" + f"API Hash: {api_content_hash}\n" + f"GT Hash : {gt_content_hash}\n" + f"API 'document' part:\n{json.dumps(api_document_content, sort_keys=True, indent=2, ensure_ascii=False)}\n" + f"Expected 'document' part (from {basename(gt_file_path)}):\n{json.dumps(expected_document_content, sort_keys=True, indent=2, ensure_ascii=False)}" + ) diff --git a/services/api/tests/test_lance.py b/services/api/tests/test_lance.py deleted file mode 100644 index a766ec9..0000000 --- a/services/api/tests/test_lance.py +++ /dev/null @@ -1,31 +0,0 @@ -from datetime import timedelta -from os.path import join -from pathlib import Path -from shutil import copytree -from tempfile import TemporaryDirectory - -import lancedb - -CURR_DIR = Path(__file__).resolve().parent - - -def test_lance(): - table_id = "test_table" - with TemporaryDirectory() as tmp_dir: - copytree(join(CURR_DIR, f"{table_id}.lance"), join(tmp_dir, f"{table_id}.lance")) - lance_db = lancedb.connect(tmp_dir) - # Try opening table - table = lance_db.open_table(table_id) - assert table.count_rows() > 0 - # Try deleting rows - rows = table._dataset.to_table(offset=0, limit=100).to_pylist() - row_ids = [r["ID"] for r in rows] - for row_id in row_ids[3:]: - table.delete(f"`ID` = '{row_id}'") - # Try table optimization - table.cleanup_old_versions(older_than=timedelta(seconds=0), delete_unverified=False) - table.compact_files() - - -if __name__ == "__main__": - test_lance() diff --git a/services/api/tests/test_protocol.py b/services/api/tests/test_protocol.py new file mode 100644 index 0000000..1c84955 --- /dev/null +++ b/services/api/tests/test_protocol.py @@ -0,0 +1,209 @@ +import pytest +from pydantic import ValidationError + +from owl.types import ( + ChatCompletionChoice, + ChatCompletionChunkResponse, + ChatCompletionDelta, + ChatCompletionMessage, + ChatCompletionResponse, + ChatCompletionUsage, + MultiRowAddRequestWithLimit, + MultiRowUpdateRequestWithLimit, +) + +REQUEST_ID = "chatcmpl-AtBWW4Kf8NoM4WDBaNSBLR8fD0fc6" +MODEL = "gpt-3.5-turbo" +CONTENT = "Hello" +SERVICE_TIER = "default" +SYSTEM_FINGERPRINT = "fp_2f141ce944" + + +@pytest.mark.parametrize( + "body", + [ + # Role chunk + ChatCompletionChunkResponse( + id=REQUEST_ID, + model=MODEL, + choices=[ + ChatCompletionChoice( + index=0, + delta=ChatCompletionDelta(role="assistant", content="", refusal=None), + logprobs=None, + finish_reason=None, + ) + ], + ), + # Content chunks + ChatCompletionChunkResponse( + id=REQUEST_ID, + model=MODEL, + choices=[ + ChatCompletionChoice( + index=0, + delta=ChatCompletionDelta(content=CONTENT), + logprobs=None, + finish_reason=None, + ) + ], + ), + # Finish reason chunk + ChatCompletionChunkResponse( + id=REQUEST_ID, + model=MODEL, + choices=[ + ChatCompletionChoice( + index=0, + logprobs=None, + finish_reason="length", + ) + ], + ), + # Usage chunk + ChatCompletionChunkResponse( + id=REQUEST_ID, + model=MODEL, + choices=[], + usage=ChatCompletionUsage( + prompt_tokens=10, + completion_tokens=5, + total_tokens=10 + 5, + ), + ), + ], +) +def test_chat_completion_chunk(body: ChatCompletionChunkResponse): + if len(body.choices) > 0: + if body.message is None: + assert body.delta is None + assert body.content == "" + else: + assert isinstance(body.message, ChatCompletionDelta) + assert isinstance(body.delta, ChatCompletionDelta) + assert isinstance(body.content, str) + else: + assert body.message is None + assert body.delta is None + assert body.content == "" + assert body.finish_reason is None or isinstance(body.finish_reason, str) + assert isinstance(body.prompt_tokens, int) + assert isinstance(body.completion_tokens, int) + assert isinstance(body.total_tokens, int) + if body.usage is not None: + assert body.prompt_tokens == body.usage.prompt_tokens + assert body.completion_tokens == body.usage.completion_tokens + assert body.total_tokens == body.usage.total_tokens + assert body.total_tokens == body.prompt_tokens + body.completion_tokens + + +@pytest.mark.parametrize( + "body", + [ + # Non-stream + ChatCompletionResponse( + id=REQUEST_ID, + model=MODEL, + choices=[ + ChatCompletionChoice( + index=0, + message=ChatCompletionMessage(content=CONTENT), + logprobs=None, + finish_reason="length", + ) + ], + usage=ChatCompletionUsage( + prompt_tokens=10, + completion_tokens=5, + total_tokens=10 + 5, + ), + ) + ], +) +def test_chat_completion(body: ChatCompletionResponse): + if len(body.choices) > 0: + assert isinstance(body.message, ChatCompletionMessage) + assert isinstance(body.content, str) + else: + assert body.message is None + assert body.content is None + assert body.finish_reason is None or isinstance(body.finish_reason, str) + assert isinstance(body.prompt_tokens, int) + assert isinstance(body.completion_tokens, int) + assert isinstance(body.total_tokens, int) + assert body.prompt_tokens == body.usage.prompt_tokens + assert body.completion_tokens == body.usage.completion_tokens + assert body.total_tokens == body.usage.total_tokens + assert body.total_tokens == body.prompt_tokens + body.completion_tokens + + +def test_multirow_add(): + # AAC files not accepted + with pytest.raises(ValidationError, match="Unsupported file type"): + MultiRowAddRequestWithLimit( + table_id="x", + data=[{"col1": "s3://val1.aac", "col2": "val2"}], + ) + body = MultiRowAddRequestWithLimit( + table_id="x", + data=[{"col1": "s3://val1.mp3", "col2": "val2"}], + ) + assert body.data == [{"col1": "s3://val1.mp3", "col2": "val2"}] + # Max 100 rows + with pytest.raises(ValidationError): + MultiRowAddRequestWithLimit( + table_id="x", + data=[{"col1": "val1"} for _ in range(101)], + ) + MultiRowAddRequestWithLimit( + table_id="x", + data=[{"col1": "val1"} for _ in range(100)], + ) + # Min 1 row + with pytest.raises(ValidationError): + MultiRowAddRequestWithLimit( + table_id="x", + data=[], + ) + body = MultiRowAddRequestWithLimit( + table_id="x", + data=[{"col1": "val1", "col2": "val2"}], + ) + assert body.table_id == "x" + assert body.data == [{"col1": "val1", "col2": "val2"}] + + +def test_multirow_update(): + # AAC files not accepted + with pytest.raises(ValidationError, match="Unsupported file type"): + MultiRowUpdateRequestWithLimit( + table_id="x", + data={"row1": {"col1": "s3://val1.aac", "col2": "val2"}}, + ) + body = MultiRowUpdateRequestWithLimit( + table_id="x", + data={"row1": {"col1": "s3://val1.mp3", "col2": "val2"}}, + ) + assert body.data == {"row1": {"col1": "s3://val1.mp3", "col2": "val2"}} + # Max 100 rows + with pytest.raises(ValidationError): + MultiRowUpdateRequestWithLimit( + table_id="x", + data={str(i): {"col1": "val1"} for i in range(101)}, + ) + MultiRowUpdateRequestWithLimit( + table_id="x", + data={str(i): {"col1": "val1"} for i in range(100)}, + ) + # Min 1 row + with pytest.raises(ValidationError): + MultiRowUpdateRequestWithLimit( + table_id="x", + data={}, + ) + body = MultiRowUpdateRequestWithLimit( + table_id="x", + data={"row1": {"col1": "val1", "col2": "val2"}}, + ) + assert body.table_id == "x" + assert body.data == {"row1": {"col1": "val1", "col2": "val2"}} diff --git a/services/api/tests/test_table.lance/_indices/80c539f0-1c19-4a7c-b273-cfb237733433/page_data.lance b/services/api/tests/test_table.lance/_indices/80c539f0-1c19-4a7c-b273-cfb237733433/page_data.lance deleted file mode 100644 index 7a54f31..0000000 Binary files a/services/api/tests/test_table.lance/_indices/80c539f0-1c19-4a7c-b273-cfb237733433/page_data.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/_indices/80c539f0-1c19-4a7c-b273-cfb237733433/page_lookup.lance b/services/api/tests/test_table.lance/_indices/80c539f0-1c19-4a7c-b273-cfb237733433/page_lookup.lance deleted file mode 100644 index 651a4ad..0000000 Binary files a/services/api/tests/test_table.lance/_indices/80c539f0-1c19-4a7c-b273-cfb237733433/page_lookup.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/_indices/c0e1017d-bc45-4449-860d-91a3e588c23b/page_data.lance b/services/api/tests/test_table.lance/_indices/c0e1017d-bc45-4449-860d-91a3e588c23b/page_data.lance deleted file mode 100644 index a899d26..0000000 Binary files a/services/api/tests/test_table.lance/_indices/c0e1017d-bc45-4449-860d-91a3e588c23b/page_data.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/_indices/c0e1017d-bc45-4449-860d-91a3e588c23b/page_lookup.lance b/services/api/tests/test_table.lance/_indices/c0e1017d-bc45-4449-860d-91a3e588c23b/page_lookup.lance deleted file mode 100644 index fe90220..0000000 Binary files a/services/api/tests/test_table.lance/_indices/c0e1017d-bc45-4449-860d-91a3e588c23b/page_lookup.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/_indices/cf06ed0f-70eb-479f-9824-dfee71a61680/page_data.lance b/services/api/tests/test_table.lance/_indices/cf06ed0f-70eb-479f-9824-dfee71a61680/page_data.lance deleted file mode 100644 index fd9cc60..0000000 Binary files a/services/api/tests/test_table.lance/_indices/cf06ed0f-70eb-479f-9824-dfee71a61680/page_data.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/_indices/cf06ed0f-70eb-479f-9824-dfee71a61680/page_lookup.lance b/services/api/tests/test_table.lance/_indices/cf06ed0f-70eb-479f-9824-dfee71a61680/page_lookup.lance deleted file mode 100644 index fe90220..0000000 Binary files a/services/api/tests/test_table.lance/_indices/cf06ed0f-70eb-479f-9824-dfee71a61680/page_lookup.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/_indices/d722107e-5c23-4c94-b9e1-2eb5cc635c9a/page_data.lance b/services/api/tests/test_table.lance/_indices/d722107e-5c23-4c94-b9e1-2eb5cc635c9a/page_data.lance deleted file mode 100644 index b800ae8..0000000 Binary files a/services/api/tests/test_table.lance/_indices/d722107e-5c23-4c94-b9e1-2eb5cc635c9a/page_data.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/_indices/d722107e-5c23-4c94-b9e1-2eb5cc635c9a/page_lookup.lance b/services/api/tests/test_table.lance/_indices/d722107e-5c23-4c94-b9e1-2eb5cc635c9a/page_lookup.lance deleted file mode 100644 index 651a4ad..0000000 Binary files a/services/api/tests/test_table.lance/_indices/d722107e-5c23-4c94-b9e1-2eb5cc635c9a/page_lookup.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/_latest.manifest b/services/api/tests/test_table.lance/_latest.manifest deleted file mode 100644 index cfed582..0000000 Binary files a/services/api/tests/test_table.lance/_latest.manifest and /dev/null differ diff --git a/services/api/tests/test_table.lance/_transactions/55-83679c50-04f5-4ad6-a5d3-c90147f82175.txn b/services/api/tests/test_table.lance/_transactions/55-83679c50-04f5-4ad6-a5d3-c90147f82175.txn deleted file mode 100644 index 03767eb..0000000 Binary files a/services/api/tests/test_table.lance/_transactions/55-83679c50-04f5-4ad6-a5d3-c90147f82175.txn and /dev/null differ diff --git a/services/api/tests/test_table.lance/_transactions/56-2a19e3d8-6397-4d41-8667-3f8c005bdb47.txn b/services/api/tests/test_table.lance/_transactions/56-2a19e3d8-6397-4d41-8667-3f8c005bdb47.txn deleted file mode 100644 index 89801b9..0000000 --- a/services/api/tests/test_table.lance/_transactions/56-2a19e3d8-6397-4d41-8667-3f8c005bdb47.txn +++ /dev/null @@ -1 +0,0 @@ -8$2a19e3d8-6397-4d41-8667-3f8c005bdb47Ú \ No newline at end of file diff --git a/services/api/tests/test_table.lance/_transactions/56-c54f1ffd-ae26-4520-9639-0d08015a5dce.txn b/services/api/tests/test_table.lance/_transactions/56-c54f1ffd-ae26-4520-9639-0d08015a5dce.txn deleted file mode 100644 index 7ddf38f..0000000 Binary files a/services/api/tests/test_table.lance/_transactions/56-c54f1ffd-ae26-4520-9639-0d08015a5dce.txn and /dev/null differ diff --git a/services/api/tests/test_table.lance/_versions/56.manifest b/services/api/tests/test_table.lance/_versions/56.manifest deleted file mode 100644 index 15e3105..0000000 Binary files a/services/api/tests/test_table.lance/_versions/56.manifest and /dev/null differ diff --git a/services/api/tests/test_table.lance/_versions/57.manifest b/services/api/tests/test_table.lance/_versions/57.manifest deleted file mode 100644 index 39531ed..0000000 Binary files a/services/api/tests/test_table.lance/_versions/57.manifest and /dev/null differ diff --git a/services/api/tests/test_table.lance/_versions/58.manifest b/services/api/tests/test_table.lance/_versions/58.manifest deleted file mode 100644 index cfed582..0000000 Binary files a/services/api/tests/test_table.lance/_versions/58.manifest and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/0218dfb5-7d6a-4594-be2f-b6da21fc2991.lance b/services/api/tests/test_table.lance/data/0218dfb5-7d6a-4594-be2f-b6da21fc2991.lance deleted file mode 100644 index a097971..0000000 Binary files a/services/api/tests/test_table.lance/data/0218dfb5-7d6a-4594-be2f-b6da21fc2991.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/0841f676-2d34-4a0d-beb2-b66eee322f5b.lance b/services/api/tests/test_table.lance/data/0841f676-2d34-4a0d-beb2-b66eee322f5b.lance deleted file mode 100644 index 52c6aa4..0000000 Binary files a/services/api/tests/test_table.lance/data/0841f676-2d34-4a0d-beb2-b66eee322f5b.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/0c4d5f53-dce1-4c49-8095-fe16471dfe92.lance b/services/api/tests/test_table.lance/data/0c4d5f53-dce1-4c49-8095-fe16471dfe92.lance deleted file mode 100644 index dc23fe8..0000000 Binary files a/services/api/tests/test_table.lance/data/0c4d5f53-dce1-4c49-8095-fe16471dfe92.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/0cab0285-7b8d-4019-8662-662342476266.lance b/services/api/tests/test_table.lance/data/0cab0285-7b8d-4019-8662-662342476266.lance deleted file mode 100644 index d0158c7..0000000 Binary files a/services/api/tests/test_table.lance/data/0cab0285-7b8d-4019-8662-662342476266.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/195821c2-10b0-4681-9b60-50f79fa017a9.lance b/services/api/tests/test_table.lance/data/195821c2-10b0-4681-9b60-50f79fa017a9.lance deleted file mode 100644 index 6951936..0000000 Binary files a/services/api/tests/test_table.lance/data/195821c2-10b0-4681-9b60-50f79fa017a9.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/2ac6344b-650c-4991-9bd1-48046519287a.lance b/services/api/tests/test_table.lance/data/2ac6344b-650c-4991-9bd1-48046519287a.lance deleted file mode 100644 index bfb0fa6..0000000 Binary files a/services/api/tests/test_table.lance/data/2ac6344b-650c-4991-9bd1-48046519287a.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/2fe78714-ed7b-4dd7-8388-889e58479f7c.lance b/services/api/tests/test_table.lance/data/2fe78714-ed7b-4dd7-8388-889e58479f7c.lance deleted file mode 100644 index 23692ca..0000000 Binary files a/services/api/tests/test_table.lance/data/2fe78714-ed7b-4dd7-8388-889e58479f7c.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/30fb0867-31d2-4cf6-8d7e-f8c68bcd4882.lance b/services/api/tests/test_table.lance/data/30fb0867-31d2-4cf6-8d7e-f8c68bcd4882.lance deleted file mode 100644 index 6f0b6f1..0000000 Binary files a/services/api/tests/test_table.lance/data/30fb0867-31d2-4cf6-8d7e-f8c68bcd4882.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/375faaf4-394d-4c1d-a217-dce02e301028.lance b/services/api/tests/test_table.lance/data/375faaf4-394d-4c1d-a217-dce02e301028.lance deleted file mode 100644 index d3b06fe..0000000 Binary files a/services/api/tests/test_table.lance/data/375faaf4-394d-4c1d-a217-dce02e301028.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/3a8ee5fd-c9ae-4015-970b-7a21f4a0d4fa.lance b/services/api/tests/test_table.lance/data/3a8ee5fd-c9ae-4015-970b-7a21f4a0d4fa.lance deleted file mode 100644 index 76a9af8..0000000 Binary files a/services/api/tests/test_table.lance/data/3a8ee5fd-c9ae-4015-970b-7a21f4a0d4fa.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/3e6187d4-9f32-4fb1-a0e4-a4e028d98574.lance b/services/api/tests/test_table.lance/data/3e6187d4-9f32-4fb1-a0e4-a4e028d98574.lance deleted file mode 100644 index 907d6d7..0000000 Binary files a/services/api/tests/test_table.lance/data/3e6187d4-9f32-4fb1-a0e4-a4e028d98574.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/4038dc14-3f39-4076-b537-d262314ce58e.lance b/services/api/tests/test_table.lance/data/4038dc14-3f39-4076-b537-d262314ce58e.lance deleted file mode 100644 index 10ab98b..0000000 Binary files a/services/api/tests/test_table.lance/data/4038dc14-3f39-4076-b537-d262314ce58e.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/476816e3-7c4a-4b74-aec1-1e5a1a2a9c57.lance b/services/api/tests/test_table.lance/data/476816e3-7c4a-4b74-aec1-1e5a1a2a9c57.lance deleted file mode 100644 index b3ddef7..0000000 Binary files a/services/api/tests/test_table.lance/data/476816e3-7c4a-4b74-aec1-1e5a1a2a9c57.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/4c5cd4e2-3156-4e17-8129-2e600e91d019.lance b/services/api/tests/test_table.lance/data/4c5cd4e2-3156-4e17-8129-2e600e91d019.lance deleted file mode 100644 index 74aa036..0000000 Binary files a/services/api/tests/test_table.lance/data/4c5cd4e2-3156-4e17-8129-2e600e91d019.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/4fb6cee8-9f15-4070-a709-02db49fb21b1.lance b/services/api/tests/test_table.lance/data/4fb6cee8-9f15-4070-a709-02db49fb21b1.lance deleted file mode 100644 index 5bc2b0a..0000000 Binary files a/services/api/tests/test_table.lance/data/4fb6cee8-9f15-4070-a709-02db49fb21b1.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/556e81d3-9067-4d1d-b17d-a4cc0a3eeefc.lance b/services/api/tests/test_table.lance/data/556e81d3-9067-4d1d-b17d-a4cc0a3eeefc.lance deleted file mode 100644 index 0ef6f8d..0000000 Binary files a/services/api/tests/test_table.lance/data/556e81d3-9067-4d1d-b17d-a4cc0a3eeefc.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/5bb527a1-3c37-4cc7-8db9-85e610247143.lance b/services/api/tests/test_table.lance/data/5bb527a1-3c37-4cc7-8db9-85e610247143.lance deleted file mode 100644 index fc1216c..0000000 Binary files a/services/api/tests/test_table.lance/data/5bb527a1-3c37-4cc7-8db9-85e610247143.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/5e04331f-8465-40b0-af53-455928d381a8.lance b/services/api/tests/test_table.lance/data/5e04331f-8465-40b0-af53-455928d381a8.lance deleted file mode 100644 index 3e8ed4b..0000000 Binary files a/services/api/tests/test_table.lance/data/5e04331f-8465-40b0-af53-455928d381a8.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/62661a16-01fe-4513-b58d-9699f70b5ae4.lance b/services/api/tests/test_table.lance/data/62661a16-01fe-4513-b58d-9699f70b5ae4.lance deleted file mode 100644 index 85619d8..0000000 Binary files a/services/api/tests/test_table.lance/data/62661a16-01fe-4513-b58d-9699f70b5ae4.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/67bff393-906c-432b-bf15-5b854effe0a7.lance b/services/api/tests/test_table.lance/data/67bff393-906c-432b-bf15-5b854effe0a7.lance deleted file mode 100644 index e5f2eab..0000000 Binary files a/services/api/tests/test_table.lance/data/67bff393-906c-432b-bf15-5b854effe0a7.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/681de983-0000-4db1-a193-ee04cee80253.lance b/services/api/tests/test_table.lance/data/681de983-0000-4db1-a193-ee04cee80253.lance deleted file mode 100644 index 63e8ecf..0000000 Binary files a/services/api/tests/test_table.lance/data/681de983-0000-4db1-a193-ee04cee80253.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/6b47eefb-b495-444c-aec6-626a0ed8e441.lance b/services/api/tests/test_table.lance/data/6b47eefb-b495-444c-aec6-626a0ed8e441.lance deleted file mode 100644 index 8770168..0000000 Binary files a/services/api/tests/test_table.lance/data/6b47eefb-b495-444c-aec6-626a0ed8e441.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/70ba4ca1-f11d-4a37-8bf2-33e642d64baf.lance b/services/api/tests/test_table.lance/data/70ba4ca1-f11d-4a37-8bf2-33e642d64baf.lance deleted file mode 100644 index 157c86b..0000000 Binary files a/services/api/tests/test_table.lance/data/70ba4ca1-f11d-4a37-8bf2-33e642d64baf.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/800d2d2e-e5d3-4f6b-86e5-5406ac625504.lance b/services/api/tests/test_table.lance/data/800d2d2e-e5d3-4f6b-86e5-5406ac625504.lance deleted file mode 100644 index 9f4d456..0000000 Binary files a/services/api/tests/test_table.lance/data/800d2d2e-e5d3-4f6b-86e5-5406ac625504.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/8263336e-11dc-47c1-a0ab-37dced034d86.lance b/services/api/tests/test_table.lance/data/8263336e-11dc-47c1-a0ab-37dced034d86.lance deleted file mode 100644 index 5f21958..0000000 Binary files a/services/api/tests/test_table.lance/data/8263336e-11dc-47c1-a0ab-37dced034d86.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/833f82bb-d3bf-45d6-93d1-d81b280d4db5.lance b/services/api/tests/test_table.lance/data/833f82bb-d3bf-45d6-93d1-d81b280d4db5.lance deleted file mode 100644 index 3d76233..0000000 Binary files a/services/api/tests/test_table.lance/data/833f82bb-d3bf-45d6-93d1-d81b280d4db5.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/85d3b452-5002-4793-bc66-fced85c77ebd.lance b/services/api/tests/test_table.lance/data/85d3b452-5002-4793-bc66-fced85c77ebd.lance deleted file mode 100644 index def4627..0000000 Binary files a/services/api/tests/test_table.lance/data/85d3b452-5002-4793-bc66-fced85c77ebd.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/94e60dc4-ca0e-408d-9977-988f7b1fa27e.lance b/services/api/tests/test_table.lance/data/94e60dc4-ca0e-408d-9977-988f7b1fa27e.lance deleted file mode 100644 index 59bbb41..0000000 Binary files a/services/api/tests/test_table.lance/data/94e60dc4-ca0e-408d-9977-988f7b1fa27e.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/9583d125-a754-4104-ad58-670223a4cfd4.lance b/services/api/tests/test_table.lance/data/9583d125-a754-4104-ad58-670223a4cfd4.lance deleted file mode 100644 index 2b38d51..0000000 Binary files a/services/api/tests/test_table.lance/data/9583d125-a754-4104-ad58-670223a4cfd4.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/9ae5a155-6864-4aee-9622-7c1d04913ae9.lance b/services/api/tests/test_table.lance/data/9ae5a155-6864-4aee-9622-7c1d04913ae9.lance deleted file mode 100644 index b75b958..0000000 Binary files a/services/api/tests/test_table.lance/data/9ae5a155-6864-4aee-9622-7c1d04913ae9.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/a20c5619-719e-48c8-a249-607527257890.lance b/services/api/tests/test_table.lance/data/a20c5619-719e-48c8-a249-607527257890.lance deleted file mode 100644 index b92018f..0000000 Binary files a/services/api/tests/test_table.lance/data/a20c5619-719e-48c8-a249-607527257890.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/a4a314e7-13e0-4a60-9d6f-b459576cc577.lance b/services/api/tests/test_table.lance/data/a4a314e7-13e0-4a60-9d6f-b459576cc577.lance deleted file mode 100644 index 0fbfe62..0000000 Binary files a/services/api/tests/test_table.lance/data/a4a314e7-13e0-4a60-9d6f-b459576cc577.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/abbdcc8c-93da-4e2e-ac61-91be0874a587.lance b/services/api/tests/test_table.lance/data/abbdcc8c-93da-4e2e-ac61-91be0874a587.lance deleted file mode 100644 index 89587e8..0000000 Binary files a/services/api/tests/test_table.lance/data/abbdcc8c-93da-4e2e-ac61-91be0874a587.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/af0d8632-619d-4e09-bb2f-ba13d0f568e5.lance b/services/api/tests/test_table.lance/data/af0d8632-619d-4e09-bb2f-ba13d0f568e5.lance deleted file mode 100644 index d934c7b..0000000 Binary files a/services/api/tests/test_table.lance/data/af0d8632-619d-4e09-bb2f-ba13d0f568e5.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/af70162a-d319-403c-bd26-251353c9966c.lance b/services/api/tests/test_table.lance/data/af70162a-d319-403c-bd26-251353c9966c.lance deleted file mode 100644 index 02201c2..0000000 Binary files a/services/api/tests/test_table.lance/data/af70162a-d319-403c-bd26-251353c9966c.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/b25448ee-84a7-46d8-b681-67a7faf2e1fc.lance b/services/api/tests/test_table.lance/data/b25448ee-84a7-46d8-b681-67a7faf2e1fc.lance deleted file mode 100644 index 1726ba9..0000000 Binary files a/services/api/tests/test_table.lance/data/b25448ee-84a7-46d8-b681-67a7faf2e1fc.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/bc2de3a9-4f96-4bd8-9856-e17e7bc6bcf8.lance b/services/api/tests/test_table.lance/data/bc2de3a9-4f96-4bd8-9856-e17e7bc6bcf8.lance deleted file mode 100644 index 1fced8d..0000000 Binary files a/services/api/tests/test_table.lance/data/bc2de3a9-4f96-4bd8-9856-e17e7bc6bcf8.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/beb7ea89-576d-45f7-b8be-cc248d0ccc77.lance b/services/api/tests/test_table.lance/data/beb7ea89-576d-45f7-b8be-cc248d0ccc77.lance deleted file mode 100644 index 1ab40b9..0000000 Binary files a/services/api/tests/test_table.lance/data/beb7ea89-576d-45f7-b8be-cc248d0ccc77.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/c1236cac-af08-4ce5-be1f-ae3e8b2024ef.lance b/services/api/tests/test_table.lance/data/c1236cac-af08-4ce5-be1f-ae3e8b2024ef.lance deleted file mode 100644 index cb2ba90..0000000 Binary files a/services/api/tests/test_table.lance/data/c1236cac-af08-4ce5-be1f-ae3e8b2024ef.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/c6729c82-8cb2-44d5-993a-7b09bf678703.lance b/services/api/tests/test_table.lance/data/c6729c82-8cb2-44d5-993a-7b09bf678703.lance deleted file mode 100644 index 25d9134..0000000 Binary files a/services/api/tests/test_table.lance/data/c6729c82-8cb2-44d5-993a-7b09bf678703.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/ca3909ff-b21b-45c5-bf85-4911bba60fde.lance b/services/api/tests/test_table.lance/data/ca3909ff-b21b-45c5-bf85-4911bba60fde.lance deleted file mode 100644 index 29aaf87..0000000 Binary files a/services/api/tests/test_table.lance/data/ca3909ff-b21b-45c5-bf85-4911bba60fde.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/ca71816d-8322-43a4-9502-bca6cd2e020c.lance b/services/api/tests/test_table.lance/data/ca71816d-8322-43a4-9502-bca6cd2e020c.lance deleted file mode 100644 index 1a7b029..0000000 Binary files a/services/api/tests/test_table.lance/data/ca71816d-8322-43a4-9502-bca6cd2e020c.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/d532c28c-7d3c-4d22-afb5-92b5a111e17b.lance b/services/api/tests/test_table.lance/data/d532c28c-7d3c-4d22-afb5-92b5a111e17b.lance deleted file mode 100644 index 884f186..0000000 Binary files a/services/api/tests/test_table.lance/data/d532c28c-7d3c-4d22-afb5-92b5a111e17b.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/da9c0e9e-7f8b-44ea-82a4-160c98ca17ae.lance b/services/api/tests/test_table.lance/data/da9c0e9e-7f8b-44ea-82a4-160c98ca17ae.lance deleted file mode 100644 index 237db2e..0000000 Binary files a/services/api/tests/test_table.lance/data/da9c0e9e-7f8b-44ea-82a4-160c98ca17ae.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/db6054f5-4108-4f23-b508-565b32020f68.lance b/services/api/tests/test_table.lance/data/db6054f5-4108-4f23-b508-565b32020f68.lance deleted file mode 100644 index 4c4d5b4..0000000 Binary files a/services/api/tests/test_table.lance/data/db6054f5-4108-4f23-b508-565b32020f68.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/e4c544bd-eb55-4af4-a2dc-f91d7863e671.lance b/services/api/tests/test_table.lance/data/e4c544bd-eb55-4af4-a2dc-f91d7863e671.lance deleted file mode 100644 index 1c95486..0000000 Binary files a/services/api/tests/test_table.lance/data/e4c544bd-eb55-4af4-a2dc-f91d7863e671.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/e628d9da-95ab-4b1f-8d21-47ca21154b81.lance b/services/api/tests/test_table.lance/data/e628d9da-95ab-4b1f-8d21-47ca21154b81.lance deleted file mode 100644 index c19c9a2..0000000 Binary files a/services/api/tests/test_table.lance/data/e628d9da-95ab-4b1f-8d21-47ca21154b81.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/e6541d20-5c18-4b80-8d71-84f899310512.lance b/services/api/tests/test_table.lance/data/e6541d20-5c18-4b80-8d71-84f899310512.lance deleted file mode 100644 index 9a8b4b3..0000000 Binary files a/services/api/tests/test_table.lance/data/e6541d20-5c18-4b80-8d71-84f899310512.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/e856769d-4dac-492c-bd12-679434292b04.lance b/services/api/tests/test_table.lance/data/e856769d-4dac-492c-bd12-679434292b04.lance deleted file mode 100644 index 82fa577..0000000 Binary files a/services/api/tests/test_table.lance/data/e856769d-4dac-492c-bd12-679434292b04.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/ea967c5f-1d05-4069-8fef-a1d090e42730.lance b/services/api/tests/test_table.lance/data/ea967c5f-1d05-4069-8fef-a1d090e42730.lance deleted file mode 100644 index a55e3ca..0000000 Binary files a/services/api/tests/test_table.lance/data/ea967c5f-1d05-4069-8fef-a1d090e42730.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/ee2829e0-90b5-416b-8c05-203ee77b75aa.lance b/services/api/tests/test_table.lance/data/ee2829e0-90b5-416b-8c05-203ee77b75aa.lance deleted file mode 100644 index ae07d93..0000000 Binary files a/services/api/tests/test_table.lance/data/ee2829e0-90b5-416b-8c05-203ee77b75aa.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/fc8dc641-5cba-4570-8ef2-8b10992f2cf8.lance b/services/api/tests/test_table.lance/data/fc8dc641-5cba-4570-8ef2-8b10992f2cf8.lance deleted file mode 100644 index 29737e7..0000000 Binary files a/services/api/tests/test_table.lance/data/fc8dc641-5cba-4570-8ef2-8b10992f2cf8.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/fd13dd8d-4f5c-4040-b8da-af081fa49e9d.lance b/services/api/tests/test_table.lance/data/fd13dd8d-4f5c-4040-b8da-af081fa49e9d.lance deleted file mode 100644 index 382abdf..0000000 Binary files a/services/api/tests/test_table.lance/data/fd13dd8d-4f5c-4040-b8da-af081fa49e9d.lance and /dev/null differ diff --git a/services/api/tests/test_table.lance/data/fe71820f-1a4e-4345-8df1-777443efb74a.lance b/services/api/tests/test_table.lance/data/fe71820f-1a4e-4345-8df1-777443efb74a.lance deleted file mode 100644 index a0ba646..0000000 Binary files a/services/api/tests/test_table.lance/data/fe71820f-1a4e-4345-8df1-777443efb74a.lance and /dev/null differ diff --git a/services/api/tests/test_types.py b/services/api/tests/test_types.py new file mode 100644 index 0000000..26b77de --- /dev/null +++ b/services/api/tests/test_types.py @@ -0,0 +1,108 @@ +from datetime import datetime, timezone + +import pytest +from pydantic import BaseModel, Field, ValidationError + +from owl.types import DatetimeUTC, LanguageCodeList, SanitisedNonEmptyStr +from owl.utils.dates import now_iso +from owl.utils.test import TEXTS + + +class IdTest(BaseModel): + id: SanitisedNonEmptyStr = Field() + + +GOOD_IDS = [ + pytest.param("Hello", id="Simple word"), + pytest.param("Hello World", id="Words with space"), + pytest.param(" Hello", id="Leading space"), + pytest.param("Hello ", id="Trailing space"), + pytest.param(" Hello ", id="Leading and trailing space"), + pytest.param("\nHello", id="Leading newline"), + pytest.param("Hello\n", id="Trailing newline"), + pytest.param("\nHello\n", id="Leading and trailing newline"), + # \u00A0 is NBSP + pytest.param("\u00a0NBSP at start", id="Leading Non-Breaking Space"), + pytest.param("NBSP at end\u00a0", id="Trailing Non-Breaking Space"), + pytest.param("H", id="Single character"), + pytest.param("1", id="Single number"), + pytest.param("?", id="Single symbol"), + pytest.param("123", id="Numbers"), + pytest.param("!@#$", id="Symbols"), + pytest.param("😊", id="Single emoji"), + pytest.param("你好", id="CJK characters"), + pytest.param("مرحبا", id="Arabic characters"), + pytest.param("Привет", id="Cyrillic/Russian characters"), + pytest.param("Hello 😊 World", id="Text with emoji"), + pytest.param("Text with multiple spaces", id="Internal multiple spaces"), + pytest.param("Test-123_ABC", id="Text with symbols"), + pytest.param("a-b_c=d+e*f/g\\h|i[j]k{l}m;n:o'p\"q,r.su?v", id="Complex symbols"), +] + [pytest.param(text, id=lang) for lang, text in TEXTS.items()] + +BAD_IDS = [ + pytest.param("", id="Empty string"), + pytest.param(" ", id="Single space"), + pytest.param(" ", id="Multiple spaces"), + pytest.param("\t", id="Tab only"), + pytest.param("\n", id="Newline only"), + pytest.param("Text\tnewlines", id="Internal tab"), + pytest.param("Text\nwith\nnewlines", id="Internal newlines"), + pytest.param(" Hello\nWorld ", id="Leading space, trailing space, internal newline"), + # \u00A0 is NBSP + pytest.param("Okay\u00a0NBSP\u00a0Okay", id="Internal Non-Breaking Space"), + pytest.param("â–ˆ â–„ â–€", id="Block elements"), + pytest.param("─ │ ┌ â”", id="Box drawing"), + pytest.param("⠲⠳⠴⠵", id="Braille"), +] + + +@pytest.mark.parametrize("value", GOOD_IDS) +def test_id_string_good(value: str): + item = IdTest(id=value) + assert item.id == value.strip() + + +@pytest.mark.parametrize("value", BAD_IDS) +def test_id_string_bad(value: str): + with pytest.raises(ValidationError): + IdTest(id=value) + + +def test_string_normalisation(): + # --- + # + # Zalgo text + assert IdTest(id="H̵̛͕̞̦̰̜Ḭ̥̟́͆Ì͂̌͑ͅä̷͔̟͓̬̯̟Í̭͉͈̮͙̣̯̬͚̞̭Ì̀̾͠m̴̡̧̛Ì̯̹̗̹̤̲̺̟̥̈Ì͊̔̑Ì͆̌̀̚ÍÍb̴̢̢̫Ì̠̗̼̬̻̮̺̭͔̘͑̆̎̚ư̵̧̡̥̙̭̿̈̀̒ÌÌŠÍ’Í‘r̷̡̡̲̼̖͎̫̮̜͇̬͌͘g̷̹Í͎̬͕͓͕Ì̃̈Ì̓̆̚Íẻ̵̡̼̬̥̹͇̭͔̯̉͛̈ÌÌ•r̸̮̖̻̮̣̗͚͖Ì̂͌̾̓̀̿̔̀͋̈Ì͌̈Ì̋͜").id == "Hämbưrgẻr" + # + # + # --- + # Arabic + assert IdTest(id="مَرْحَبًا بÙÙƒÙمْ").id == "مرحبا بكم" + # --- + # Thai + assert IdTest(id="สวัสดีครับ คามุย อิอิ").id == "สวสดครบ คามย ออ" + + +def test_datetime_utc(): + class DatetimeTest(BaseModel): + dt: DatetimeUTC = Field() + + now = now_iso("Asia/Kuala_Lumpur") + d = DatetimeTest(dt=now) + assert isinstance(d.dt, datetime) + assert d.dt.tzinfo is timezone.utc + assert datetime.fromisoformat(now) == d.dt + + +def test_language_list(): + class TestModel(BaseModel): + lang: LanguageCodeList + + model = TestModel(lang=["en", "FR", "zh-cn", "ZH-sg"]) + assert set(model.lang) == {"en", "fr", "zh-CN", "zh-SG"} + + model = TestModel(lang=["en", "mul"]) + assert set(model.lang) == {"en", "fr", "es", "zh", "ko", "ja", "it"} + + with pytest.raises(ValidationError): + TestModel(lang=["xx"]) diff --git a/services/api/tests/utils/test_auth.py b/services/api/tests/utils/test_auth.py new file mode 100644 index 0000000..5335e42 --- /dev/null +++ b/services/api/tests/utils/test_auth.py @@ -0,0 +1,126 @@ +import pytest + +from owl.types import OrgMember_, ProjectMember_, Role, UserRead +from owl.utils.auth import has_permissions +from owl.utils.dates import now +from owl.utils.exceptions import ForbiddenError + +USER_ID = "user_id" +ORG_ID = "0" +PROJ_ID = "project_id" +USER_KWARGS = dict( + id=USER_ID, + name="name", + email="email@example.com", + organizations=[], + projects=[], + created_at=now(), + updated_at=now(), + email_verified=True, + password_hash="***", # Password is not used in this test +) +ORG_MEMBER_KWARGS = dict( + user_id=USER_ID, + organization_id=ORG_ID, + created_at=now(), + updated_at=now(), +) +PROJ_MEMBER_KWARGS = dict( + user_id=USER_ID, + project_id=PROJ_ID, + created_at=now(), + updated_at=now(), +) + + +@pytest.mark.cloud +def test_has_permissions(): + ### --- ADMIN permissions --- ### + sys_user = UserRead( + org_memberships=[OrgMember_(role=Role.ADMIN, **ORG_MEMBER_KWARGS)], + proj_memberships=[ProjectMember_(role=Role.ADMIN, **PROJ_MEMBER_KWARGS)], + **USER_KWARGS, + ) + # Must pass in org ID or proj ID + with pytest.raises(ValueError): + has_permissions(sys_user, ["organization"]) + with pytest.raises(ValueError): + has_permissions(sys_user, ["organization.admin"]) + with pytest.raises(ValueError): + has_permissions(sys_user, ["project"]) + with pytest.raises(ValueError): + has_permissions(sys_user, ["project.admin"]) + with pytest.raises(ValueError): + has_permissions(sys_user, ["organization", "project"], project_id=PROJ_ID) + with pytest.raises(ValueError): + has_permissions(sys_user, ["organization", "project"], organization_id=ORG_ID) + # Membership checks + assert has_permissions(sys_user, ["system"]) is True + assert has_permissions(sys_user, ["organization"], organization_id=ORG_ID) is True + assert has_permissions(sys_user, ["project"], project_id=PROJ_ID) is True + with pytest.raises(ForbiddenError): + has_permissions(sys_user, ["organization"], organization_id="ORG_ID") + with pytest.raises(ForbiddenError): + has_permissions(sys_user, ["project"], project_id="PROJ_ID") + assert has_permissions(sys_user, ["organization"], organization_id="ORG_ID", raise_error=False) is False # fmt: off + assert has_permissions(sys_user, ["project"], project_id="PROJ_ID", raise_error=False) is False + # Permission checks + assert has_permissions(sys_user, ["system.admin"]) is True + assert has_permissions(sys_user, ["system.member"]) is True + assert has_permissions(sys_user, ["organization.admin"], organization_id=ORG_ID) is True + assert has_permissions(sys_user, ["project.admin"], project_id=PROJ_ID) is True + + ### --- MEMBER permissions --- ### + sys_user = UserRead( + org_memberships=[OrgMember_(role=Role.MEMBER, **ORG_MEMBER_KWARGS)], + proj_memberships=[ProjectMember_(role=Role.MEMBER, **PROJ_MEMBER_KWARGS)], + **USER_KWARGS, + ) + # Membership checks + assert has_permissions(sys_user, ["system"]) is True + assert has_permissions(sys_user, ["organization"], organization_id=ORG_ID) is True + assert has_permissions(sys_user, ["project"], project_id=PROJ_ID) is True + # Permission checks + with pytest.raises(ForbiddenError): + has_permissions(sys_user, ["system.admin"]) + assert has_permissions(sys_user, ["system.member"]) is True + assert has_permissions(sys_user, ["system.guest"]) is True + with pytest.raises(ForbiddenError): + has_permissions(sys_user, ["organization.admin"], organization_id=ORG_ID) + assert has_permissions(sys_user, ["organization.member"], organization_id=ORG_ID) is True + assert has_permissions(sys_user, ["organization.guest"], organization_id=ORG_ID) is True + with pytest.raises(ForbiddenError): + has_permissions(sys_user, ["project.admin"], project_id=PROJ_ID) + assert has_permissions(sys_user, ["project.member"], project_id=PROJ_ID) is True + assert has_permissions(sys_user, ["project.guest"], project_id=PROJ_ID) is True + + ### --- Update membership --- ### + user = sys_user.model_copy(deep=True) + user.org_memberships[0].organization_id = "1" + assert has_permissions(sys_user, ["system"]) is True + with pytest.raises(ForbiddenError): + has_permissions(user, ["system"]) + assert has_permissions(user, ["system", "organization"], organization_id="1") is True + assert ( + has_permissions( + user, + ["system", "organization", "project"], + organization_id="1", + project_id="PROJ_ID", + ) + is True + ) + with pytest.raises(ForbiddenError): + has_permissions( + user, + ["system", "organization", "project"], + organization_id="ORG_ID", + project_id="PROJ_ID", + ) + + ### --- Update permission --- ### + assert has_permissions(sys_user, ["project.member"], project_id=PROJ_ID) is True + sys_user.proj_memberships[0].role = Role.GUEST + with pytest.raises(ForbiddenError): + has_permissions(sys_user, ["project.member"], project_id=PROJ_ID) + assert has_permissions(sys_user, ["project.guest"], project_id=PROJ_ID) is True diff --git a/services/api/tests/utils/test_billing_event.py b/services/api/tests/utils/test_billing_event.py new file mode 100644 index 0000000..3dc4687 --- /dev/null +++ b/services/api/tests/utils/test_billing_event.py @@ -0,0 +1,608 @@ +""" +Tests for the BillingManager's event creation and processing for all usage types. + +This module verifies that different API endpoints and periodic tasks trigger the +correct billing events, leading to accurate updates in an organization's usage +and credit records in the database. + +It covers: +- LLM, Embedding, and Reranker token/search usage and costs. +- Egress (bandwidth) usage for streaming responses. +- Database and File Storage usage calculated by the periodic Celery task. +""" + +from contextlib import contextmanager +from dataclasses import dataclass +from os.path import dirname, join, realpath +from time import sleep + +import pytest +from loguru import logger + +from jamaibase import JamAI +from jamaibase import types as t +from owl.types import ( + ChatEntry, + ChatRequest, + ColumnSchemaCreate, + EmbeddingRequest, + LLMGenConfig, + OrganizationRead, + PaymentState, + PricePlan_, + PriceTier, + Product, + Products, + ProjectRead, + RAGParams, + RerankingRequest, + TableType, + UserRead, +) +from owl.utils.dates import now +from owl.utils.test import ( + ELLM_DESCRIBE_CONFIG, + ELLM_DESCRIBE_DEPLOYMENT, + ELLM_EMBEDDING_CONFIG, + ELLM_EMBEDDING_DEPLOYMENT, + GPT_41_NANO_CONFIG, + GPT_41_NANO_DEPLOYMENT, + STREAM_PARAMS, + TEXT_EMBEDDING_3_SMALL_CONFIG, + TEXT_EMBEDDING_3_SMALL_DEPLOYMENT, + RERANK_ENGLISH_v3_SMALL_CONFIG, + RERANK_ENGLISH_v3_SMALL_DEPLOYMENT, + add_table_rows, + create_deployment, + create_model_config, + create_project, + create_table, + get_file_map, + setup_organizations, +) + +USAGE_RETRY = 30 +USAGE_RETRY_DELAY = 1.0 +MODEL_PROVIDER_PARAMS = dict(argvalues=[True, False], ids=["ellm", "other"]) + +TEST_FILE_DIR = join(dirname(dirname(realpath(__file__))), "files") +FILES = get_file_map(TEST_FILE_DIR) + + +@dataclass(slots=True) +class BillingContext: + user: UserRead + org: OrganizationRead + project: ProjectRead + ellm_chat_model_id: str + chat_model_id: str + ellm_embedding_model_id: str + embedding_model_id: str + ellm_rerank_model_id: str + rerank_model_id: str + + +@pytest.fixture(scope="module") +def setup(): + """ + Sets up a test environment with an organization, project, and both internal (ELLM) + and external models configured for billing tests. + """ + with setup_organizations() as ctx: + with create_project(user_id=ctx.user.id, organization_id=ctx.org.id) as project: + # Create ELLM and External models for all types (Chat, Embed, Rerank) + with ( + # --- Chat Models --- + create_model_config(ELLM_DESCRIBE_CONFIG) as ellm_chat_model, + create_model_config(GPT_41_NANO_CONFIG) as chat_model, + # --- Embedding Models --- + create_model_config(ELLM_EMBEDDING_CONFIG) as ellm_embed_model, + create_model_config(TEXT_EMBEDDING_3_SMALL_CONFIG) as embed_model, + # --- Reranking Models --- + create_model_config( + dict( + id=f"ellm/{RERANK_ENGLISH_v3_SMALL_CONFIG.id}", + name=f"ELLM {RERANK_ENGLISH_v3_SMALL_CONFIG.name}", + owned_by="ellm", + **RERANK_ENGLISH_v3_SMALL_CONFIG.model_dump( + exclude={"id", "name", "owned_by"} + ), + ) + ) as ellm_rerank_model, + create_model_config(RERANK_ENGLISH_v3_SMALL_CONFIG) as rerank_model, + # --- Deployments --- + create_deployment(ELLM_DESCRIBE_DEPLOYMENT), + create_deployment(GPT_41_NANO_DEPLOYMENT), + create_deployment(ELLM_EMBEDDING_DEPLOYMENT), + create_deployment(TEXT_EMBEDDING_3_SMALL_DEPLOYMENT), + create_deployment( + dict( + model_id=f"ellm/{RERANK_ENGLISH_v3_SMALL_DEPLOYMENT.model_id}", + name=f"ELLM {RERANK_ENGLISH_v3_SMALL_DEPLOYMENT.name}", + **RERANK_ENGLISH_v3_SMALL_DEPLOYMENT.model_dump( + exclude={"model_id", "name"}, mode="json" + ), + ) + ), + create_deployment(RERANK_ENGLISH_v3_SMALL_DEPLOYMENT), + ): + yield BillingContext( + user=ctx.user, + org=ctx.org, + project=project, + ellm_chat_model_id=ellm_chat_model.id, + chat_model_id=chat_model.id, + ellm_embedding_model_id=ellm_embed_model.id, + embedding_model_id=embed_model.id, + ellm_rerank_model_id=ellm_rerank_model.id, + rerank_model_id=rerank_model.id, + ) + + +def _cmp(new_org: OrganizationRead, org: OrganizationRead, attr: str, op: str) -> bool: + return getattr(getattr(new_org, attr), op)(getattr(org, attr)) + + +@contextmanager +def _test_usage_event( + client: JamAI, + org_id: str, + is_ellm: bool, + usage_attr: str, + quota_attr: str, +): + """ + Helper function to test billing events for a specific usage type. + + Args: + client: The JamAI client instance. + org_id: The ID of the organization to test. + is_ellm: Boolean indicating if an ELLM model is being tested. + usage_attr: The attribute name for usage on the OrganizationRead object. + quota_attr: The attribute name for quota on the OrganizationRead object. + """ + org = client.organizations.get_organization(org_id) + assert isinstance(org, t.OrganizationRead) + yield + for i in range(USAGE_RETRY): + sleep(USAGE_RETRY_DELAY) + logger.info(f"{usage_attr}: Attempt {i}") + new_org = client.organizations.get_organization(org_id) + checks = { + "credit": _cmp(new_org, org, "credit", "__eq__"), + "credit_grant": _cmp(new_org, org, "credit_grant", "__eq__" if is_ellm else "__lt__"), + quota_attr: _cmp(new_org, org, quota_attr, "__eq__"), + usage_attr: _cmp(new_org, org, usage_attr, "__gt__" if is_ellm else "__eq__"), + "egress_quota_gib": _cmp(new_org, org, "egress_quota_gib", "__eq__"), + "egress_usage_gib": _cmp(new_org, org, "egress_usage_gib", "__gt__"), + } + if all(checks.values()): + break + else: + org = {k: getattr(org, k) for k in checks} + new_org = {k: getattr(new_org, k) for k in checks} + raise AssertionError(f"Usage failed to update: {checks=} {new_org=} {org=}") + + +@pytest.mark.cloud +@pytest.mark.parametrize("is_ellm", **MODEL_PROVIDER_PARAMS) +@pytest.mark.parametrize("stream", **STREAM_PARAMS) +def test_create_llm_events(setup: BillingContext, is_ellm: bool, stream: bool): + """Verifies that LLM usage events correctly update organization metrics.""" + client = JamAI(user_id=setup.user.id, project_id=setup.project.id) + request = ChatRequest( + model=setup.ellm_chat_model_id if is_ellm else setup.chat_model_id, + messages=[ChatEntry.user(content="Tell me a very short joke.")], + max_tokens=10, + stream=stream, + ) + + with _test_usage_event( + client=client, + org_id=setup.org.id, + is_ellm=is_ellm, + usage_attr="llm_tokens_usage_mtok", + quota_attr="llm_tokens_quota_mtok", + ): + if stream: + list(client.generate_chat_completions(request)) + else: + client.generate_chat_completions(request) + + +@pytest.mark.cloud +@pytest.mark.parametrize("is_ellm", **MODEL_PROVIDER_PARAMS) +def test_create_embedding_events(setup: BillingContext, is_ellm: bool): + """Verifies that embedding usage events correctly update organization metrics.""" + client = JamAI(user_id=setup.user.id, project_id=setup.project.id) + request = EmbeddingRequest( + model=setup.ellm_embedding_model_id if is_ellm else setup.embedding_model_id, + input="This is a test for embedding billing.", + ) + + with _test_usage_event( + client=client, + org_id=setup.org.id, + is_ellm=is_ellm, + usage_attr="embedding_tokens_usage_mtok", + quota_attr="embedding_tokens_quota_mtok", + ): + client.generate_embeddings(request) + + +@pytest.mark.cloud +@pytest.mark.parametrize("is_ellm", **MODEL_PROVIDER_PARAMS) +def test_create_reranker_events(setup: BillingContext, is_ellm: bool): + """Verifies that reranker usage events correctly update organization metrics.""" + client = JamAI(user_id=setup.user.id, project_id=setup.project.id) + documents = [ + "Paris is the capital of France.", + "The Eiffel Tower is in Paris.", + "Berlin is the capital of Germany.", + ] + request = RerankingRequest( + model=setup.ellm_rerank_model_id if is_ellm else setup.rerank_model_id, + query="What is the capital of France?", + documents=documents, + ) + + with _test_usage_event( + client=client, + org_id=setup.org.id, + is_ellm=is_ellm, + usage_attr="reranker_usage_ksearch", + quota_attr="reranker_quota_ksearch", + ): + client.rerank(request) + + +def _retry(func): + for i in range(USAGE_RETRY): + sleep(USAGE_RETRY_DELAY) + logger.info(f"{func.__name__}: Attempt {i}") + try: + return func() + except Exception: + if i == USAGE_RETRY - 1: + raise + + +def _check_quotas(org: OrganizationRead, new_org: OrganizationRead): + # Credits + assert new_org.credit == org.credit + # LLM + assert new_org.llm_tokens_quota_mtok == org.llm_tokens_quota_mtok + # Embed + assert new_org.embedding_tokens_quota_mtok == org.embedding_tokens_quota_mtok + # Rerank (no usage yet) + assert new_org.reranker_quota_ksearch == org.reranker_quota_ksearch + # Egress + assert new_org.egress_quota_gib == org.egress_quota_gib + # DB storage + assert new_org.db_quota_gib == org.db_quota_gib + # File storage + assert new_org.file_quota_gib == org.file_quota_gib + + +@pytest.mark.cloud +@pytest.mark.timeout(180) +def test_gen_table_billing(setup: BillingContext): + client = JamAI(user_id=setup.user.id, project_id=setup.project.id) + org = client.organizations.get_organization(setup.org.id) + with ( + create_table( + client, TableType.KNOWLEDGE, embedding_model=setup.embedding_model_id, cols=[] + ) as kt, + create_table( + client, TableType.KNOWLEDGE, embedding_model=setup.ellm_embedding_model_id, cols=[] + ) as ellm_kt, + ): + ### --- Perform RAG --- ### + system_prompt = "Be concise." + gen_config_kwargs = dict( + system_prompt=system_prompt, + prompt="", + max_tokens=20, + temperature=0.001, + top_p=0.001, + ) + rag_kwargs = dict(search_query="", k=2) + cols = [ + ColumnSchemaCreate(id="question", dtype="str"), + ColumnSchemaCreate(id="image", dtype="image"), + ColumnSchemaCreate( + id="ellm", + dtype="str", + gen_config=LLMGenConfig( + model=setup.ellm_chat_model_id, + multi_turn=False, + rag_params=RAGParams( + reranking_model=setup.ellm_rerank_model_id, + table_id=ellm_kt.id, + **rag_kwargs, + ), + **gen_config_kwargs, + ), + ), + ColumnSchemaCreate( + id="non_ellm", + dtype="str", + gen_config=LLMGenConfig( + model=setup.chat_model_id, + multi_turn=False, + rag_params=RAGParams( + reranking_model=setup.rerank_model_id, + table_id=kt.id, + **rag_kwargs, + ), + **gen_config_kwargs, + ), + ), + ] + + ### --- Embed file --- ### + client.table.embed_file(file_path=FILES["weather.txt"], table_id=kt.id) + client.table.embed_file(file_path=FILES["weather.txt"], table_id=ellm_kt.id) + + # Check the billing data + def _check_embed(): + new_org = client.organizations.get_organization(setup.org.id) + # fmt: off + assert new_org.credit_grant < org.credit_grant, ( + f"{new_org.credit_grant=}, {org.credit_grant=}" + ) + assert new_org.llm_tokens_usage_mtok > org.llm_tokens_usage_mtok, ( + f"{new_org.llm_tokens_usage_mtok=}, {org.llm_tokens_usage_mtok=}" + ) + assert new_org.embedding_tokens_usage_mtok > org.embedding_tokens_usage_mtok, ( + f"{new_org.embedding_tokens_usage_mtok=}, {org.embedding_tokens_usage_mtok=}" + ) + # No usage yet + assert new_org.reranker_usage_ksearch == org.reranker_usage_ksearch, ( + f"{new_org.reranker_usage_ksearch=}, {org.reranker_usage_ksearch=}" + ) + assert new_org.egress_usage_gib > org.egress_usage_gib, ( + f"{new_org.egress_usage_gib=}, {org.egress_usage_gib=}" + ) + assert new_org.db_usage_gib > org.db_usage_gib, ( + f"{new_org.db_usage_gib=}, {org.db_usage_gib=}" + ) + assert new_org.file_usage_gib > org.file_usage_gib, ( + f"{new_org.file_usage_gib=}, {org.file_usage_gib=}" + ) + # fmt: on + _check_quotas(org, new_org) + return new_org + + org = _retry(_check_embed) + + ### --- RAG --- ### + image_uri = client.file.upload_file(FILES["rabbit.jpeg"]).uri + table_type = TableType.ACTION + with create_table(client, table_type, cols=cols) as table: + ### Stream + data = [dict(question="What is it?", image=image_uri)] + response = add_table_rows(client, table_type, table.id, data, stream=True) + assert len(response.rows) == len(data) + + # Check the billing data + def _check_rag_stream(): + new_org = client.organizations.get_organization(setup.org.id) + # fmt: off + assert new_org.credit_grant < org.credit_grant, ( + f"{new_org.credit_grant=}, {org.credit_grant=}" + ) + assert new_org.llm_tokens_usage_mtok > org.llm_tokens_usage_mtok, ( + f"{new_org.llm_tokens_usage_mtok=}, {org.llm_tokens_usage_mtok=}" + ) + assert new_org.embedding_tokens_usage_mtok > org.embedding_tokens_usage_mtok, ( + f"{new_org.embedding_tokens_usage_mtok=}, {org.embedding_tokens_usage_mtok=}" + ) + assert new_org.reranker_usage_ksearch > org.reranker_usage_ksearch, ( + f"{new_org.reranker_usage_ksearch=}, {org.reranker_usage_ksearch=}" + ) + assert new_org.egress_usage_gib > org.egress_usage_gib, ( + f"{new_org.egress_usage_gib=}, {org.egress_usage_gib=}" + ) + assert new_org.db_usage_gib > org.db_usage_gib, ( + f"{new_org.db_usage_gib=}, {org.db_usage_gib=}" + ) + assert new_org.file_usage_gib > org.file_usage_gib, ( + f"{new_org.file_usage_gib=}, {org.file_usage_gib=}" + ) + # fmt: on + _check_quotas(org, new_org) + return new_org + + org = _retry(_check_rag_stream) + + ### Non-stream + data = [dict(question="What is it?", image=image_uri)] + response = add_table_rows(client, table_type, table.id, data, stream=False) + assert len(response.rows) == len(data) + + # Check the billing data + def _check_rag_non_stream(): + new_org = client.organizations.get_organization(setup.org.id) + # fmt: off + assert new_org.credit_grant < org.credit_grant, ( + f"{new_org.credit_grant=}, {org.credit_grant=}" + ) + assert new_org.llm_tokens_usage_mtok > org.llm_tokens_usage_mtok, ( + f"{new_org.llm_tokens_usage_mtok=}, {org.llm_tokens_usage_mtok=}" + ) + assert new_org.embedding_tokens_usage_mtok > org.embedding_tokens_usage_mtok, ( + f"{new_org.embedding_tokens_usage_mtok=}, {org.embedding_tokens_usage_mtok=}" + ) + assert new_org.reranker_usage_ksearch > org.reranker_usage_ksearch, ( + f"{new_org.reranker_usage_ksearch=}, {org.reranker_usage_ksearch=}" + ) + assert new_org.egress_usage_gib > org.egress_usage_gib, ( + f"{new_org.egress_usage_gib=}, {org.egress_usage_gib=}" + ) + # No new page allocated + assert new_org.db_usage_gib == org.db_usage_gib, ( + f"{new_org.db_usage_gib=}, {org.db_usage_gib=}" + ) + # No new file uploaded + assert new_org.file_usage_gib == org.file_usage_gib, ( + f"{new_org.file_usage_gib=}, {org.file_usage_gib=}" + ) + # fmt: on + _check_quotas(org, new_org) + return new_org + + org = _retry(_check_rag_non_stream) + + ### --- Tables deleted --- ### + # Check the billing data + def _check_delete(): + new_org = client.organizations.get_organization(setup.org.id) + # fmt: off + assert new_org.credit_grant == org.credit_grant, ( + f"{new_org.credit_grant=}, {org.credit_grant=}" + ) + assert new_org.llm_tokens_usage_mtok == org.llm_tokens_usage_mtok, ( + f"{new_org.llm_tokens_usage_mtok=}, {org.llm_tokens_usage_mtok=}" + ) + assert new_org.embedding_tokens_usage_mtok == org.embedding_tokens_usage_mtok, ( + f"{new_org.embedding_tokens_usage_mtok=}, {org.embedding_tokens_usage_mtok=}" + ) + assert new_org.reranker_usage_ksearch == org.reranker_usage_ksearch, ( + f"{new_org.reranker_usage_ksearch=}, {org.reranker_usage_ksearch=}" + ) + assert new_org.egress_usage_gib > org.egress_usage_gib, ( + f"{new_org.egress_usage_gib=}, {org.egress_usage_gib=}" + ) + assert new_org.db_usage_gib < org.db_usage_gib, ( + f"{new_org.db_usage_gib=}, {org.db_usage_gib=}" + ) + assert new_org.file_usage_gib == org.file_usage_gib, ( + f"{new_org.file_usage_gib=}, {org.file_usage_gib=}" + ) + # fmt: on + _check_quotas(org, new_org) + return new_org + + org = _retry(_check_delete) + + +@pytest.mark.cloud +def test_tiered_billing(): + from owl.utils.billing.cloud import BillingManager + + base_kwargs = dict(created_at=now(), updated_at=now()) + price_plan = PricePlan_( + id="free", + name="Free plan", + stripe_price_id_live="stripe_price_id_live", + stripe_price_id_test="stripe_price_id_test", + flat_cost=0.0, + credit_grant=0.0, + max_users=2, # For ease of testing + products=Products( + llm_tokens=Product( + name="ELLM tokens", + included=PriceTier(unit_cost=0.5, up_to=0.75), + tiers=[], + unit="Million Tokens", + ), + embedding_tokens=Product( + name="Embedding tokens", + included=PriceTier(unit_cost=0.5, up_to=0.75), + tiers=[], + unit="Million Tokens", + ), + reranker_searches=Product( + name="Reranker searches", + included=PriceTier(unit_cost=0.5, up_to=0.75), + tiers=[], + unit="Thousand Searches", + ), + db_storage=Product( + name="Database storage", + included=PriceTier(unit_cost=0.5, up_to=0.75), + tiers=[], + unit="GiB", + ), + file_storage=Product( + name="File storage", + included=PriceTier(unit_cost=0.5, up_to=0.75), + tiers=[], + unit="GiB", + ), + egress=Product( + name="Egress bandwidth", + included=PriceTier(unit_cost=0.5, up_to=0.5), + tiers=[ + PriceTier(unit_cost=1.0, up_to=1.0), + PriceTier(unit_cost=2.0, up_to=None), + ], + unit="GiB", + ), + ), + is_private=False, + stripe_price_id="stripe_price_id", + **base_kwargs, + ) + assert price_plan.products.egress.included.unit_cost == 0 + org = OrganizationRead( + id="test_org", + name="test_org", + created_by="", + owner="", + stripe_id="stripe_id", + external_keys={}, + price_plan_id=price_plan.id, + payment_state=PaymentState.SUCCESS, + last_subscription_payment_at=now(), + quota_reset_at=now(), + credit=0.0, + credit_grant=0.0, + llm_tokens_quota_mtok=price_plan.products.llm_tokens.included.up_to, + llm_tokens_usage_mtok=0.0, + embedding_tokens_quota_mtok=price_plan.products.embedding_tokens.included.up_to, + embedding_tokens_usage_mtok=0.0, + reranker_quota_ksearch=price_plan.products.reranker_searches.included.up_to, + reranker_usage_ksearch=0.0, + db_quota_gib=price_plan.products.db_storage.included.up_to, + db_usage_gib=0.0, + db_usage_updated_at=now(), + file_quota_gib=price_plan.products.file_storage.included.up_to, + file_usage_gib=0.0, + file_usage_updated_at=now(), + egress_quota_gib=price_plan.products.egress.included.up_to, + egress_usage_gib=0.0, + quotas={}, + active=True, + price_plan=price_plan, + **base_kwargs, + ) + # Test single charge + billing = BillingManager( + organization=org.model_copy(), + project_id="test_project", + user_id="test_user", + ) + usage = 2.0 + billing.create_egress_events(usage) + billing.org.egress_usage_gib += usage + assert round(billing.cost, 2) == 2.0 + # Test multiple charge + billing = BillingManager( + organization=org.model_copy(), + project_id="test_project", + user_id="test_user", + ) + usage = 0.4 + billing.create_egress_events(usage) + billing.org.egress_usage_gib += usage + assert billing.cost == 0.0 + usage = 0.2 # 0.1 * 0.0 + 0.1 * 1.0 + billing.create_egress_events(usage) + billing.org.egress_usage_gib += usage + assert round(billing.cost, 2) == 0.1 + usage = 1.2 # 0.9 * 1.0 + 0.3 * 2.0 + billing.create_egress_events(usage) + billing.org.egress_usage_gib += usage + assert round(billing.cost, 2) == 1.6 diff --git a/services/api/tests/utils/test_crypt.py b/services/api/tests/utils/test_crypt.py new file mode 100644 index 0000000..a92928b --- /dev/null +++ b/services/api/tests/utils/test_crypt.py @@ -0,0 +1,112 @@ +import io + +import pytest + +from owl.utils.crypt import ( + blake2b_hash_file, + decrypt, + encrypt_deterministic, + encrypt_random, + generate_key, + hash_string_blake2b, +) + + +def test_encrypt_random(): + message = "Hello, World!" + password = "secret" + encrypted = encrypt_random(message, password) + decrypted = decrypt(encrypted, password) + assert message == decrypted + + +def test_encrypt_deterministic(): + message = "Hello, World!" + password = "secret" + encrypted1 = encrypt_deterministic(message, password) + encrypted2 = encrypt_deterministic(message, password) + assert encrypted1 == encrypted2 + decrypted = decrypt(encrypted1, password) + assert message == decrypted + + +def test_decrypt_invalid_parts(): + with pytest.raises(ValueError): + decrypt("invalid*format*with*three*parts", "password") + + +def test_decrypt_wrong_password(): + message = "Hello, World!" + password = "correct_password" + wrong_password = "wrong_password" + encrypted = encrypt_random(message, password) + with pytest.raises(ValueError): + decrypt(encrypted, wrong_password) + + +def test_empty_message(): + message = "" + password = "secret" + encrypted = encrypt_random(message, password) + decrypted = decrypt(encrypted, password) + assert message == decrypted + + +def test_long_message(): + message = "A" * 1000000 # 1 million characters + password = "secret" + encrypted = encrypt_random(message, password) + decrypted = decrypt(encrypted, password) + assert message == decrypted + + +def test_hash_string_blake2b(): + string = "Hello, World!" + hashed = hash_string_blake2b(string) + assert len(hashed) == 8 + + +def test_hash_string_blake2b_custom_size(): + string = "Hello, World!" + hashed = hash_string_blake2b(string, key_length=16) + assert len(hashed) == 16 + + +def test_blake2b_hash_file(): + file_content = b"Hello, World!" + file = io.BytesIO(file_content) + hashed = blake2b_hash_file(file) + assert len(hashed) == 128 # Default blake2b digest size is 64 bytes + + +def test_blake2b_hash_file_custom_blocksize(): + file_content = b"Hello, World!" * 1000 + file = io.BytesIO(file_content) + hashed = blake2b_hash_file(file, blocksize=1024) + assert len(hashed) == 128 + + +def test_generate_key_default(): + key = generate_key() + assert len(key) == 48 + + +def test_generate_key_custom_length(): + key = generate_key(key_length=32) + assert len(key) == 32 + + +def test_generate_key_with_prefix(): + key = generate_key(prefix="test_") + assert key.startswith("test_") + assert len(key) == 53 # 48 + 5 (prefix length) + + +def test_generate_key_invalid_length(): + with pytest.raises(ValueError): + generate_key(key_length=15) + + +def test_generate_key_odd_length(): + with pytest.raises(ValueError): + generate_key(key_length=33) diff --git a/services/api/tests/utils/test_dates.py b/services/api/tests/utils/test_dates.py new file mode 100644 index 0000000..7010d6d --- /dev/null +++ b/services/api/tests/utils/test_dates.py @@ -0,0 +1,109 @@ +import unittest +from datetime import date, datetime, timezone +from zoneinfo import ZoneInfo + +from freezegun import freeze_time + +from owl.utils.dates import ( + date_to_utc, + date_to_utc_iso, + ensure_utc_timezone, + now, + now_iso, + utc_iso_from_datetime, + utc_iso_from_string, + utc_iso_from_uuid7, + utc_iso_from_uuid7_draft2, +) + + +class TestDateTimeFunctions(unittest.TestCase): + @freeze_time("2023-05-01 12:00:00+00:00") + def test_now_iso(self): + self.assertEqual(now_iso(), "2023-05-01T12:00:00+00:00") + self.assertEqual(now_iso("America/New_York"), "2023-05-01T08:00:00-04:00") + + @freeze_time("2023-05-01 12:00:00+00:00") + def test_now(self): + self.assertEqual(now(), datetime(2023, 5, 1, 12, 0, 0, tzinfo=timezone.utc)) + expected_ny_time = datetime(2023, 5, 1, 12, 0, 0, tzinfo=timezone.utc).astimezone( + ZoneInfo("America/New_York") + ) + self.assertEqual( + now("America/New_York"), + expected_ny_time, + ) + + def test_utc_iso_from_string(self): + self.assertEqual( + utc_iso_from_string("2023-05-01T12:00:00+02:00"), "2023-05-01T10:00:00+00:00" + ) + with self.assertRaises(ValueError): + utc_iso_from_string("2023-05-01T12:00:00") # No timezone + + def test_utc_iso_from_datetime(self): + dt = datetime(2023, 5, 1, 12, 0, 0, tzinfo=ZoneInfo("Europe/Berlin")) + self.assertEqual(utc_iso_from_datetime(dt), "2023-05-01T10:00:00+00:00") + + def test_utc_iso_from_uuid7(self): + uuid7_str = "018859e1-6a62-7f60-b6e1-f6e4b8ec6b66" + result = utc_iso_from_uuid7(uuid7_str) + self.assertTrue(result.startswith("2023-")) # Check if it's at least from 2023 + + # def test_utc_iso_from_uuid7_draft2(self): + # uuid7_str = "018859e1-6a62-7f60-b6e1-f6e4b8ec6b66" + # result = utc_iso_from_uuid7_draft2(uuid7_str) + # self.assertTrue(result.startswith("2023-")) # Check if it's at least from 2023 + + def test_date_to_utc_iso(self): + d = date(2023, 5, 1) + self.assertEqual(date_to_utc_iso(d), "2023-05-01T00:00:00+00:00") + self.assertEqual(date_to_utc_iso(d, "America/New_York"), "2023-05-01T04:00:00+00:00") + + def test_date_to_utc(self): + d = date(2023, 5, 1) + self.assertEqual(date_to_utc(d), datetime(2023, 5, 1, 0, 0, 0, tzinfo=timezone.utc)) + self.assertEqual( + date_to_utc(d, "America/New_York"), + datetime(2023, 5, 1, 0, 0, 0, tzinfo=ZoneInfo("America/New_York")), + ) + + def test_ensure_utc_timezone(self): + self.assertEqual( + ensure_utc_timezone("2023-05-01T12:00:00+00:00"), "2023-05-01T12:00:00+00:00" + ) + with self.assertRaises(ValueError): + ensure_utc_timezone("2023-05-01T12:00:00+02:00") + + # Edge cases + def test_utc_iso_from_string_edge_cases(self): + with self.assertRaises(ValueError): + utc_iso_from_string("invalid_datetime") + with self.assertRaises(ValueError): + utc_iso_from_string("2023-05-01") # No time + + def test_utc_iso_from_datetime_edge_cases(self): + with self.assertRaises(ValueError): + utc_iso_from_datetime(datetime(2023, 5, 1)) # No timezone + + def test_utc_iso_from_uuid7_edge_cases(self): + with self.assertRaises(ValueError): + utc_iso_from_uuid7("invalid_uuid") + + def test_utc_iso_from_uuid7_draft2_edge_cases(self): + with self.assertRaises(ValueError): + utc_iso_from_uuid7_draft2("invalid_uuid") + + def test_date_to_utc_iso_edge_cases(self): + with self.assertRaises(ValueError): + date_to_utc_iso(date(2023, 5, 1), "Invalid/Timezone") + + def test_ensure_utc_timezone_edge_cases(self): + with self.assertRaises(ValueError): + ensure_utc_timezone("invalid_datetime") + with self.assertRaises(ValueError): + ensure_utc_timezone("2023-05-01T12:00:00") # No timezone + + +if __name__ == "__main__": + unittest.main() diff --git a/services/api/tests/utils/test_file.py b/services/api/tests/utils/test_file.py new file mode 100644 index 0000000..f562d99 --- /dev/null +++ b/services/api/tests/utils/test_file.py @@ -0,0 +1,268 @@ +import os +import re +import tempfile +from dataclasses import dataclass +from io import BytesIO +from os.path import basename, dirname, join, realpath +from urllib.parse import urlparse + +import httpx +import numpy as np +import pytest +from PIL import Image + +from jamaibase import JamAI +from jamaibase.types import ( + FileUploadResponse, + GetURLResponse, + OrganizationCreate, +) +from jamaibase.utils.exceptions import BadInputError +from owl.types import Role +from owl.utils.test import ( + create_organization, + create_project, + create_user, + get_file_map, + upload_file, +) + +TEST_FILE_DIR = join(dirname(dirname(realpath(__file__))), "files") +FILES = get_file_map(TEST_FILE_DIR) +# Define the paths to your test image and audio files +IMAGE_FILES = [ + FILES["cifar10-deer.jpg"], + FILES["rabbit.png"], + FILES["rabbit_cifar10-deer.gif"], + FILES["rabbit_cifar10-deer.webp"], +] +AUDIO_FILES = [ + FILES["gutter.wav"], + FILES["gutter.mp3"], +] +DOC_FILES = [ + FILES["1970_PSS_ThAT_mechanism.pdf"], + FILES["Claims Form.xlsx"], +] +ALL_FILES = IMAGE_FILES + AUDIO_FILES + DOC_FILES + + +@dataclass(slots=True) +class FileContext: + superuser_id: str + user_id: str + org_id: str + project_id: str + + +def _read_file_content(file_path): + with open(file_path, "rb") as f: + return f.read() + + +@pytest.fixture(scope="module") +def setup(): + """ + Fixture to set up the necessary organization and projects for file tests. + """ + with ( + # Create superuser + create_user() as superuser, + # Create user + create_user({"email": "testuser@example.com", "name": "Test User"}) as user, + # Create organization + create_organization( + body=OrganizationCreate(name="Clubhouse"), user_id=superuser.id + ) as org, + # Create project + create_project(dict(name="Bucket A"), user_id=superuser.id, organization_id=org.id) as p0, + ): + assert superuser.id == "0" + assert org.id == "0" + client = JamAI(user_id=superuser.id) + # Join organization and project + client.organizations.join_organization( + user_id=user.id, organization_id=org.id, role=Role.ADMIN + ) + client.projects.join_project(user_id=user.id, project_id=p0.id, role=Role.ADMIN) + + yield FileContext( + superuser_id=superuser.id, user_id=user.id, org_id=org.id, project_id=p0.id + ) + + +@pytest.mark.parametrize("image_file", IMAGE_FILES) +def test_upload_image(setup: FileContext, image_file: str): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + # Ensure the image file exists + assert os.path.exists(image_file), f"Test image file does not exist: {image_file}" + # Upload the file + upload_response = upload_file(client, image_file) + assert isinstance(upload_response, FileUploadResponse) + assert upload_response.uri.startswith(("file://", "s3://")), ( + f"Returned URI '{upload_response.uri}' does not start with 'file://' or 's3://'" + ) + + filename = os.path.basename(image_file) + expected_uri_pattern = re.compile( + rf"(file|s3)://[^/]+/raw/{setup.org_id}/{setup.project_id}/[a-f0-9-]{{36}}/" + + re.escape(filename) + + "$" + ) + # Check if the returned URI matches the expected format + assert expected_uri_pattern.match(upload_response.uri), ( + f"Returned URI '{upload_response.uri}' does not match the expected format: " + f"(file|s3)://file/raw/{setup.org_id}/{setup.project_id}/{{UUID}}/{filename}" + ) + + +@pytest.mark.parametrize("audio_file", AUDIO_FILES) +def test_upload_audio(setup: FileContext, audio_file: str): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + + # Ensure the audio file exists + assert os.path.exists(audio_file), f"Test audio file does not exist: {audio_file}" + # Upload the file + upload_response = upload_file(client, audio_file) + assert isinstance(upload_response, FileUploadResponse) + assert upload_response.uri.startswith(("file://", "s3://")), ( + f"Returned URI '{upload_response.uri}' does not start with 'file://' or 's3://'" + ) + + filename = os.path.basename(audio_file) + expected_uri_pattern = re.compile( + rf"(file|s3)://[^/]+/raw/{setup.org_id}/{setup.project_id}/[a-f0-9-]{{36}}/" + + re.escape(filename) + + "$" + ) + # Check if the returned URI matches the expected format + assert expected_uri_pattern.match(upload_response.uri), ( + f"Returned URI '{upload_response.uri}' does not match the expected format: " + f"(file|s3)://file/raw/{setup.org_id}/{setup.project_id}/{{UUID}}/{filename}" + ) + + +@pytest.mark.parametrize("doc_file", DOC_FILES) +def test_upload_doc(setup: FileContext, doc_file: str): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + + # Ensure the doc file exists + assert os.path.exists(doc_file), f"Test doc file does not exist: {doc_file}" + # Upload the file + upload_response = upload_file(client, doc_file) + assert isinstance(upload_response, FileUploadResponse) + assert upload_response.uri.startswith(("file://", "s3://")), ( + f"Returned URI '{upload_response.uri}' does not start with 'file://' or 's3://'" + ) + + filename = os.path.basename(doc_file) + expected_uri_pattern = re.compile( + rf"(file|s3)://[^/]+/raw/{setup.org_id}/{setup.project_id}/[a-f0-9-]{{36}}/" + + re.escape(filename) + + "$" + ) + + # Check if the returned URI matches the expected format + assert expected_uri_pattern.match(upload_response.uri), ( + f"Returned URI '{upload_response.uri}' does not match the expected format: " + f"(file|s3)://file/raw/{setup.org_id}/{setup.project_id}/{{UUID}}/{filename}" + ) + + +def test_upload_large_image_file(setup: FileContext): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + + # Create 25MB image file, assuming 3 bytes per pixel (RGB) and 8 bits per byte + side_length = int(np.sqrt((25 * 1024 * 1024) / 3)) + data = np.random.randint(0, 256, (side_length, side_length, 3), dtype=np.uint8) + img = Image.fromarray(data, "RGB") + + with tempfile.TemporaryDirectory() as temp_dir: + file_path = os.path.join(temp_dir, "large_image.png") + img.save(file_path, format="PNG") + + pattern = re.compile("File size exceeds .+ limit") + with pytest.raises(BadInputError, match=pattern): + upload_file(client, file_path) + + +def test_get_raw_urls(setup: FileContext): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + # Upload files first + uploaded_uris = [] + for f in ALL_FILES: + response = upload_file(client, f) + uploaded_uris.append(response.uri) + + # Now test get_raw_urls + response = client.file.get_raw_urls(uploaded_uris) + assert isinstance(response, GetURLResponse) + assert len(response.urls) == len(ALL_FILES) + for original_file, url in zip(ALL_FILES, response.urls, strict=True): + downloaded_content = httpx.get(url).content + original_content = _read_file_content(original_file) + # Compare the contents + assert original_content == downloaded_content, ( + f"Content mismatch for file: {original_file}" + ) + + # Check if the returned URIs are absolute paths + for url in response.urls: + parsed_uri = urlparse(url) + + if parsed_uri.scheme in ("http", "https"): + assert parsed_uri.netloc, f"Invalid HTTP/HTTPS URL: {url}" + elif parsed_uri.scheme == "file" or not parsed_uri.scheme: + file_path = parsed_uri.path if parsed_uri.scheme == "file" else url + assert os.path.isabs(file_path), f"File path is not absolute: {url}" + else: + raise ValueError(f"Unsupported URI or file not found: {url}") + + +def test_get_thumbnail_urls(setup: FileContext): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + + # Upload files first + uploaded_uris = [upload_file(client, f).uri for f in ALL_FILES] + + # Test get_thumbnail_urls + response = client.file.get_thumbnail_urls(uploaded_uris) + assert isinstance(response, GetURLResponse) + assert len(response.urls) == len(ALL_FILES) + thumb_url_map = {basename(f): url for f, url in zip(ALL_FILES, response.urls, strict=True)} + + # Compare thumbnails + for file_path in ALL_FILES: + thumb_url = thumb_url_map[basename(file_path)] + if file_path in IMAGE_FILES + DOC_FILES: + expected_thumb = _read_file_content(f"{file_path}.thumb.webp") + elif file_path in AUDIO_FILES: + expected_thumb = _read_file_content(f"{file_path}.thumb.mp3") + else: + raise ValueError(f"Unexpected file: {file_path}") + if thumb_url.startswith(("http://", "https://")): + downloaded_thumb = httpx.get(thumb_url).content + else: + downloaded_thumb = _read_file_content(thumb_url) + # Compare the contents + if file_path in AUDIO_FILES: + # We could find a way to strip out ID3 tags but it's easier to just compare parts of it + expected_thumb = expected_thumb[-round(len(expected_thumb) * 0.9) :] + downloaded_thumb = downloaded_thumb[-round(len(downloaded_thumb) * 0.9) :] + assert expected_thumb == downloaded_thumb, f"Thumbnail mismatch for file: {file_path}" + + +def test_thumbnail_transparency(setup: FileContext): + client = JamAI(user_id=setup.user_id, project_id=setup.project_id) + response = upload_file(client, FILES["github-mark-white.png"]) + response = client.file.get_thumbnail_urls([response.uri]) + assert isinstance(response, GetURLResponse) + assert len(response.urls) == 1 + thumb_url = response.urls[0] + if thumb_url.startswith(("http://", "https://")): + downloaded_thumbnail = httpx.get(thumb_url).content + else: + downloaded_thumbnail = _read_file_content(thumb_url) + + image = Image.open(BytesIO(downloaded_thumbnail)) + assert image.mode == "RGBA" diff --git a/services/api/tests/utils/test_io.py b/services/api/tests/utils/test_io.py new file mode 100644 index 0000000..b8f54dd --- /dev/null +++ b/services/api/tests/utils/test_io.py @@ -0,0 +1,140 @@ +import pickle +import unittest +from unittest.mock import MagicMock, mock_open, patch + +import numpy as np +import pandas as pd +from PIL import ExifTags, Image + +from owl.utils.io import ( + csv_to_df, + df_to_csv, + dump_json, + dump_pickle, + dump_toml, + dump_yaml, + json_dumps, + json_loads, + load_pickle, + read_image, + read_json, + read_toml, + read_yaml, +) + + +class TestFileOperations(unittest.TestCase): + def test_load_pickle(self): + mock_data = {"key": "value"} + with patch("builtins.open", mock_open(read_data=pickle.dumps(mock_data))): + result = load_pickle("dummy_path") + self.assertEqual(result, mock_data) + + def test_dump_pickle(self): + mock_data = {"key": "value"} + mock_file = mock_open() + with patch("builtins.open", mock_file): + dump_pickle("dummy_path", mock_data) + mock_file().write.assert_called() + + def test_read_json(self): + mock_data = '{"key": "value"}' + with patch("builtins.open", mock_open(read_data=mock_data)): + result = read_json("dummy_path") + self.assertEqual(result, {"key": "value"}) + + def test_dump_json(self): + mock_data = {"key": "value"} + mock_file = mock_open() + with patch("builtins.open", mock_file): + result = dump_json(mock_data, "dummy_path") + self.assertEqual(result, "dummy_path") + mock_file().write.assert_called() + + def test_json_loads(self): + mock_data = '{"key": "value"}' + result = json_loads(mock_data) + self.assertEqual(result, {"key": "value"}) + + def test_json_dumps(self): + mock_data = {"key": "value"} + result = json_dumps(mock_data) + self.assertEqual(result, '{"key":"value"}') + + def test_read_yaml(self): + mock_data = "key: value" + with patch("builtins.open", mock_open(read_data=mock_data)): + result = read_yaml("dummy_path") + self.assertEqual(result, {"key": "value"}) + + def test_dump_yaml(self): + mock_data = {"key": "value"} + mock_file = mock_open() + with patch("builtins.open", mock_file): + result = dump_yaml(mock_data, "dummy_path") + self.assertEqual(result, "dummy_path") + mock_file().write.assert_called() + + def test_read_toml(self): + mock_data = 'key = "value"' + with patch("builtins.open", mock_open(read_data=mock_data)): + result = read_toml("dummy_path") + self.assertEqual(result, {"key": "value"}) + + def test_dump_toml(self): + mock_data = {"key": "value"} + mock_file = mock_open() + with patch("builtins.open", mock_file): + result = dump_toml(mock_data, "dummy_path") + self.assertEqual(result, "dummy_path") + mock_file().write.assert_called() + + def test_csv_to_df(self): + mock_data = "col1,col2\n1,2\n3,4" + result = csv_to_df(mock_data) + expected = pd.DataFrame({"col1": [1, 3], "col2": [2, 4]}) + pd.testing.assert_frame_equal(result, expected) + + def test_csv_to_df_with_column_names(self): + mock_data = "1,2\n3,4" + result = csv_to_df(mock_data, column_names=["A", "B"]) + expected = pd.DataFrame({"A": [1, 3], "B": [2, 4]}) + pd.testing.assert_frame_equal(result, expected) + + @patch("pandas.DataFrame.to_csv") + def test_df_to_csv(self, mock_to_csv): + df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + df_to_csv(df, "dummy_path") + mock_to_csv.assert_called_once() + + @patch("PIL.Image.open") + def test_read_image_rotated(self, mock_open): + mock_image = Image.new("RGB", (100, 100)) + mock_exif = {} + for key, value in ExifTags.TAGS.items(): + if value == "Orientation": + mock_exif[key] = 3 # 3 is the code for 180 degree rotation + break + mock_image.getexif = MagicMock(return_value=mock_exif) + mock_open.return_value.__enter__.return_value = mock_image + + result, is_rotated = read_image("dummy_path") + self.assertIsInstance(result, np.ndarray) + self.assertEqual(result.shape, (100, 100, 3)) + self.assertTrue(is_rotated) + + @patch("PIL.Image.open") + def test_read_image_not_rotated(self, mock_open): + mock_image = Image.new("RGB", (100, 100)) + mock_exif = {} # Empty EXIF data (no orientation) + mock_image.getexif = MagicMock(return_value=mock_exif) + mock_open.return_value.__enter__.return_value = mock_image + + result, is_rotated = read_image("dummy_path") + self.assertIsInstance(result, np.ndarray) + self.assertEqual(result.shape, (100, 100, 3)) + self.assertFalse(is_rotated) + + +if __name__ == "__main__": + unittest.main() diff --git a/services/api/tests/utils/test_jwt.py b/services/api/tests/utils/test_jwt.py new file mode 100644 index 0000000..d8b9c9f --- /dev/null +++ b/services/api/tests/utils/test_jwt.py @@ -0,0 +1,44 @@ +from datetime import datetime, timedelta, timezone +from time import sleep + +import pytest + +from owl.utils.exceptions import AuthorizationError +from owl.utils.jwt import decode_jwt, encode_jwt + + +def test_jwt_round_trip(): + data = {"user_id": 123, "role": "admin"} + expiry = datetime.now(timezone.utc) + timedelta(minutes=5) + token = encode_jwt(data, expiry) + decoded = decode_jwt(token, "expired", "invalid") + # Should contain original data plus 'iat' and 'exp' + assert decoded["user_id"] == 123 + assert decoded["role"] == "admin" + assert "iat" in decoded + assert "exp" in decoded + + +def test_jwt_expired(): + expiry = datetime.now(timezone.utc) - timedelta(seconds=1) + token = encode_jwt({"user_id": 456}, expiry) + sleep(2) + with pytest.raises(AuthorizationError, match="expired"): + decode_jwt(token, "expired", "invalid") + + +def test_jwt_invalid_signature(): + data = {"user_id": 789} + expiry = datetime.now(timezone.utc) + timedelta(minutes=1) + token = encode_jwt(data, expiry) + # Tamper with the token + bad_token = token + "abc" + with pytest.raises(AuthorizationError, match="invalid"): + decode_jwt(bad_token, "expired", "invalid") + + +def test_jwt_invalid_token_format(): + # Not even a JWT + bad_token = "not.a.jwt" + with pytest.raises(AuthorizationError, match="invalid"): + decode_jwt(bad_token, "expired", "invalid") diff --git a/services/api/tests/utils/test_mcp.py b/services/api/tests/utils/test_mcp.py new file mode 100644 index 0000000..657295d --- /dev/null +++ b/services/api/tests/utils/test_mcp.py @@ -0,0 +1,227 @@ +from contextlib import asynccontextmanager +from dataclasses import dataclass + +import pytest +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client +from mcp.types import ( + CallToolResult, + ClientNotification, + EmptyResult, + InitializedNotification, + ListToolsResult, +) + +from jamaibase import JamAIAsync +from jamaibase.types import ( + OrganizationCreate, + Page, + ProjectRead, + Role, +) +from owl.configs import ENV_CONFIG +from owl.utils.test import ( + create_organization, + create_project, + create_user, +) + + +@dataclass(slots=True) +class SetupContext: + superorg_id: str + org_id: str + superproject_id: str + project_id: str + superuser_id: str + user_id: str + guestuser_id: str + + +@pytest.fixture(scope="module") +async def setup(): + with ( + # Create superuser + create_user() as superuser, + # Create user + create_user({"email": "testuser@example.com", "name": "Test User"}) as user, + # Create guestuser + create_user({"email": "guest@example.com", "name": "Test Guest User"}) as guestuser, + # Create super organization + create_organization( + body=OrganizationCreate(name="Clubhouse"), user_id=superuser.id + ) as superorg, + # Create organization + create_organization(body=OrganizationCreate(name="CommonOrg"), user_id=user.id) as org, + # Create project + create_project( + dict(name="projA"), user_id=superuser.id, organization_id=superorg.id + ) as p0, + create_project(dict(name="projA"), user_id=user.id, organization_id=org.id) as p1, + ): + client = JamAIAsync(user_id=user.id) + # guest user join organization but not project + await client.organizations.join_organization( + user_id=guestuser.id, organization_id=org.id, role=Role.MEMBER + ) + yield SetupContext( + superorg_id=superorg.id, + org_id=org.id, + superproject_id=p0.id, + project_id=p1.id, + superuser_id=superuser.id, + user_id=user.id, + guestuser_id=guestuser.id, + ) + + +@asynccontextmanager +async def mcp_session(user_id: str, project_id: str | None = None): + # Connect to a streamable HTTP server + headers = { + "X-USER-ID": user_id, + "X-PROJECT-ID": project_id if project_id else "", + } + if ENV_CONFIG.is_cloud: + headers["Authorization"] = f"Bearer {ENV_CONFIG.service_key_plain}" + async with streamablehttp_client( + url=f"http://localhost:{ENV_CONFIG.port}/api/v1/mcp/http", + headers=headers, + ) as (read_stream, write_stream, _): + # Create a session using the client streams + async with ClientSession(read_stream, write_stream) as session: + # Initialize the connection + await session.initialize() + yield session + + +async def test_send_ping(setup: SetupContext): + async with mcp_session(setup.superuser_id) as session: + response = await session.send_ping() + assert isinstance(response, EmptyResult) + + +async def test_send_notification(setup: SetupContext): + async with mcp_session(setup.superuser_id) as session: + response = await session.send_notification( + ClientNotification(InitializedNotification(method="notifications/initialized")) + ) + assert response is None + + +async def test_list_tools(setup: SetupContext): + async with mcp_session(setup.superuser_id) as session: + tool_list = await session.list_tools() + assert isinstance(tool_list, ListToolsResult) + tools = {tool.name: tool for tool in tool_list.tools} + # Should have all tools available + assert ( + "list_organizations_api_v2_organizations_list_get" in tools + ) # need system membership + assert "model_info_api_v1_models_get" in tools # need project membership + assert "create_project_api_v2_projects_post" in tools # needs organization.admin + assert ( + "create_action_table_api_v2_gen_tables_action_post" in tools + ) # needs organization or project member permission + assert ( + "create_conversation_api_v2_conversations_post" in tools + ) # needs project member permission + assert ( + "list_projects_api_v2_projects_list_get" in tools + ) # need system or organization guest permission + + async with mcp_session(setup.user_id) as session: + tool_list = await session.list_tools() + assert isinstance(tool_list, ListToolsResult) + tools = {tool.name: tool for tool in tool_list.tools} + assert "model_info_api_v1_models_get" in tools # need project membership + assert "create_project_api_v2_projects_post" in tools # needs organization.admin + assert ( + "create_action_table_api_v2_gen_tables_action_post" in tools + ) # needs organization or project member permission + assert ( + "create_conversation_api_v2_conversations_post" in tools + ) # needs project member permission + assert ( + "list_projects_api_v2_projects_list_get" in tools + ) # need system or organization guest permission + + async with mcp_session(setup.guestuser_id) as session: + tool_list = await session.list_tools() + assert isinstance(tool_list, ListToolsResult) + tools = {tool.name: tool for tool in tool_list.tools} + assert "model_info_api_v1_models_get" not in tools # need project membership + assert "create_project_api_v2_projects_post" not in tools # needs organization.admin + assert ( + "create_action_table_api_v2_gen_tables_action_post" in tools + ) # needs organization or project member permission + assert ( + "create_conversation_api_v2_conversations_post" not in tools + ) # needs project member permission + assert ( + "list_projects_api_v2_projects_list_get" in tools + ) # need system or organization guest permission + + +@pytest.mark.cloud +async def test_list_tools_system_membership(setup: SetupContext): + async with mcp_session(setup.superuser_id) as session: + tool_list = await session.list_tools() + assert isinstance(tool_list, ListToolsResult) + tools = {tool.name: tool for tool in tool_list.tools} + # Should have all tools available + assert ( + "list_organizations_api_v2_organizations_list_get" in tools + ) # need system membership + + async with mcp_session(setup.user_id) as session: + tool_list = await session.list_tools() + assert isinstance(tool_list, ListToolsResult) + tools = {tool.name: tool for tool in tool_list.tools} + assert ( + "list_organizations_api_v2_organizations_list_get" not in tools + ) # need system membership + + async with mcp_session(setup.guestuser_id) as session: + tool_list = await session.list_tools() + assert isinstance(tool_list, ListToolsResult) + tools = {tool.name: tool for tool in tool_list.tools} + assert ( + "list_organizations_api_v2_organizations_list_get" not in tools + ) # need system membership + + +async def test_call_tool(setup: SetupContext): + async with mcp_session(setup.superuser_id) as session: + # List projects + tool_result = await session.call_tool( + "list_projects_api_v2_projects_list_get", + dict( + organization_id=setup.superorg_id, + limit=2, + order_by="created_at", + order_ascending=False, + ), + ) + assert isinstance(tool_result, CallToolResult) + assert not tool_result.isError + assert isinstance(tool_result.content[0].text, str) + projects = Page[ProjectRead].model_validate_json(tool_result.content[0].text) + assert projects.total == 1 + assert projects.items[0].id == setup.superproject_id + # Create Proj + new_proj_name = "MCP proj" + tool_result = await session.call_tool( + "create_project_api_v2_projects_post", + dict(organization_id=setup.superorg_id, name=new_proj_name), + ) + assert isinstance(tool_result, CallToolResult) + assert isinstance(tool_result.content[0].text, str) + proj = ProjectRead.model_validate_json(tool_result.content[0].text) + assert proj.organization.id == setup.superorg_id + assert proj.name == new_proj_name + # Fetch the updated organization + client = JamAIAsync(user_id=setup.superuser_id) + p = await client.projects.get_project(proj.id) + assert isinstance(p, ProjectRead) + assert p.name == new_proj_name diff --git a/services/api/tests/utils/test_utils.py b/services/api/tests/utils/test_utils.py new file mode 100644 index 0000000..fdf37b5 --- /dev/null +++ b/services/api/tests/utils/test_utils.py @@ -0,0 +1,166 @@ +import numpy as np +import pytest + +from owl.utils import mask_content, mask_dict, merge_dict, validate_where_expr + + +def test_mask_content(): + # mask_content(x: str | list | dict | np.ndarray | Any) -> str | list | dict | None + x = "str" + assert mask_content(x) == "*** (str_len=3)" + x = "long-string" + assert mask_content(x) == "lo***ng (str_len=11)" + x = 0 + assert mask_content(x) is None + x = False + assert mask_content(x) is None + x = np.ones(3) + assert mask_content(x) == "array(shape=(3,), dtype=float64)" + x = ["long-string", np.ones(3), 0] + assert mask_content(x) == ["lo***ng (str_len=11)", "array(shape=(3,), dtype=float64)", None] + x = dict(x=["long-string", np.ones(3), 0], y=0, z=dict(a="str")) + assert mask_content(x) == dict( + x=["lo***ng (str_len=11)", "array(shape=(3,), dtype=float64)", None], + y=None, + z=dict(a="*** (str_len=3)"), + ) + + +def test_mask_dict(): + x = dict(a=0, b=1, c="", d="d") + assert mask_dict(x) == dict(a=0, b="***", c="", d="***") + + +def test_merge_dict(): + x = dict(a=1, b=dict(p=2, q=3)) + y = dict(b=dict(p=30)) + assert merge_dict(x, y) == dict(a=1, b=dict(p=30, q=3)) + + x = dict(a=1, b=dict(p=2, q=3)) + y = dict(b=dict(p=[])) + assert merge_dict(x, y) == dict(a=1, b=dict(p=[], q=3)) + + x = dict(a=1, b=[dict(p=2, q=3)]) + y = dict(b=[dict(p=30)]) + assert merge_dict(x, y) == dict(a=1, b=[dict(p=30)]) + + x = dict(a=1, b=dict(p=dict(r=3, t=None), q=3)) + y = dict(b=dict(p=30)) + assert merge_dict(x, y) == dict(a=1, b=dict(p=30, q=3)) + + x = dict(a=1, b=dict(p=dict(r=3, t=None), q=3)) + y = dict(b=dict(p=dict(t=True))) + assert merge_dict(x, y) == dict(a=1, b=dict(p=dict(r=3, t=True), q=3)) + + x = dict(a=1, b=dict(p=dict(r=3, t=None), q=3)) + y = dict(b=dict(p=dict(t={}))) + assert merge_dict(x, y) == dict(a=1, b=dict(p=dict(r=3, t={}), q=3)) + + x = dict(a=1, b=None) + y = dict(b=dict(p=3)) + assert merge_dict(x, y) == dict(a=1, b=dict(p=3)) + + x = dict(a=1, b=dict(p=2)) + y = dict(b=None) + assert merge_dict(x, y) == dict(a=1, b=None) + + x = dict(a=1, b=dict(p=2, q=3)) + y = dict(b=dict(p=30), c=True) + assert merge_dict(x, y) == dict(a=1, b=dict(p=30, q=3), c=True) + + x = dict(a=1, b=dict(p=2, q=3)) + y = dict(a="yes", b=dict(p=30), c=True) + assert merge_dict(x, y) == dict(a="yes", b=dict(p=30, q=3), c=True) + + +def test_validate_where_expr(): + # Basic cases + sql = validate_where_expr("WHERE a = 1") + assert sql == "a = 1" + sql = validate_where_expr("WHERE a =\n1") + assert sql == "a = 1" + sql = validate_where_expr("WHERE a = 'x'") + assert sql == "a = 'x'" + sql = validate_where_expr("WHERE (a = 'x')") + assert sql == "(a = 'x')" + sql = validate_where_expr("a = 1") + assert sql == "a = 1" + sql = validate_where_expr(""""a" = 'x'""") + assert sql == """"a" = 'x'""" + # Nested comparisons + sql = validate_where_expr( + """WHERE a = 1 OR ((b = NULL AND c = 9) OR ("b (1)" = TRUE) AND c = '9')""" + ) + assert sql == """a = 1 OR ((b = NULL AND c = 9) OR ("b (1)" = TRUE) AND c = '9')""" + # Comparison with a column + sql = validate_where_expr('WHERE (("ID" = 1 AND "Updated at" = 9) AND "Updated at" = "M")') + assert sql == '(("ID" = 1 AND "Updated at" = 9) AND "Updated at" = "M")' + # Wildcard + sql = validate_where_expr('"222 two three" ~* 3;') + assert sql == '"222 two three" ~* 3' + # Column name with parenthesis + sql = validate_where_expr('"text (en)" ~* 3;') + assert sql == '"text (en)" ~* 3' + sql = validate_where_expr(""""text (en)" ~* 'yes (no)';""") + assert sql == """"text (en)" ~* 'yes (no)'""" + + # ID mapping + sql = validate_where_expr("WHERE a = 'x'", id_map={"a": "b"}) + assert sql == """"b" = 'x'""" + sql = validate_where_expr(""""a" = 'x'""", id_map={"a": "b"}) + assert sql == """"b" = 'x'""" + sql = validate_where_expr( + """WHERE a = 1 OR ((b = NULL AND c = 9) OR ("b" = TRUE) AND c = '9')""", + id_map={"a": "b"}, + ) + assert sql == """"b" = 1 OR (("b" = NULL AND "c" = 9) OR ("b" = TRUE) AND "c" = '9')""" + sql = validate_where_expr( + 'WHERE (("ID" = 1 AND "Updated at" = 9) AND "Updated at" = "M")', + id_map={"ID": "a", "Updated at": "b", "M": "c"}, + ) + assert sql == '(("a" = 1 AND "b" = 9) AND "b" = "c")' + sql = validate_where_expr('"222 two three" ~* 3;', id_map={"222 two three": "a"}) + assert sql == '"a" ~* 3' + + # Illegal SQL + for stmt in [ + # Classic drop table + "DROP TABLE users; --", + # Update data for all users + "UPDATE users SET is_admin = 1", + # Insert a new admin user + "INSERT INTO users (username, is_admin) VALUES ('attacker', 1);", + # Comment + "email = 'a@a.com' --", + "email = 'a@a.com' /*", + # Shutdown the database (in some systems like SQL Server) + "SHUTDOWN", + # Attempt to alter a table + "name = 'x' OR 1 = (ALTER TABLE users ADD COLUMN hacked VARCHAR(100))", + "name = 'x' OR 1 = \n(ALTER TABLE users ADD COLUMN hacked VARCHAR(100))", + "name = 'x' OR 1 = \r(ALTER TABLE users ADD COLUMN hacked VARCHAR(100))", + # Keywords used directly + "id > 0 OR UPDATE users SET is_admin = 1", + "id = 1 OR MERGE INTO users", + # Attempt to drop a column + "ALTER TABLE users DROP COLUMN password_hash;", + # Truncate a table + "TRUNCATE TABLE logs;", + # Functions + "1=1 AND pg_sleep(10)", + "1=1 AND pg_sleep (10)", + "1=1 AND set_config(10)", + "1=1 AND BENCHMARK(50000000, ENCODE('key', 'val'))", + # Exec + "EXEC master.dbo.xp_cmdshell 'dir c:';", + # Using comments to break up keywords + "DR/**/OP TABLE users;", + # Using different character encodings or functions + "EXEC(CHAR(100) + CHAR(114) + CHAR(111) + CHAR(112) + ' TABLE users')", # SQL Server 'drop' + ]: + with pytest.raises(ValueError): + validate_where_expr(stmt) + with pytest.raises(ValueError): + validate_where_expr(f"{stmt}; id = 1") + sql = validate_where_expr(f"id = 1; {stmt}") + assert sql == "id = 1" diff --git a/services/app/.env.example b/services/app/.env.example old mode 100644 new mode 100755 index 0f2257e..8d84dec --- a/services/app/.env.example +++ b/services/app/.env.example @@ -1,27 +1,33 @@ # Sveltekit config BODY_SIZE_LIMIT="Infinity" -# Services URLs -JAMAI_URL="http://localhost:6969" +# Services +OWL_URL="http://localhost:6969" PUBLIC_JAMAI_URL="" +PUBLIC_ADMIN_ORGANIZATION_ID="0" -# Playwright test user -TEST_ACC_EMAIL="" -TEST_ACC_PW="" -TEST_ACC_USERID="" +# Auth config +AUTH_SECRET="changeme" # Set to false only if you have the secrets -PUBLIC_IS_LOCAL="true" PUBLIC_IS_SPA="false" # Generate as a single-page application, only works if running locally -BASE_URL="" -JAMAI_SERVICE_KEY="" +OWL_SERVICE_KEY="" +OWL_STRIPE_API_KEY="" +OWL_STRIPE_PUBLISHABLE_KEY_LIVE="" +OWL_STRIPE_PUBLISHABLE_KEY_TEST="" AUTH0_ISSUER_BASE_URL="" AUTH0_CLIENT_ID="" AUTH0_CLIENT_SECRET="" AUTH0_MGMTAPI_CLIENT_ID="" AUTH0_MGMTAPI_CLIENT_SECRET="" AUTH0_SECRET="" -PUBLIC_STRIPE_PUBLISHABLE_KEY="" -STRIPE_SECRET_KEY="" -STRIPE_WEBHOOK_SECRET="" -RESEND_API_KEY="" \ No newline at end of file +RESEND_API_KEY="" + +# Test config +CI_TEST_MODE="false" +TEST_USER_ID="" +TEST_USER_USERNAME="" +TEST_USER_PASSWORD="" +TEST_FREE_PLAN_ID="" +TEST_PRO_PLAN_ID="" +TEST_TEAM_PLAN_ID="" \ No newline at end of file diff --git a/services/app/.eslintignore b/services/app/.eslintignore old mode 100644 new mode 100755 diff --git a/services/app/.eslintrc.cjs b/services/app/.eslintrc.cjs old mode 100644 new mode 100755 diff --git a/services/app/.gitignore b/services/app/.gitignore old mode 100644 new mode 100755 index 9c46f28..4f34e08 --- a/services/app/.gitignore +++ b/services/app/.gitignore @@ -10,4 +10,5 @@ node_modules vite.config.js.timestamp-* vite.config.ts.timestamp-* *.db -playwright \ No newline at end of file +playwright +.vscode \ No newline at end of file diff --git a/services/app/.npmrc b/services/app/.npmrc old mode 100644 new mode 100755 diff --git a/services/app/.prettierignore b/services/app/.prettierignore old mode 100644 new mode 100755 diff --git a/services/app/.prettierrc b/services/app/.prettierrc old mode 100644 new mode 100755 index 9573023..664a09d --- a/services/app/.prettierrc +++ b/services/app/.prettierrc @@ -3,6 +3,11 @@ "singleQuote": true, "trailingComma": "none", "printWidth": 100, - "plugins": ["prettier-plugin-svelte"], + "tabWidth": 2, + "plugins": [ + "prettier-plugin-svelte", + "prettier-plugin-tailwindcss", + "prettier-plugin-organize-imports" + ], "overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }] } diff --git a/services/app/README.md b/services/app/README.md old mode 100644 new mode 100755 diff --git a/services/app/build.bat b/services/app/build.bat old mode 100644 new mode 100755 index 1f9db98..498eaac --- a/services/app/build.bat +++ b/services/app/build.bat @@ -24,7 +24,7 @@ powershell -Command "Get-Content .env | ForEach-Object { $line = $_ -replace '#. echo %BASE_URL% echo %PUBLIC_JAMAI_URL% -echo %JAMAI_URL% +echo %OWL_URL% echo %PUBLIC_IS_SPA% rem Set the flag variable diff --git a/services/app/components.json b/services/app/components.json old mode 100644 new mode 100755 index 865be21..9c615d6 --- a/services/app/components.json +++ b/services/app/components.json @@ -1,14 +1,16 @@ { - "$schema": "https://shadcn-svelte.com/schema.json", - "style": "default", - "tailwind": { - "config": "tailwind.config.js", - "css": "src/app.css", - "baseColor": "neutral" - }, - "aliases": { - "components": "$lib/components", - "utils": "$lib/utils" - }, - "typescript": true -} \ No newline at end of file + "$schema": "https://next.shadcn-svelte.com/schema.json", + "tailwind": { + "css": "src/app.css", + "baseColor": "slate" + }, + "aliases": { + "components": "$lib/components", + "utils": "$lib/utils", + "ui": "$lib/components/ui", + "hooks": "$lib/hooks", + "lib": "$lib" + }, + "typescript": true, + "registry": "https://tw3.shadcn-svelte.com/registry/default" +} diff --git a/services/app/electron/icons/icon.icns b/services/app/electron/icons/icon.icns old mode 100644 new mode 100755 diff --git a/services/app/electron/icons/icon.ico b/services/app/electron/icons/icon.ico old mode 100644 new mode 100755 diff --git a/services/app/electron/icons/icon.png b/services/app/electron/icons/icon.png old mode 100644 new mode 100755 diff --git a/services/app/electron/main.js b/services/app/electron/main.js old mode 100644 new mode 100755 diff --git a/services/app/forge.config.cjs b/services/app/forge.config.cjs old mode 100644 new mode 100755 diff --git a/services/app/messages/en.json b/services/app/messages/en.json new file mode 100644 index 0000000..0a59ad6 --- /dev/null +++ b/services/app/messages/en.json @@ -0,0 +1,89 @@ +{ + "$schema": "https://inlang.com/schema/inlang-message-format", + "sortable_field_name": "Name", + "sortable_field_created_at": "Date created", + "sortable_field_updated_at": "Date modified", + "left_dock": { + "home": "Home", + "project": "Project", + "organization": "Organization", + "docs": "Docs", + "logout": "Log Out", + "upgrade": "Ready for more?
Upgrade to our premium plan", + "upgrade_btn": "Upgrade Plan", + "show_hide_btn": "Show/hide side navigation bar", + "analytics": "Analytics" + }, + "breadcrumbs": { + "org_btn": "Switch organizations", + "org_placeholder": "Unknown", + "org_create_btn": "Create organization", + "org_default": "Default Organization", + "org_join_btn": "Join organization" + }, + "organization": { + "navigation": { + "general": "General", + "team": "Team", + "secrets": "Secrets", + "billing": "Billing", + "usage": "Usage", + "heading": "Organization" + }, + "team_page": { + "subheading": "Organization Members", + "invite_btn": "Invite people", + "idx": "No.", + "uid": "User ID", + "email": "Email", + "role": "Role", + "created": "Created at", + "useredit_heading": "Edit user role", + "useredit_field_role": "User role", + "useredit_field_role_title": "Select user role" + } + }, + "save": "Save", + "cancel": "Cancel", + "close": "Close", + "confirm_message": "Are you sure?", + "project": { + "heading": "Projects", + "search_placeholder": "Search Project", + "create_btn": "New Project", + "import_btn": "Import Project", + "subheading": "All Projects", + "settings_rename": "Rename project", + "settings_export": "Export project", + "settings_delete": "Delete project", + "updated_at": "Last updated", + "settings_btn": "Project settings", + "edit": { + "heading": "Edit project name", + "field_name": "Project name" + }, + "create": { + "heading": "New project", + "field_name": "Project name" + }, + "delete": { + "heading": "Delete project", + "text_content": "Do you really want to delete project `{project_name}` ? This process cannot be undone.", + "text_confirm": "Enter project {confirm_text} to confirm", + "field_confirm": "Project {confirm_text}" + } + }, + "sortable": { + "name": "Name", + "created_at": "Date created", + "updated_at": "Date modified", + "direction_asc": "Ascending", + "direction_desc": "Descending" + }, + "field_required": "Required", + "add": "Add", + "create": "Create", + "delete": "Delete", + "project_export_confirm": "Export project `{project_name}`?", + "project_export_fail": "Failed to export project" +} diff --git a/services/app/package-lock.json b/services/app/package-lock.json old mode 100644 new mode 100755 index 3db7d75..09262a5 --- a/services/app/package-lock.json +++ b/services/app/package-lock.json @@ -1,24 +1,27 @@ { "name": "jamaibase-app", - "version": "0.2.0", + "version": "0.5.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "jamaibase-app", - "version": "0.2.0", + "version": "0.5.0", "dependencies": { + "@auth/sveltekit": "^1.9.2", "@fontsource-variable/roboto-flex": "^5.0.15", "@formkit/auto-animate": "^0.8.1", - "@stripe/stripe-js": "^3.4.0", + "@monaco-editor/loader": "^1.5.0", + "@stripe/stripe-js": "^3.5.0", "@tailwindcss/container-queries": "^0.1.1", "auth0": "^4.4.0", "axios": "^1.6.8", - "bits-ui": "^0.20.1", "chart.js": "^4.4.3", "chartjs-adapter-moment": "^1.0.1", "clsx": "^2.1.0", "cors": "^2.8.5", + "csvtojson": "^2.0.10", + "date-fns": "^4.1.0", "dexie": "^4.0.10", "dotenv": "^16.4.5", "electron-serve": "^2.0.0", @@ -29,22 +32,30 @@ "lodash": "^4.17.21", "lucide-svelte": "^0.359.0", "minio": "^7.1.3", - "mode-watcher": "^0.3.0", + "minisearch": "^7.1.2", + "monaco-editor": "^0.52.2", "node-cache": "^5.1.2", "nprogress": "^0.2.0", "overlayscrollbars-svelte": "^0.5.1", "papaparse": "^5.4.1", + "pdfjs-dist": "^4.10.38", + "pdfobject": "^2.3.1", "pretty-bytes": "^6.1.1", + "prosemirror-commands": "^1.7.1", + "prosemirror-history": "^1.4.1", + "prosemirror-keymap": "^1.2.3", + "prosemirror-model": "^1.25.3", + "prosemirror-state": "^1.4.3", + "prosemirror-view": "^1.41.0", "showdown": "^2.1.0", "showdown-htmlescape": "^0.1.9", - "stripe": "^15.5.0", + "stripe": "^15.12.0", "svelte-persisted-store": "^0.9.1", - "svelte-sonner": "^0.3.24", "tailwind-merge": "^2.2.2", "tailwind-variants": "^0.2.1", "undici": "^6.19.4", "uuid": "^9.0.1", - "zod": "^3.22.4" + "zod": "^3.25.67" }, "devDependencies": { "@electron-forge/cli": "^7.4.0", @@ -53,10 +64,13 @@ "@electron-forge/maker-squirrel": "^7.4.0", "@electron-forge/maker-zip": "^7.4.0", "@faker-js/faker": "^8.4.1", + "@inlang/cli": "^3.0.0", + "@inlang/paraglide-js": "2.0.13", + "@lucide/svelte": "^0.482.0", "@playwright/test": "^1.28.1", "@sveltejs/adapter-node": "^5.0.1", "@sveltejs/adapter-static": "^3.0.2", - "@sveltejs/kit": "^2.5.27", + "@sveltejs/kit": "^2.15.0", "@sveltejs/vite-plugin-svelte": "^4.0.0", "@types/cors": "^2.8.17", "@types/eslint": "^8.56.0", @@ -64,28 +78,38 @@ "@types/lodash": "^4.17.0", "@types/nprogress": "^0.2.3", "@types/papaparse": "^5.3.14", + "@types/pdfobject": "^2.2.5", "@types/showdown": "^2.0.6", "@types/uuid": "^9.0.8", "@typescript-eslint/eslint-plugin": "^7.0.0", "@typescript-eslint/parser": "^7.0.0", "autoprefixer": "^10.4.18", + "bits-ui": "^1.8.0", "concurrently": "^8.2.2", "cross-env": "^7.0.3", "electron": "^31.0.1", "eslint": "^8.56.0", "eslint-config-prettier": "^9.1.0", "eslint-plugin-svelte": "^2.45.1", + "mode-watcher": "^1.0.7", + "paneforge": "^1.0.0-next.5", "postcss": "^8.4.37", "prettier": "^3.1.1", + "prettier-plugin-organize-imports": "^4.1.0", "prettier-plugin-svelte": "^3.2.6", + "prettier-plugin-tailwindcss": "^0.6.12", "run-script-os": "^1.1.6", "svelte": "^5.0.0", "svelte-check": "^4.0.0", + "svelte-sonner": "^0.3.28", + "sveltekit-superforms": "^2.27.0", "tailwindcss": "^3.4.1", "tailwindcss-animate": "^1.0.7", "tslib": "^2.4.1", "typescript": "^5.5.0", + "vaul-svelte": "^1.0.0-next.7", "vite": "^5.4.4", + "vite-plugin-devtools-json": "^0.4.1", "vitest": "^1.2.0" } }, @@ -121,13 +145,96 @@ "node": ">=6.0.0" } }, - "node_modules/@babel/runtime": { - "version": "7.24.1", - "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.24.1.tgz", - "integrity": "sha512-+BIznRzyqBf+2wCTxcKE3wDjfGeCoVE61KSHGpkzqrLi8qxqFwBeUFyId2cxkTmm55fzDGnm0+yCxaxygrLUnQ==", + "node_modules/@ark/schema": { + "version": "0.46.0", + "resolved": "https://registry.npmjs.org/@ark/schema/-/schema-0.46.0.tgz", + "integrity": "sha512-c2UQdKgP2eqqDArfBqQIJppxJHvNNXuQPeuSPlDML4rjw+f1cu0qAlzOG4b8ujgm9ctIDWwhpyw6gjG5ledIVQ==", + "dev": true, + "license": "MIT", + "optional": true, + "dependencies": { + "@ark/util": "0.46.0" + } + }, + "node_modules/@ark/util": { + "version": "0.46.0", + "resolved": "https://registry.npmjs.org/@ark/util/-/util-0.46.0.tgz", + "integrity": "sha512-JPy/NGWn/lvf1WmGCPw2VGpBg5utZraE84I7wli18EDF3p3zc/e9WolT35tINeZO3l7C77SjqRJeAUoT0CvMRg==", + "dev": true, + "license": "MIT", + "optional": true + }, + "node_modules/@auth/core": { + "version": "0.39.1", + "resolved": "https://registry.npmjs.org/@auth/core/-/core-0.39.1.tgz", + "integrity": "sha512-McD8slui0oOA1pjR5sPjLPl5Zm//nLP/8T3kr8hxIsvNLvsiudYvPHhDFPjh1KcZ2nFxCkZmP6bRxaaPd/AnLA==", + "license": "ISC", + "dependencies": { + "@panva/hkdf": "^1.2.1", + "jose": "^6.0.6", + "oauth4webapi": "^3.3.0", + "preact": "10.24.3", + "preact-render-to-string": "6.5.11" + }, + "peerDependencies": { + "@simplewebauthn/browser": "^9.0.1", + "@simplewebauthn/server": "^9.0.2", + "nodemailer": "^6.8.0" + }, + "peerDependenciesMeta": { + "@simplewebauthn/browser": { + "optional": true + }, + "@simplewebauthn/server": { + "optional": true + }, + "nodemailer": { + "optional": true + } + } + }, + "node_modules/@auth/core/node_modules/jose": { + "version": "6.0.11", + "resolved": "https://registry.npmjs.org/jose/-/jose-6.0.11.tgz", + "integrity": "sha512-QxG7EaliDARm1O1S8BGakqncGT9s25bKL1WSf6/oa17Tkqwi8D2ZNglqCF+DsYF88/rV66Q/Q2mFAy697E1DUg==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/panva" + } + }, + "node_modules/@auth/sveltekit": { + "version": "1.9.2", + "resolved": "https://registry.npmjs.org/@auth/sveltekit/-/sveltekit-1.9.2.tgz", + "integrity": "sha512-EqwLS6kFyhtDQ4HfLEeIROIK/rIRhdSosnofI6XT4woXKtctBt4UeKTcoBKumXwR7u04/lnB6rHoXuG1nnD52A==", + "license": "ISC", "dependencies": { - "regenerator-runtime": "^0.14.0" + "@auth/core": "0.39.1", + "set-cookie-parser": "^2.7.0" + }, + "peerDependencies": { + "@simplewebauthn/browser": "^9.0.1", + "@simplewebauthn/server": "^9.0.3", + "@sveltejs/kit": "^1.0.0 || ^2.0.0", + "nodemailer": "^6.6.5", + "svelte": "^3.54.0 || ^4.0.0 || ^5.0.0-0" }, + "peerDependenciesMeta": { + "@simplewebauthn/browser": { + "optional": true + }, + "@simplewebauthn/server": { + "optional": true + }, + "nodemailer": { + "optional": true + } + } + }, + "node_modules/@babel/runtime": { + "version": "7.27.6", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.6.tgz", + "integrity": "sha512-vbavdySgbTTrmFE+EsiqUTzlOr5bzlnJtUv9PynGCAKvfQqjIXbvFdumPM/GxMDfyuGMJaJAU6TO4zc1Jf1i8Q==", + "license": "MIT", "engines": { "node": ">=6.9.0" } @@ -1752,7 +1859,6 @@ "cpu": [ "ppc64" ], - "dev": true, "optional": true, "os": [ "aix" @@ -1768,7 +1874,6 @@ "cpu": [ "arm" ], - "dev": true, "optional": true, "os": [ "android" @@ -1784,7 +1889,6 @@ "cpu": [ "arm64" ], - "dev": true, "optional": true, "os": [ "android" @@ -1800,7 +1904,6 @@ "cpu": [ "x64" ], - "dev": true, "optional": true, "os": [ "android" @@ -1816,7 +1919,6 @@ "cpu": [ "arm64" ], - "dev": true, "optional": true, "os": [ "darwin" @@ -1832,7 +1934,6 @@ "cpu": [ "x64" ], - "dev": true, "optional": true, "os": [ "darwin" @@ -1848,7 +1949,6 @@ "cpu": [ "arm64" ], - "dev": true, "optional": true, "os": [ "freebsd" @@ -1864,7 +1964,6 @@ "cpu": [ "x64" ], - "dev": true, "optional": true, "os": [ "freebsd" @@ -1880,7 +1979,6 @@ "cpu": [ "arm" ], - "dev": true, "optional": true, "os": [ "linux" @@ -1896,7 +1994,6 @@ "cpu": [ "arm64" ], - "dev": true, "optional": true, "os": [ "linux" @@ -1912,7 +2009,6 @@ "cpu": [ "ia32" ], - "dev": true, "optional": true, "os": [ "linux" @@ -1928,7 +2024,6 @@ "cpu": [ "loong64" ], - "dev": true, "optional": true, "os": [ "linux" @@ -1944,7 +2039,6 @@ "cpu": [ "mips64el" ], - "dev": true, "optional": true, "os": [ "linux" @@ -1960,7 +2054,6 @@ "cpu": [ "ppc64" ], - "dev": true, "optional": true, "os": [ "linux" @@ -1976,7 +2069,6 @@ "cpu": [ "riscv64" ], - "dev": true, "optional": true, "os": [ "linux" @@ -1992,7 +2084,6 @@ "cpu": [ "s390x" ], - "dev": true, "optional": true, "os": [ "linux" @@ -2008,7 +2099,6 @@ "cpu": [ "x64" ], - "dev": true, "optional": true, "os": [ "linux" @@ -2024,7 +2114,6 @@ "cpu": [ "x64" ], - "dev": true, "optional": true, "os": [ "netbsd" @@ -2040,7 +2129,6 @@ "cpu": [ "x64" ], - "dev": true, "optional": true, "os": [ "openbsd" @@ -2056,7 +2144,6 @@ "cpu": [ "x64" ], - "dev": true, "optional": true, "os": [ "sunos" @@ -2072,7 +2159,6 @@ "cpu": [ "arm64" ], - "dev": true, "optional": true, "os": [ "win32" @@ -2088,7 +2174,6 @@ "cpu": [ "ia32" ], - "dev": true, "optional": true, "os": [ "win32" @@ -2104,7 +2189,6 @@ "cpu": [ "x64" ], - "dev": true, "optional": true, "os": [ "win32" @@ -2191,6 +2275,14 @@ "node": "^12.22.0 || ^14.17.0 || >=16.0.0" } }, + "node_modules/@exodus/schemasafe": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/@exodus/schemasafe/-/schemasafe-1.3.0.tgz", + "integrity": "sha512-5Aap/GaRupgNx/feGBwLLTVv8OQFfv3pq2lPRzPg9R+IOBnDgghTGW7l7EuVXOvg5cc/xSAlRW8rBrjIC3Nvqw==", + "dev": true, + "license": "MIT", + "optional": true + }, "node_modules/@faker-js/faker": { "version": "8.4.1", "resolved": "https://registry.npmjs.org/@faker-js/faker/-/faker-8.4.1.tgz", @@ -2208,26 +2300,32 @@ } }, "node_modules/@floating-ui/core": { - "version": "1.6.8", - "resolved": "https://registry.npmjs.org/@floating-ui/core/-/core-1.6.8.tgz", - "integrity": "sha512-7XJ9cPU+yI2QeLS+FCSlqNFZJq8arvswefkZrYI1yQBbftw6FyrZOxYSh+9S7z7TpeWlRt9zJ5IhM1WIL334jA==", + "version": "1.7.1", + "resolved": "https://registry.npmjs.org/@floating-ui/core/-/core-1.7.1.tgz", + "integrity": "sha512-azI0DrjMMfIug/ExbBaeDVJXcY0a7EPvPjb2xAJPa4HeimBX+Z18HK8QQR3jb6356SnDDdxx+hinMLcJEDdOjw==", + "dev": true, + "license": "MIT", "dependencies": { - "@floating-ui/utils": "^0.2.8" + "@floating-ui/utils": "^0.2.9" } }, "node_modules/@floating-ui/dom": { - "version": "1.6.12", - "resolved": "https://registry.npmjs.org/@floating-ui/dom/-/dom-1.6.12.tgz", - "integrity": "sha512-NP83c0HjokcGVEMeoStg317VD9W7eDlGK7457dMBANbKA6GJZdc7rjujdgqzTaz93jkGgc5P/jeWbaCHnMNc+w==", + "version": "1.7.1", + "resolved": "https://registry.npmjs.org/@floating-ui/dom/-/dom-1.7.1.tgz", + "integrity": "sha512-cwsmW/zyw5ltYTUeeYJ60CnQuPqmGwuGVhG9w0PRaRKkAyi38BT5CKrpIbb+jtahSwUl04cWzSx9ZOIxeS6RsQ==", + "dev": true, + "license": "MIT", "dependencies": { - "@floating-ui/core": "^1.6.0", - "@floating-ui/utils": "^0.2.8" + "@floating-ui/core": "^1.7.1", + "@floating-ui/utils": "^0.2.9" } }, "node_modules/@floating-ui/utils": { - "version": "0.2.8", - "resolved": "https://registry.npmjs.org/@floating-ui/utils/-/utils-0.2.8.tgz", - "integrity": "sha512-kym7SodPp8/wloecOpcmSnWJsK7M0E5Wg8UcFA+uO4B9s5d0ywXOEro/8HM9x0rW+TljRzul/14UYz3TleT3ig==" + "version": "0.2.9", + "resolved": "https://registry.npmjs.org/@floating-ui/utils/-/utils-0.2.9.tgz", + "integrity": "sha512-MDWhGtE+eHw5JW7lq4qhc5yRLS11ERl1c7Z6Xd0a58DozHES6EnNNwUWbMiG4J9Cgj053Bhk8zvlhFYKVhULwg==", + "dev": true, + "license": "MIT" }, "node_modules/@fontsource-variable/roboto-flex": { "version": "5.0.15", @@ -2246,15 +2344,50 @@ "dev": true, "license": "MIT" }, + "node_modules/@gcornut/valibot-json-schema": { + "version": "0.42.0", + "resolved": "https://registry.npmjs.org/@gcornut/valibot-json-schema/-/valibot-json-schema-0.42.0.tgz", + "integrity": "sha512-4Et4AN6wmqeA0PfU5Clkv/IS27wiefsWf6TemAZrb75uzkClYEFavim7SboeKwbll9Nbsn2Iv0LT/HS5H7orZg==", + "dev": true, + "optional": true, + "dependencies": { + "valibot": "~0.42.0" + }, + "bin": { + "valibot-json-schema": "bin/index.js" + }, + "optionalDependencies": { + "@types/json-schema": ">= 7.0.14", + "esbuild-runner": ">= 2.2.2" + } + }, + "node_modules/@gcornut/valibot-json-schema/node_modules/valibot": { + "version": "0.42.1", + "resolved": "https://registry.npmjs.org/valibot/-/valibot-0.42.1.tgz", + "integrity": "sha512-3keXV29Ar5b//Hqi4MbSdV7lfVp6zuYLZuA9V1PvQUsXqogr+u5lvLPLk3A4f74VUXDnf/JfWMN6sB+koJ/FFw==", + "dev": true, + "license": "MIT", + "optional": true, + "peerDependencies": { + "typescript": ">=5" + }, + "peerDependenciesMeta": { + "typescript": { + "optional": true + } + } + }, "node_modules/@hapi/hoek": { "version": "9.3.0", "resolved": "https://registry.npmjs.org/@hapi/hoek/-/hoek-9.3.0.tgz", - "integrity": "sha512-/c6rf4UJlmHlC9b5BaNvzAcFv7HZ2QHaV0D4/HNlBdvFnvQq8RI4kYdhyPCl7Xj+oWvTWQ8ujhqS53LIgAe6KQ==" + "integrity": "sha512-/c6rf4UJlmHlC9b5BaNvzAcFv7HZ2QHaV0D4/HNlBdvFnvQq8RI4kYdhyPCl7Xj+oWvTWQ8ujhqS53LIgAe6KQ==", + "license": "BSD-3-Clause" }, "node_modules/@hapi/topo": { "version": "5.1.0", "resolved": "https://registry.npmjs.org/@hapi/topo/-/topo-5.1.0.tgz", "integrity": "sha512-foQZKJig7Ob0BMAYBfcJk8d77QtOe7Wo4ox7ff1lQYoNNAb6jwcY1ncdoy2e9wQZzvNy7ODZCYJkK8kzmcAnAg==", + "license": "BSD-3-Clause", "dependencies": { "@hapi/hoek": "^9.0.0" } @@ -2314,10 +2447,106 @@ "integrity": "sha512-6EwiSjwWYP7pTckG6I5eyFANjPhmPjUX9JRLUSfNPC7FX7zK9gyZAfUEaECL6ALTpGX5AjnBq3C9XmVWPitNpw==", "dev": true }, + "node_modules/@inlang/cli": { + "version": "3.0.11", + "resolved": "https://registry.npmjs.org/@inlang/cli/-/cli-3.0.11.tgz", + "integrity": "sha512-JGyDrB7Jy0GRT6Z3QdenoJdxq+2Hob4pm4+wjrUa/bhXCTWAG+vbL+irP6OOS4EO+X8upn94NC39hbC0+72cHg==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@inlang/sdk": "2.4.8", + "esbuild-wasm": "^0.19.2" + }, + "bin": { + "inlang": "bin/run.js" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@inlang/paraglide-js": { + "version": "2.0.13", + "resolved": "https://registry.npmjs.org/@inlang/paraglide-js/-/paraglide-js-2.0.13.tgz", + "integrity": "sha512-8tccsLzGa9uw0rufFqbHSM6GDF8+X1BgfBOyjG7PweBF2zGhN5fMu/nVNbsZiVKpXyR7lcfMxajIBwKhZ/zGKw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@inlang/recommend-sherlock": "0.2.1", + "@inlang/sdk": "2.4.8", + "commander": "11.1.0", + "consola": "3.4.0", + "json5": "2.2.3", + "unplugin": "^2.1.2", + "urlpattern-polyfill": "^10.0.0" + }, + "bin": { + "paraglide-js": "bin/run.js" + } + }, + "node_modules/@inlang/paraglide-js/node_modules/commander": { + "version": "11.1.0", + "resolved": "https://registry.npmjs.org/commander/-/commander-11.1.0.tgz", + "integrity": "sha512-yPVavfyCcRhmorC7rWlkHn15b4wDVgVmBA7kV4QVBsF7kv/9TKJAbAXVTxvTnwP8HHKjRCJDClKbciiYS7p0DQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=16" + } + }, + "node_modules/@inlang/recommend-sherlock": { + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/@inlang/recommend-sherlock/-/recommend-sherlock-0.2.1.tgz", + "integrity": "sha512-ckv8HvHy/iTqaVAEKrr+gnl+p3XFNwe5D2+6w6wJk2ORV2XkcRkKOJ/XsTUJbPSiyi4PI+p+T3bqbmNx/rDUlg==", + "dev": true, + "license": "MIT", + "dependencies": { + "comment-json": "^4.2.3" + } + }, + "node_modules/@inlang/sdk": { + "version": "2.4.8", + "resolved": "https://registry.npmjs.org/@inlang/sdk/-/sdk-2.4.8.tgz", + "integrity": "sha512-tyXNe/5+1Vn/eDt3mVklVjZh5qxFwqdF9+hdB6wRUCexVRw6w/w854TIRFrHuaAwFq/0N/ij/yXzll9oScAB+Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "@lix-js/sdk": "0.4.7", + "@sinclair/typebox": "^0.31.17", + "kysely": "^0.27.4", + "sqlite-wasm-kysely": "0.3.0", + "uuid": "^10.0.0" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@inlang/sdk/node_modules/@sinclair/typebox": { + "version": "0.31.28", + "resolved": "https://registry.npmjs.org/@sinclair/typebox/-/typebox-0.31.28.tgz", + "integrity": "sha512-/s55Jujywdw/Jpan+vsy6JZs1z2ZTGxTmbZTPiuSL2wz9mfzA2gN1zzaqmvfi4pq+uOt7Du85fkiwv5ymW84aQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@inlang/sdk/node_modules/uuid": { + "version": "10.0.0", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-10.0.0.tgz", + "integrity": "sha512-8XkAphELsDnEGrDxUOHB3RGvXz6TeuYSGEZBOjtTtPm2lwhGBjLgOzLHB63IUWfBpNucQjND6d3AOudO+H3RWQ==", + "dev": true, + "funding": [ + "https://github.com/sponsors/broofa", + "https://github.com/sponsors/ctavan" + ], + "license": "MIT", + "bin": { + "uuid": "dist/bin/uuid" + } + }, "node_modules/@internationalized/date": { - "version": "3.5.4", - "resolved": "https://registry.npmjs.org/@internationalized/date/-/date-3.5.4.tgz", - "integrity": "sha512-qoVJVro+O0rBaw+8HPjUB1iH8Ihf8oziEnqMnvhJUSuVIrHOuZ6eNLHNvzXJKUvAtaDiqMnRlg8Z2mgh09BlUw==", + "version": "3.8.1", + "resolved": "https://registry.npmjs.org/@internationalized/date/-/date-3.8.1.tgz", + "integrity": "sha512-PgVE6B6eIZtzf9Gu5HvJxRK3ufUFz9DhspELuhW/N0GuMGMTLvPQNRkHP2hTuP9lblOk+f+1xi96sPiPXANXAA==", + "dev": true, + "license": "Apache-2.0", "dependencies": { "@swc/helpers": "^0.5.0" } @@ -2368,6 +2597,7 @@ "resolved": "https://registry.npmjs.org/@jest/schemas/-/schemas-29.6.3.tgz", "integrity": "sha512-mo5j5X+jIZmJQveBKeS/clAueipV7KgiX1vMgCxam1RNYiqE1w62n0/tJJnHtjW8ZHcQco5gY85jA3mi0L+nSA==", "dev": true, + "license": "MIT", "dependencies": { "@sinclair/typebox": "^0.27.8" }, @@ -2423,6 +2653,56 @@ "resolved": "https://registry.npmjs.org/@kurkle/color/-/color-0.3.2.tgz", "integrity": "sha512-fuscdXJ9G1qb7W8VdHi+IwRqij3lBkosAm4ydQtEmbY58OzHXqQhvlxqEkoz0yssNVn38bcpRWgA9PP+OGoisw==" }, + "node_modules/@lix-js/sdk": { + "version": "0.4.7", + "resolved": "https://registry.npmjs.org/@lix-js/sdk/-/sdk-0.4.7.tgz", + "integrity": "sha512-pRbW+joG12L0ULfMiWYosIW0plmW4AsUdiPCp+Z8rAsElJ+wJ6in58zhD3UwUcd4BNcpldEGjg6PdA7e0RgsDQ==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@lix-js/server-protocol-schema": "0.1.1", + "dedent": "1.5.1", + "human-id": "^4.1.1", + "js-sha256": "^0.11.0", + "kysely": "^0.27.4", + "sqlite-wasm-kysely": "0.3.0", + "uuid": "^10.0.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@lix-js/sdk/node_modules/uuid": { + "version": "10.0.0", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-10.0.0.tgz", + "integrity": "sha512-8XkAphELsDnEGrDxUOHB3RGvXz6TeuYSGEZBOjtTtPm2lwhGBjLgOzLHB63IUWfBpNucQjND6d3AOudO+H3RWQ==", + "dev": true, + "funding": [ + "https://github.com/sponsors/broofa", + "https://github.com/sponsors/ctavan" + ], + "license": "MIT", + "bin": { + "uuid": "dist/bin/uuid" + } + }, + "node_modules/@lix-js/server-protocol-schema": { + "version": "0.1.1", + "resolved": "https://registry.npmjs.org/@lix-js/server-protocol-schema/-/server-protocol-schema-0.1.1.tgz", + "integrity": "sha512-jBeALB6prAbtr5q4vTuxnRZZv1M2rKe8iNqRQhFJ4Tv7150unEa0vKyz0hs8Gl3fUGsWaNJBh3J8++fpbrpRBQ==", + "dev": true, + "license": "Apache-2.0" + }, + "node_modules/@lucide/svelte": { + "version": "0.482.0", + "resolved": "https://registry.npmjs.org/@lucide/svelte/-/svelte-0.482.0.tgz", + "integrity": "sha512-n2ycHU9cNcleRDwwpEHBJ6pYzVhHIaL3a+9dQa8kns9hB2g05bY+v2p2KP8v0pZwtNhYTHk/F2o2uZ1bVtQGhw==", + "dev": true, + "license": "ISC", + "peerDependencies": { + "svelte": "^5" + } + }, "node_modules/@malept/cross-spawn-promise": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/@malept/cross-spawn-promise/-/cross-spawn-promise-1.1.1.tgz", @@ -2447,65 +2727,259 @@ "node": ">= 10" } }, - "node_modules/@nodelib/fs.scandir": { - "version": "2.1.5", - "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", - "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", + "node_modules/@monaco-editor/loader": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/@monaco-editor/loader/-/loader-1.5.0.tgz", + "integrity": "sha512-hKoGSM+7aAc7eRTRjpqAZucPmoNOC4UUbknb/VNoTkEIkCPhqV8LfbsgM1webRM7S/z21eHEx9Fkwx8Z/C/+Xw==", + "license": "MIT", "dependencies": { - "@nodelib/fs.stat": "2.0.5", - "run-parallel": "^1.1.9" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/@nodelib/fs.stat": { - "version": "2.0.5", - "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", - "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", - "engines": { - "node": ">= 8" + "state-local": "^1.0.6" } }, - "node_modules/@nodelib/fs.walk": { - "version": "1.2.8", - "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", - "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", - "dependencies": { - "@nodelib/fs.scandir": "2.1.5", - "fastq": "^1.6.0" - }, + "node_modules/@napi-rs/canvas": { + "version": "0.1.73", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas/-/canvas-0.1.73.tgz", + "integrity": "sha512-9iwPZrNlCK4rG+vWyDvyvGeYjck9MoP0NVQP6N60gqJNFA1GsN0imG05pzNsqfCvFxUxgiTYlR8ff0HC1HXJiw==", + "license": "MIT", + "optional": true, + "workspaces": [ + "e2e/*" + ], "engines": { - "node": ">= 8" - } - }, - "node_modules/@npmcli/fs": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/@npmcli/fs/-/fs-2.1.2.tgz", - "integrity": "sha512-yOJKRvohFOaLqipNtwYB9WugyZKhC/DZC4VYPmpaCzDBrA8YpK3qHZ8/HGscMnE4GqbkLNuVcCnxkeQEdGt6LQ==", - "dev": true, - "license": "ISC", - "dependencies": { - "@gar/promisify": "^1.1.3", - "semver": "^7.3.5" + "node": ">= 10" }, + "optionalDependencies": { + "@napi-rs/canvas-android-arm64": "0.1.73", + "@napi-rs/canvas-darwin-arm64": "0.1.73", + "@napi-rs/canvas-darwin-x64": "0.1.73", + "@napi-rs/canvas-linux-arm-gnueabihf": "0.1.73", + "@napi-rs/canvas-linux-arm64-gnu": "0.1.73", + "@napi-rs/canvas-linux-arm64-musl": "0.1.73", + "@napi-rs/canvas-linux-riscv64-gnu": "0.1.73", + "@napi-rs/canvas-linux-x64-gnu": "0.1.73", + "@napi-rs/canvas-linux-x64-musl": "0.1.73", + "@napi-rs/canvas-win32-x64-msvc": "0.1.73" + } + }, + "node_modules/@napi-rs/canvas-android-arm64": { + "version": "0.1.73", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-android-arm64/-/canvas-android-arm64-0.1.73.tgz", + "integrity": "sha512-s8dMhfYIHVv7gz8BXg3Nb6cFi950Y0xH5R/sotNZzUVvU9EVqHfkqiGJ4UIqu+15UhqguT6mI3Bv1mhpRkmMQw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], "engines": { - "node": "^12.13.0 || ^14.15.0 || >=16.0.0" + "node": ">= 10" } }, - "node_modules/@npmcli/move-file": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/@npmcli/move-file/-/move-file-2.0.1.tgz", - "integrity": "sha512-mJd2Z5TjYWq/ttPLLGqArdtnC74J6bOzg4rMDnN+p1xTacZ2yPRCk2y0oSWQtygLR9YVQXgOcONrwtnk3JupxQ==", - "deprecated": "This functionality has been moved to @npmcli/fs", - "dev": true, + "node_modules/@napi-rs/canvas-darwin-arm64": { + "version": "0.1.73", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-arm64/-/canvas-darwin-arm64-0.1.73.tgz", + "integrity": "sha512-bLPCq8Yyq1vMdVdIpQAqmgf6VGUknk8e7NdSZXJJFOA9gxkJ1RGcHOwoXo7h0gzhHxSorg71hIxyxtwXpq10Rw==", + "cpu": [ + "arm64" + ], "license": "MIT", - "dependencies": { - "mkdirp": "^1.0.4", - "rimraf": "^3.0.2" - }, + "optional": true, + "os": [ + "darwin" + ], "engines": { - "node": "^12.13.0 || ^14.15.0 || >=16.0.0" + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-darwin-x64": { + "version": "0.1.73", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-x64/-/canvas-darwin-x64-0.1.73.tgz", + "integrity": "sha512-GR1CcehDjdNYXN3bj8PIXcXfYLUUOQANjQpM+KNnmpRo7ojsuqPjT7ZVH+6zoG/aqRJWhiSo+ChQMRazZlRU9g==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-arm-gnueabihf": { + "version": "0.1.73", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm-gnueabihf/-/canvas-linux-arm-gnueabihf-0.1.73.tgz", + "integrity": "sha512-cM7F0kBJVFio0+U2iKSW4fWSfYQ8CPg4/DRZodSum/GcIyfB8+UPJSRM1BvvlcWinKLfX1zUYOwonZX9IFRRcw==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-arm64-gnu": { + "version": "0.1.73", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-gnu/-/canvas-linux-arm64-gnu-0.1.73.tgz", + "integrity": "sha512-PMWNrMON9uz9klz1B8ZY/RXepQSC5dxxHQTowfw93Tb3fLtWO5oNX2k9utw7OM4ypT9BUZUWJnDQ5bfuXc/EUQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-arm64-musl": { + "version": "0.1.73", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-musl/-/canvas-linux-arm64-musl-0.1.73.tgz", + "integrity": "sha512-lX0z2bNmnk1PGZ+0a9OZwI2lPPvWjRYzPqvEitXX7lspyLFrOzh2kcQiLL7bhyODN23QvfriqwYqp5GreSzVvA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-riscv64-gnu": { + "version": "0.1.73", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-riscv64-gnu/-/canvas-linux-riscv64-gnu-0.1.73.tgz", + "integrity": "sha512-QDQgMElwxAoADsSR3UYvdTTQk5XOyD9J5kq15Z8XpGwpZOZsSE0zZ/X1JaOtS2x+HEZL6z1S6MF/1uhZFZb5ig==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-x64-gnu": { + "version": "0.1.73", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-gnu/-/canvas-linux-x64-gnu-0.1.73.tgz", + "integrity": "sha512-wbzLJrTalQrpyrU1YRrO6w6pdr5vcebbJa+Aut5QfTaW9eEmMb1WFG6l1V+cCa5LdHmRr8bsvl0nJDU/IYDsmw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-x64-musl": { + "version": "0.1.73", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-musl/-/canvas-linux-x64-musl-0.1.73.tgz", + "integrity": "sha512-xbfhYrUufoTAKvsEx2ZUN4jvACabIF0h1F5Ik1Rk4e/kQq6c+Dwa5QF0bGrfLhceLpzHT0pCMGMDeQKQrcUIyA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-win32-x64-msvc": { + "version": "0.1.73", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-x64-msvc/-/canvas-win32-x64-msvc-0.1.73.tgz", + "integrity": "sha512-YQmHXBufFBdWqhx+ympeTPkMfs3RNxaOgWm59vyjpsub7Us07BwCcmu1N5kildhO8Fm0syoI2kHnzGkJBLSvsg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@nodelib/fs.scandir": { + "version": "2.1.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", + "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", + "dependencies": { + "@nodelib/fs.stat": "2.0.5", + "run-parallel": "^1.1.9" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.stat": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", + "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.walk": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", + "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", + "dependencies": { + "@nodelib/fs.scandir": "2.1.5", + "fastq": "^1.6.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@npmcli/fs": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/@npmcli/fs/-/fs-2.1.2.tgz", + "integrity": "sha512-yOJKRvohFOaLqipNtwYB9WugyZKhC/DZC4VYPmpaCzDBrA8YpK3qHZ8/HGscMnE4GqbkLNuVcCnxkeQEdGt6LQ==", + "dev": true, + "license": "ISC", + "dependencies": { + "@gar/promisify": "^1.1.3", + "semver": "^7.3.5" + }, + "engines": { + "node": "^12.13.0 || ^14.15.0 || >=16.0.0" + } + }, + "node_modules/@npmcli/move-file": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/@npmcli/move-file/-/move-file-2.0.1.tgz", + "integrity": "sha512-mJd2Z5TjYWq/ttPLLGqArdtnC74J6bOzg4rMDnN+p1xTacZ2yPRCk2y0oSWQtygLR9YVQXgOcONrwtnk3JupxQ==", + "deprecated": "This functionality has been moved to @npmcli/fs", + "dev": true, + "license": "MIT", + "dependencies": { + "mkdirp": "^1.0.4", + "rimraf": "^3.0.2" + }, + "engines": { + "node": "^12.13.0 || ^14.15.0 || >=16.0.0" } }, "node_modules/@npmcli/move-file/node_modules/mkdirp": { @@ -2525,10 +2999,20 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/@panva/asn1.js/-/asn1.js-1.0.0.tgz", "integrity": "sha512-UdkG3mLEqXgnlKsWanWcgb6dOjUzJ+XC5f+aWw30qrtjxeNUSfKX1cd5FBzOaXQumoe9nIqeZUvrRJS03HCCtw==", + "license": "MIT", "engines": { "node": ">=10.13.0" } }, + "node_modules/@panva/hkdf": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/@panva/hkdf/-/hkdf-1.2.1.tgz", + "integrity": "sha512-6oclG6Y3PiDFcoyk8srjLfVKyMfVCKJ27JwNPViuXziFpmdz+MZnZN/aKY0JGXgYuO/VghU0jcOAZgWXZ1Dmrw==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/panva" + } + }, "node_modules/@pkgjs/parseargs": { "version": "0.11.0", "resolved": "https://registry.npmjs.org/@pkgjs/parseargs/-/parseargs-0.11.0.tgz", @@ -2554,24 +3038,36 @@ } }, "node_modules/@polka/url": { - "version": "1.0.0-next.25", - "resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.25.tgz", - "integrity": "sha512-j7P6Rgr3mmtdkeDGTe0E/aYyWEWVtc5yFXtHCRHs28/jptDEWfaVOc5T7cblqy1XKPPfCxJc/8DwQ5YgLOZOVQ==", - "dev": true + "version": "1.0.0-next.29", + "resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.29.tgz", + "integrity": "sha512-wwQAWhWSuHaag8c4q/KN/vCoeOJYshAIvMQwD4GpSb3OiZklFfvAgmj0VCBBImRpuF/aFgIRzllXlVX93Jevww==", + "license": "MIT" + }, + "node_modules/@poppinss/macroable": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/@poppinss/macroable/-/macroable-1.0.4.tgz", + "integrity": "sha512-ct43jurbe7lsUX5eIrj4ijO3j/6zIPp7CDnFWXDs7UPAbw1Pu1iH3oAmFdP4jcskKJBURH5M9oTtyeiUXyHX8Q==", + "dev": true, + "license": "MIT", + "optional": true, + "engines": { + "node": ">=18.16.0" + } }, "node_modules/@rollup/plugin-commonjs": { - "version": "28.0.0", - "resolved": "https://registry.npmjs.org/@rollup/plugin-commonjs/-/plugin-commonjs-28.0.0.tgz", - "integrity": "sha512-BJcu+a+Mpq476DMXG+hevgPSl56bkUoi88dKT8t3RyUp8kGuOh+2bU8Gs7zXDlu+fyZggnJ+iOBGrb/O1SorYg==", + "version": "28.0.3", + "resolved": "https://registry.npmjs.org/@rollup/plugin-commonjs/-/plugin-commonjs-28.0.3.tgz", + "integrity": "sha512-pyltgilam1QPdn+Zd9gaCfOLcnjMEJ9gV+bTw6/r73INdvzf1ah9zLIJBm+kW7R6IUFIQ1YO+VqZtYxZNWFPEQ==", "dev": true, + "license": "MIT", "dependencies": { "@rollup/pluginutils": "^5.0.1", "commondir": "^1.0.1", "estree-walker": "^2.0.2", - "fdir": "^6.1.1", + "fdir": "^6.2.0", "is-reference": "1.2.1", "magic-string": "^0.30.3", - "picomatch": "^2.3.1" + "picomatch": "^4.0.2" }, "engines": { "node": ">=16.0.0 || 14 >= 14.17" @@ -2585,33 +3081,16 @@ } } }, - "node_modules/@rollup/plugin-commonjs/node_modules/estree-walker": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-2.0.2.tgz", - "integrity": "sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==", - "dev": true - }, "node_modules/@rollup/plugin-commonjs/node_modules/is-reference": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/is-reference/-/is-reference-1.2.1.tgz", "integrity": "sha512-U82MsXXiFIrjCK4otLT+o2NA2Cd2g5MLoOVXUZjIOhLurrRxpEXzI8O0KZHr3IjLvlAH1kTPYSuqer5T9ZVBKQ==", "dev": true, + "license": "MIT", "dependencies": { "@types/estree": "*" } }, - "node_modules/@rollup/plugin-commonjs/node_modules/picomatch": { - "version": "2.3.1", - "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", - "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", - "dev": true, - "engines": { - "node": ">=8.6" - }, - "funding": { - "url": "https://github.com/sponsors/jonschlinkert" - } - }, "node_modules/@rollup/plugin-json": { "version": "6.1.0", "resolved": "https://registry.npmjs.org/@rollup/plugin-json/-/plugin-json-6.1.0.tgz", @@ -2633,10 +3112,11 @@ } }, "node_modules/@rollup/plugin-node-resolve": { - "version": "15.3.0", - "resolved": "https://registry.npmjs.org/@rollup/plugin-node-resolve/-/plugin-node-resolve-15.3.0.tgz", - "integrity": "sha512-9eO5McEICxMzJpDW9OnMYSv4Sta3hmt7VtBFz5zR9273suNOydOyq/FrGeGy+KsTRFm8w0SLVhzig2ILFT63Ag==", + "version": "16.0.1", + "resolved": "https://registry.npmjs.org/@rollup/plugin-node-resolve/-/plugin-node-resolve-16.0.1.tgz", + "integrity": "sha512-tk5YCxJWIG81umIvNkSod2qK5KyQW19qcBF/B78n1bjtOON6gzKoVeSzAE8yHCZEDmqkHKkxplExA8KzdJLJpA==", "dev": true, + "license": "MIT", "dependencies": { "@rollup/pluginutils": "^5.0.1", "@types/resolve": "1.20.2", @@ -2678,12 +3158,6 @@ } } }, - "node_modules/@rollup/pluginutils/node_modules/estree-walker": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-2.0.2.tgz", - "integrity": "sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==", - "dev": true - }, "node_modules/@rollup/pluginutils/node_modules/picomatch": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", @@ -2703,7 +3177,6 @@ "cpu": [ "arm" ], - "dev": true, "optional": true, "os": [ "android" @@ -2716,7 +3189,6 @@ "cpu": [ "arm64" ], - "dev": true, "optional": true, "os": [ "android" @@ -2729,7 +3201,6 @@ "cpu": [ "arm64" ], - "dev": true, "optional": true, "os": [ "darwin" @@ -2742,7 +3213,6 @@ "cpu": [ "x64" ], - "dev": true, "optional": true, "os": [ "darwin" @@ -2755,7 +3225,6 @@ "cpu": [ "arm" ], - "dev": true, "optional": true, "os": [ "linux" @@ -2768,7 +3237,6 @@ "cpu": [ "arm" ], - "dev": true, "optional": true, "os": [ "linux" @@ -2781,7 +3249,6 @@ "cpu": [ "arm64" ], - "dev": true, "optional": true, "os": [ "linux" @@ -2794,7 +3261,6 @@ "cpu": [ "arm64" ], - "dev": true, "optional": true, "os": [ "linux" @@ -2807,7 +3273,6 @@ "cpu": [ "ppc64" ], - "dev": true, "optional": true, "os": [ "linux" @@ -2820,7 +3285,6 @@ "cpu": [ "riscv64" ], - "dev": true, "optional": true, "os": [ "linux" @@ -2833,7 +3297,6 @@ "cpu": [ "s390x" ], - "dev": true, "optional": true, "os": [ "linux" @@ -2846,7 +3309,6 @@ "cpu": [ "x64" ], - "dev": true, "optional": true, "os": [ "linux" @@ -2859,7 +3321,6 @@ "cpu": [ "x64" ], - "dev": true, "optional": true, "os": [ "linux" @@ -2872,7 +3333,6 @@ "cpu": [ "arm64" ], - "dev": true, "optional": true, "os": [ "win32" @@ -2885,7 +3345,6 @@ "cpu": [ "ia32" ], - "dev": true, "optional": true, "os": [ "win32" @@ -2898,7 +3357,6 @@ "cpu": [ "x64" ], - "dev": true, "optional": true, "os": [ "win32" @@ -2908,6 +3366,7 @@ "version": "4.1.5", "resolved": "https://registry.npmjs.org/@sideway/address/-/address-4.1.5.tgz", "integrity": "sha512-IqO/DUQHUkPeixNQ8n0JA6102hT9CmaljNTPmQ1u8MEhBo/R4Q8eKLN/vGZxuebwOroDB4cbpjheD4+/sKFK4Q==", + "license": "BSD-3-Clause", "dependencies": { "@hapi/hoek": "^9.0.0" } @@ -2915,18 +3374,21 @@ "node_modules/@sideway/formula": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/@sideway/formula/-/formula-3.0.1.tgz", - "integrity": "sha512-/poHZJJVjx3L+zVD6g9KgHfYnb443oi7wLu/XKojDviHy6HOEOA6z1Trk5aR1dGcmPenJEgb2sK2I80LeS3MIg==" + "integrity": "sha512-/poHZJJVjx3L+zVD6g9KgHfYnb443oi7wLu/XKojDviHy6HOEOA6z1Trk5aR1dGcmPenJEgb2sK2I80LeS3MIg==", + "license": "BSD-3-Clause" }, "node_modules/@sideway/pinpoint": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/@sideway/pinpoint/-/pinpoint-2.0.0.tgz", - "integrity": "sha512-RNiOoTPkptFtSVzQevY/yWtZwf/RxyVnPy/OcA9HBM3MlGDnBEYL5B41H0MTn0Uec8Hi+2qUtTfG2WWZBmMejQ==" + "integrity": "sha512-RNiOoTPkptFtSVzQevY/yWtZwf/RxyVnPy/OcA9HBM3MlGDnBEYL5B41H0MTn0Uec8Hi+2qUtTfG2WWZBmMejQ==", + "license": "BSD-3-Clause" }, "node_modules/@sinclair/typebox": { "version": "0.27.8", "resolved": "https://registry.npmjs.org/@sinclair/typebox/-/typebox-0.27.8.tgz", "integrity": "sha512-+Fj43pSMwJs4KRrH/938Uf+uAELIgVBmQzg/q1YG10djyfA3TnrU8N8XzqCh/okZdszqBQTZf96idMfE5lnwTA==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/@sindresorhus/is": { "version": "4.6.0", @@ -2939,23 +3401,52 @@ "url": "https://github.com/sindresorhus/is?sponsor=1" } }, + "node_modules/@sqlite.org/sqlite-wasm": { + "version": "3.48.0-build4", + "resolved": "https://registry.npmjs.org/@sqlite.org/sqlite-wasm/-/sqlite-wasm-3.48.0-build4.tgz", + "integrity": "sha512-hI6twvUkzOmyGZhQMza1gpfqErZxXRw6JEsiVjUbo7tFanVD+8Oil0Ih3l2nGzHdxPI41zFmfUQG7GHqhciKZQ==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "sqlite-wasm": "bin/index.js" + } + }, + "node_modules/@standard-schema/spec": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.0.0.tgz", + "integrity": "sha512-m2bOd0f2RT9k8QJx1JN85cZYyH1RqFBdlwtkSlf4tBDYLCiiZnv1fIIwacK6cqwXavOydf0NPToMQgpKq+dVlA==", + "dev": true, + "license": "MIT", + "optional": true + }, "node_modules/@stripe/stripe-js": { - "version": "3.4.0", - "resolved": "https://registry.npmjs.org/@stripe/stripe-js/-/stripe-js-3.4.0.tgz", - "integrity": "sha512-a2kUP7OrsV0SSIk3UxWa+cnrW+PPIyuCbWIBH8vxfHIqmyeQN/d0lsplZJ2h7MlLsU/sB3EyhNBkhLLT+zHwKw==", + "version": "3.5.0", + "resolved": "https://registry.npmjs.org/@stripe/stripe-js/-/stripe-js-3.5.0.tgz", + "integrity": "sha512-pKS3wZnJoL1iTyGBXAvCwduNNeghJHY6QSRSNNvpYnrrQrLZ6Owsazjyynu0e0ObRgks0i7Rv+pe2M7/MBTZpQ==", + "license": "MIT", "engines": { "node": ">=12.16" } }, + "node_modules/@sveltejs/acorn-typescript": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/@sveltejs/acorn-typescript/-/acorn-typescript-1.0.5.tgz", + "integrity": "sha512-IwQk4yfwLdibDlrXVE04jTZYlLnwsTT2PIOQQGNLWfjavGifnk1JD1LcZjZaBTRcxZu2FfPfNLOE04DSu9lqtQ==", + "license": "MIT", + "peerDependencies": { + "acorn": "^8.9.0" + } + }, "node_modules/@sveltejs/adapter-node": { - "version": "5.2.5", - "resolved": "https://registry.npmjs.org/@sveltejs/adapter-node/-/adapter-node-5.2.5.tgz", - "integrity": "sha512-FVeysFqeIlKFpDF1Oj38gby34f6uA9FuXnV330Z0RHmSyOR9JzJs70/nFKy1Ue3fWtf7S0RemOrP66Vr9Jcmew==", + "version": "5.2.12", + "resolved": "https://registry.npmjs.org/@sveltejs/adapter-node/-/adapter-node-5.2.12.tgz", + "integrity": "sha512-0bp4Yb3jKIEcZWVcJC/L1xXp9zzJS4hDwfb4VITAkfT4OVdkspSHsx7YhqJDbb2hgLl6R9Vs7VQR+fqIVOxPUQ==", "dev": true, + "license": "MIT", "dependencies": { - "@rollup/plugin-commonjs": "^28.0.0", + "@rollup/plugin-commonjs": "^28.0.1", "@rollup/plugin-json": "^6.1.0", - "@rollup/plugin-node-resolve": "^15.3.0", + "@rollup/plugin-node-resolve": "^16.0.0", "rollup": "^4.9.5" }, "peerDependencies": { @@ -2963,33 +3454,34 @@ } }, "node_modules/@sveltejs/adapter-static": { - "version": "3.0.5", - "resolved": "https://registry.npmjs.org/@sveltejs/adapter-static/-/adapter-static-3.0.5.tgz", - "integrity": "sha512-kFJR7RxeB6FBvrKZWAEzIALatgy11ISaaZbcPup8JdWUdrmmfUHHTJ738YHJTEfnCiiXi6aX8Q6ePY7tnSMD6Q==", + "version": "3.0.8", + "resolved": "https://registry.npmjs.org/@sveltejs/adapter-static/-/adapter-static-3.0.8.tgz", + "integrity": "sha512-YaDrquRpZwfcXbnlDsSrBQNCChVOT9MGuSg+dMAyfsAa1SmiAhrA5jUYUiIMC59G92kIbY/AaQOWcBdq+lh+zg==", "dev": true, + "license": "MIT", "peerDependencies": { "@sveltejs/kit": "^2.0.0" } }, "node_modules/@sveltejs/kit": { - "version": "2.6.3", - "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.6.3.tgz", - "integrity": "sha512-baIAnmfMqAISrPtTC/22w6ay5kTEIQ/vq9bctiaQgRIoLCPBNhb6LEidTuWQS7OzPYCDBMuMX1t/fMvi4r3q/g==", - "dev": true, - "hasInstallScript": true, + "version": "2.21.3", + "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.21.3.tgz", + "integrity": "sha512-Bd05srNOaqP05qnytjg/KkWNlkcwEpE76s0xGSlgzL4I8pLyrK3c9+a7zMCquoiEEIZF2ecGTn6Fj/lELjaa8A==", + "license": "MIT", "dependencies": { + "@sveltejs/acorn-typescript": "^1.0.5", "@types/cookie": "^0.6.0", + "acorn": "^8.14.1", "cookie": "^0.6.0", "devalue": "^5.1.0", - "esm-env": "^1.0.0", - "import-meta-resolve": "^4.1.0", + "esm-env": "^1.2.2", "kleur": "^4.1.5", "magic-string": "^0.30.5", "mrmime": "^2.0.0", "sade": "^1.8.1", "set-cookie-parser": "^2.6.0", - "sirv": "^2.0.4", - "tiny-glob": "^0.2.9" + "sirv": "^3.0.0", + "vitefu": "^1.0.6" }, "bin": { "svelte-kit": "svelte-kit.js" @@ -2998,16 +3490,15 @@ "node": ">=18.13" }, "peerDependencies": { - "@sveltejs/vite-plugin-svelte": "^3.0.0 || ^4.0.0-next.1", + "@sveltejs/vite-plugin-svelte": "^3.0.0 || ^4.0.0-next.1 || ^5.0.0", "svelte": "^4.0.0 || ^5.0.0-next.0", - "vite": "^5.0.3" + "vite": "^5.0.3 || ^6.0.0" } }, "node_modules/@sveltejs/vite-plugin-svelte": { "version": "4.0.4", "resolved": "https://registry.npmjs.org/@sveltejs/vite-plugin-svelte/-/vite-plugin-svelte-4.0.4.tgz", "integrity": "sha512-0ba1RQ/PHen5FGpdSrW7Y3fAMQjrXantECALeOiOdBdzR5+5vPP6HVZRLmZaQL+W8m++o+haIAKq5qT+MiZ7VA==", - "dev": true, "dependencies": { "@sveltejs/vite-plugin-svelte-inspector": "^3.0.0-next.0||^3.0.0", "debug": "^4.3.7", @@ -3028,7 +3519,6 @@ "version": "3.0.1", "resolved": "https://registry.npmjs.org/@sveltejs/vite-plugin-svelte-inspector/-/vite-plugin-svelte-inspector-3.0.1.tgz", "integrity": "sha512-2CKypmj1sM4GE7HjllT7UKmo4Q6L5xFRd7VMGEWhYnZ+wc6AUVU01IBd7yUi6WnFndEwWoMNOd6e8UjoN0nbvQ==", - "dev": true, "dependencies": { "debug": "^4.3.7" }, @@ -3042,11 +3532,13 @@ } }, "node_modules/@swc/helpers": { - "version": "0.5.11", - "resolved": "https://registry.npmjs.org/@swc/helpers/-/helpers-0.5.11.tgz", - "integrity": "sha512-YNlnKRWF2sVojTpIyzwou9XoTNbzbzONwRhOoniEioF1AtaitTvVZblaQRrAzChWQ1bLYyYSWzM18y4WwgzJ+A==", + "version": "0.5.17", + "resolved": "https://registry.npmjs.org/@swc/helpers/-/helpers-0.5.17.tgz", + "integrity": "sha512-5IKx/Y13RsYd+sauPb2x+U/xZikHjolzfuDgTAl/Tdf3Q8rslRvC19NKDLgAJQ6wsqADk10ntlv08nPFw/gO/A==", + "dev": true, + "license": "Apache-2.0", "dependencies": { - "tslib": "^2.4.0" + "tslib": "^2.8.0" } }, "node_modules/@szmarczak/http-timer": { @@ -3111,8 +3603,7 @@ "node_modules/@types/cookie": { "version": "0.6.0", "resolved": "https://registry.npmjs.org/@types/cookie/-/cookie-0.6.0.tgz", - "integrity": "sha512-4Kh9a6B2bQciAhf7FSuMRRkUWecJgJu9nPnx3yzpsfXX/c50REIqpHY4C82bXP90qrLtXtkDxTZosYO3UpOwlA==", - "dev": true + "integrity": "sha512-4Kh9a6B2bQciAhf7FSuMRRkUWecJgJu9nPnx3yzpsfXX/c50REIqpHY4C82bXP90qrLtXtkDxTZosYO3UpOwlA==" }, "node_modules/@types/cors": { "version": "2.8.17", @@ -3254,6 +3745,13 @@ "@types/node": "*" } }, + "node_modules/@types/pdfobject": { + "version": "2.2.5", + "resolved": "https://registry.npmjs.org/@types/pdfobject/-/pdfobject-2.2.5.tgz", + "integrity": "sha512-7gD5tqc/RUDq0PyoLemL0vEHxBYi+zY0WVaFAx/Y0jBsXFgot1vB9No1GhDZGwRGJMCIZbgAb74QG9MTyTNU/g==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/qs": { "version": "6.9.15", "resolved": "https://registry.npmjs.org/@types/qs/-/qs-6.9.15.tgz", @@ -3270,7 +3768,8 @@ "version": "1.20.2", "resolved": "https://registry.npmjs.org/@types/resolve/-/resolve-1.20.2.tgz", "integrity": "sha512-60BCwRFOZCQhDncwQdxxeOEEkbc5dIMccYLwbxsS4TUNeVECQ/pBJ0j09mrHOl/JJvpRPGwO9SvE4nR2Nb/a4Q==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/@types/responselike": { "version": "1.0.3", @@ -3319,6 +3818,14 @@ "integrity": "sha512-jg+97EGIcY9AGHJJRaaPVgetKDsrTgbRjQ5Msgjh/DQKEFl0DtyRr/VCOyD1T2R1MNeWPK/u7JoGhlDZnKBAfA==", "dev": true }, + "node_modules/@types/validator": { + "version": "13.15.2", + "resolved": "https://registry.npmjs.org/@types/validator/-/validator-13.15.2.tgz", + "integrity": "sha512-y7pa/oEJJ4iGYBxOpfAKn5b9+xuihvzDVnC/OSvlVnGxVg0pOqmjiMafiJ1KVNQEaPZf9HsEp5icEwGg8uIe5Q==", + "dev": true, + "license": "MIT", + "optional": true + }, "node_modules/@types/yauzl": { "version": "2.10.3", "resolved": "https://registry.npmjs.org/@types/yauzl/-/yauzl-2.10.3.tgz", @@ -3330,6 +3837,41 @@ "@types/node": "*" } }, + "node_modules/@typeschema/class-validator": { + "version": "0.3.0", + "resolved": "https://registry.npmjs.org/@typeschema/class-validator/-/class-validator-0.3.0.tgz", + "integrity": "sha512-OJSFeZDIQ8EK1HTljKLT5CItM2wsbgczLN8tMEfz3I1Lmhc5TBfkZ0eikFzUC16tI3d1Nag7um6TfCgp2I2Bww==", + "dev": true, + "license": "MIT", + "optional": true, + "dependencies": { + "@typeschema/core": "0.14.0" + }, + "peerDependencies": { + "class-validator": "^0.14.1" + }, + "peerDependenciesMeta": { + "class-validator": { + "optional": true + } + } + }, + "node_modules/@typeschema/core": { + "version": "0.14.0", + "resolved": "https://registry.npmjs.org/@typeschema/core/-/core-0.14.0.tgz", + "integrity": "sha512-Ia6PtZHcL3KqsAWXjMi5xIyZ7XMH4aSnOQes8mfMLx+wGFGtGRNlwe6Y7cYvX+WfNK67OL0/HSe9t8QDygV0/w==", + "dev": true, + "license": "MIT", + "optional": true, + "peerDependencies": { + "@types/json-schema": "^7.0.15" + }, + "peerDependenciesMeta": { + "@types/json-schema": { + "optional": true + } + } + }, "node_modules/@typescript-eslint/eslint-plugin": { "version": "7.3.1", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-7.3.1.tgz", @@ -3526,14 +4068,75 @@ "integrity": "sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ==", "dev": true }, + "node_modules/@vinejs/compiler": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/@vinejs/compiler/-/compiler-3.0.0.tgz", + "integrity": "sha512-v9Lsv59nR56+bmy2p0+czjZxsLHwaibJ+SV5iK9JJfehlJMa501jUJQqqz4X/OqKXrxtE3uTQmSqjUqzF3B2mw==", + "dev": true, + "license": "MIT", + "optional": true, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@vinejs/vine": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/@vinejs/vine/-/vine-3.0.1.tgz", + "integrity": "sha512-ZtvYkYpZOYdvbws3uaOAvTFuvFXoQGAtmzeiXu+XSMGxi5GVsODpoI9Xu9TplEMuD/5fmAtBbKb9cQHkWkLXDQ==", + "dev": true, + "license": "MIT", + "optional": true, + "dependencies": { + "@poppinss/macroable": "^1.0.4", + "@types/validator": "^13.12.2", + "@vinejs/compiler": "^3.0.0", + "camelcase": "^8.0.0", + "dayjs": "^1.11.13", + "dlv": "^1.1.3", + "normalize-url": "^8.0.1", + "validator": "^13.12.0" + }, + "engines": { + "node": ">=18.16.0" + } + }, + "node_modules/@vinejs/vine/node_modules/camelcase": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/camelcase/-/camelcase-8.0.0.tgz", + "integrity": "sha512-8WB3Jcas3swSvjIeA2yvCJ+Miyz5l1ZmB6HFb9R1317dt9LCQoswg/BGrmAmkWVEszSrrg4RwmO46qIm2OEnSA==", + "dev": true, + "license": "MIT", + "optional": true, + "engines": { + "node": ">=16" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/@vinejs/vine/node_modules/normalize-url": { + "version": "8.0.2", + "resolved": "https://registry.npmjs.org/normalize-url/-/normalize-url-8.0.2.tgz", + "integrity": "sha512-Ee/R3SyN4BuynXcnTaekmaVdbDAEiNrHqjQIA37mHU8G9pf7aaAD4ZX3XjBLo6rsdcxA/gtkcNYZLt30ACgynw==", + "dev": true, + "license": "MIT", + "optional": true, + "engines": { + "node": ">=14.16" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/@vitest/expect": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-1.4.0.tgz", - "integrity": "sha512-Jths0sWCJZ8BxjKe+p+eKsoqev1/T8lYcrjavEaz8auEJ4jAVY0GwW3JKmdVU4mmNPLPHixh4GNXP7GFtAiDHA==", + "version": "1.6.1", + "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-1.6.1.tgz", + "integrity": "sha512-jXL+9+ZNIJKruofqXuuTClf44eSpcHlgj3CiuNihUF3Ioujtmc0zIa3UJOW5RjDK1YLBJZnWBlPuqhYycLioog==", "dev": true, + "license": "MIT", "dependencies": { - "@vitest/spy": "1.4.0", - "@vitest/utils": "1.4.0", + "@vitest/spy": "1.6.1", + "@vitest/utils": "1.6.1", "chai": "^4.3.10" }, "funding": { @@ -3541,12 +4144,13 @@ } }, "node_modules/@vitest/runner": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-1.4.0.tgz", - "integrity": "sha512-EDYVSmesqlQ4RD2VvWo3hQgTJ7ZrFQ2VSJdfiJiArkCerDAGeyF1i6dHkmySqk573jLp6d/cfqCN+7wUB5tLgg==", + "version": "1.6.1", + "resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-1.6.1.tgz", + "integrity": "sha512-3nSnYXkVkf3mXFfE7vVyPmi3Sazhb/2cfZGGs0JRzFsPFvAMBEcrweV1V1GsrstdXeKCTXlJbvnQwGWgEIHmOA==", "dev": true, + "license": "MIT", "dependencies": { - "@vitest/utils": "1.4.0", + "@vitest/utils": "1.6.1", "p-limit": "^5.0.0", "pathe": "^1.1.1" }, @@ -3559,6 +4163,7 @@ "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-5.0.0.tgz", "integrity": "sha512-/Eaoq+QyLSiXQ4lyYV23f14mZRQcXnxfHrN0vCai+ak9G0pp9iEQukIIZq5NccEvwRB8PUnZT0KsOoDCINS1qQ==", "dev": true, + "license": "MIT", "dependencies": { "yocto-queue": "^1.0.0" }, @@ -3570,10 +4175,11 @@ } }, "node_modules/@vitest/runner/node_modules/yocto-queue": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-1.0.0.tgz", - "integrity": "sha512-9bnSc/HEW2uRy67wc+T8UwauLuPJVn28jb+GtJY16iiKWyvmYJRXVT4UamsAEGQfPohgr2q4Tq0sQbQlxTfi1g==", + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-1.2.1.tgz", + "integrity": "sha512-AyeEbWOu/TAXdxlV9wmGcR0+yh2j3vYPGOECcIj2S7MkrLyC7ne+oye2BKTItt0ii2PHk4cDy+95+LshzbXnGg==", "dev": true, + "license": "MIT", "engines": { "node": ">=12.20" }, @@ -3582,10 +4188,11 @@ } }, "node_modules/@vitest/snapshot": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-1.4.0.tgz", - "integrity": "sha512-saAFnt5pPIA5qDGxOHxJ/XxhMFKkUSBJmVt5VgDsAqPTX6JP326r5C/c9UuCMPoXNzuudTPsYDZCoJ5ilpqG2A==", + "version": "1.6.1", + "resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-1.6.1.tgz", + "integrity": "sha512-WvidQuWAzU2p95u8GAKlRMqMyN1yOJkGHnx3M1PL9Raf7AQ1kwLKg04ADlCa3+OXUZE7BceOhVZiuWAbzCKcUQ==", "dev": true, + "license": "MIT", "dependencies": { "magic-string": "^0.30.5", "pathe": "^1.1.1", @@ -3596,10 +4203,11 @@ } }, "node_modules/@vitest/spy": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-1.4.0.tgz", - "integrity": "sha512-Ywau/Qs1DzM/8Uc+yA77CwSegizMlcgTJuYGAi0jujOteJOUf1ujunHThYo243KG9nAyWT3L9ifPYZ5+As/+6Q==", + "version": "1.6.1", + "resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-1.6.1.tgz", + "integrity": "sha512-MGcMmpGkZebsMZhbQKkAf9CX5zGvjkBTqf8Zx3ApYWXr3wG+QvEu2eXWfnIIWYSJExIp4V9FCKDEeygzkYrXMw==", "dev": true, + "license": "MIT", "dependencies": { "tinyspy": "^2.2.0" }, @@ -3608,10 +4216,11 @@ } }, "node_modules/@vitest/utils": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-1.4.0.tgz", - "integrity": "sha512-mx3Yd1/6e2Vt/PUC98DcqTirtfxUyAZ32uK82r8rZzbtBeBo+nqgnjx/LvqQdWsrvNtm14VmurNgcf4nqY5gJg==", + "version": "1.6.1", + "resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-1.6.1.tgz", + "integrity": "sha512-jOrrUvXM4Av9ZWiG1EajNto0u96kWAhJ1LmPmJhXXQx/32MecEKd10pOLYgS2BQx1TgkGhloPU1ArDW2vvaY6g==", "dev": true, + "license": "MIT", "dependencies": { "diff-sequences": "^29.6.3", "estree-walker": "^3.0.3", @@ -3622,6 +4231,16 @@ "url": "https://opencollective.com/vitest" } }, + "node_modules/@vitest/utils/node_modules/estree-walker": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-3.0.3.tgz", + "integrity": "sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "^1.0.0" + } + }, "node_modules/@xmldom/xmldom": { "version": "0.8.10", "resolved": "https://registry.npmjs.org/@xmldom/xmldom/-/xmldom-0.8.10.tgz", @@ -3658,9 +4277,10 @@ } }, "node_modules/acorn": { - "version": "8.14.0", - "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.14.0.tgz", - "integrity": "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA==", + "version": "8.14.1", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.14.1.tgz", + "integrity": "sha512-OvQ/2pUDKmgfCg++xsTX1wGxfTaszcHVcTctW4UJB4hibJx2HXxxO5UmVgyjMa+ZDsiaf5wWLXYpRWMmBI0QHg==", + "license": "MIT", "bin": { "acorn": "bin/acorn" }, @@ -3677,14 +4297,6 @@ "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" } }, - "node_modules/acorn-typescript": { - "version": "1.4.13", - "resolved": "https://registry.npmjs.org/acorn-typescript/-/acorn-typescript-1.4.13.tgz", - "integrity": "sha512-xsc9Xv0xlVfwp2o7sQ+GCQ1PgbkdcpWdTzrwXxO3xDMTAywVS3oXVOcOHuRjAPkS4P9b+yc/qNF15460v+jp4Q==", - "peerDependencies": { - "acorn": ">=8.9.0" - } - }, "node_modules/acorn-walk": { "version": "8.3.2", "resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.3.2.tgz", @@ -3866,9 +4478,9 @@ "optional": true }, "node_modules/appdmg/node_modules/cross-spawn": { - "version": "6.0.5", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.5.tgz", - "integrity": "sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ==", + "version": "6.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.6.tgz", + "integrity": "sha512-VqCUuhcd1iB+dsv8gxPttb5iZh/D0iubSP21g36KXdEuf6I5JiioesUVjpCdHV9MZRUfVFlvwtIUyPfxo5trtw==", "dev": true, "license": "MIT", "optional": true, @@ -4052,11 +4664,30 @@ "node": ">= 0.4" } }, + "node_modules/arktype": { + "version": "2.1.20", + "resolved": "https://registry.npmjs.org/arktype/-/arktype-2.1.20.tgz", + "integrity": "sha512-IZCEEXaJ8g+Ijd59WtSYwtjnqXiwM8sWQ5EjGamcto7+HVN9eK0C4p0zDlCuAwWhpqr6fIBkxPuYDl4/Mcj/+Q==", + "dev": true, + "license": "MIT", + "optional": true, + "dependencies": { + "@ark/schema": "0.46.0", + "@ark/util": "0.46.0" + } + }, "node_modules/array-flatten": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", "integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==" }, + "node_modules/array-timsort": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/array-timsort/-/array-timsort-1.0.3.tgz", + "integrity": "sha512-/+3GRL7dDAGEfM6TseQk/U+mi18TU2Ms9I3UlLdUMhz2hbvGNTKdj9xniwXfUqgYhHxRx0+8UnKkvlNwVU+cWQ==", + "dev": true, + "license": "MIT" + }, "node_modules/array-union": { "version": "2.1.0", "resolved": "https://registry.npmjs.org/array-union/-/array-union-2.1.0.tgz", @@ -4132,6 +4763,7 @@ "resolved": "https://registry.npmjs.org/assertion-error/-/assertion-error-1.1.0.tgz", "integrity": "sha512-jgsaNduz+ndvGyFt3uSuWqvy4lCnIJiovtouQN5JZHOKCS2QuhEdbcQHFhVksz2N2U9hXJo8odG7ETyWlEeuDw==", "dev": true, + "license": "MIT", "engines": { "node": "*" } @@ -4157,26 +4789,24 @@ } }, "node_modules/auth0": { - "version": "4.4.0", - "resolved": "https://registry.npmjs.org/auth0/-/auth0-4.4.0.tgz", - "integrity": "sha512-umlAogwQDUYvL1Pd4RnViyps7lkvntZ3+VVDW+/4ML7/GzkJcq6VGJ20Nb60eYosQkxa7up1n9lrA4Of+BZsUg==", + "version": "4.27.0", + "resolved": "https://registry.npmjs.org/auth0/-/auth0-4.27.0.tgz", + "integrity": "sha512-4FGgjzKCH/f7rQLQVR5dM30asjOObeW3PyHa8bQrS4rKkuv22JoNxox26fb1FZ3hI4zEgbVbPm9x7pHrljZzrw==", "license": "MIT", "dependencies": { "jose": "^4.13.2", + "undici-types": "^6.15.0", "uuid": "^9.0.0" }, "engines": { "node": ">=18" } }, - "node_modules/auth0/node_modules/jose": { - "version": "4.15.5", - "resolved": "https://registry.npmjs.org/jose/-/jose-4.15.5.tgz", - "integrity": "sha512-jc7BFxgKPKi94uOvEmzlSWFFe2+vASyXaKUpdQKatWAESU2MWjDfFf0fdfc83CDKcA5QecabZeNLyfhe3yKNkg==", - "license": "MIT", - "funding": { - "url": "https://github.com/sponsors/panva" - } + "node_modules/auth0/node_modules/undici-types": { + "version": "6.21.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.21.0.tgz", + "integrity": "sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==", + "license": "MIT" }, "node_modules/author-regex": { "version": "1.0.0", @@ -4240,9 +4870,10 @@ } }, "node_modules/axios": { - "version": "1.7.4", - "resolved": "https://registry.npmjs.org/axios/-/axios-1.7.4.tgz", - "integrity": "sha512-DukmaFRnY6AzAALSH4J2M3k6PkaC+MfaAGdEERRWcC9q3/TWQwLpHR8ZRLKTdQ3aBDL64EdluRDjJqKw+BPZEw==", + "version": "1.9.0", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.9.0.tgz", + "integrity": "sha512-re4CqKTJaURpzbLHtIi6XpDv20/CnpXOtjRY5/CU32L8gU8ek9UIivcfvSWvmKEngmVbrUtPpdDwWDWL7DNHvg==", + "license": "MIT", "dependencies": { "follow-redirects": "^1.15.6", "form-data": "^4.0.0", @@ -4298,6 +4929,7 @@ "version": "3.0.1", "resolved": "https://registry.npmjs.org/base64url/-/base64url-3.0.1.tgz", "integrity": "sha512-ir1UPr3dkwexU7FdV8qBBbNDRUhMmIekYMFZfi+C/sLNnRESKPl23nB9b2pltqfOQNnGzsDdId90AEtG5tCx4A==", + "license": "MIT", "engines": { "node": ">=6.0.0" } @@ -4314,49 +4946,30 @@ } }, "node_modules/bits-ui": { - "version": "0.20.1", - "resolved": "https://registry.npmjs.org/bits-ui/-/bits-ui-0.20.1.tgz", - "integrity": "sha512-P0JRuWn+XpFYsAbGnPlyPVKab88v2S8Q57cUI3LZdh0nulO7fgxbXgBHgEAmmgNk63XxyvhmfYz44kZFRrHtLA==", + "version": "1.8.0", + "resolved": "https://registry.npmjs.org/bits-ui/-/bits-ui-1.8.0.tgz", + "integrity": "sha512-CXD6Orp7l8QevNDcRPLXc/b8iMVgxDWT2LyTwsdLzJKh9CxesOmPuNePSPqAxKoT59FIdU4aFPS1k7eBdbaCxg==", + "dev": true, + "license": "MIT", "dependencies": { - "@internationalized/date": "^3.5.1", - "@melt-ui/svelte": "0.76.2", - "nanoid": "^5.0.5" + "@floating-ui/core": "^1.6.4", + "@floating-ui/dom": "^1.6.7", + "@internationalized/date": "^3.5.6", + "css.escape": "^1.5.1", + "esm-env": "^1.1.2", + "runed": "^0.23.2", + "svelte-toolbelt": "^0.7.1", + "tabbable": "^6.2.0" }, - "peerDependencies": { - "svelte": "^4.0.0" - } - }, - "node_modules/bits-ui/node_modules/@melt-ui/svelte": { - "version": "0.76.2", - "resolved": "https://registry.npmjs.org/@melt-ui/svelte/-/svelte-0.76.2.tgz", - "integrity": "sha512-7SbOa11tXUS95T3fReL+dwDs5FyJtCEqrqG3inRziDws346SYLsxOQ6HmX+4BkIsQh1R8U3XNa+EMmdMt38lMA==", - "dependencies": { - "@floating-ui/core": "^1.3.1", - "@floating-ui/dom": "^1.4.5", - "@internationalized/date": "^3.5.0", - "dequal": "^2.0.3", - "focus-trap": "^7.5.2", - "nanoid": "^5.0.4" + "engines": { + "node": ">=18", + "pnpm": ">=8.7.0" }, - "peerDependencies": { - "svelte": ">=3 <5" - } - }, - "node_modules/bits-ui/node_modules/nanoid": { - "version": "5.0.6", - "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-5.0.6.tgz", - "integrity": "sha512-rRq0eMHoGZxlvaFOUdK1Ev83Bd1IgzzR+WJ3IbDJ7QOSdAxYjlurSPqFs9s4lJg29RT6nPwizFtJhQS6V5xgiA==", - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/ai" - } - ], - "bin": { - "nanoid": "bin/nanoid.js" + "funding": { + "url": "https://github.com/sponsors/huntabyte" }, - "engines": { - "node": "^18 || >=20" + "peerDependencies": { + "svelte": "^5.11.0" } }, "node_modules/bl": { @@ -4383,7 +4996,6 @@ "version": "3.7.2", "resolved": "https://registry.npmjs.org/bluebird/-/bluebird-3.7.2.tgz", "integrity": "sha512-XpNj6GDQzdfW+r2Wnn7xiSAd7TM3jzkxGXBGTtWKuSXv1xUV+azxAm8jdWZN06QTQk+2N2XB9jRDkvbmQmcRtg==", - "dev": true, "license": "MIT" }, "node_modules/body-parser": { @@ -4550,6 +5162,7 @@ "resolved": "https://registry.npmjs.org/cac/-/cac-6.7.14.tgz", "integrity": "sha512-b6Ilus+c3RrdDk+JhLKUAQfzzgLEPy6wcXqS7f/xe1EETvsDP6GORG7SFuOs6cID5YkqchW/LXZbX5bc8j7ZcQ==", "dev": true, + "license": "MIT", "engines": { "node": ">=8" } @@ -4757,10 +5370,11 @@ ] }, "node_modules/chai": { - "version": "4.4.1", - "resolved": "https://registry.npmjs.org/chai/-/chai-4.4.1.tgz", - "integrity": "sha512-13sOfMv2+DWduEU+/xbun3LScLoqN17nBeTLUsmDfKdoiC1fr0n9PU4guu4AhRcOVFk/sW8LyZWHuhWtQZiF+g==", + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/chai/-/chai-4.5.0.tgz", + "integrity": "sha512-RITGBfijLkBddZvnn8jdqoTypxvqbOLYQkGGxXzeFjVHvudaPw0HNFD9x928/eUwYWd2dPCugVqspGALTZZQKw==", "dev": true, + "license": "MIT", "dependencies": { "assertion-error": "^1.1.0", "check-error": "^1.0.3", @@ -4768,7 +5382,7 @@ "get-func-name": "^2.0.2", "loupe": "^2.3.6", "pathval": "^1.1.1", - "type-detect": "^4.0.8" + "type-detect": "^4.1.0" }, "engines": { "node": ">=4" @@ -4815,6 +5429,7 @@ "resolved": "https://registry.npmjs.org/check-error/-/check-error-1.0.3.tgz", "integrity": "sha512-iKEoDYaRmd1mxM90a2OEfWhjsjPpYPuQ+lMYsoxB126+t8fw7ySEO48nmDg5COTjxDI65/Y2OWpeEHk3ZOe8zg==", "dev": true, + "license": "MIT", "dependencies": { "get-func-name": "^2.0.2" }, @@ -4884,6 +5499,19 @@ "license": "MIT", "optional": true }, + "node_modules/class-validator": { + "version": "0.14.2", + "resolved": "https://registry.npmjs.org/class-validator/-/class-validator-0.14.2.tgz", + "integrity": "sha512-3kMVRF2io8N8pY1IFIXlho9r8IPUUIfHe2hYVtiebvAzU2XeQFXTv+XI4WX+TnXmtwXMDcjngcpkiPM0O9PvLw==", + "dev": true, + "license": "MIT", + "optional": true, + "dependencies": { + "@types/validator": "^13.11.8", + "libphonenumber-js": "^1.11.1", + "validator": "^13.9.0" + } + }, "node_modules/clean-stack": { "version": "2.2.0", "resolved": "https://registry.npmjs.org/clean-stack/-/clean-stack-2.2.0.tgz", @@ -5092,11 +5720,29 @@ "node": ">= 6" } }, + "node_modules/comment-json": { + "version": "4.2.5", + "resolved": "https://registry.npmjs.org/comment-json/-/comment-json-4.2.5.tgz", + "integrity": "sha512-bKw/r35jR3HGt5PEPm1ljsQQGyCrR8sFGNiN5L+ykDHdpO8Smxkrkla9Yi6NkQyUrb8V54PGhfMs6NrIwtxtdw==", + "dev": true, + "license": "MIT", + "dependencies": { + "array-timsort": "^1.0.3", + "core-util-is": "^1.0.3", + "esprima": "^4.0.1", + "has-own-prop": "^2.0.0", + "repeat-string": "^1.6.1" + }, + "engines": { + "node": ">= 6" + } + }, "node_modules/commondir": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/commondir/-/commondir-1.0.1.tgz", "integrity": "sha512-W9pAhw0ja1Edb5GVdIF1mjZw/ASI0AlShXM83UUGe2DVr5TdAPEA1OA8m/g8zWp9x6On7gqufY+FatDbC3MDQg==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/compare-version": { "version": "0.1.2", @@ -5155,6 +5801,23 @@ "node": ">=12" } }, + "node_modules/concurrently/node_modules/date-fns": { + "version": "2.30.0", + "resolved": "https://registry.npmjs.org/date-fns/-/date-fns-2.30.0.tgz", + "integrity": "sha512-fnULvOpxnC5/Vg3NCiWelDsLiUc9bRwAPs/+LfTLNvetFCtCTN+yQz15C/fs4AwX1R9K5GLtLfn8QW+dWisaAw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.21.0" + }, + "engines": { + "node": ">=0.11" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/date-fns" + } + }, "node_modules/concurrently/node_modules/emoji-regex": { "version": "8.0.0", "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", @@ -5243,6 +5906,16 @@ "node": ">=12" } }, + "node_modules/consola": { + "version": "3.4.0", + "resolved": "https://registry.npmjs.org/consola/-/consola-3.4.0.tgz", + "integrity": "sha512-EiPU8G6dQG0GFHNR8ljnZFki/8a+cQwEQ+7wpxdChl02Q8HXlwEZWD5lqAF8vC2sEC3Tehr8hy7vErz88LHyUA==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^14.18.0 || >=16.10.0" + } + }, "node_modules/console-control-strings": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/console-control-strings/-/console-control-strings-1.1.0.tgz", @@ -5273,7 +5946,6 @@ "version": "0.6.0", "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz", "integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==", - "dev": true, "engines": { "node": ">= 0.6" } @@ -5283,6 +5955,13 @@ "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz", "integrity": "sha512-QADzlaHc8icV8I7vbaJXJwod9HWYp8uCqf1xa4OfNu1T7JVxQIrUgOWtHdNDtPiywmFbiS12VjotIXLrKM3orQ==" }, + "node_modules/core-util-is": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.3.tgz", + "integrity": "sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==", + "dev": true, + "license": "MIT" + }, "node_modules/cors": { "version": "2.8.5", "resolved": "https://registry.npmjs.org/cors/-/cors-2.8.5.tgz", @@ -5321,9 +6000,10 @@ } }, "node_modules/cross-spawn": { - "version": "7.0.3", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", - "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "license": "MIT", "dependencies": { "path-key": "^3.1.0", "shebang-command": "^2.0.0", @@ -5357,6 +6037,13 @@ "node": ">=12.10" } }, + "node_modules/css.escape": { + "version": "1.5.1", + "resolved": "https://registry.npmjs.org/css.escape/-/css.escape-1.5.1.tgz", + "integrity": "sha512-YUifsXXuknHlUsmlgyY0PKzgPOr7/FjCePfHNt0jxm83wHZi44VDMQ7/fGNkjY3/jV1MC+1CmZbaHzugyeRtpg==", + "dev": true, + "license": "MIT" + }, "node_modules/cssesc": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz", @@ -5368,22 +6055,53 @@ "node": ">=4" } }, - "node_modules/date-fns": { - "version": "2.30.0", - "resolved": "https://registry.npmjs.org/date-fns/-/date-fns-2.30.0.tgz", - "integrity": "sha512-fnULvOpxnC5/Vg3NCiWelDsLiUc9bRwAPs/+LfTLNvetFCtCTN+yQz15C/fs4AwX1R9K5GLtLfn8QW+dWisaAw==", - "dev": true, + "node_modules/csvtojson": { + "version": "2.0.10", + "resolved": "https://registry.npmjs.org/csvtojson/-/csvtojson-2.0.10.tgz", + "integrity": "sha512-lUWFxGKyhraKCW8Qghz6Z0f2l/PqB1W3AO0HKJzGIQ5JRSlR651ekJDiGJbBT4sRNNv5ddnSGVEnsxP9XRCVpQ==", + "license": "MIT", "dependencies": { - "@babel/runtime": "^7.21.0" + "bluebird": "^3.5.1", + "lodash": "^4.17.3", + "strip-bom": "^2.0.0" + }, + "bin": { + "csvtojson": "bin/csvtojson" }, "engines": { - "node": ">=0.11" + "node": ">=4.0.0" + } + }, + "node_modules/csvtojson/node_modules/strip-bom": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/strip-bom/-/strip-bom-2.0.0.tgz", + "integrity": "sha512-kwrX1y7czp1E69n2ajbG65mIo9dqvJ+8aBQXOGVxqwvNbsXdFM6Lq37dLAY3mknUwru8CfcCbfOLL/gMo+fi3g==", + "license": "MIT", + "dependencies": { + "is-utf8": "^0.2.0" }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/date-fns": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/date-fns/-/date-fns-4.1.0.tgz", + "integrity": "sha512-Ukq0owbQXxa/U3EGtsdVBkR1w7KOQ5gIBqdH2hkvknzZPYvBxb/aa6E8L7tmjFtkwZBu3UXBbjIgPo/Ez4xaNg==", + "license": "MIT", "funding": { - "type": "opencollective", - "url": "https://opencollective.com/date-fns" + "type": "github", + "url": "https://github.com/sponsors/kossnocorp" } }, + "node_modules/dayjs": { + "version": "1.11.13", + "resolved": "https://registry.npmjs.org/dayjs/-/dayjs-1.11.13.tgz", + "integrity": "sha512-oaMBel6gjolK862uaPQOVTA7q3TZhuSvuMQAAglQDOWYO9A91IrAOUJEyKVlqJlHE0vq5p5UXxzdPfMH/x6xNg==", + "dev": true, + "license": "MIT", + "optional": true + }, "node_modules/debug": { "version": "4.4.0", "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.0.tgz", @@ -5441,11 +6159,27 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/dedent": { + "version": "1.5.1", + "resolved": "https://registry.npmjs.org/dedent/-/dedent-1.5.1.tgz", + "integrity": "sha512-+LxW+KLWxu3HW3M2w2ympwtqPrqYRzU8fqi6Fhd18fBALe15blJPI/I4+UHveMVG6lJqB4JNd4UG0S5cnVHwIg==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "babel-plugin-macros": "^3.1.0" + }, + "peerDependenciesMeta": { + "babel-plugin-macros": { + "optional": true + } + } + }, "node_modules/deep-eql": { - "version": "4.1.3", - "resolved": "https://registry.npmjs.org/deep-eql/-/deep-eql-4.1.3.tgz", - "integrity": "sha512-WaEtAOpRA1MQ0eohqZjpGD8zdI0Ovsm8mmFhaDN8dvDZzyoUMcYDnf5Y6iu7HTXxf8JDS23qWa4a+hKCDyOPzw==", + "version": "4.1.4", + "resolved": "https://registry.npmjs.org/deep-eql/-/deep-eql-4.1.4.tgz", + "integrity": "sha512-SUwdGfqdKOwxCPeVYjwSyRpJ7Z+fhpwIAtmCUdZIWZ/YP5R9WAsyuSgpLVDi9bjWoN2LXHNss/dk3urXtdQxGg==", "dev": true, + "license": "MIT", "dependencies": { "type-detect": "^4.0.0" }, @@ -5463,7 +6197,6 @@ "version": "4.3.1", "resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-4.3.1.tgz", "integrity": "sha512-3sUqbMEc77XqpdNO7FRyRog+eW3ph+GYCbj+rK+uYyRMuwsVy0rMiVtPn+QJlKFvWP/1PYpapqYn0Me2knFn+A==", - "dev": true, "engines": { "node": ">=0.10.0" } @@ -5557,14 +6290,6 @@ "node": ">= 0.8" } }, - "node_modules/dequal": { - "version": "2.0.3", - "resolved": "https://registry.npmjs.org/dequal/-/dequal-2.0.3.tgz", - "integrity": "sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==", - "engines": { - "node": ">=6" - } - }, "node_modules/destroy": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/destroy/-/destroy-1.2.0.tgz", @@ -5595,8 +6320,7 @@ "node_modules/devalue": { "version": "5.1.1", "resolved": "https://registry.npmjs.org/devalue/-/devalue-5.1.1.tgz", - "integrity": "sha512-maua5KUiapvEwiEAe+XnlZ3Rh0GD+qI1J/nb9vrJc3muPXvcF/8gXYTWF76+5DAqHyDUtOIImEuo0YKE9mshVw==", - "dev": true + "integrity": "sha512-maua5KUiapvEwiEAe+XnlZ3Rh0GD+qI1J/nb9vrJc3muPXvcF/8gXYTWF76+5DAqHyDUtOIImEuo0YKE9mshVw==" }, "node_modules/dexie": { "version": "4.0.10", @@ -5614,6 +6338,7 @@ "resolved": "https://registry.npmjs.org/diff-sequences/-/diff-sequences-29.6.3.tgz", "integrity": "sha512-EjePK1srD3P08o2j4f0ExnylqRs5B9tJjcp9t1krH2qRi8CCdsYfwe9JgSLurFBWwq4uOlipzfk5fHNvwFKr8Q==", "dev": true, + "license": "MIT", "engines": { "node": "^14.15.0 || ^16.10.0 || >=18.0.0" } @@ -5681,6 +6406,18 @@ "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz", "integrity": "sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==" }, + "node_modules/effect": { + "version": "3.16.9", + "resolved": "https://registry.npmjs.org/effect/-/effect-3.16.9.tgz", + "integrity": "sha512-onKn21L/Us3G/x4BeUxiE4B/jNiJ09uRcYEfSYVPJE10dTUM3aDdO3g15PW6ccF1BJuOtQt1cxx4/1lACwX/bA==", + "dev": true, + "license": "MIT", + "optional": true, + "dependencies": { + "@standard-schema/spec": "^1.0.0", + "fast-check": "^3.23.1" + } + }, "node_modules/electron": { "version": "31.0.1", "resolved": "https://registry.npmjs.org/electron/-/electron-31.0.1.tgz", @@ -6127,7 +6864,6 @@ "version": "0.21.5", "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.21.5.tgz", "integrity": "sha512-mg3OPMV4hXywwpoDxu3Qda5xCKQi+vCTZq8S9J/EpkhB2HzKXq4SNFZE3+NK93JYxc8VMSep+lOUSC/RVKaBqw==", - "dev": true, "hasInstallScript": true, "bin": { "esbuild": "bin/esbuild" @@ -6161,6 +6897,45 @@ "@esbuild/win32-x64": "0.21.5" } }, + "node_modules/esbuild-runner": { + "version": "2.2.2", + "resolved": "https://registry.npmjs.org/esbuild-runner/-/esbuild-runner-2.2.2.tgz", + "integrity": "sha512-fRFVXcmYVmSmtYm2mL8RlUASt2TDkGh3uRcvHFOKNr/T58VrfVeKD9uT9nlgxk96u0LS0ehS/GY7Da/bXWKkhw==", + "dev": true, + "license": "Apache License 2.0", + "optional": true, + "dependencies": { + "source-map-support": "0.5.21", + "tslib": "2.4.0" + }, + "bin": { + "esr": "bin/esr.js" + }, + "peerDependencies": { + "esbuild": "*" + } + }, + "node_modules/esbuild-runner/node_modules/tslib": { + "version": "2.4.0", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.4.0.tgz", + "integrity": "sha512-d6xOpEDfsi2CZVlPQzGeux8XMwLT9hssAsaPYExaQMuYskwb+x1x7J371tWlbBdWHroy99KnVB6qIkUbs5X3UQ==", + "dev": true, + "license": "0BSD", + "optional": true + }, + "node_modules/esbuild-wasm": { + "version": "0.19.12", + "resolved": "https://registry.npmjs.org/esbuild-wasm/-/esbuild-wasm-0.19.12.tgz", + "integrity": "sha512-Zmc4hk6FibJZBcTx5/8K/4jT3/oG1vkGTEeKJUQFCUQKimD6Q7+adp/bdVQyYJFolMKaXkQnVZdV4O5ZaTYmyQ==", + "dev": true, + "license": "MIT", + "bin": { + "esbuild": "bin/esbuild" + }, + "engines": { + "node": ">=12" + } + }, "node_modules/escalade": { "version": "3.1.2", "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.2.tgz", @@ -6354,9 +7129,10 @@ } }, "node_modules/esm-env": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/esm-env/-/esm-env-1.2.1.tgz", - "integrity": "sha512-U9JedYYjCnadUlXk7e1Kr+aENQhtUaoaV9+gZm1T8LC/YBAPJx3NSPIAurFOC0U5vrdSevnUJS2/wUVxGwPhng==" + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/esm-env/-/esm-env-1.2.2.tgz", + "integrity": "sha512-Epxrv+Nr/CaL4ZcFGPJIYLWFom+YeV1DqMLHJoEd9SYRxNbaFruBwfEX/kkHUJf55j2+TUbmDcmuilbP1TmXHA==", + "license": "MIT" }, "node_modules/espree": { "version": "9.6.1", @@ -6375,6 +7151,20 @@ "url": "https://opencollective.com/eslint" } }, + "node_modules/esprima": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/esprima/-/esprima-4.0.1.tgz", + "integrity": "sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==", + "dev": true, + "license": "BSD-2-Clause", + "bin": { + "esparse": "bin/esparse.js", + "esvalidate": "bin/esvalidate.js" + }, + "engines": { + "node": ">=4" + } + }, "node_modules/esquery": { "version": "1.5.0", "resolved": "https://registry.npmjs.org/esquery/-/esquery-1.5.0.tgz", @@ -6388,9 +7178,10 @@ } }, "node_modules/esrap": { - "version": "1.3.2", - "resolved": "https://registry.npmjs.org/esrap/-/esrap-1.3.2.tgz", - "integrity": "sha512-C4PXusxYhFT98GjLSmb20k9PREuUdporer50dhzGuJu9IJXktbMddVCMLAERl5dAHyAi73GWWCE4FVHGP1794g==", + "version": "1.4.6", + "resolved": "https://registry.npmjs.org/esrap/-/esrap-1.4.6.tgz", + "integrity": "sha512-F/D2mADJ9SHY3IwksD4DAXjTt7qt7GWUf3/8RhCNWmC/67tyb55dpimHmy7EplakFaflV0R/PC+fdSPqrRHAQw==", + "license": "MIT", "dependencies": { "@jridgewell/sourcemap-codec": "^1.4.15" } @@ -6417,13 +7208,11 @@ } }, "node_modules/estree-walker": { - "version": "3.0.3", - "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-3.0.3.tgz", - "integrity": "sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-2.0.2.tgz", + "integrity": "sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==", "dev": true, - "dependencies": { - "@types/estree": "^1.0.0" - } + "license": "MIT" }, "node_modules/esutils": { "version": "2.0.3", @@ -6493,9 +7282,10 @@ "license": "Apache-2.0" }, "node_modules/express": { - "version": "4.21.1", - "resolved": "https://registry.npmjs.org/express/-/express-4.21.1.tgz", - "integrity": "sha512-YSFlK1Ee0/GC8QaO91tHcDxJiE/X4FbpAyQWkxAvG6AXCuR65YzK8ua6D9hvi/TzUfZMpc+BwuM1IPw8fmQBiQ==", + "version": "4.21.2", + "resolved": "https://registry.npmjs.org/express/-/express-4.21.2.tgz", + "integrity": "sha512-28HqgMZAmih1Czt9ny7qr6ek2qddF4FclbMzwhCREB6OFfH+rXAnuNCwo1/wFvrtbgsQDb4kSbX9de9lFbrXnA==", + "license": "MIT", "dependencies": { "accepts": "~1.3.8", "array-flatten": "1.1.1", @@ -6516,7 +7306,7 @@ "methods": "~1.1.2", "on-finished": "2.4.1", "parseurl": "~1.3.3", - "path-to-regexp": "0.1.10", + "path-to-regexp": "0.1.12", "proxy-addr": "~2.0.7", "qs": "6.13.0", "range-parser": "~1.2.1", @@ -6531,21 +7321,26 @@ }, "engines": { "node": ">= 0.10.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" } }, "node_modules/express-openid-connect": { - "version": "2.17.1", - "resolved": "https://registry.npmjs.org/express-openid-connect/-/express-openid-connect-2.17.1.tgz", - "integrity": "sha512-5pVK6PNV09x6UN29R9Mer0XF3hwQq2HxiFsjZvLuIQ9ezeTUGbqrefzBOpzciz1S/1WWVaVPDIcj4EBpD8WB3Q==", + "version": "2.18.1", + "resolved": "https://registry.npmjs.org/express-openid-connect/-/express-openid-connect-2.18.1.tgz", + "integrity": "sha512-trHqgwXxWF0n/XrDsRzsvQtnBNbU03iCNXbKR/sHwBqXlvCgup341bW7B8t6nr3L/CMoDpK+9gsTnx3qLCqdjQ==", + "license": "MIT", "dependencies": { "base64url": "^3.0.1", "clone": "^2.1.2", - "cookie": "^0.5.0", + "cookie": "^0.7.1", "debug": "^4.3.4", "futoin-hkdf": "^1.5.1", "http-errors": "^1.8.1", "joi": "^17.7.0", - "jose": "^2.0.6", + "jose": "^2.0.7", "on-headers": "^1.0.2", "openid-client": "^4.9.1", "url-join": "^4.0.1" @@ -6558,9 +7353,10 @@ } }, "node_modules/express-openid-connect/node_modules/cookie": { - "version": "0.5.0", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.5.0.tgz", - "integrity": "sha512-YZ3GUyn/o8gfKJlnlX7g7xq4gyO6OSuhGPKaaGssGB2qgDUS0gPgtTvoyZLTt9Ab6dC4hfc9dV5arkvc/OCmrw==", + "version": "0.7.2", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz", + "integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==", + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -6569,6 +7365,7 @@ "version": "1.1.2", "resolved": "https://registry.npmjs.org/depd/-/depd-1.1.2.tgz", "integrity": "sha512-7emPTl6Dpo6JRXOXjLRxck+FlLRX5847cLKEn00PLAgc3g2hTZZgr+e4c2v6QpSmLeFP3n5yUo7ft6avBK/5jQ==", + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -6577,6 +7374,7 @@ "version": "1.8.1", "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-1.8.1.tgz", "integrity": "sha512-Kpk9Sm7NmI+RHhnj6OIWDI1d6fIoFAtFt9RLaTMRlg/8w49juAStsrBgp0Dp4OdxdVbRIeKhtCUvoi/RuAhO4g==", + "license": "MIT", "dependencies": { "depd": "~1.1.2", "inherits": "2.0.4", @@ -6588,10 +7386,26 @@ "node": ">= 0.6" } }, + "node_modules/express-openid-connect/node_modules/jose": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/jose/-/jose-2.0.7.tgz", + "integrity": "sha512-5hFWIigKqC+e/lRyQhfnirrAqUdIPMB7SJRqflJaO29dW7q5DFvH1XCSTmv6PQ6pb++0k6MJlLRoS0Wv4s38Wg==", + "license": "MIT", + "dependencies": { + "@panva/asn1.js": "^1.0.0" + }, + "engines": { + "node": ">=10.13.0 < 13 || >=13.7.0" + }, + "funding": { + "url": "https://github.com/sponsors/panva" + } + }, "node_modules/express-openid-connect/node_modules/statuses": { "version": "1.5.0", "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.5.0.tgz", "integrity": "sha512-OpZ3zP+jT1PI7I8nemJX4AKmAX070ZkYPVWV/AaKTJl+tXCTGyVdC1a4SL8RUQYEwk/f34ZX8UTykN68FwrqAA==", + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -6662,6 +7476,30 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/fast-check": { + "version": "3.23.2", + "resolved": "https://registry.npmjs.org/fast-check/-/fast-check-3.23.2.tgz", + "integrity": "sha512-h5+1OzzfCC3Ef7VbtKdcv7zsstUQwUDlYpUTvjeUsJAssPgLn7QzbboPtL5ro04Mq0rPOsMzl7q5hIbRs2wD1A==", + "dev": true, + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/dubzzz" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fast-check" + } + ], + "license": "MIT", + "optional": true, + "dependencies": { + "pure-rand": "^6.1.0" + }, + "engines": { + "node": ">=8.0.0" + } + }, "node_modules/fast-deep-equal": { "version": "3.1.3", "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", @@ -6955,14 +7793,6 @@ "imul": "^1.0.0" } }, - "node_modules/focus-trap": { - "version": "7.6.2", - "resolved": "https://registry.npmjs.org/focus-trap/-/focus-trap-7.6.2.tgz", - "integrity": "sha512-9FhUxK1hVju2+AiQIDJ5Dd//9R2n2RAfJ0qfhF4IHGHgcoEUTMpbTeG/zbEuwaiYXfuAH6XE0/aCyxDdRM+W5w==", - "dependencies": { - "tabbable": "^6.2.0" - } - }, "node_modules/follow-redirects": { "version": "1.15.6", "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", @@ -7153,6 +7983,7 @@ "version": "1.5.3", "resolved": "https://registry.npmjs.org/futoin-hkdf/-/futoin-hkdf-1.5.3.tgz", "integrity": "sha512-SewY5KdMpaoCeh7jachEWFsh1nNlaDjNHZXWqL5IGwtpEYHTgkr2+AMCgNwKWkcc0wpSYrZfR7he4WdmHFtDxQ==", + "license": "Apache-2.0", "engines": { "node": ">=8" } @@ -7319,6 +8150,7 @@ "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.2.tgz", "integrity": "sha512-8vXOvuE167CtIc3OyItco7N/dpRtBbYOsPsXCz7X/PMnlGjYjSGuZJgM1Y7mmew7BKf9BqvLX2tnOVy1BBUsxQ==", "dev": true, + "license": "MIT", "engines": { "node": "*" } @@ -7546,12 +8378,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/globalyzer": { - "version": "0.1.0", - "resolved": "https://registry.npmjs.org/globalyzer/-/globalyzer-0.1.0.tgz", - "integrity": "sha512-40oNTM9UfG6aBmuKxk/giHn5nQ8RVz/SS4Ir6zgzOv9/qC3kKZ9v4etGTcJbEl/NyVQH7FGU7d+X1egr57Md2Q==", - "dev": true - }, "node_modules/globby": { "version": "11.1.0", "resolved": "https://registry.npmjs.org/globby/-/globby-11.1.0.tgz", @@ -7572,12 +8398,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/globrex": { - "version": "0.1.2", - "resolved": "https://registry.npmjs.org/globrex/-/globrex-0.1.2.tgz", - "integrity": "sha512-uHJgbwAMwNFf5mLst7IWLNg14x1CkeqglJb/K3doi4dw6q2IvAAmM/Y81kevy83wP+Sst+nutFTYOGg3d1lsxg==", - "dev": true - }, "node_modules/gopd": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz", @@ -7625,11 +8445,21 @@ "integrity": "sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==", "dev": true }, - "node_modules/has-flag": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", - "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/has-own-prop": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/has-own-prop/-/has-own-prop-2.0.0.tgz", + "integrity": "sha512-Pq0h+hvsVm6dDEa8x82GnLSYHOzNDt7f0ddFa3FqcQlgzEiptPqL+XrOJNavjOzSYiYWIrgeVYYgGlLmnxwilQ==", "dev": true, + "license": "MIT", "engines": { "node": ">=8" } @@ -7781,6 +8611,16 @@ "node": ">= 6" } }, + "node_modules/human-id": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/human-id/-/human-id-4.1.1.tgz", + "integrity": "sha512-3gKm/gCSUipeLsRYZbbdA1BD83lBoWUkZ7G9VFrhWPAU76KwYo5KR8V28bpoPm/ygy0x5/GCbpRQdY7VLYCoIg==", + "dev": true, + "license": "MIT", + "bin": { + "human-id": "dist/cli.js" + } + }, "node_modules/human-signals": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/human-signals/-/human-signals-5.0.0.tgz", @@ -7871,16 +8711,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/import-meta-resolve": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/import-meta-resolve/-/import-meta-resolve-4.1.0.tgz", - "integrity": "sha512-I6fiaX09Xivtk+THaMfAwnA3MVA5Big1WHF1Dfx9hFuvNIWpXnorlkzhcQf6ehrqQiiZECRt1poOAkPmer3ruw==", - "dev": true, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/wooorm" - } - }, "node_modules/imul": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/imul/-/imul-1.0.1.tgz", @@ -7938,6 +8768,13 @@ "dev": true, "license": "ISC" }, + "node_modules/inline-style-parser": { + "version": "0.2.4", + "resolved": "https://registry.npmjs.org/inline-style-parser/-/inline-style-parser-0.2.4.tgz", + "integrity": "sha512-0aO8FkhNZlj/ZIbNi7Lxxr12obT7cL1moPfE4tg1LkX7LlLfC6DeX4l2ZEud1ukP9jNQyNnfzQVqwbwmAATY4Q==", + "dev": true, + "license": "MIT" + }, "node_modules/interpret": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/interpret/-/interpret-3.1.1.tgz", @@ -8087,7 +8924,8 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/is-module/-/is-module-1.0.0.tgz", "integrity": "sha512-51ypPSPCoTEIN9dy5Oy+h4pShgJmPCygKfyRCISBI+JoWT/2oJvK8QPxmwv7b/p239jXrm9M1mlQbyKJ5A152g==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/is-my-ip-valid": { "version": "1.0.1", @@ -8184,6 +9022,12 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/is-utf8": { + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/is-utf8/-/is-utf8-0.2.1.tgz", + "integrity": "sha512-rMYPYvCzsXywIsldgLaSoPlw5PfoB/ssr7hY4pLfcodrA5M/eArza1a9VmTiNIBNMjOGr1Ow9mTyU2o69U6U9Q==", + "license": "MIT" + }, "node_modules/is-windows": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/is-windows/-/is-windows-1.0.2.tgz", @@ -8225,9 +9069,10 @@ } }, "node_modules/joi": { - "version": "17.13.1", - "resolved": "https://registry.npmjs.org/joi/-/joi-17.13.1.tgz", - "integrity": "sha512-vaBlIKCyo4FCUtCm7Eu4QZd/q02bWcxfUO6YSXAZOWF6gzcLBeba8kwotUdYJjDLW8Cz8RywsSOqiNJZW0mNvg==", + "version": "17.13.3", + "resolved": "https://registry.npmjs.org/joi/-/joi-17.13.3.tgz", + "integrity": "sha512-otDA4ldcIx+ZXsKHWmp0YizCweVRZG96J10b0FevjfuncLO1oX59THoAmHkNubYJ+9gWsYsp5k8v4ib6oDv1fA==", + "license": "BSD-3-Clause", "dependencies": { "@hapi/hoek": "^9.3.0", "@hapi/topo": "^5.1.0", @@ -8237,19 +9082,21 @@ } }, "node_modules/jose": { - "version": "2.0.7", - "resolved": "https://registry.npmjs.org/jose/-/jose-2.0.7.tgz", - "integrity": "sha512-5hFWIigKqC+e/lRyQhfnirrAqUdIPMB7SJRqflJaO29dW7q5DFvH1XCSTmv6PQ6pb++0k6MJlLRoS0Wv4s38Wg==", - "dependencies": { - "@panva/asn1.js": "^1.0.0" - }, - "engines": { - "node": ">=10.13.0 < 13 || >=13.7.0" - }, + "version": "4.15.9", + "resolved": "https://registry.npmjs.org/jose/-/jose-4.15.9.tgz", + "integrity": "sha512-1vUQX+IdDMVPj4k8kOxgUqlcK518yluMuGZwqlr44FS1ppZB/5GWh4rZG89erpOBOJjU/OBsnCVFfapsRz6nEA==", + "license": "MIT", "funding": { "url": "https://github.com/sponsors/panva" } }, + "node_modules/js-sha256": { + "version": "0.11.1", + "resolved": "https://registry.npmjs.org/js-sha256/-/js-sha256-0.11.1.tgz", + "integrity": "sha512-o6WSo/LUvY2uC4j7mO50a2ms7E/EAdbP0swigLV+nzHKTTaYnaLIWJ02VdXrsJX0vGedDESQnLsOekr94ryfjg==", + "dev": true, + "license": "MIT" + }, "node_modules/js-tokens": { "version": "8.0.3", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-8.0.3.tgz", @@ -8280,6 +9127,21 @@ "resolved": "https://registry.npmjs.org/json-buffer/-/json-buffer-3.0.1.tgz", "integrity": "sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ==" }, + "node_modules/json-schema-to-ts": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/json-schema-to-ts/-/json-schema-to-ts-3.1.1.tgz", + "integrity": "sha512-+DWg8jCJG2TEnpy7kOm/7/AxaYoaRbjVB4LFZLySZlWn8exGs3A4OLJR966cVvU26N7X9TWxl+Jsw7dzAqKT6g==", + "dev": true, + "license": "MIT", + "optional": true, + "dependencies": { + "@babel/runtime": "^7.18.3", + "ts-algebra": "^2.0.0" + }, + "engines": { + "node": ">=16" + } + }, "node_modules/json-schema-traverse": { "version": "0.4.1", "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", @@ -8305,6 +9167,19 @@ "license": "ISC", "optional": true }, + "node_modules/json5": { + "version": "2.2.3", + "resolved": "https://registry.npmjs.org/json5/-/json5-2.2.3.tgz", + "integrity": "sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==", + "dev": true, + "license": "MIT", + "bin": { + "json5": "lib/cli.js" + }, + "engines": { + "node": ">=6" + } + }, "node_modules/jsonc-parser": { "version": "3.2.1", "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz", @@ -8354,7 +9229,6 @@ "version": "4.1.5", "resolved": "https://registry.npmjs.org/kleur/-/kleur-4.1.5.tgz", "integrity": "sha512-o+NO+8WrRiQEE4/7nwRJhN1HWpVmJm511pBHUxPLtp0BUISzlBplORYSmTclCnJvQq2tKu/sgl3xVpkc7ZWuQQ==", - "dev": true, "engines": { "node": ">=6" } @@ -8365,6 +9239,16 @@ "integrity": "sha512-a/RAk2BfKk+WFGhhOCAYqSiFLc34k8Mt/6NWRI4joER0EYUzXIcFivjjnoD3+XU1DggLn/tZc3DOAgke7l8a4A==", "dev": true }, + "node_modules/kysely": { + "version": "0.27.6", + "resolved": "https://registry.npmjs.org/kysely/-/kysely-0.27.6.tgz", + "integrity": "sha512-FIyV/64EkKhJmjgC0g2hygpBv5RNWVPyNCqSAD7eTCv6eFWNIi4PN1UvdSJGicN/o35bnevgis4Y0UDC0qi8jQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.0.0" + } + }, "node_modules/levn": { "version": "0.4.1", "resolved": "https://registry.npmjs.org/levn/-/levn-0.4.1.tgz", @@ -8378,6 +9262,14 @@ "node": ">= 0.8.0" } }, + "node_modules/libphonenumber-js": { + "version": "1.12.9", + "resolved": "https://registry.npmjs.org/libphonenumber-js/-/libphonenumber-js-1.12.9.tgz", + "integrity": "sha512-VWwAdNeJgN7jFOD+wN4qx83DTPMVPPAUyx9/TUkBXKLiNkuWWk6anV0439tgdtwaJDrEdqkvdN22iA6J4bUCZg==", + "dev": true, + "license": "MIT", + "optional": true + }, "node_modules/lilconfig": { "version": "2.1.0", "resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-2.1.0.tgz", @@ -8653,6 +9545,7 @@ "resolved": "https://registry.npmjs.org/loupe/-/loupe-2.3.7.tgz", "integrity": "sha512-zSMINGVYkdpYSOBmLi0D1Uo7JU9nVdQKrHxC8eYlV+9YKK9WePqAlL7lSlorG/U2Fw1w0hTBmaa/jrQ3UbPHtA==", "dev": true, + "license": "MIT", "dependencies": { "get-func-name": "^2.0.1" } @@ -8669,6 +9562,7 @@ "version": "6.0.0", "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", + "license": "ISC", "dependencies": { "yallist": "^4.0.0" }, @@ -8710,7 +9604,8 @@ "node_modules/make-error": { "version": "1.3.6", "resolved": "https://registry.npmjs.org/make-error/-/make-error-1.3.6.tgz", - "integrity": "sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw==" + "integrity": "sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw==", + "license": "ISC" }, "node_modules/make-fetch-happen": { "version": "10.2.1", @@ -8823,6 +9718,13 @@ "node": ">=6" } }, + "node_modules/memoize-weak": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/memoize-weak/-/memoize-weak-1.0.2.tgz", + "integrity": "sha512-gj39xkrjEw7nCn4nJ1M5ms6+MyMlyiGmttzsqAUsAKn6bYKwuTHh/AO3cKPF8IBrTIYTxb0wWXFs3E//Y8VoWQ==", + "dev": true, + "license": "ISC" + }, "node_modules/merge-descriptors": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.3.tgz", @@ -9117,6 +10019,12 @@ "node": ">=8" } }, + "node_modules/minisearch": { + "version": "7.1.2", + "resolved": "https://registry.npmjs.org/minisearch/-/minisearch-7.1.2.tgz", + "integrity": "sha512-R1Pd9eF+MD5JYDDSPAp/q1ougKglm14uEkPMvQ/05RGmx6G9wvmLTrTI/Q5iPNJLYqNdsDQ7qTGIcNWR+FrHmA==", + "license": "MIT" + }, "node_modules/minizlib": { "version": "2.1.2", "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-2.1.2.tgz", @@ -9170,11 +10078,33 @@ } }, "node_modules/mode-watcher": { - "version": "0.3.0", - "resolved": "https://registry.npmjs.org/mode-watcher/-/mode-watcher-0.3.0.tgz", - "integrity": "sha512-k8jjuTx94HaaRKWO6JDf8wL761hFatrTIHJKl+E+3JWcnv+GnMBH062zcLsy0lbCI3n7RZxxHaWi66auFnUO4g==", + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/mode-watcher/-/mode-watcher-1.0.7.tgz", + "integrity": "sha512-ZGA7ZGdOvBJeTQkzdBOnXSgTkO6U6iIFWJoyGCTt6oHNg9XP9NBvS26De+V4W2aqI+B0yYXUskFG2VnEo3zyMQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "runed": "^0.25.0", + "svelte-toolbelt": "^0.7.1" + }, + "peerDependencies": { + "svelte": "^5.27.0" + } + }, + "node_modules/mode-watcher/node_modules/runed": { + "version": "0.25.0", + "resolved": "https://registry.npmjs.org/runed/-/runed-0.25.0.tgz", + "integrity": "sha512-7+ma4AG9FT2sWQEA0Egf6mb7PBT2vHyuHail1ie8ropfSjvZGtEAx8YTmUjv/APCsdRRxEVvArNjALk9zFSOrg==", + "dev": true, + "funding": [ + "https://github.com/sponsors/huntabyte", + "https://github.com/sponsors/tglide" + ], + "dependencies": { + "esm-env": "^1.0.0" + }, "peerDependencies": { - "svelte": "^4.0.0" + "svelte": "^5.7.0" } }, "node_modules/moment": { @@ -9186,20 +10116,25 @@ "node": "*" } }, + "node_modules/monaco-editor": { + "version": "0.52.2", + "resolved": "https://registry.npmjs.org/monaco-editor/-/monaco-editor-0.52.2.tgz", + "integrity": "sha512-GEQWEZmfkOGLdd3XK8ryrfWz3AIP8YymVXiPHEdewrUq7mh0qrKrfHLNCXcbB6sTnMLnOZ3ztSiKcciFUkIJwQ==", + "license": "MIT" + }, "node_modules/mri": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/mri/-/mri-1.2.0.tgz", "integrity": "sha512-tzzskb3bG8LvYGFF/mDTpq3jpI6Q9wc3LEmBaghu+DdCssd1FakN7Bc0hVNmEyGq1bq3RgfkCb3cmQLpNPOroA==", - "dev": true, "engines": { "node": ">=4" } }, "node_modules/mrmime": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/mrmime/-/mrmime-2.0.0.tgz", - "integrity": "sha512-eu38+hdgojoyq63s+yTpN4XMBdt5l8HhMhc4VKLO9KM5caLIBvUm4thi7fFaxyTmCKeNnXZ5pAlBwCUnhA09uw==", - "dev": true, + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/mrmime/-/mrmime-2.0.1.tgz", + "integrity": "sha512-Y3wQdFg2Va6etvQ5I82yUhGdsKrcYox6p7FfL1LbK2J4V01F9TGlepTIhnK24t7koZibmg82KGglhA1XK5IsLQ==", + "license": "MIT", "engines": { "node": ">=10" } @@ -9241,15 +10176,16 @@ "optional": true }, "node_modules/nanoid": { - "version": "3.3.7", - "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.7.tgz", - "integrity": "sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g==", + "version": "3.3.11", + "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.11.tgz", + "integrity": "sha512-N8SpfPUnUp1bK+PMYW8qSWdl9U+wwNWI4QKxOYDy9JAro3WMX7p2OeVRF9v+347pnakNevPmiHhNmZ2HbFA76w==", "funding": [ { "type": "github", "url": "https://github.com/sponsors/ai" } ], + "license": "MIT", "bin": { "nanoid": "bin/nanoid.cjs" }, @@ -9488,6 +10424,15 @@ "resolved": "https://registry.npmjs.org/nprogress/-/nprogress-0.2.0.tgz", "integrity": "sha512-I19aIingLgR1fmhftnbWWO3dXc0hSxqHQHQb3H8m+K3TnEn/iSeTZZOyvKXWqQESMwuUVnatlCnZdLBZZt2VSA==" }, + "node_modules/oauth4webapi": { + "version": "3.5.2", + "resolved": "https://registry.npmjs.org/oauth4webapi/-/oauth4webapi-3.5.2.tgz", + "integrity": "sha512-VYz5BaP3izIrUc1GAVzIoz4JnljiW0YAUFObMBwsqDnfHxz2sjLu3W7/8vE8Ms9IbMewN9+1kcvhY3tMscAeGQ==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/panva" + } + }, "node_modules/object-assign": { "version": "4.1.1", "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", @@ -9524,9 +10469,10 @@ } }, "node_modules/oidc-token-hash": { - "version": "5.0.3", - "resolved": "https://registry.npmjs.org/oidc-token-hash/-/oidc-token-hash-5.0.3.tgz", - "integrity": "sha512-IF4PcGgzAr6XXSff26Sk/+P4KZFJVuHAJZj3wgO3vX2bMdNVp/QXTP3P7CEm9V1IdG8lDLY3HhiqpsE/nOwpPw==", + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/oidc-token-hash/-/oidc-token-hash-5.1.0.tgz", + "integrity": "sha512-y0W+X7Ppo7oZX6eovsRkuzcSM40Bicg2JEJkDJ4irIt1wsYAP5MLSNv+QAogO8xivMffw/9OvV3um1pxXgt1uA==", + "license": "MIT", "engines": { "node": "^10.13.0 || >=12.0.0" } @@ -9546,6 +10492,7 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.0.2.tgz", "integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==", + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -9577,6 +10524,7 @@ "version": "4.9.1", "resolved": "https://registry.npmjs.org/openid-client/-/openid-client-4.9.1.tgz", "integrity": "sha512-DYUF07AHjI3QDKqKbn2F7RqozT4hyi4JvmpodLrq0HHoNP7t/AjeG/uqiBK1/N2PZSAQEThVjDLHSmJN4iqu/w==", + "license": "MIT", "dependencies": { "aggregate-error": "^3.1.0", "got": "^11.8.0", @@ -9593,10 +10541,26 @@ "url": "https://github.com/sponsors/panva" } }, + "node_modules/openid-client/node_modules/jose": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/jose/-/jose-2.0.7.tgz", + "integrity": "sha512-5hFWIigKqC+e/lRyQhfnirrAqUdIPMB7SJRqflJaO29dW7q5DFvH1XCSTmv6PQ6pb++0k6MJlLRoS0Wv4s38Wg==", + "license": "MIT", + "dependencies": { + "@panva/asn1.js": "^1.0.0" + }, + "engines": { + "node": ">=10.13.0 < 13 || >=13.7.0" + }, + "funding": { + "url": "https://github.com/sponsors/panva" + } + }, "node_modules/openid-client/node_modules/object-hash": { "version": "2.2.0", "resolved": "https://registry.npmjs.org/object-hash/-/object-hash-2.2.0.tgz", "integrity": "sha512-gScRMn0bS5fH+IuwyIFgnh9zBdo4DV+6GhygmWM9HyNJSgS0hScp1f5vjtm7oIIOiT9trXrShAkLFSc2IqKNgw==", + "license": "MIT", "engines": { "node": ">= 6" } @@ -9702,6 +10666,12 @@ "dev": true, "license": "ISC" }, + "node_modules/orderedmap": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/orderedmap/-/orderedmap-2.1.1.tgz", + "integrity": "sha512-TvAWxi0nDe1j/rtMcWcIj94+Ffe6n7zhow33h40SKxmsmozs6dz/e+EajymfoFcHd7sxNn8yHM8839uixMOV6g==", + "license": "MIT" + }, "node_modules/overlayscrollbars": { "version": "2.6.1", "resolved": "https://registry.npmjs.org/overlayscrollbars/-/overlayscrollbars-2.6.1.tgz", @@ -9809,6 +10779,20 @@ "node": ">=6" } }, + "node_modules/paneforge": { + "version": "1.0.0-next.5", + "resolved": "https://registry.npmjs.org/paneforge/-/paneforge-1.0.0-next.5.tgz", + "integrity": "sha512-1ArDM+GMEO+o6pixEAFobhTkWkyxUDdHyw2bKruvQIXBStJmdRP7HoV4jNBZ/2i9UHDzmczxJzA3D2tKa91phw==", + "dev": true, + "license": "MIT", + "dependencies": { + "runed": "^0.23.4", + "svelte-toolbelt": "^0.7.1" + }, + "peerDependencies": { + "svelte": "^5.20.0" + } + }, "node_modules/papaparse": { "version": "5.4.1", "resolved": "https://registry.npmjs.org/papaparse/-/papaparse-5.4.1.tgz", @@ -9944,9 +10928,10 @@ } }, "node_modules/path-to-regexp": { - "version": "0.1.10", - "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.10.tgz", - "integrity": "sha512-7lf7qcQidTku0Gu3YDPc8DJ1q7OOucfa/BSsIwjuh56VU7katFvuM8hULfkwB3Fns/rsVF7PwPKVw1sl5KQS9w==" + "version": "0.1.12", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.12.tgz", + "integrity": "sha512-RA1GjUVMnvYFxuqovrEqZoxxW5NUZqbwKtYz/Tt7nXerk0LbLblQmrsgdeOxV5SFHf0UDggjS/bSeOZwt1pmEQ==", + "license": "MIT" }, "node_modules/path-type": { "version": "4.0.0", @@ -9968,10 +10953,29 @@ "resolved": "https://registry.npmjs.org/pathval/-/pathval-1.1.1.tgz", "integrity": "sha512-Dp6zGqpTdETdR63lehJYPeIOqpiNBNtc7BpWSLrOje7UaIsE5aY92r/AunQA7rsXvet3lrJ3JnZX29UPTKXyKQ==", "dev": true, + "license": "MIT", "engines": { "node": "*" } }, + "node_modules/pdfjs-dist": { + "version": "4.10.38", + "resolved": "https://registry.npmjs.org/pdfjs-dist/-/pdfjs-dist-4.10.38.tgz", + "integrity": "sha512-/Y3fcFrXEAsMjJXeL9J8+ZG9U01LbuWaYypvDW2ycW1jL269L3js3DVBjDJ0Up9Np1uqDXsDrRihHANhZOlwdQ==", + "license": "Apache-2.0", + "engines": { + "node": ">=20" + }, + "optionalDependencies": { + "@napi-rs/canvas": "^0.1.65" + } + }, + "node_modules/pdfobject": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/pdfobject/-/pdfobject-2.3.1.tgz", + "integrity": "sha512-vluuGiSDmMGpOvWFGiUY4trNB8aGKLDVxIXuuGHjX0kK3bMxCANUVtLivctE7uejLBScWCnbVarKatFVvdwXaQ==", + "license": "MIT" + }, "node_modules/pe-library": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/pe-library/-/pe-library-1.0.1.tgz", @@ -10004,8 +11008,6 @@ "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.2.tgz", "integrity": "sha512-M7BAV6Rlcy5u+m6oPhAPFgJTzAioX/6B0DxyvDlo9l8+T3nLKbrczg2WLUyzd45L8RqfUMyGPzekbMvX2Ldkwg==", "dev": true, - "optional": true, - "peer": true, "engines": { "node": ">=12" }, @@ -10365,6 +11367,25 @@ "node": "^12.20.0 || >=14" } }, + "node_modules/preact": { + "version": "10.24.3", + "resolved": "https://registry.npmjs.org/preact/-/preact-10.24.3.tgz", + "integrity": "sha512-Z2dPnBnMUfyQfSQ+GBdsGa16hz35YmLmtTLhM169uW944hYL6xzTYkJjC07j+Wosz733pMWx0fgON3JNw1jJQA==", + "license": "MIT", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/preact" + } + }, + "node_modules/preact-render-to-string": { + "version": "6.5.11", + "resolved": "https://registry.npmjs.org/preact-render-to-string/-/preact-render-to-string-6.5.11.tgz", + "integrity": "sha512-ubnauqoGczeGISiOh6RjX0/cdaF8v/oDXIjO85XALCQjwQP+SB4RDXXtvZ6yTYSjG+PC1QRP2AhPgCEsM2EvUw==", + "license": "MIT", + "peerDependencies": { + "preact": ">=10" + } + }, "node_modules/prelude-ls": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz", @@ -10389,6 +11410,23 @@ "url": "https://github.com/prettier/prettier?sponsor=1" } }, + "node_modules/prettier-plugin-organize-imports": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/prettier-plugin-organize-imports/-/prettier-plugin-organize-imports-4.1.0.tgz", + "integrity": "sha512-5aWRdCgv645xaa58X8lOxzZoiHAldAPChljr/MT0crXVOWTZ+Svl4hIWlz+niYSlO6ikE5UXkN1JrRvIP2ut0A==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "prettier": ">=2.0", + "typescript": ">=2.9", + "vue-tsc": "^2.1.0" + }, + "peerDependenciesMeta": { + "vue-tsc": { + "optional": true + } + } + }, "node_modules/prettier-plugin-svelte": { "version": "3.3.2", "resolved": "https://registry.npmjs.org/prettier-plugin-svelte/-/prettier-plugin-svelte-3.3.2.tgz", @@ -10399,6 +11437,85 @@ "svelte": "^3.2.0 || ^4.0.0-next.0 || ^5.0.0-next.0" } }, + "node_modules/prettier-plugin-tailwindcss": { + "version": "0.6.12", + "resolved": "https://registry.npmjs.org/prettier-plugin-tailwindcss/-/prettier-plugin-tailwindcss-0.6.12.tgz", + "integrity": "sha512-OuTQKoqNwV7RnxTPwXWzOFXy6Jc4z8oeRZYGuMpRyG3WbuR3jjXdQFK8qFBMBx8UHWdHrddARz2fgUenild6aw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.21.3" + }, + "peerDependencies": { + "@ianvs/prettier-plugin-sort-imports": "*", + "@prettier/plugin-pug": "*", + "@shopify/prettier-plugin-liquid": "*", + "@trivago/prettier-plugin-sort-imports": "*", + "@zackad/prettier-plugin-twig": "*", + "prettier": "^3.0", + "prettier-plugin-astro": "*", + "prettier-plugin-css-order": "*", + "prettier-plugin-import-sort": "*", + "prettier-plugin-jsdoc": "*", + "prettier-plugin-marko": "*", + "prettier-plugin-multiline-arrays": "*", + "prettier-plugin-organize-attributes": "*", + "prettier-plugin-organize-imports": "*", + "prettier-plugin-sort-imports": "*", + "prettier-plugin-style-order": "*", + "prettier-plugin-svelte": "*" + }, + "peerDependenciesMeta": { + "@ianvs/prettier-plugin-sort-imports": { + "optional": true + }, + "@prettier/plugin-pug": { + "optional": true + }, + "@shopify/prettier-plugin-liquid": { + "optional": true + }, + "@trivago/prettier-plugin-sort-imports": { + "optional": true + }, + "@zackad/prettier-plugin-twig": { + "optional": true + }, + "prettier-plugin-astro": { + "optional": true + }, + "prettier-plugin-css-order": { + "optional": true + }, + "prettier-plugin-import-sort": { + "optional": true + }, + "prettier-plugin-jsdoc": { + "optional": true + }, + "prettier-plugin-marko": { + "optional": true + }, + "prettier-plugin-multiline-arrays": { + "optional": true + }, + "prettier-plugin-organize-attributes": { + "optional": true + }, + "prettier-plugin-organize-imports": { + "optional": true + }, + "prettier-plugin-sort-imports": { + "optional": true + }, + "prettier-plugin-style-order": { + "optional": true + }, + "prettier-plugin-svelte": { + "optional": true + } + } + }, "node_modules/pretty-bytes": { "version": "6.1.1", "resolved": "https://registry.npmjs.org/pretty-bytes/-/pretty-bytes-6.1.1.tgz", @@ -10415,6 +11532,7 @@ "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-29.7.0.tgz", "integrity": "sha512-Pdlw/oPxN+aXdmM9R00JVC9WVFoCLTKJvDVLgmJ+qAffBMxsV85l/Lu7sNx4zSzPyoL2euImuEwHhOXdEgNFZQ==", "dev": true, + "license": "MIT", "dependencies": { "@jest/schemas": "^29.6.3", "ansi-styles": "^5.0.0", @@ -10429,6 +11547,7 @@ "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-5.2.0.tgz", "integrity": "sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==", "dev": true, + "license": "MIT", "engines": { "node": ">=10" }, @@ -10467,6 +11586,87 @@ "node": ">=10" } }, + "node_modules/property-expr": { + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/property-expr/-/property-expr-2.0.6.tgz", + "integrity": "sha512-SVtmxhRE/CGkn3eZY1T6pC8Nln6Fr/lu1mKSgRud0eC73whjGfoAogbn78LkD8aFL0zz3bAFerKSnOl7NlErBA==", + "dev": true, + "license": "MIT", + "optional": true + }, + "node_modules/prosemirror-commands": { + "version": "1.7.1", + "resolved": "https://registry.npmjs.org/prosemirror-commands/-/prosemirror-commands-1.7.1.tgz", + "integrity": "sha512-rT7qZnQtx5c0/y/KlYaGvtG411S97UaL6gdp6RIZ23DLHanMYLyfGBV5DtSnZdthQql7W+lEVbpSfwtO8T+L2w==", + "license": "MIT", + "dependencies": { + "prosemirror-model": "^1.0.0", + "prosemirror-state": "^1.0.0", + "prosemirror-transform": "^1.10.2" + } + }, + "node_modules/prosemirror-history": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/prosemirror-history/-/prosemirror-history-1.4.1.tgz", + "integrity": "sha512-2JZD8z2JviJrboD9cPuX/Sv/1ChFng+xh2tChQ2X4bB2HeK+rra/bmJ3xGntCcjhOqIzSDG6Id7e8RJ9QPXLEQ==", + "license": "MIT", + "dependencies": { + "prosemirror-state": "^1.2.2", + "prosemirror-transform": "^1.0.0", + "prosemirror-view": "^1.31.0", + "rope-sequence": "^1.3.0" + } + }, + "node_modules/prosemirror-keymap": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/prosemirror-keymap/-/prosemirror-keymap-1.2.3.tgz", + "integrity": "sha512-4HucRlpiLd1IPQQXNqeo81BGtkY8Ai5smHhKW9jjPKRc2wQIxksg7Hl1tTI2IfT2B/LgX6bfYvXxEpJl7aKYKw==", + "license": "MIT", + "dependencies": { + "prosemirror-state": "^1.0.0", + "w3c-keyname": "^2.2.0" + } + }, + "node_modules/prosemirror-model": { + "version": "1.25.3", + "resolved": "https://registry.npmjs.org/prosemirror-model/-/prosemirror-model-1.25.3.tgz", + "integrity": "sha512-dY2HdaNXlARknJbrManZ1WyUtos+AP97AmvqdOQtWtrrC5g4mohVX5DTi9rXNFSk09eczLq9GuNTtq3EfMeMGA==", + "license": "MIT", + "dependencies": { + "orderedmap": "^2.0.0" + } + }, + "node_modules/prosemirror-state": { + "version": "1.4.3", + "resolved": "https://registry.npmjs.org/prosemirror-state/-/prosemirror-state-1.4.3.tgz", + "integrity": "sha512-goFKORVbvPuAQaXhpbemJFRKJ2aixr+AZMGiquiqKxaucC6hlpHNZHWgz5R7dS4roHiwq9vDctE//CZ++o0W1Q==", + "license": "MIT", + "dependencies": { + "prosemirror-model": "^1.0.0", + "prosemirror-transform": "^1.0.0", + "prosemirror-view": "^1.27.0" + } + }, + "node_modules/prosemirror-transform": { + "version": "1.10.4", + "resolved": "https://registry.npmjs.org/prosemirror-transform/-/prosemirror-transform-1.10.4.tgz", + "integrity": "sha512-pwDy22nAnGqNR1feOQKHxoFkkUtepoFAd3r2hbEDsnf4wp57kKA36hXsB3njA9FtONBEwSDnDeCiJe+ItD+ykw==", + "license": "MIT", + "dependencies": { + "prosemirror-model": "^1.21.0" + } + }, + "node_modules/prosemirror-view": { + "version": "1.41.0", + "resolved": "https://registry.npmjs.org/prosemirror-view/-/prosemirror-view-1.41.0.tgz", + "integrity": "sha512-FatMIIl0vRHMcNc3sPy3cMw5MMyWuO1nWQxqvYpJvXAruucGvmQ2tyyjT2/Lbok77T9a/qZqBVCq4sj43V2ihw==", + "license": "MIT", + "dependencies": { + "prosemirror-model": "^1.20.0", + "prosemirror-state": "^1.0.0", + "prosemirror-transform": "^1.1.0" + } + }, "node_modules/proxy-addr": { "version": "2.0.7", "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", @@ -10510,6 +11710,24 @@ "node": ">=6" } }, + "node_modules/pure-rand": { + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/pure-rand/-/pure-rand-6.1.0.tgz", + "integrity": "sha512-bVWawvoZoBYpp6yIoQtQXHZjmz35RSVHnUOTefl8Vcjr8snTPY1wnpSPMWekcFwbxI6gtmT7rSYPFvz71ldiOA==", + "dev": true, + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/dubzzz" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fast-check" + } + ], + "license": "MIT", + "optional": true + }, "node_modules/qs": { "version": "6.13.0", "resolved": "https://registry.npmjs.org/qs/-/qs-6.13.0.tgz", @@ -10606,10 +11824,11 @@ } }, "node_modules/react-is": { - "version": "18.2.0", - "resolved": "https://registry.npmjs.org/react-is/-/react-is-18.2.0.tgz", - "integrity": "sha512-xWGDIW6x921xtzPkhiULtthJHoJvBbF3q26fzloPCK0hsvxtPVelvftw3zjbHWSkR2km9Z+4uxbDDK/6Zw9B8w==", - "dev": true + "version": "18.3.1", + "resolved": "https://registry.npmjs.org/react-is/-/react-is-18.3.1.tgz", + "integrity": "sha512-/LLMVyas0ljjAtoYiPqYiL8VWXzUUdThrmU5+n20DZv+a+ClRoevUzw5JxU+Ieh5/c87ytoTBV9G1FiKfNJdmg==", + "dev": true, + "license": "MIT" }, "node_modules/read-binary-file-arch": { "version": "1.0.6", @@ -10795,18 +12014,12 @@ "node": ">= 10.13.0" } }, - "node_modules/regenerator-runtime": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/regenerator-runtime/-/regenerator-runtime-0.14.1.tgz", - "integrity": "sha512-dYnhHh0nJoMfnkZs6GmmhFknAGRrLznOu5nc9ML+EJxGvrx6H7teuevqVqCuPcPK//3eDrrjQhehXVx9cnkGdw==" - }, "node_modules/repeat-string": { "version": "1.6.1", "resolved": "https://registry.npmjs.org/repeat-string/-/repeat-string-1.6.1.tgz", "integrity": "sha512-PV0dzCYDNfRi1jCDbJzpW7jNNDRuCOG/jI5ctQcGKt/clZD+YcPS3yIlWuTJMmESC8aevCFmWJy5wjAFgNqN6w==", "dev": true, "license": "MIT", - "optional": true, "engines": { "node": ">=0.10" } @@ -11025,7 +12238,6 @@ "version": "4.24.0", "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.24.0.tgz", "integrity": "sha512-DOmrlGSXNk1DM0ljiQA+i+o0rSLhtii1je5wgk60j49d1jHT5YYttBv1iWOnYSTG+fZZESUOSNiAl89SIet+Cg==", - "dev": true, "dependencies": { "@types/estree": "1.0.6" }, @@ -11056,6 +12268,12 @@ "fsevents": "~2.3.2" } }, + "node_modules/rope-sequence": { + "version": "1.3.4", + "resolved": "https://registry.npmjs.org/rope-sequence/-/rope-sequence-1.3.4.tgz", + "integrity": "sha512-UT5EDe2cu2E/6O4igUr5PSFs23nvvukicWHx6GnOPlHAiiYbzNuCRQCuiUdHJQcqKalLKlrYJnjY0ySGsXNQXQ==", + "license": "MIT" + }, "node_modules/run-parallel": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", @@ -11088,6 +12306,22 @@ "run-script-os": "index.js" } }, + "node_modules/runed": { + "version": "0.23.4", + "resolved": "https://registry.npmjs.org/runed/-/runed-0.23.4.tgz", + "integrity": "sha512-9q8oUiBYeXIDLWNK5DfCWlkL0EW3oGbk845VdKlPeia28l751VpfesaB/+7pI6rnbx1I6rqoZ2fZxptOJLxILA==", + "dev": true, + "funding": [ + "https://github.com/sponsors/huntabyte", + "https://github.com/sponsors/tglide" + ], + "dependencies": { + "esm-env": "^1.0.0" + }, + "peerDependencies": { + "svelte": "^5.7.0" + } + }, "node_modules/rxjs": { "version": "7.8.1", "resolved": "https://registry.npmjs.org/rxjs/-/rxjs-7.8.1.tgz", @@ -11101,7 +12335,6 @@ "version": "1.8.1", "resolved": "https://registry.npmjs.org/sade/-/sade-1.8.1.tgz", "integrity": "sha512-xal3CZX1Xlo/k4ApwCFrHVACi9fBqJ7V+mwhBsuf/1IOKbBy098Fex+Wa/5QMubw09pSZ/u8EY8PWgevJsXp1A==", - "dev": true, "dependencies": { "mri": "^1.1.0" }, @@ -11253,10 +12486,10 @@ "integrity": "sha512-KiKBS8AnWGEyLzofFfmvKwpdPzqiy16LvQfK3yv/fVH7Bj13/wl3JSR1J+rfgRE9q7xUJK4qvgS8raSOeLUehw==" }, "node_modules/set-cookie-parser": { - "version": "2.6.0", - "resolved": "https://registry.npmjs.org/set-cookie-parser/-/set-cookie-parser-2.6.0.tgz", - "integrity": "sha512-RVnVQxTXuerk653XfuliOxBP81Sf0+qfQE73LIYKcyMYHG94AuH0kgrQpRDuTZnSmjpysHmzxJXKNfa6PjFhyQ==", - "dev": true + "version": "2.7.1", + "resolved": "https://registry.npmjs.org/set-cookie-parser/-/set-cookie-parser-2.7.1.tgz", + "integrity": "sha512-IOc8uWeOZgnb3ptbCURJWNjWUPcO3ZnTTdzsurqERrP6nPyv+paC55vJM0LpOlT2ne+Ix+9+CRG1MNLlyZ4GjQ==", + "license": "MIT" }, "node_modules/set-function-length": { "version": "1.2.2", @@ -11390,17 +12623,17 @@ } }, "node_modules/sirv": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/sirv/-/sirv-2.0.4.tgz", - "integrity": "sha512-94Bdh3cC2PKrbgSOUqTiGPWVZeSiXfKOVZNJniWoqrWrRkB1CJzBU3NEbiTsPcYy1lDsANA/THzS+9WBiy5nfQ==", - "dev": true, + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/sirv/-/sirv-3.0.1.tgz", + "integrity": "sha512-FoqMu0NCGBLCcAkS1qA+XJIQTR6/JHfQXl+uGteNCQ76T91DMUjPa9xfmeqMY3z80nLSg9yQmNjK0Px6RWsH/A==", + "license": "MIT", "dependencies": { "@polka/url": "^1.0.0-next.24", "mrmime": "^2.0.0", "totalist": "^3.0.0" }, "engines": { - "node": ">= 10" + "node": ">=18" } }, "node_modules/slash": { @@ -11539,6 +12772,18 @@ "dev": true, "license": "BSD-3-Clause" }, + "node_modules/sqlite-wasm-kysely": { + "version": "0.3.0", + "resolved": "https://registry.npmjs.org/sqlite-wasm-kysely/-/sqlite-wasm-kysely-0.3.0.tgz", + "integrity": "sha512-TzjBNv7KwRw6E3pdKdlRyZiTmUIE0UttT/Sl56MVwVARl/u5gp978KepazCJZewFUnlWHz9i3NQd4kOtP/Afdg==", + "dev": true, + "dependencies": { + "@sqlite.org/sqlite-wasm": "^3.48.0-build2" + }, + "peerDependencies": { + "kysely": "*" + } + }, "node_modules/ssri": { "version": "9.0.1", "resolved": "https://registry.npmjs.org/ssri/-/ssri-9.0.1.tgz", @@ -11571,6 +12816,12 @@ "integrity": "sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==", "dev": true }, + "node_modules/state-local": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/state-local/-/state-local-1.0.7.tgz", + "integrity": "sha512-HTEHMNieakEnoe33shBYcZ7NX83ACUjCu8c40iOGEZsngj9zRnkqS9j1pqQPXwobB0ZcVTk27REb7COQ0UR59w==", + "license": "MIT" + }, "node_modules/statuses": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", @@ -11775,9 +13026,10 @@ } }, "node_modules/stripe": { - "version": "15.5.0", - "resolved": "https://registry.npmjs.org/stripe/-/stripe-15.5.0.tgz", - "integrity": "sha512-c04ToET4ZUzoeSh2rWarXCPNa2+6YzkwNAcWaT4axYRlN/u1XMkz9+inouNsXWjeT6ttBrp1twz10x/sCbWLpQ==", + "version": "15.12.0", + "resolved": "https://registry.npmjs.org/stripe/-/stripe-15.12.0.tgz", + "integrity": "sha512-slTbYS1WhRJXVB8YXU8fgHizkUrM9KJyrw4Dd8pLEwzKHYyQTIE46EePC2MVbSDZdE24o1GdNtzmJV4PrPpmJA==", + "license": "MIT", "dependencies": { "@types/node": ">=8.1.0", "qs": "^6.11.0" @@ -11791,6 +13043,16 @@ "resolved": "https://registry.npmjs.org/strnum/-/strnum-1.0.5.tgz", "integrity": "sha512-J8bbNyKKXl5qYcR36TIO8W3mVGVHrmmxsd5PAItGkmyzwJvybiw2IVq5nqd0i4LSNSkB/sx9VHllbfFdr9k1JA==" }, + "node_modules/style-to-object": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/style-to-object/-/style-to-object-1.0.8.tgz", + "integrity": "sha512-xT47I/Eo0rwJmaXC4oilDGDWLohVhR6o/xAQcPQN8q6QBuZVL8qMYL85kLmST5cPjAorwvqIA4qXTRQoYHaL6g==", + "dev": true, + "license": "MIT", + "dependencies": { + "inline-style-parser": "0.2.4" + } + }, "node_modules/sucrase": { "version": "3.35.0", "resolved": "https://registry.npmjs.org/sucrase/-/sucrase-3.35.0.tgz", @@ -11853,6 +13115,17 @@ "node": ">= 8.0" } }, + "node_modules/superstruct": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/superstruct/-/superstruct-2.0.2.tgz", + "integrity": "sha512-uV+TFRZdXsqXTL2pRvujROjdZQ4RAlBUS5BTh9IGm+jTqQntYThciG/qu57Gs69yjnVUSqdxF9YLmSnpupBW9A==", + "dev": true, + "license": "MIT", + "optional": true, + "engines": { + "node": ">=14.0.0" + } + }, "node_modules/supports-color": { "version": "7.2.0", "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", @@ -11877,20 +13150,21 @@ } }, "node_modules/svelte": { - "version": "5.16.5", - "resolved": "https://registry.npmjs.org/svelte/-/svelte-5.16.5.tgz", - "integrity": "sha512-zTG45crJUGjNYQgmQ0YDxFJ7ge1O6ZwevPxGgGOxuMOXOQhcH9LC9GEx2JS9/BlkhxdsO8ETofQ76ouFwDVpCQ==", + "version": "5.33.14", + "resolved": "https://registry.npmjs.org/svelte/-/svelte-5.33.14.tgz", + "integrity": "sha512-kRlbhIlMTijbFmVDQFDeKXPLlX1/ovXwV0I162wRqQhRcygaqDIcu1d/Ese3H2uI+yt3uT8E7ndgDthQv5v5BA==", + "license": "MIT", "dependencies": { "@ampproject/remapping": "^2.3.0", "@jridgewell/sourcemap-codec": "^1.5.0", + "@sveltejs/acorn-typescript": "^1.0.5", "@types/estree": "^1.0.5", "acorn": "^8.12.1", - "acorn-typescript": "^1.4.13", "aria-query": "^5.3.1", "axobject-query": "^4.1.0", "clsx": "^2.1.1", "esm-env": "^1.2.1", - "esrap": "^1.3.2", + "esrap": "^1.4.6", "is-reference": "^3.0.3", "locate-character": "^3.0.0", "magic-string": "^0.30.11", @@ -11990,17 +13264,147 @@ } }, "node_modules/svelte-sonner": { - "version": "0.3.24", - "resolved": "https://registry.npmjs.org/svelte-sonner/-/svelte-sonner-0.3.24.tgz", - "integrity": "sha512-txuL0JBUs0v6qGrr0PGCsbXmKHuthdrAkfISYi8umuveF7+gINb6EXl6VmKY9aHhyxCqvVgqd6yophQNrnor4w==", + "version": "0.3.28", + "resolved": "https://registry.npmjs.org/svelte-sonner/-/svelte-sonner-0.3.28.tgz", + "integrity": "sha512-K3AmlySeFifF/cKgsYNv5uXqMVNln0NBAacOYgmkQStLa/UoU0LhfAACU6Gr+YYC8bOCHdVmFNoKuDbMEsppJg==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "svelte": "^3.0.0 || ^4.0.0 || ^5.0.0-next.1" + } + }, + "node_modules/svelte-toolbelt": { + "version": "0.7.1", + "resolved": "https://registry.npmjs.org/svelte-toolbelt/-/svelte-toolbelt-0.7.1.tgz", + "integrity": "sha512-HcBOcR17Vx9bjaOceUvxkY3nGmbBmCBBbuWLLEWO6jtmWH8f/QoWmbyUfQZrpDINH39en1b8mptfPQT9VKQ1xQ==", + "dev": true, + "funding": [ + "https://github.com/sponsors/huntabyte" + ], + "dependencies": { + "clsx": "^2.1.1", + "runed": "^0.23.2", + "style-to-object": "^1.0.8" + }, + "engines": { + "node": ">=18", + "pnpm": ">=8.7.0" + }, + "peerDependencies": { + "svelte": "^5.0.0" + } + }, + "node_modules/sveltekit-superforms": { + "version": "2.27.0", + "resolved": "https://registry.npmjs.org/sveltekit-superforms/-/sveltekit-superforms-2.27.0.tgz", + "integrity": "sha512-FXIdUg4VRVZeAdVH/zB7JtHvuoC6RmHDw032meEasqB5v+i1ud4pwU/Big+6eJ2SysqrzCahBbCvLN2qzRPVUw==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/ciscoheat" + }, + { + "type": "ko-fi", + "url": "https://ko-fi.com/ciscoheat" + }, + { + "type": "paypal", + "url": "https://www.paypal.com/donate/?hosted_button_id=NY7F5ALHHSVQS" + } + ], + "license": "MIT", + "dependencies": { + "devalue": "^5.1.1", + "memoize-weak": "^1.0.2", + "ts-deepmerge": "^7.0.3" + }, + "optionalDependencies": { + "@exodus/schemasafe": "^1.3.0", + "@gcornut/valibot-json-schema": "^0.42.0", + "@sinclair/typebox": "^0.34.35", + "@typeschema/class-validator": "^0.3.0", + "@vinejs/vine": "^3.0.1", + "arktype": "^2.1.20", + "class-validator": "^0.14.2", + "effect": "^3.16.7", + "joi": "^17.13.3", + "json-schema-to-ts": "^3.1.1", + "superstruct": "^2.0.2", + "valibot": "^1.1.0", + "yup": "^1.6.1", + "zod": "^3.25.64", + "zod-to-json-schema": "^3.24.5" + }, "peerDependencies": { - "svelte": ">=3 <5" + "@exodus/schemasafe": "^1.3.0", + "@sinclair/typebox": "^0.34.28", + "@sveltejs/kit": "1.x || 2.x", + "@typeschema/class-validator": "^0.3.0", + "@vinejs/vine": "^1.8.0 || ^2.0.0 || ^3.0.0", + "arktype": ">=2.0.0-rc.23", + "class-validator": "^0.14.1", + "effect": "^3.13.7", + "joi": "^17.13.1", + "superstruct": "^2.0.2", + "svelte": "3.x || 4.x || >=5.0.0-next.51", + "valibot": "^1.0.0", + "yup": "^1.4.0", + "zod": "^3.25.0" + }, + "peerDependenciesMeta": { + "@exodus/schemasafe": { + "optional": true + }, + "@sinclair/typebox": { + "optional": true + }, + "@typeschema/class-validator": { + "optional": true + }, + "@vinejs/vine": { + "optional": true + }, + "arktype": { + "optional": true + }, + "class-validator": { + "optional": true + }, + "effect": { + "optional": true + }, + "joi": { + "optional": true + }, + "superstruct": { + "optional": true + }, + "valibot": { + "optional": true + }, + "yup": { + "optional": true + }, + "zod": { + "optional": true + } } }, + "node_modules/sveltekit-superforms/node_modules/@sinclair/typebox": { + "version": "0.34.37", + "resolved": "https://registry.npmjs.org/@sinclair/typebox/-/typebox-0.34.37.tgz", + "integrity": "sha512-2TRuQVgQYfy+EzHRTIvkhv2ADEouJ2xNS/Vq+W5EuuewBdOrvATvljZTxHWZSTYr2sTjTHpGvucaGAt67S2akw==", + "dev": true, + "license": "MIT", + "optional": true + }, "node_modules/tabbable": { "version": "6.2.0", "resolved": "https://registry.npmjs.org/tabbable/-/tabbable-6.2.0.tgz", - "integrity": "sha512-Cat63mxsVJlzYvN51JmVXIgNoUokrIaT2zLclCXjRd8boZ0004U4KCs/sToJ75C6sdlByWxpYnb5Boif1VSFew==" + "integrity": "sha512-Cat63mxsVJlzYvN51JmVXIgNoUokrIaT2zLclCXjRd8boZ0004U4KCs/sToJ75C6sdlByWxpYnb5Boif1VSFew==", + "dev": true, + "license": "MIT" }, "node_modules/tailwind-merge": { "version": "2.2.2", @@ -12070,6 +13474,7 @@ "resolved": "https://registry.npmjs.org/tailwindcss-animate/-/tailwindcss-animate-1.0.7.tgz", "integrity": "sha512-bl6mpH3T7I3UFxuvDEXLxy/VuFxBk5bbzplh7tXI68mwMokNYd1t9qPBHlnyTwfa4JGC4zP516I1hYYtQ/vspA==", "dev": true, + "license": "MIT", "peerDependencies": { "tailwindcss": ">=3.0.0 || insiders" } @@ -12234,6 +13639,14 @@ "readable-stream": "3" } }, + "node_modules/tiny-case": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/tiny-case/-/tiny-case-1.0.3.tgz", + "integrity": "sha512-Eet/eeMhkO6TX8mnUteS9zgPbUMQa4I6Kkp5ORiBD5476/m+PIRiumP5tmh5ioJpH7k51Kehawy2UDfsnxxY8Q==", + "dev": true, + "license": "MIT", + "optional": true + }, "node_modules/tiny-each-async": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/tiny-each-async/-/tiny-each-async-2.0.3.tgz", @@ -12242,16 +13655,6 @@ "license": "MIT", "optional": true }, - "node_modules/tiny-glob": { - "version": "0.2.9", - "resolved": "https://registry.npmjs.org/tiny-glob/-/tiny-glob-0.2.9.tgz", - "integrity": "sha512-g/55ssRPUjShh+xkfx9UPDXqhckHEsHr4Vd9zX55oSdGZc/MD0m3sferOkwWtp98bv+kcVfEHtRJgBVJzelrzg==", - "dev": true, - "dependencies": { - "globalyzer": "0.1.0", - "globrex": "^0.1.2" - } - }, "node_modules/tinybench": { "version": "2.6.0", "resolved": "https://registry.npmjs.org/tinybench/-/tinybench-2.6.0.tgz", @@ -12259,10 +13662,11 @@ "dev": true }, "node_modules/tinypool": { - "version": "0.8.2", - "resolved": "https://registry.npmjs.org/tinypool/-/tinypool-0.8.2.tgz", - "integrity": "sha512-SUszKYe5wgsxnNOVlBYO6IC+8VGWdVGZWAqUxp3UErNBtptZvWbwyUOyzNL59zigz2rCA92QiL3wvG+JDSdJdQ==", + "version": "0.8.4", + "resolved": "https://registry.npmjs.org/tinypool/-/tinypool-0.8.4.tgz", + "integrity": "sha512-i11VH5gS6IFeLY3gMBQ00/MmLncVP7JLXOw1vlgkytLmJK7QnEr7NXf0LBdxfmNPAeyetukOk0bOYrJrFGjYJQ==", "dev": true, + "license": "MIT", "engines": { "node": ">=14.0.0" } @@ -12272,6 +13676,7 @@ "resolved": "https://registry.npmjs.org/tinyspy/-/tinyspy-2.2.1.tgz", "integrity": "sha512-KYad6Vy5VDWV4GH3fjpseMQ/XU2BhIYP7Vzd0LG44qRWm/Yt2WCOTicFdvmgo6gWaqooMQCawTtILVQJupKu7A==", "dev": true, + "license": "MIT", "engines": { "node": ">=14.0.0" } @@ -12339,11 +13744,19 @@ "node": ">=0.6" } }, + "node_modules/toposort": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/toposort/-/toposort-2.0.2.tgz", + "integrity": "sha512-0a5EOkAUp8D4moMi2W8ZF8jcga7BgZd91O/yabJCFY8az+XSzeGyTKs0Aoo897iV1Nj6guFq8orWDS96z91oGg==", + "dev": true, + "license": "MIT", + "optional": true + }, "node_modules/totalist": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/totalist/-/totalist-3.0.1.tgz", "integrity": "sha512-sf4i37nQ2LBx4m3wB74y+ubopq6W/dIzXg0FDGjsYnZHVa1Da8FH853wlL2gtUhg+xJXjfk3kUZS3BRoQeoQBQ==", - "dev": true, + "license": "MIT", "engines": { "node": ">=6" } @@ -12387,6 +13800,14 @@ "node": ">=0.8.0" } }, + "node_modules/ts-algebra": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/ts-algebra/-/ts-algebra-2.0.0.tgz", + "integrity": "sha512-FPAhNPFMrkwz76P7cdjdmiShwMynZYN6SgOujD1urY4oNm80Ou9oMdmbR45LotcKOXoy7wSmHkRFE6Mxbrhefw==", + "dev": true, + "license": "MIT", + "optional": true + }, "node_modules/ts-api-utils": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-1.3.0.tgz", @@ -12399,15 +13820,27 @@ "typescript": ">=4.2.0" } }, + "node_modules/ts-deepmerge": { + "version": "7.0.3", + "resolved": "https://registry.npmjs.org/ts-deepmerge/-/ts-deepmerge-7.0.3.tgz", + "integrity": "sha512-Du/ZW2RfwV/D4cmA5rXafYjBQVuvu4qGiEEla4EmEHVHgRdx68Gftx7i66jn2bzHPwSVZY36Ae6OuDn9el4ZKA==", + "dev": true, + "license": "ISC", + "engines": { + "node": ">=14.13.1" + } + }, "node_modules/ts-interface-checker": { "version": "0.1.13", "resolved": "https://registry.npmjs.org/ts-interface-checker/-/ts-interface-checker-0.1.13.tgz", "integrity": "sha512-Y/arvbn+rrz3JCKl9C4kVNfTfSm2/mEp5FSz5EsZSANGPSlQrpRI5M4PKF+mJnE52jOO90PnPSc3Ur3bTQw0gA==" }, "node_modules/tslib": { - "version": "2.6.2", - "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", - "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==" + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "dev": true, + "license": "0BSD" }, "node_modules/type-check": { "version": "0.4.0", @@ -12422,10 +13855,11 @@ } }, "node_modules/type-detect": { - "version": "4.0.8", - "resolved": "https://registry.npmjs.org/type-detect/-/type-detect-4.0.8.tgz", - "integrity": "sha512-0fr/mIH1dlO+x7TlcMy+bIDqKPsw/70tVyeHW787goQjhmqaZe10uwLujubK9q9Lg6Fiho1KUKDYz0Z7k7g5/g==", + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/type-detect/-/type-detect-4.1.0.tgz", + "integrity": "sha512-Acylog8/luQ8L7il+geoSxhEkazvkslg7PSNKOX59mbB9cOveP5aq9h74Y7YU8yDpJwetzQQrfIwtf4Wp4LKcw==", "dev": true, + "license": "MIT", "engines": { "node": ">=4" } @@ -12474,9 +13908,10 @@ "dev": true }, "node_modules/undici": { - "version": "6.19.4", - "resolved": "https://registry.npmjs.org/undici/-/undici-6.19.4.tgz", - "integrity": "sha512-i3uaEUwNdkRq2qtTRRJb13moW5HWqviu7Vl7oYRYz++uPtGHJj+x7TGjcEuwS5Mt2P4nA0U9dhIX3DdB6JGY0g==", + "version": "6.21.3", + "resolved": "https://registry.npmjs.org/undici/-/undici-6.21.3.tgz", + "integrity": "sha512-gBLkYIlEnSp8pFbT64yFgGE6UIB9tAkhukC23PmMDCe5Nd+cRqKxSjw5y54MK2AZMgZfJWMaNE4nYUHgi1XEOw==", + "license": "MIT", "engines": { "node": ">=18.17" } @@ -12541,6 +13976,21 @@ "node": ">= 0.8" } }, + "node_modules/unplugin": { + "version": "2.3.5", + "resolved": "https://registry.npmjs.org/unplugin/-/unplugin-2.3.5.tgz", + "integrity": "sha512-RyWSb5AHmGtjjNQ6gIlA67sHOsWpsbWpwDokLwTcejVdOjEkJZh7QKu14J00gDDVSh8kGH4KYC/TNBceXFZhtw==", + "dev": true, + "license": "MIT", + "dependencies": { + "acorn": "^8.14.1", + "picomatch": "^4.0.2", + "webpack-virtual-modules": "^0.6.2" + }, + "engines": { + "node": ">=18.12.0" + } + }, "node_modules/update-browserslist-db": { "version": "1.0.13", "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.0.13.tgz", @@ -12583,7 +14033,15 @@ "node_modules/url-join": { "version": "4.0.1", "resolved": "https://registry.npmjs.org/url-join/-/url-join-4.0.1.tgz", - "integrity": "sha512-jk1+QP6ZJqyOiuEI9AEWQfju/nB2Pw466kbA0LEZljHwKeMgd9WrAEgEGxjPDD2+TNbbb37rTyhEfrCXfuKXnA==" + "integrity": "sha512-jk1+QP6ZJqyOiuEI9AEWQfju/nB2Pw466kbA0LEZljHwKeMgd9WrAEgEGxjPDD2+TNbbb37rTyhEfrCXfuKXnA==", + "license": "MIT" + }, + "node_modules/urlpattern-polyfill": { + "version": "10.1.0", + "resolved": "https://registry.npmjs.org/urlpattern-polyfill/-/urlpattern-polyfill-10.1.0.tgz", + "integrity": "sha512-IGjKp/o0NL3Bso1PymYURCJxMPNAf/ILOpendP9f5B6e1rTJgdgiOvgfoT8VxCAdY+Wisb9uhGaJJf3yZ2V9nw==", + "dev": true, + "license": "MIT" }, "node_modules/username": { "version": "5.1.0", @@ -12600,9 +14058,9 @@ } }, "node_modules/username/node_modules/cross-spawn": { - "version": "6.0.5", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.5.tgz", - "integrity": "sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ==", + "version": "6.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.6.tgz", + "integrity": "sha512-VqCUuhcd1iB+dsv8gxPttb5iZh/D0iubSP21g36KXdEuf6I5JiioesUVjpCdHV9MZRUfVFlvwtIUyPfxo5trtw==", "dev": true, "license": "MIT", "dependencies": { @@ -12771,6 +14229,22 @@ "uuid": "dist/bin/uuid" } }, + "node_modules/valibot": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/valibot/-/valibot-1.1.0.tgz", + "integrity": "sha512-Nk8lX30Qhu+9txPYTwM0cFlWLdPFsFr6LblzqIySfbZph9+BFsAHsNvHOymEviUepeIW6KFHzpX8TKhbptBXXw==", + "dev": true, + "license": "MIT", + "optional": true, + "peerDependencies": { + "typescript": ">=5" + }, + "peerDependenciesMeta": { + "typescript": { + "optional": true + } + } + }, "node_modules/validate-npm-package-license": { "version": "3.0.4", "resolved": "https://registry.npmjs.org/validate-npm-package-license/-/validate-npm-package-license-3.0.4.tgz", @@ -12782,6 +14256,17 @@ "spdx-expression-parse": "^3.0.0" } }, + "node_modules/validator": { + "version": "13.15.15", + "resolved": "https://registry.npmjs.org/validator/-/validator-13.15.15.tgz", + "integrity": "sha512-BgWVbCI72aIQy937xbawcs+hrVaN/CZ2UwutgaJ36hGqRrLNM+f5LUT/YPRbo8IV/ASeFzXszezV+y2+rq3l8A==", + "dev": true, + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 0.10" + } + }, "node_modules/vary": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/vary/-/vary-1.1.2.tgz", @@ -12790,11 +14275,29 @@ "node": ">= 0.8" } }, - "node_modules/vite": { - "version": "5.4.8", - "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.8.tgz", - "integrity": "sha512-FqrItQ4DT1NC4zCUqMB4c4AZORMKIa0m8/URVCZ77OZ/QSNeJ54bU1vrFADbDsuwfIPcgknRkmqakQcgnL4GiQ==", + "node_modules/vaul-svelte": { + "version": "1.0.0-next.7", + "resolved": "https://registry.npmjs.org/vaul-svelte/-/vaul-svelte-1.0.0-next.7.tgz", + "integrity": "sha512-7zN7Bi3dFQixvvbUJY9uGDe7Ws/dGZeBQR2pXdXmzQiakjrxBvWo0QrmsX3HK+VH+SZOltz378cmgmCS9f9rSg==", "dev": true, + "license": "MIT", + "dependencies": { + "runed": "^0.23.2", + "svelte-toolbelt": "^0.7.1" + }, + "engines": { + "node": ">=18", + "pnpm": ">=8.7.0" + }, + "peerDependencies": { + "svelte": "^5.0.0" + } + }, + "node_modules/vite": { + "version": "5.4.19", + "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.19.tgz", + "integrity": "sha512-qO3aKv3HoQC8QKiNSTuUM1l9o/XX3+c+VTgLHbJWHZGeTPVAg2XwazI9UWzoxjIJCGCV2zU60uqMzjeLZuULqA==", + "license": "MIT", "dependencies": { "esbuild": "^0.21.3", "postcss": "^8.4.43", @@ -12850,10 +14353,11 @@ } }, "node_modules/vite-node": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/vite-node/-/vite-node-1.4.0.tgz", - "integrity": "sha512-VZDAseqjrHgNd4Kh8icYHWzTKSCZMhia7GyHfhtzLW33fZlG9SwsB6CEhgyVOWkJfJ2pFLrp/Gj1FSfAiqH9Lw==", + "version": "1.6.1", + "resolved": "https://registry.npmjs.org/vite-node/-/vite-node-1.6.1.tgz", + "integrity": "sha512-YAXkfvGtuTzwWbDSACdJSg4A4DZiAqckWe90Zapc/sEX3XvHcw1NdurM/6od8J207tSDqNbSsgdCacBgvJKFuA==", "dev": true, + "license": "MIT", "dependencies": { "cac": "^6.7.14", "debug": "^4.3.4", @@ -12871,11 +14375,37 @@ "url": "https://opencollective.com/vitest" } }, + "node_modules/vite-plugin-devtools-json": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/vite-plugin-devtools-json/-/vite-plugin-devtools-json-0.4.1.tgz", + "integrity": "sha512-pN+QJL+NwZUV+Via8w/Sh6X2pDrVClIMDAXdl7+EteXKB6mcHhsFGGclmxrPx6ZPGKSK5ez5ns64oRpjE5wFCg==", + "dev": true, + "license": "MIT", + "dependencies": { + "uuid": "^11.1.0" + }, + "peerDependencies": { + "vite": "^2.7.0 || ^3.0.0 || ^4.0.0 || ^5.0.0 || ^6.0.0 || ^7.0.0" + } + }, + "node_modules/vite-plugin-devtools-json/node_modules/uuid": { + "version": "11.1.0", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-11.1.0.tgz", + "integrity": "sha512-0/A9rDy9P7cJ+8w1c9WD9V//9Wj15Ce2MPz8Ri6032usz+NfePxx5AcN3bN+r6ZL6jEo066/yNYB3tn4pQEx+A==", + "dev": true, + "funding": [ + "https://github.com/sponsors/broofa", + "https://github.com/sponsors/ctavan" + ], + "license": "MIT", + "bin": { + "uuid": "dist/esm/bin/uuid" + } + }, "node_modules/vite/node_modules/fsevents": { "version": "2.3.3", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", - "dev": true, "hasInstallScript": true, "optional": true, "os": [ @@ -12886,10 +14416,10 @@ } }, "node_modules/vitefu": { - "version": "1.0.5", - "resolved": "https://registry.npmjs.org/vitefu/-/vitefu-1.0.5.tgz", - "integrity": "sha512-h4Vflt9gxODPFNGPwp4zAMZRpZR7eslzwH2c5hn5kNZ5rhnKyRJ50U+yGCdc2IRaBs8O4haIgLNGrV5CrpMsCA==", - "dev": true, + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/vitefu/-/vitefu-1.0.6.tgz", + "integrity": "sha512-+Rex1GlappUyNN6UfwbVZne/9cYC4+R2XDk9xkNXBKMw6HQagdX9PgZ8V2v1WUSK1wfBLp7qbI1+XSNIlB1xmA==", + "license": "MIT", "workspaces": [ "tests/deps/*", "tests/projects/*" @@ -12904,16 +14434,17 @@ } }, "node_modules/vitest": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/vitest/-/vitest-1.4.0.tgz", - "integrity": "sha512-gujzn0g7fmwf83/WzrDTnncZt2UiXP41mHuFYFrdwaLRVQ6JYQEiME2IfEjU3vcFL3VKa75XhI3lFgn+hfVsQw==", + "version": "1.6.1", + "resolved": "https://registry.npmjs.org/vitest/-/vitest-1.6.1.tgz", + "integrity": "sha512-Ljb1cnSJSivGN0LqXd/zmDbWEM0RNNg2t1QW/XUhYl/qPqyu7CsqeWtqQXHVaJsecLPuDoak2oJcZN2QoRIOag==", "dev": true, + "license": "MIT", "dependencies": { - "@vitest/expect": "1.4.0", - "@vitest/runner": "1.4.0", - "@vitest/snapshot": "1.4.0", - "@vitest/spy": "1.4.0", - "@vitest/utils": "1.4.0", + "@vitest/expect": "1.6.1", + "@vitest/runner": "1.6.1", + "@vitest/snapshot": "1.6.1", + "@vitest/spy": "1.6.1", + "@vitest/utils": "1.6.1", "acorn-walk": "^8.3.2", "chai": "^4.3.10", "debug": "^4.3.4", @@ -12925,9 +14456,9 @@ "std-env": "^3.5.0", "strip-literal": "^2.0.0", "tinybench": "^2.5.1", - "tinypool": "^0.8.2", + "tinypool": "^0.8.3", "vite": "^5.0.0", - "vite-node": "1.4.0", + "vite-node": "1.6.1", "why-is-node-running": "^2.2.2" }, "bin": { @@ -12942,8 +14473,8 @@ "peerDependencies": { "@edge-runtime/vm": "*", "@types/node": "^18.0.0 || >=20.0.0", - "@vitest/browser": "1.4.0", - "@vitest/ui": "1.4.0", + "@vitest/browser": "1.6.1", + "@vitest/ui": "1.6.1", "happy-dom": "*", "jsdom": "*" }, @@ -12968,6 +14499,12 @@ } } }, + "node_modules/w3c-keyname": { + "version": "2.2.8", + "resolved": "https://registry.npmjs.org/w3c-keyname/-/w3c-keyname-2.2.8.tgz", + "integrity": "sha512-dpojBhNsCNN7T82Tm7k26A6G9ML3NkhDsnw9n/eoxSRlVBB4CEtIQ/KTCLI2Fwf3ataSXRhYFkQi3SlnFwPvPQ==", + "license": "MIT" + }, "node_modules/wcwidth": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/wcwidth/-/wcwidth-1.0.1.tgz", @@ -12996,6 +14533,13 @@ "dev": true, "license": "BSD-2-Clause" }, + "node_modules/webpack-virtual-modules": { + "version": "0.6.2", + "resolved": "https://registry.npmjs.org/webpack-virtual-modules/-/webpack-virtual-modules-0.6.2.tgz", + "integrity": "sha512-66/V2i5hQanC51vBQKPH4aI8NMAcBW59FVBs+rC7eGHupMyfn34q7rZIE+ETlJ+XTevqfUhVVBgSUNSW2flEUQ==", + "dev": true, + "license": "MIT" + }, "node_modules/whatwg-url": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/whatwg-url/-/whatwg-url-5.0.0.tgz", @@ -13397,9 +14941,9 @@ } }, "node_modules/yarn-or-npm/node_modules/cross-spawn": { - "version": "6.0.5", - "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.5.tgz", - "integrity": "sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ==", + "version": "6.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.6.tgz", + "integrity": "sha512-VqCUuhcd1iB+dsv8gxPttb5iZh/D0iubSP21g36KXdEuf6I5JiioesUVjpCdHV9MZRUfVFlvwtIUyPfxo5trtw==", "dev": true, "license": "MIT", "dependencies": { @@ -13492,18 +15036,58 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/yup": { + "version": "1.6.1", + "resolved": "https://registry.npmjs.org/yup/-/yup-1.6.1.tgz", + "integrity": "sha512-JED8pB50qbA4FOkDol0bYF/p60qSEDQqBD0/qeIrUCG1KbPBIQ776fCUNb9ldbPcSTxA69g/47XTo4TqWiuXOA==", + "dev": true, + "license": "MIT", + "optional": true, + "dependencies": { + "property-expr": "^2.0.5", + "tiny-case": "^1.0.3", + "toposort": "^2.0.2", + "type-fest": "^2.19.0" + } + }, + "node_modules/yup/node_modules/type-fest": { + "version": "2.19.0", + "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-2.19.0.tgz", + "integrity": "sha512-RAH822pAdBgcNMAfWnCBU3CFZcfZ/i1eZjwFU/dsLKumyuuP3niueg2UAukXYF0E2AAoc82ZSSf9J0WQBinzHA==", + "dev": true, + "license": "(MIT OR CC0-1.0)", + "optional": true, + "engines": { + "node": ">=12.20" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/zimmerframe": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/zimmerframe/-/zimmerframe-1.1.2.tgz", "integrity": "sha512-rAbqEGa8ovJy4pyBxZM70hg4pE6gDgaQ0Sl9M3enG3I0d6H4XSAM3GeNGLKnsBpuijUow064sf7ww1nutC5/3w==" }, "node_modules/zod": { - "version": "3.22.4", - "resolved": "https://registry.npmjs.org/zod/-/zod-3.22.4.tgz", - "integrity": "sha512-iC+8Io04lddc+mVqQ9AZ7OQ2MrUKGN+oIQyq1vemgt46jwCwLfhq7/pwnBnNXXXZb8VTVLKwp9EDkx+ryxIWmg==", + "version": "3.25.67", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.67.tgz", + "integrity": "sha512-idA2YXwpCdqUSKRCACDE6ItZD9TZzy3OZMtpfLoh6oPR47lipysRrJfjzMqFxQ3uJuUPyUeWe1r9vLH33xO/Qw==", + "license": "MIT", "funding": { "url": "https://github.com/sponsors/colinhacks" } + }, + "node_modules/zod-to-json-schema": { + "version": "3.24.5", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.5.tgz", + "integrity": "sha512-/AuWwMP+YqiPbsJx5D6TfgRTc4kTLjsh5SOcd4bLsfUg2RcEXrFMJl1DGgdHy2aCfsIA/cr/1JM0xcB2GZji8g==", + "dev": true, + "license": "ISC", + "optional": true, + "peerDependencies": { + "zod": "^3.24.1" + } } } } diff --git a/services/app/package.json b/services/app/package.json old mode 100644 new mode 100755 index 9312519..a885fc2 --- a/services/app/package.json +++ b/services/app/package.json @@ -1,6 +1,6 @@ { "name": "jamaibase-app", - "version": "0.2.0", + "version": "0.5.0", "private": true, "main": "electron/main.js", "author": "EmbeddedLLM", @@ -15,7 +15,7 @@ "make": "npm run build && electron-forge make", "preview": "vite preview", "start": "node server", - "devstart": "ORIGIN=http://localhost:4173 HOST=localhost FRONTEND_PORT=4173 NODE_ENV=development node server", + "devstart": "cross-env ORIGIN=http://localhost:4173 HOST=localhost FRONTEND_PORT=4173 NODE_ENV=development node server", "test": "npm run test:integration && npm run test:unit", "check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json", "check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch", @@ -23,7 +23,8 @@ "format": "prettier --write .", "test:integration": "playwright test", "test:unit": "vitest", - "start:debug_electron": "electron-forge start --inspect-electron" + "start:debug_electron": "electron-forge start --inspect-electron", + "machine-translate": "inlang machine translate --project project.inlang" }, "devDependencies": { "@electron-forge/cli": "^7.4.0", @@ -32,10 +33,13 @@ "@electron-forge/maker-squirrel": "^7.4.0", "@electron-forge/maker-zip": "^7.4.0", "@faker-js/faker": "^8.4.1", + "@inlang/cli": "^3.0.0", + "@inlang/paraglide-js": "2.0.13", + "@lucide/svelte": "^0.482.0", "@playwright/test": "^1.28.1", "@sveltejs/adapter-node": "^5.0.1", "@sveltejs/adapter-static": "^3.0.2", - "@sveltejs/kit": "^2.5.27", + "@sveltejs/kit": "^2.15.0", "@sveltejs/vite-plugin-svelte": "^4.0.0", "@types/cors": "^2.8.17", "@types/eslint": "^8.56.0", @@ -43,43 +47,56 @@ "@types/lodash": "^4.17.0", "@types/nprogress": "^0.2.3", "@types/papaparse": "^5.3.14", + "@types/pdfobject": "^2.2.5", "@types/showdown": "^2.0.6", "@types/uuid": "^9.0.8", "@typescript-eslint/eslint-plugin": "^7.0.0", "@typescript-eslint/parser": "^7.0.0", "autoprefixer": "^10.4.18", + "bits-ui": "^1.8.0", "concurrently": "^8.2.2", "cross-env": "^7.0.3", "electron": "^31.0.1", "eslint": "^8.56.0", "eslint-config-prettier": "^9.1.0", "eslint-plugin-svelte": "^2.45.1", + "mode-watcher": "^1.0.7", + "paneforge": "^1.0.0-next.5", "postcss": "^8.4.37", "prettier": "^3.1.1", + "prettier-plugin-organize-imports": "^4.1.0", "prettier-plugin-svelte": "^3.2.6", + "prettier-plugin-tailwindcss": "^0.6.12", "run-script-os": "^1.1.6", "svelte": "^5.0.0", "svelte-check": "^4.0.0", + "svelte-sonner": "^0.3.28", + "sveltekit-superforms": "^2.27.0", "tailwindcss": "^3.4.1", "tailwindcss-animate": "^1.0.7", "tslib": "^2.4.1", "typescript": "^5.5.0", + "vaul-svelte": "^1.0.0-next.7", "vite": "^5.4.4", + "vite-plugin-devtools-json": "^0.4.1", "vitest": "^1.2.0" }, "type": "module", "dependencies": { + "@auth/sveltekit": "^1.9.2", "@fontsource-variable/roboto-flex": "^5.0.15", "@formkit/auto-animate": "^0.8.1", - "@stripe/stripe-js": "^3.4.0", + "@monaco-editor/loader": "^1.5.0", + "@stripe/stripe-js": "^3.5.0", "@tailwindcss/container-queries": "^0.1.1", "auth0": "^4.4.0", "axios": "^1.6.8", - "bits-ui": "^0.20.1", "chart.js": "^4.4.3", "chartjs-adapter-moment": "^1.0.1", "clsx": "^2.1.0", "cors": "^2.8.5", + "csvtojson": "^2.0.10", + "date-fns": "^4.1.0", "dexie": "^4.0.10", "dotenv": "^16.4.5", "electron-serve": "^2.0.0", @@ -90,21 +107,29 @@ "lodash": "^4.17.21", "lucide-svelte": "^0.359.0", "minio": "^7.1.3", - "mode-watcher": "^0.3.0", + "minisearch": "^7.1.2", + "monaco-editor": "^0.52.2", "node-cache": "^5.1.2", "nprogress": "^0.2.0", "overlayscrollbars-svelte": "^0.5.1", "papaparse": "^5.4.1", + "pdfjs-dist": "^4.10.38", + "pdfobject": "^2.3.1", "pretty-bytes": "^6.1.1", + "prosemirror-commands": "^1.7.1", + "prosemirror-history": "^1.4.1", + "prosemirror-keymap": "^1.2.3", + "prosemirror-model": "^1.25.3", + "prosemirror-state": "^1.4.3", + "prosemirror-view": "^1.41.0", "showdown": "^2.1.0", "showdown-htmlescape": "^0.1.9", - "stripe": "^15.5.0", + "stripe": "^15.12.0", "svelte-persisted-store": "^0.9.1", - "svelte-sonner": "^0.3.24", "tailwind-merge": "^2.2.2", "tailwind-variants": "^0.2.1", "undici": "^6.19.4", "uuid": "^9.0.1", - "zod": "^3.22.4" + "zod": "^3.25.67" } } diff --git a/services/app/playwright.config.ts b/services/app/playwright.config.ts old mode 100644 new mode 100755 diff --git a/services/app/postcss.config.js b/services/app/postcss.config.js old mode 100644 new mode 100755 diff --git a/services/app/project.inlang/.gitignore b/services/app/project.inlang/.gitignore new file mode 100644 index 0000000..5e46596 --- /dev/null +++ b/services/app/project.inlang/.gitignore @@ -0,0 +1 @@ +cache \ No newline at end of file diff --git a/services/app/project.inlang/project_id b/services/app/project.inlang/project_id new file mode 100644 index 0000000..47e2976 --- /dev/null +++ b/services/app/project.inlang/project_id @@ -0,0 +1 @@ +7kdRYkHy8FwuNcDFad \ No newline at end of file diff --git a/services/app/project.inlang/settings.json b/services/app/project.inlang/settings.json new file mode 100644 index 0000000..67d14e4 --- /dev/null +++ b/services/app/project.inlang/settings.json @@ -0,0 +1,12 @@ +{ + "$schema": "https://inlang.com/schema/project-settings", + "baseLocale": "en", + "locales": ["en"], + "modules": [ + "https://cdn.jsdelivr.net/npm/@inlang/plugin-message-format@4/dist/index.js", + "https://cdn.jsdelivr.net/npm/@inlang/plugin-m-function-matcher@2/dist/index.js" + ], + "plugin.inlang.messageFormat": { + "pathPattern": "./messages/{locale}.json" + } +} diff --git a/services/app/server/index.js b/services/app/server/index.js old mode 100644 new mode 100755 index f927ec2..5dd0671 --- a/services/app/server/index.js +++ b/services/app/server/index.js @@ -1,16 +1,16 @@ +import cors from 'cors'; import 'dotenv/config'; -import { handler } from '../build/handler.js'; import express from 'express'; -import cors from 'cors'; import expressOpenIdConnect from 'express-openid-connect'; +import { handler } from '../build/handler.js'; -const { NODE_ENV, BASE_URL } = process.env; +const { NODE_ENV, ORIGIN } = process.env; const FRONTEND_PORT = process.env.FRONTEND_PORT || 4000; const app = express(); app.use(cors()); -if (process.env.PUBLIC_IS_LOCAL === 'false') { +if (!!process.env.OWL_SERVICE_KEY && !!process.env.AUTH0_CLIENT_SECRET) { // The `auth` router attaches /login, /logout and /callback routes to the baseURL app.use( expressOpenIdConnect.auth({ @@ -20,7 +20,7 @@ if (process.env.PUBLIC_IS_LOCAL === 'false') { }, authRequired: false, auth0Logout: true, - baseURL: NODE_ENV === 'production' ? BASE_URL : `http://localhost:${FRONTEND_PORT}`, + baseURL: NODE_ENV === 'production' ? ORIGIN : `http://localhost:${FRONTEND_PORT}`, clientID: process.env.AUTH0_CLIENT_ID, clientSecret: process.env.AUTH0_CLIENT_SECRET, issuerBaseURL: process.env.AUTH0_ISSUER_BASE_URL, diff --git a/services/app/src/app.css b/services/app/src/app.css old mode 100644 new mode 100755 diff --git a/services/app/src/app.d.ts b/services/app/src/app.d.ts old mode 100644 new mode 100755 index e337c8d..8618f54 --- a/services/app/src/app.d.ts +++ b/services/app/src/app.d.ts @@ -1,30 +1,23 @@ /* eslint-disable @typescript-eslint/ban-types */ // See https://kit.svelte.dev/docs/types#app + +import type { Auth0User, User } from '$lib/types'; + // for information about these interfaces declare global { namespace App { // interface Error {} interface Locals { - user?: User; + ossMode: boolean; + auth0Mode: boolean; + user?: Partial & User; } - interface PageData { - user?: User; + // interface PageData {} + interface PageState { + page?: number; } // interface Platform {} } } -type User = { - sid: string; - given_name?: string; - nickname: string; - name: string; - picture: string; - locale?: string; - updated_at: '2024-05-06T17:16:18.952Z'; - email: string; - email_verified: boolean; - sub: string; -}; - export {}; diff --git a/services/app/src/app.html b/services/app/src/app.html old mode 100644 new mode 100755 index 9b1b3b8..c0e63a6 --- a/services/app/src/app.html +++ b/services/app/src/app.html @@ -1,5 +1,5 @@ - + diff --git a/services/app/src/globalStore.ts b/services/app/src/globalStore.ts old mode 100644 new mode 100755 index 1ca4e21..b5b04de --- a/services/app/src/globalStore.ts +++ b/services/app/src/globalStore.ts @@ -1,4 +1,4 @@ -import type { AvailableModel, Organization, Project, UploadQueue } from '$lib/types'; +import type { ModelConfig, OrganizationReadRes, Project, UploadQueue } from '$lib/types'; import { serializer } from '$lib/utils'; import { persisted } from 'svelte-persisted-store'; import { writable } from 'svelte/store'; @@ -14,6 +14,11 @@ type SortOptions = { orderBy: string; order: 'asc' | 'desc'; }; +export const modelConfigSort = persisted( + 'modelConfigSort', + { orderBy: 'created_at', order: 'desc', filter: 'all' }, + { serializer } +); export const projectSort = persisted( 'projectSort', { orderBy: 'updated_at', order: 'desc' }, @@ -35,7 +40,7 @@ export const cTableSort = persisted( { serializer } ); -export const modelsAvailable = writable([]); +export const modelsAvailable = writable([]); export const uploadQueue = writable({ activeFile: null, @@ -45,7 +50,23 @@ export const uploadQueue = writable({ export const uploadController = writable(null); //* Non-local -export const activeOrganization = writable(null); +function createActiveOrgStore() { + const { subscribe, set, update } = writable(null); + + return { + subscribe, + set, + update, + setOrgCookie: (id: string | null) => { + if (id) { + document.cookie = `activeOrganizationId=${id}; path=/; max-age=604800; samesite=strict`; + } else { + document.cookie = `activeOrganizationId=; path=/; max-age=604800; samesite=strict`; + } + } + }; +} +export const activeOrganization = createActiveOrgStore(); export const activeProject = writable(null); export const loadingProjectData = writable<{ loading: boolean; error?: string }>({ loading: true, diff --git a/services/app/src/hljs-theme.css b/services/app/src/hljs-theme.css old mode 100644 new mode 100755 diff --git a/services/app/src/hooks.server.ts b/services/app/src/hooks.server.ts old mode 100644 new mode 100755 index ca9c53f..2d2dff5 --- a/services/app/src/hooks.server.ts +++ b/services/app/src/hooks.server.ts @@ -1,65 +1,87 @@ -import { PUBLIC_IS_LOCAL } from '$env/static/public'; -import { JAMAI_URL, JAMAI_SERVICE_KEY } from '$env/static/private'; import { dev } from '$app/environment'; -import { json, redirect, type Handle } from '@sveltejs/kit'; -import { Agent } from 'undici'; -import { getPrices } from '$lib/server/nodeCache'; +import { env } from '$env/dynamic/private'; +import { handle as authenticationHandle } from '$lib/auth'; import logger from '$lib/logger'; +import { paraglideMiddleware } from '$lib/paraglide/server'; +import { getPrices } from '$lib/server/nodeCache'; +import type { Auth0User, User } from '$lib/types'; +import type { Session } from '@auth/sveltekit'; +import { error, redirect, type Handle } from '@sveltejs/kit'; +import { sequence } from '@sveltejs/kit/hooks'; +import { Agent } from 'undici'; + +const { AUTH0_CLIENT_SECRET, OWL_SERVICE_KEY, OWL_URL } = env; +const ossMode = !OWL_SERVICE_KEY; +const auth0Mode = !!OWL_SERVICE_KEY && !!AUTH0_CLIENT_SECRET; -const PROXY_PATHS: { path: string; target: string }[] = [ +const PROXY_PATHS: { path: string; exclude?: string[]; target: string }[] = [ + { + path: '/api/owl/organizations', + exclude: ['/api/owl/organizations/webhooks/stripe'], + target: `${OWL_URL}/api/v2/organizations` + }, + { + path: '/api/owl/projects', + // exclude: ['/api/owl/projects/export', '/api/owl/projects/import'], + target: `${OWL_URL}/api/v2/projects` + }, { - path: '/api/v1/gen_tables', - target: JAMAI_URL + path: '/api/owl/gen_tables', + target: `${OWL_URL}/api/v2/gen_tables` }, { - path: '/api/v1/models', - target: JAMAI_URL + path: '/api/owl/models', + target: `${OWL_URL}/api/v2/models` }, { - path: '/api/v1/model_names', - target: JAMAI_URL + path: '/api/owl/model_names', + target: `${OWL_URL}/api/v2/model_names` }, { - path: '/api/v1/chat/completions', - target: JAMAI_URL + path: '/api/owl/chat/completions', + target: `${OWL_URL}/api/v2/chat/completions` }, { - path: '/api/v1/files', - target: JAMAI_URL + path: '/api/owl/conversations', + target: `${OWL_URL}/api/v2/conversations` + }, + { + path: '/api/owl/files', + target: `${OWL_URL}/api/v2/files` }, { path: '/api/file', - target: JAMAI_URL + target: `${OWL_URL}/api/v2/file` }, { - path: '/api/public/v1/templates', - target: JAMAI_URL + path: '/api/owl/templates', + target: `${OWL_URL}/api/v2/templates` } ]; const handleApiProxy: Handle = async ({ event }) => { const proxyPath = PROXY_PATHS.find((p) => event.url.pathname.startsWith(p.path))!; - const urlPath = `${proxyPath!.target}${event.url.pathname}${event.url.search}`; + const urlPath = `${proxyPath.target}${event.url.pathname.replace(proxyPath.path, '')}${event.url.search}`; const proxiedUrl = new URL(urlPath); event.request.headers.delete('connection'); - if (PUBLIC_IS_LOCAL === 'false') { - if (event.locals.user) { - event.request.headers.append('Authorization', `Bearer ${JAMAI_SERVICE_KEY}`); - event.request.headers.append('x-user-id', event.locals.user.sub); + if (event.locals.user) { + if (!ossMode) { + event.request.headers.append('Authorization', `Bearer ${OWL_SERVICE_KEY}`); } + event.request.headers.append('x-user-id', event.locals.user.id); } - const projectId = - event.request.headers.get('x-project-id') || event.cookies.get('activeProjectId'); - if (!projectId) { - return json({ message: 'Missing project ID' }, { status: 400 }); + if (!event.request.headers.get('x-project-id') && event.cookies.get('activeProjectId')) { + event.request.headers.append('x-project-id', event.cookies.get('activeProjectId')!); } - if (!event.request.headers.get('x-project-id')) { - event.request.headers.append('x-project-id', projectId); - } + // const projectId = + // event.request.headers.get('x-project-id') || event.cookies.get('activeProjectId'); + // if (!projectId) { + // return json({ message: 'Missing project ID' }, { status: 400 }); + // } return fetch(proxiedUrl.toString(), { body: event.request.body, @@ -81,11 +103,16 @@ const handleApiProxy: Handle = async ({ event }) => { }); }; -export const handle: Handle = async ({ event, resolve }) => { +export const mainHandle: Handle = async ({ event, resolve }) => { const { cookies, locals, request, url } = event; - if (dev && !request.url.includes('/api/v1/files')) console.log('Connecting', request.url); + if (dev && !request.url.includes('/api/owl/files')) console.log('Connecting', request.url); + + locals.ossMode = ossMode; + locals.auth0Mode = auth0Mode; - if (PUBLIC_IS_LOCAL === 'false') { + let auth0UserData: Auth0User; + let session: Session | null; + if (auth0Mode) { //? Workaround for event.platform unavailable in development if (dev) { const user = await ( @@ -93,38 +120,149 @@ export const handle: Handle = async ({ event, resolve }) => { headers: { cookie: `appSession=${cookies.get('appSession')}` } }) ).json(); - locals.user = Object.keys(user).length ? user : undefined; + auth0UserData = Object.keys(user).length ? user : undefined; } else { // @ts-expect-error missing type - locals.user = event.platform?.req?.res?.locals?.user; + auth0UserData = event.platform?.req?.res?.locals?.user; } + } else { + session = await locals.auth(); } - if (PUBLIC_IS_LOCAL === 'false' && !url.pathname.startsWith('/api')) { + //@ts-expect-error asd + if (auth0UserData || session) { + //@ts-expect-error asd + let userApiData = await getUserApiData(auth0UserData?.sub ?? session?.user?.id); + if (!userApiData.data) { + if (auth0Mode && userApiData.status === 404) { + const userUpsertRes = await fetch(`${OWL_URL}/api/v2/users`, { + method: 'POST', + headers: { + Authorization: `Bearer ${OWL_SERVICE_KEY}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + id: auth0UserData!.sub, + name: + auth0UserData!.email === auth0UserData!.name + ? auth0UserData!.nickname + : auth0UserData!.name, + email: auth0UserData!.email, + email_verified: true + }) + }); + const userUpsertBody = (await userUpsertRes.json()) as User; + + if (!userUpsertRes.ok) { + logger.error('APP_USER_UPSERT', userUpsertBody); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + throw error(userUpsertRes.status, userUpsertBody as any); + } else { + userApiData = { status: 200, data: userUpsertBody }; + } + } else { + // logger.error('APP_USER_GET', `User not found: ${session.user.id}`); + if (!url.pathname.startsWith('/login') && !url.pathname.startsWith('/register')) { + throw redirect(302, '/login'); + } + } + } + locals.user = { + ...(auth0UserData! ?? {}), + ...userApiData.data!, + email_verified: (auth0UserData! ?? {}).email_verified ?? userApiData.data!.email_verified + }; + } + + //? Bandaid fix for email verification - REMOVE LATER + /* if (auth0Mode && locals.user) { + await fetch( + `${OWL_URL}/api/v2/users/verify/email/code?${new URLSearchParams([ + ['user_email', locals.user.email], + ['valid_days', '7'] + ])}`, + { + method: 'POST', + headers: { Authorization: `Bearer ${OWL_SERVICE_KEY}`, 'x-user-id': locals.user.sub! } + } + ) + .then((r) => r.json()) + .then((emailCode) => + fetch( + `${OWL_URL}/api/v2/users/verify/email?${new URLSearchParams([['verification_code', emailCode.id]])}`, + { + method: 'POST', + headers: { Authorization: `Bearer ${OWL_SERVICE_KEY}`, 'x-user-id': locals.user!.sub! } + } + ) + ); + } */ + + if ( + !url.pathname.startsWith('/api') && + !url.pathname.startsWith('/login') && + !url.pathname.startsWith('/register') + ) { if (!locals.user) { const originalUrl = url.pathname + (url.searchParams.size > 0 ? `?${url.searchParams.toString()}` : ''); - throw redirect(302, `/login${originalUrl ? `?returnTo=${originalUrl}` : ''}`); - } else { - if (!locals.user.email_verified && !url.pathname.startsWith('/verify-email')) { - throw redirect( - 302, - `/verify-email${url.searchParams.size > 0 ? `?${url.searchParams.toString()}` : ''}` - ); - } + throw redirect( + 302, + `/login${originalUrl ? `?returnTo=${encodeURIComponent(originalUrl)}` : ''}` + ); } } - if (PROXY_PATHS.some((p) => url.pathname.startsWith(p.path))) { + if ( + PROXY_PATHS.some( + (p) => + url.pathname.startsWith(p.path) && + (!p.exclude || !p.exclude.some((ex) => url.pathname.startsWith(ex))) + ) + ) { return await handleApiProxy({ event, resolve }); } return await resolve(event); }; +const paraglideHandle: Handle = ({ event, resolve }) => + paraglideMiddleware(event.request, ({ request: localizedRequest, locale }) => { + event.request = localizedRequest; + return resolve(event, { + transformPageChunk: ({ html }) => { + return html.replace('%lang%', locale); + } + }); + }); + +export const handle: Handle = sequence(authenticationHandle, mainHandle, paraglideHandle); + //* Server startup script -if (PUBLIC_IS_LOCAL === 'false') { - (async function () { - await getPrices(); - })(); +(async function () { + await getPrices(); +})(); + +async function getUserApiData(userId: string) { + const userApiRes = await fetch( + `${OWL_URL}/api/v2/users?${new URLSearchParams([['user_id', userId]])}`, + { + headers: { + Authorization: `Bearer ${OWL_SERVICE_KEY}`, + 'x-user-id': userId + } + } + ); + + const userApiBody = await userApiRes.json(); + if (userApiRes.ok) { + return { status: 200, data: userApiBody as User }; + } else { + if (!/User "([^"]*)" is not found\./.test(userApiBody.message)) { + logger.error('APP_USER_GET', userApiBody); + return { status: userApiRes.status, data: undefined }; + } else { + return { status: 404, data: undefined }; + } + } } diff --git a/services/app/src/hooks.ts b/services/app/src/hooks.ts new file mode 100644 index 0000000..fd4a845 --- /dev/null +++ b/services/app/src/hooks.ts @@ -0,0 +1,6 @@ +import { deLocalizeUrl } from '$lib/paraglide/runtime'; +import type { Reroute } from '@sveltejs/kit'; + +export const reroute: Reroute = (request) => { + return deLocalizeUrl(request.url).pathname; +}; diff --git a/services/app/src/lib/assets/Black-Long-Main.svg b/services/app/src/lib/assets/Black-Long-Main.svg old mode 100644 new mode 100755 diff --git a/services/app/src/lib/assets/Black-Long.svg b/services/app/src/lib/assets/Black-Long.svg old mode 100644 new mode 100755 diff --git a/services/app/src/lib/assets/Black-Main.svg b/services/app/src/lib/assets/Black-Main.svg old mode 100644 new mode 100755 diff --git a/services/app/src/lib/assets/Black.svg b/services/app/src/lib/assets/Black.svg old mode 100644 new mode 100755 diff --git a/services/app/src/lib/assets/Jamai-Long-Black-Main.svg b/services/app/src/lib/assets/Jamai-Long-Black-Main.svg old mode 100644 new mode 100755 diff --git a/services/app/src/lib/assets/Jamai-Long-White-Main.svg b/services/app/src/lib/assets/Jamai-Long-White-Main.svg old mode 100644 new mode 100755 diff --git a/services/app/src/lib/assets/White-Long-Main.svg b/services/app/src/lib/assets/White-Long-Main.svg old mode 100644 new mode 100755 diff --git a/services/app/src/lib/assets/White-Long.svg b/services/app/src/lib/assets/White-Long.svg old mode 100644 new mode 100755 diff --git a/services/app/src/lib/assets/White-Main.svg b/services/app/src/lib/assets/White-Main.svg old mode 100644 new mode 100755 diff --git a/services/app/src/lib/assets/White.svg b/services/app/src/lib/assets/White.svg old mode 100644 new mode 100755 diff --git a/services/app/src/lib/assets/dark-mode.svg b/services/app/src/lib/assets/dark-mode.svg old mode 100644 new mode 100755 diff --git a/services/app/src/lib/assets/jamai-onboarding-bg.svg b/services/app/src/lib/assets/jamai-onboarding-bg.svg old mode 100644 new mode 100755 diff --git a/services/app/src/lib/assets/light-mode.svg b/services/app/src/lib/assets/light-mode.svg old mode 100644 new mode 100755 diff --git a/services/app/src/lib/assets/model-icons/allenai.png b/services/app/src/lib/assets/model-icons/allenai.png new file mode 100644 index 0000000..60976e2 Binary files /dev/null and b/services/app/src/lib/assets/model-icons/allenai.png differ diff --git a/services/app/src/lib/assets/model-icons/anthropic.png b/services/app/src/lib/assets/model-icons/anthropic.png new file mode 100644 index 0000000..9d94862 Binary files /dev/null and b/services/app/src/lib/assets/model-icons/anthropic.png differ diff --git a/services/app/src/lib/assets/model-icons/cohere.png b/services/app/src/lib/assets/model-icons/cohere.png new file mode 100644 index 0000000..3da0b83 Binary files /dev/null and b/services/app/src/lib/assets/model-icons/cohere.png differ diff --git a/services/app/src/lib/assets/model-icons/deepseek.png b/services/app/src/lib/assets/model-icons/deepseek.png new file mode 100644 index 0000000..a2bf71a Binary files /dev/null and b/services/app/src/lib/assets/model-icons/deepseek.png differ diff --git a/services/app/src/lib/assets/model-icons/gemini.png b/services/app/src/lib/assets/model-icons/gemini.png new file mode 100644 index 0000000..f1f675b Binary files /dev/null and b/services/app/src/lib/assets/model-icons/gemini.png differ diff --git a/services/app/src/lib/assets/model-icons/generic.png b/services/app/src/lib/assets/model-icons/generic.png new file mode 100644 index 0000000..0b9d2ae Binary files /dev/null and b/services/app/src/lib/assets/model-icons/generic.png differ diff --git a/services/app/src/lib/assets/model-icons/generic2.png b/services/app/src/lib/assets/model-icons/generic2.png new file mode 100644 index 0000000..ff07b5e Binary files /dev/null and b/services/app/src/lib/assets/model-icons/generic2.png differ diff --git a/services/app/src/lib/assets/model-icons/index.ts b/services/app/src/lib/assets/model-icons/index.ts new file mode 100644 index 0000000..02cceaa --- /dev/null +++ b/services/app/src/lib/assets/model-icons/index.ts @@ -0,0 +1,25 @@ +import allenai from './allenai.png'; +import anthropic from './anthropic.png'; +import cohere from './cohere.png'; +import deepseek from './deepseek.png'; +import gemini from './gemini.png'; +import generic from './generic.png'; +import generic2 from './generic2.png'; +import meta from './meta.png'; +import mistral from './mistral.png'; +import openai from './openai.png'; +import qwen from './qwen.png'; + +export { + allenai, + anthropic, + cohere, + deepseek, + gemini, + generic, + generic2, + meta, + mistral, + openai, + qwen +}; diff --git a/services/app/src/lib/assets/model-icons/meta.png b/services/app/src/lib/assets/model-icons/meta.png new file mode 100644 index 0000000..9d3ed82 Binary files /dev/null and b/services/app/src/lib/assets/model-icons/meta.png differ diff --git a/services/app/src/lib/assets/model-icons/mistral.png b/services/app/src/lib/assets/model-icons/mistral.png new file mode 100644 index 0000000..adbc3a8 Binary files /dev/null and b/services/app/src/lib/assets/model-icons/mistral.png differ diff --git a/services/app/src/lib/assets/model-icons/openai.png b/services/app/src/lib/assets/model-icons/openai.png new file mode 100644 index 0000000..01eb4cc Binary files /dev/null and b/services/app/src/lib/assets/model-icons/openai.png differ diff --git a/services/app/src/lib/assets/model-icons/qwen.png b/services/app/src/lib/assets/model-icons/qwen.png new file mode 100644 index 0000000..c805b63 Binary files /dev/null and b/services/app/src/lib/assets/model-icons/qwen.png differ diff --git a/services/app/src/lib/assets/system-mode.svg b/services/app/src/lib/assets/system-mode.svg old mode 100644 new mode 100755 diff --git a/services/app/src/lib/auth.ts b/services/app/src/lib/auth.ts new file mode 100644 index 0000000..4ddd4e7 --- /dev/null +++ b/services/app/src/lib/auth.ts @@ -0,0 +1,205 @@ +import { env } from '$env/dynamic/private'; +import { CredentialsSignin, SvelteKitAuth, type DefaultSession } from '@auth/sveltekit'; + +import Credentials from '@auth/sveltekit/providers/credentials'; +import logger from './logger'; +import type { User } from './types'; + +const { AUTH_SECRET, OWL_URL, USE_SECURE_COOKIES, IDLE_AUTH_TIMEOUT, ABSOLUTE_AUTH_TIMEOUT } = env; + +const DEFAULT_AUTH_ABSOLUTE_TIMEOUT = 86400; +const DEFAULT_AUTH_IDLE_TIMEOUT = 900; + +const ABSOLUTE_MAX_LIFETIME = + (Number(ABSOLUTE_AUTH_TIMEOUT) || DEFAULT_AUTH_ABSOLUTE_TIMEOUT) * 1000; + +type SessionUser = Pick< + User, + | 'id' + | 'email' + | 'name' + | 'preferred_name' + | 'preferred_email' + | 'picture_url' + | 'preferred_picture_url' +>; + +declare module '@auth/sveltekit' { + interface Session { + user: SessionUser & DefaultSession['user']; + } + interface User extends SessionUser {} +} + +class InvalidCredentials extends CredentialsSignin { + code = 'invalid_credentials'; +} +class InsufficientCredentials extends CredentialsSignin { + code = 'insufficient_credentials'; +} +class UserExists extends CredentialsSignin { + code = 'user_exists'; +} +class UserNotFound extends CredentialsSignin { + code = 'user_not_found'; +} + +export const { handle } = SvelteKitAuth({ + trustHost: true, + providers: [ + Credentials({ + id: 'credentials', + name: 'Credentials', + credentials: { + email: {}, + name: {}, + password: {}, + isNewAccount: {} + }, + authorize: async (credentials) => { + if (!credentials?.email || !credentials?.password) { + throw new InsufficientCredentials('Email and password are required'); + } + + if (credentials.isNewAccount === 'true') { + if (!credentials?.email || !credentials?.name) { + throw new InsufficientCredentials(); + } + const response = await fetch(`${OWL_URL}/api/v2/auth/register/password`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + email: credentials.email, + name: credentials.name, + password: credentials.password + }) + }); + + const data = await response.json(); + if (!response.ok) { + if (data.error !== 'resource_exists' || data.error !== 'unauthorized') { + logger.error('AUTH_SIGNUP_ERROR', data); + } + + if (data.error === 'resource_exists') { + throw new UserExists(data?.message); + } + if (data.error === 'unauthorized') { + throw new InvalidCredentials(data?.message); + } + throw new CredentialsSignin(); + } + + if (data.id) { + delete data.password_hash; + return data; + } + } else { + const response = await fetch(`${OWL_URL}/api/v2/auth/login/password`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + email: credentials.email, + password: credentials.password + }) + }); + + const data = await response.json(); + + if (!response.ok) { + if (data.message !== 'User not found.' || data.error !== 'unauthorized') { + logger.error('AUTH_LOGIN_ERROR', data); + } + + if (data.message === 'User not found.') { + throw new UserNotFound(data?.message); + } + if (data.error === 'unauthorized') { + throw new InvalidCredentials(data?.message); + } + throw new CredentialsSignin(); + } + + if (data.id) { + delete data.password_hash; + return data; + } + } + + throw new CredentialsSignin(); + } + }) + ], + callbacks: { + // @ts-expect-error ignore + async session({ session, user, token }) { + if (token.forceLogout) { + return null; + } + + if (user) { + assignProperties(user, session.user); + } + if (token) { + assignProperties(token, session.user); + } + return session; + }, + + async redirect({ url, baseUrl }) { + // Allows relative callback URLs + if (url.startsWith('/')) return `${baseUrl}${url}`; + // Allows callback URLs on the same origin + return url; + }, + + async jwt({ token, user }) { + if (user) { + assignProperties(user, token); + } + + if (token.createdAt) { + const age = Date.now() - Number(token.createdAt); + if (age > ABSOLUTE_MAX_LIFETIME) { + token.forceLogout = true; + } + } else { + token.createdAt = Date.now(); + } + + return token; + } + }, + pages: { + signIn: '/login', + newUser: '/register' + }, + session: { + strategy: 'jwt' + // maxAge: Number(IDLE_AUTH_TIMEOUT) || DEFAULT_AUTH_IDLE_TIMEOUT + }, + + secret: AUTH_SECRET, + useSecureCookies: USE_SECURE_COOKIES === 'true' ? true : undefined +}); + +function assignProperties(source: Record, target: Record) { + const properties = [ + 'id', + 'email', + 'name', + 'preferred_name', + 'picture_url', + 'preferred_picture_url' + ]; + + properties.forEach((property) => { + if (source[property] !== undefined) { + target[property] = source[property]; + } + }); +} diff --git a/services/app/src/lib/components/Checkbox.svelte b/services/app/src/lib/components/Checkbox.svelte old mode 100644 new mode 100755 index 119805c..5915af7 --- a/services/app/src/lib/components/Checkbox.svelte +++ b/services/app/src/lib/components/Checkbox.svelte @@ -7,17 +7,30 @@ checkedChange: { event: MouseEvent; value: boolean }; }>(); - let className: string | undefined | null = undefined; - export { className as class }; + - export let id: string | undefined = undefined; - export let defaultChecked: boolean = false; - export let disabled: boolean | undefined = undefined; - export let required: boolean | undefined = undefined; - export let name: string | undefined = undefined; - export let checked: boolean = defaultChecked; - export let validateBeforeChange: (e: MouseEvent) => boolean = () => true; + interface Props { + class?: string | undefined | null; + id?: string | undefined; + defaultChecked?: boolean; + disabled?: boolean | undefined; + required?: boolean | undefined; + name?: string | undefined; + checked?: boolean; + validateBeforeChange?: (e: MouseEvent) => boolean; + } + + let { + class: className = undefined, + id = undefined, + defaultChecked = false, + disabled = undefined, + required = undefined, + name = undefined, + checked = $bindable(defaultChecked), + validateBeforeChange = () => true + }: Props = $props(); function toggle(e: MouseEvent) { if (validateBeforeChange(e) == false) return; @@ -31,8 +44,8 @@ + {/snippet} - + + + Sort by + {#each sortableFields as { id, title, Icon }} - {#each ['asc', 'desc'] as direction} - - - {title} - {direction === 'asc' ? '(Ascending)' : '(Descending)'} - - {/each} + + + {title} + + {/each} + +
+ + + Order + + {#each ['asc', 'desc'] as direction} + + + {direction === 'asc' + ? `${m['sortable.direction_asc']()}` + : `${m['sortable.direction_desc']()}`} + {/each}
diff --git a/services/app/src/lib/components/preset/UserDetailsBtn.svelte b/services/app/src/lib/components/preset/UserDetailsBtn.svelte new file mode 100644 index 0000000..b4d64b6 --- /dev/null +++ b/services/app/src/lib/components/preset/UserDetailsBtn.svelte @@ -0,0 +1,156 @@ + + + + + {#snippet child({ props })} + + {/snippet} + + + +
+ {page.data.user?.email} +
+ + + + {#snippet child({ props })} + + {@render joinOrgIcon('h-4 w-4')} + Join organization + + {/snippet} + + + {#snippet child({ props })} + + + Join project + + {/snippet} + + + {#snippet child({ props })} + + + Create organization + + {/snippet} + + + {#snippet child({ props })} + + + Create project + + {/snippet} + + + + + + + + {#snippet child({ props })} + + + Account Settings + + {/snippet} + + goto('/logout') : () => signOut()} + class="!text-[#F04438]" + > + + Sign out + + +
+
+ +{#snippet joinOrgIcon(className = '')} + + + + + + + +{/snippet} diff --git a/services/app/src/lib/components/tables/(sub)/ColumnDropdown.svelte b/services/app/src/lib/components/tables/(sub)/ColumnDropdown.svelte old mode 100644 new mode 100755 index 4bbb704..d9a9a34 --- a/services/app/src/lib/components/tables/(sub)/ColumnDropdown.svelte +++ b/services/app/src/lib/components/tables/(sub)/ColumnDropdown.svelte @@ -1,13 +1,15 @@ - - + + {#snippet child({ props })} + + {/snippet} - + {#if colType === 'output'} - tableState.setColumnSettings({ column, isOpen: true })}> - + tableState.setColumnSettings({ column, isOpen: true })}> + Open settings @@ -221,20 +253,36 @@ {#if !readonly && !tableStaticCols[tableType].includes(column.id)} - {#if colType === 'output' && $tableState.selectedRows.length > 0} + {#if colType === 'output'} - - + + {#if tableState.selectedRows.length === 0} + + +
{ + if (animationFrameId) cancelAnimationFrame(animationFrameId); + tooltipPos.visible = false; + }} + class="pointer-events-auto absolute -bottom-1 -top-1 left-0 right-0 cursor-default" + >
+ {/if} + + Regenerate
- handleRegen('run_selected')}> + handleRegen('run_selected')}> This column - handleRegen('run_before')}> + handleRegen('run_before')}> Up to this column - handleRegen('run_after')}> + handleRegen('run_after')}> This column onwards @@ -242,23 +290,36 @@ {/if} { + onclick={async () => { tableState.setRenamingCol(column.id); //? Tick doesn't work - setTimeout(() => document.getElementById('column-id-edit')?.focus(), 100); + setTimeout(() => document.getElementById('column-id-edit')?.focus(), 200); }} > - + Rename tableState.setDeletingCol(column.id)} + onclick={() => tableState.setDeletingCol(column.id)} class="!text-[#F04438]" > - + Delete column
{/if}
+ + + + Select at least one row to regenerate + + diff --git a/services/app/src/lib/components/tables/(sub)/ColumnHeader.svelte b/services/app/src/lib/components/tables/(sub)/ColumnHeader.svelte old mode 100644 new mode 100755 index aa427ca..3f99a53 --- a/services/app/src/lib/components/tables/(sub)/ColumnHeader.svelte +++ b/services/app/src/lib/components/tables/(sub)/ColumnHeader.svelte @@ -1,15 +1,17 @@ { - if ($tableState.resizingCol) { + onmousemove={handleColResize} + onmouseup={() => { + if (tableState.resizingCol) { db[`${tableType}_table`].put({ id: tableData.id, - columns: $tableState.colSizes + columns: $state.snapshot(tableState.colSizes) }); - $tableState.resizingCol = null; + tableState.resizingCol = null; } }} /> @@ -228,42 +245,44 @@ {#each tableData.cols as column, index (column.id)} {@const colType = !column.gen_config ? 'input' : 'output'} {@const isCustomCol = column.id !== 'ID' && column.id !== 'Updated at'} - - + +
handleColumnHeaderClick(column)} - on:dragover={(e) => { + onclick={() => handleColumnHeaderClick(column)} + ondragover={(e) => { if (isCustomCol) { e.preventDefault(); hoveredColumnIndex = index; } }} class={cn( - 'relative [&>*]:z-[-5] flex items-center gap-1 [&:not(:last-child)]:border-r border-[#E4E7EC] data-dark:border-[#333] cursor-default', + 'relative flex cursor-default items-center gap-1 border-[#E4E7EC] data-dark:border-[#333] [&:not(:last-child)]:border-r [&>*]:z-[-5]', isCustomCol && !readonly ? 'px-1' : 'pl-2 pr-1', - $tableState.columnSettings.column?.id == column.id && - $tableState.columnSettings.isOpen && + tableState.columnSettings.column?.id == column.id && + tableState.columnSettings.isOpen && 'bg-[#30A8FF33]', - draggingColumn?.id == column.id && 'opacity-0' + draggingColumn?.id == column.id && 'opacity-0', + tableState.renamingCol && 'pointer-events-none' )} > {#if isCustomCol} {/if} @@ -271,8 +290,8 @@ - {#if !$tableState.colSizes[draggingColumn.id] || $tableState.colSizes[draggingColumn.id] >= 150} + {#if !tableState.colSizes[draggingColumn.id] || tableState.colSizes[draggingColumn.id] >= 150} - + {colType} - {#if !$tableState.colSizes[draggingColumn.id] || $tableState.colSizes[draggingColumn.id] >= 220} + {#if !tableState.colSizes[draggingColumn.id] || tableState.colSizes[draggingColumn.id] >= 220} {draggingColumn.dtype} @@ -413,13 +437,13 @@ {/if} - + {draggingColumn.id} -
- -{/if} + {/if} + diff --git a/services/app/src/lib/components/tables/(sub)/ColumnSettings.svelte b/services/app/src/lib/components/tables/(sub)/ColumnSettings.svelte index f681e00..955dc8c 100644 --- a/services/app/src/lib/components/tables/(sub)/ColumnSettings.svelte +++ b/services/app/src/lib/components/tables/(sub)/ColumnSettings.svelte @@ -1,227 +1,152 @@ { - if ($tableState.columnSettings.isOpen && e.key === 'Escape') { + onkeydown={(e) => { + if (tableState.columnSettings.isOpen && e.key === 'Escape') { closeColumnSettings(); } }} /> - - + +
-{#if $tableState.columnSettings.isOpen || showActual} +{#if tableState.columnSettings.isOpen || showActual}
{ - if ($tableState.columnSettings.isOpen) { + inert={!tableState.columnSettings.isOpen} + onanimationstart={() => { + if (tableState.columnSettings.isOpen) { showActual = true; } }} - on:animationend={() => { - if (!$tableState.columnSettings.isOpen) { + onanimationend={() => { + if (!tableState.columnSettings.isOpen) { showActual = false; } }} - class="absolute z-40 bottom-0 {$tableState.columnSettings.column?.gen_config + class="absolute bottom-0 z-40 px-4 py-3 {tableState.columnSettings.column?.gen_config ? 'column-settings max-h-full' - : 'h-16 max-h-16'} w-full bg-white data-dark:bg-[#0D0E11] {$tableState.columnSettings.isOpen + : 'h-16 max-h-16'} w-full {tableState.columnSettings.isOpen ? 'animate-in slide-in-from-bottom-full' : 'animate-out slide-out-to-bottom-full'} duration-300 ease-in-out" > -
+
- {#if showPromptTab && selectedGenConfigObj !== 'gen_config.code'} + {#if selectedGenConfig?.object !== 'gen_config.python'} + {#if showPromptTab && selectedGenConfig?.object !== 'gen_config.code'} + + {/if} + + {:else} + {/if} - -
- {#if selectedGenConfigObj} + {#if selectedGenConfig?.object}
- - {!$tableState.columnSettings.column?.gen_config ? 'input' : 'output'} + + {!tableState.columnSettings.column?.gen_config ? 'input' : 'output'} - {$tableState.columnSettings.column?.dtype} + {tableState.columnSettings.column?.dtype} - {#if $tableState.columnSettings.column?.gen_config?.object === 'gen_config.llm' && $tableState.columnSettings.column.gen_config.multi_turn} -
+ {#if tableState.columnSettings.column?.gen_config?.object === 'gen_config.llm' && tableState.columnSettings.column.gen_config.multi_turn} +
- +
{/if}
- {$tableState.columnSettings.column?.id} + {tableState.columnSettings.column?.id}
-
- {#if (tableType !== 'knowledge' || showPromptTab) && selectedGenConfigObj !== 'gen_config.code'} -
+
+ {#if (tableType !== 'knowledge' || showPromptTab) && selectedGenConfig.object === 'gen_config.llm'} +
+ +
+ { + const modelDetails = $modelsAvailable.find((val) => val.id == model); + if ( + modelDetails && + (selectedGenConfig.max_tokens ?? 0) > modelDetails.context_length + ) { + selectedGenConfig.max_tokens = modelDetails.context_length; + } + }} + class="w-64 border-transparent bg-[#F9FAFB] hover:bg-[#e1e2e6] data-dark:bg-[#42464e]" />
{/if} - - - - - {#each tableData?.cols.filter((col) => !['ID', 'Updated at'].includes(col.id) && col.id !== $tableState.columnSettings.column?.id && col.dtype === 'str') ?? [] as column} - (selectedSourceColumn = column.id)} - value={column.id} - label={column.id} - class="flex justify-between gap-10 cursor-pointer" - > - {column.id} - - {/each} - - -
+ {#snippet children()} + + {selectedGenConfig.source_column || 'Select source column'} + + {/snippet} + + + {#each tableData?.cols.filter((col) => !['ID', 'Updated at'].includes(col.id) && col.id !== tableState.columnSettings.column?.id && col.dtype === 'str') ?? [] as column} + + {column.id} + + {/each} + + +
+ {/if}
{#if showPromptTab && !readonly} - {/if}
{:else if selectedTab === 'prompt'} -
-
- Customize prompt +
+
+
+
+ Columns: + {#each [...usableColumns, ...(selectedGenConfig.object === 'gen_config.python' && originalCol ? [originalCol] : [])] as column} + + {/each} +
+ + {#if selectedGenConfig.object === 'gen_config.llm'} + { + if (selectedGenConfig?.object === 'gen_config.llm') { + return selectedGenConfig?.prompt ?? ''; + } else { + return ''; + } + }, + (v) => { + if (selectedGenConfig?.object === 'gen_config.llm') + selectedGenConfig.prompt = v; }} - > - {column.id} - - {/each} + {usableColumns} + /> + {:else if selectedGenConfig.object === 'gen_config.python'} + + {/if}
- -
- -
- - {#if !readonly} - - {/if} -
-
- {:else} -
-
-
+ {#if selectedGenConfig.object === 'gen_config.llm'}
- Customize system prompt - - -
+ - {#if isRAGEnabled}
- RAG Settings + -
-
- k +
+
+ - (editRAGk = - parseInt(editRAGk) <= 0 ? '1' : parseInt(editRAGk).toString())} - class="px-3 py-2 text-sm bg-[#F2F4F7] data-dark:bg-[#42464e] rounded-md border border-transparent data-dark:border-[#42464E] placeholder:text-muted-foreground focus-visible:outline-none focus-visible:border-[#4169e1] data-dark:focus-visible:border-[#5b7ee5] disabled:cursor-not-allowed disabled:opacity-50 transition-colors" + id="temperature" + name="temperature" + step=".01" + bind:value={selectedGenConfig.temperature} + onchange={(e) => { + const value = parseFloat(e.currentTarget.value); + + if (isNaN(value)) { + selectedGenConfig.temperature = 1; + } else if (value < 0.01) { + selectedGenConfig.temperature = 0.01; + } else if (value > 1) { + selectedGenConfig.temperature = 1; + } else { + selectedGenConfig.temperature = Number(value.toFixed(2)); + } + }} + class="rounded-md border border-[#E3E3E3] bg-white px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus-visible:border-[#d5607c] focus-visible:shadow-[0_0_0_1px_#FFD8DF] focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 data-dark:border-[#42464E] data-dark:bg-[#42464e] data-dark:focus-visible:border-[#5b7ee5]" />
-
-
- - Reranking Model - +
+ - - (selectedRerankModel = selectedRerankModel === model ? '' : model)} - selectedModel={selectedRerankModel} - buttonText={($modelsAvailable.find( - (model) => model.id == selectedRerankModel - )?.name ?? - selectedRerankModel) || - 'Select model (optional)'} - class="h-10 bg-[#F2F4F7] data-dark:bg-[#42464e] hover:bg-[#e1e2e6] border-transparent" - /> -
+ { + const value = parseInt(e.currentTarget.value); + const model = $modelsAvailable.find( + (model) => model.id == selectedGenConfig.model + ); + + if (isNaN(value)) { + selectedGenConfig.max_tokens = 1; + } else if (value < 1 || value > 1e20) { + selectedGenConfig.max_tokens = 1; + } else if (model && value > model.context_length) { + selectedGenConfig.max_tokens = model.context_length; + } else { + selectedGenConfig.max_tokens = value; + } + }} + class="rounded-md border border-[#E3E3E3] bg-white px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus-visible:border-[#d5607c] focus-visible:shadow-[0_0_0_1px_#FFD8DF] focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 data-dark:border-[#42464E] data-dark:bg-[#42464e] data-dark:focus-visible:border-[#5b7ee5]" + /> -
- - Knowledge tables - + model.id == selectedGenConfig.model) + ?.context_length} + step="1" + /> +
-
- +
+ { + const value = parseFloat(e.currentTarget.value); + + if (isNaN(value)) { + selectedGenConfig.top_p = 1; + } else if (value < 0.01) { + selectedGenConfig.top_p = 0.001; + } else if (value > 1) { + selectedGenConfig.top_p = 1; + } else { + selectedGenConfig.top_p = Number(value.toFixed(3)); + } + }} + class="rounded-md border border-[#E3E3E3] bg-white px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus-visible:border-[#d5607c] focus-visible:shadow-[0_0_0_1px_#FFD8DF] focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 data-dark:border-[#42464E] data-dark:bg-[#42464e] data-dark:focus-visible:border-[#5b7ee5]" /> - + bind:value={selectedGenConfig.top_p} + min=".001" + max="1" + step=".001" + />
- {/if} -
+
+ {/if} +
-
+
+ + {#if !readonly} + + {/if} +
+
+ {:else if selectedGenConfig.object === 'gen_config.llm'} +
+
+
+

RAG Settings

-
-
Settings
+
+
+ -
- Model +

+ Model will retrieve relevant context from Knowledge Table for accurate + response +

+
- { - selectedModel = model; - - const modelDetails = $modelsAvailable.find((val) => val.id == model); - if (modelDetails && parseInt(editMaxTokens) > modelDetails.context_length) { - editMaxTokens = modelDetails.context_length.toString(); + id="rag-enabled" + name="rag-enabled" + class="" + bind:checked={() => !!selectedGenConfig.rag_params, + (v) => { + if (v) { + selectedGenConfig.rag_params = { + table_id: '', + k: 1, + reranking_model: null + }; + } else { + selectedGenConfig.rag_params = null; } }} - buttonText={($modelsAvailable.find((model) => model.id == selectedModel) - ?.name ?? - selectedModel) || - 'Select model'} - class="w-full bg-[#F2F4F7] data-dark:bg-[#42464e] hover:bg-[#e1e2e6] border-transparent" />
-
+
- - Temperature - + - { - const value = parseFloat(e.currentTarget.value); - - if (isNaN(value)) { - editTemperature = '1'; - } else if (value < 0.01) { - editTemperature = '0.01'; - } else if (value > 1) { - editTemperature = '1'; - } else { - editTemperature = value.toFixed(2); - } - }} - class="px-3 py-2 text-sm bg-[#F2F4F7] data-dark:bg-[#42464e] rounded-md border border-transparent data-dark:border-[#42464E] placeholder:text-muted-foreground focus-visible:outline-none focus-visible:border-[#4169e1] data-dark:focus-visible:border-[#5b7ee5] disabled:cursor-not-allowed disabled:opacity-50 transition-colors" - /> - - +

+ Model will cite its sources [1] as it writes +

-
- - Max tokens - + selectedGenConfig.rag_params?.inline_citations ?? false, + (v) => { + if (selectedGenConfig.rag_params) { + selectedGenConfig.rag_params.inline_citations = v; + } + }} + /> +
- { - const value = parseInt(e.currentTarget.value); - const model = $modelsAvailable.find((model) => model.id == selectedModel); - - if (isNaN(value)) { - editMaxTokens = '1'; - } else if (value < 1 || value > 1e20) { - editMaxTokens = '1'; - } else if (model && value > model.context_length) { - editMaxTokens = model.context_length.toString(); - } else { - editMaxTokens = value.toString(); - } - }} - class="px-3 py-2 text-sm bg-[#F2F4F7] data-dark:bg-[#42464e] rounded-md border border-transparent data-dark:border-[#42464E] placeholder:text-muted-foreground focus-visible:outline-none focus-visible:border-[#4169e1] data-dark:focus-visible:border-[#5b7ee5] disabled:cursor-not-allowed disabled:opacity-50 transition-colors" - /> +
- model.id == selectedModel) - ?.context_length} - step="1" - /> -
+
+
+
+ -
- Top-p +

Number of chunks or documents in context

+
{ - const value = parseFloat(e.currentTarget.value); - - if (isNaN(value)) { - editTopP = '1'; - } else if (value < 0.01) { - editTopP = '0.001'; - } else if (value > 1) { - editTopP = '1'; - } else { - editTopP = value.toFixed(3); + id="rag-k" + name="rag-k" + bind:value={() => selectedGenConfig.rag_params?.k ?? 1, + (v) => { + if (selectedGenConfig.rag_params) { + selectedGenConfig.rag_params.k = v; + } + }} + onblur={() => { + if (selectedGenConfig.rag_params) { + selectedGenConfig.rag_params.k = + //@ts-ignore + parseInt(selectedGenConfig.rag_params.k) <= 0 + ? 1 + : //@ts-ignore + parseInt(selectedGenConfig.rag_params.k); } }} - class="px-3 py-2 text-sm bg-[#F2F4F7] data-dark:bg-[#42464e] rounded-md border border-transparent data-dark:border-[#42464E] placeholder:text-muted-foreground focus-visible:outline-none focus-visible:border-[#4169e1] data-dark:focus-visible:border-[#5b7ee5] disabled:cursor-not-allowed disabled:opacity-50 transition-colors" + class="w-16 rounded-md border border-transparent bg-[#F2F4F7] px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus-visible:border-[#d5607c] focus-visible:shadow-[0_0_0_1px_#FFD8DF] focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 data-dark:border-[#42464E] data-dark:bg-[#42464e] data-dark:focus-visible:border-[#5b7ee5]" /> +
- + selectedGenConfig.rag_params?.k ?? 0, + (v) => { + if (selectedGenConfig.rag_params) { + selectedGenConfig.rag_params.k = v; + } + }} + min="1" + max="1024" + step="1" + /> +
+ +
+ + +

+ Model to reorder retrieved chunks or documents based on relevance +

+ + selectedGenConfig.rag_params?.reranking_model ?? '', + (v) => { + if (selectedGenConfig.rag_params) { + selectedGenConfig.rag_params.reranking_model = v; + } + }} + class="h-10 border-transparent bg-[#F2F4F7] hover:bg-[#e1e2e6] disabled:hover:bg-[#F2F4F7] data-dark:bg-[#42464e]" + /> +
+
+ +
+

Search Knowledge Table

+ +
+ + + + + +
+ +
+
+ Selected knowledge table ({selectedGenConfig.rag_params?.table_id ? '1' : '0'})
+ + {#if selectedGenConfig.rag_params?.table_id} +
+ + + {selectedGenConfig.rag_params?.table_id} +
+ {:else} +
+ No knowledge table selected +
+ {/if}
- + {#if !readonly} - {/if} @@ -820,8 +893,16 @@
- {#if !readonly} - + {#if !readonly && selectedGenConfig?.object === 'gen_config.llm'} + selectedGenConfig.rag_params?.table_id ?? '', + (v) => { + if (selectedGenConfig.rag_params) { + selectedGenConfig.rag_params.table_id = v; + } + }} + /> {/if} {/if} diff --git a/services/app/src/lib/components/tables/(sub)/ConvList.svelte b/services/app/src/lib/components/tables/(sub)/ConvList.svelte old mode 100644 new mode 100755 index 1ea8f92..c198df4 --- a/services/app/src/lib/components/tables/(sub)/ConvList.svelte +++ b/services/app/src/lib/components/tables/(sub)/ConvList.svelte @@ -5,8 +5,8 @@ import { Button } from '$lib/components/ui/button'; import AddIcon from '$lib/icons/AddIcon.svelte'; - let rightDockButton: HTMLButtonElement; - let showRightDockButton = false; + let rightDockButton: HTMLButtonElement | undefined = $state(); + let showRightDockButton = $state(false); function mouseMoveListener(e: MouseEvent) { const chatWindow = document.getElementById('chat-table'); @@ -14,7 +14,7 @@ //* Show/hide the right dock button on hover right side if ( - rightDockButton.contains(el) || + rightDockButton?.contains(el) || (chatWindow?.contains(el) && chatWindow?.offsetWidth - (e.clientX - chatWindow?.offsetLeft) < 75) ) { @@ -27,24 +27,24 @@ function handleNewConv() {} - + -
+
@@ -52,10 +52,10 @@ disabled={!$showRightDock} variant="outline" title="New conversation" - on:click={handleNewConv} - class="flex items-center gap-3 mt-6 p-4 w-full text-center bg-transparent whitespace-nowrap overflow-hidden" + onclick={handleNewConv} + class="mt-6 flex w-full items-center gap-3 overflow-hidden whitespace-nowrap bg-transparent p-4 text-center" > - + New conversation diff --git a/services/app/src/lib/components/tables/(sub)/Conversations.svelte b/services/app/src/lib/components/tables/(sub)/Conversations.svelte old mode 100644 new mode 100755 index 112fe53..9cdf3f3 --- a/services/app/src/lib/components/tables/(sub)/Conversations.svelte +++ b/services/app/src/lib/components/tables/(sub)/Conversations.svelte @@ -6,7 +6,7 @@ import { OverlayScrollbarsComponent } from 'overlayscrollbars-svelte'; import { browser } from '$app/environment'; import { beforeNavigate } from '$app/navigation'; - import { page } from '$app/stores'; + import { page } from '$app/state'; // import { activeConversation, pastConversations, type DBConversation } from './conversationsStore'; import { showRightDock } from '$globalStore'; import logger from '$lib/logger'; @@ -33,28 +33,28 @@ older: 'Older' }; - let autoAnimateController: ReturnType; - let pastConversations: GenTable[] = []; + let autoAnimateController: ReturnType | undefined = $state(); + let pastConversations: GenTable[] = $state([]); let searchResults: typeof pastConversations = []; - let isEditingTitle: string | null = null; - let editedTitle: string; - let saveEditBtn: HTMLButtonElement; + let isEditingTitle: string | null = $state(null); + let editedTitle: string = $state(''); + let saveEditBtn: HTMLButtonElement | undefined = $state(); - let isDeletingConv: string | null = null; + let isDeletingConv: string | null = $state(null); let fetchConvController: AbortController | null = null; - let isFilterByAgent = false; - let isLoadingMoreConversations = false; - let moreConversationsFinished = false; //FIXME: Bandaid fix for infinite loop caused by loading circle - let currentOffset = 0; + let isFilterByAgent = $state(false); + let isLoadingMoreConversations = $state(false); + let moreConversationsFinished = $state(false); //FIXME: Bandaid fix for infinite loop caused by loading circle + let currentOffset = $state(0); const limit = 50; - let searchQuery: string; + let searchQuery: string = $state(''); let isNoResults = false; async function getPastConversations() { - const tableData = (await $page.data.table) as + const tableData = (await page.data.table) as | { error: number; message: any; @@ -82,11 +82,11 @@ searchParams.append('parent_id', tableData.data.parent_id ?? ''); } - const response = await fetch(`${PUBLIC_JAMAI_URL}/api/v1/gen_tables/chat?` + searchParams, { + const response = await fetch(`${PUBLIC_JAMAI_URL}/api/owl/gen_tables/chat?` + searchParams, { credentials: 'same-origin', signal: fetchConvController?.signal, headers: { - 'x-project-id': $page.params.project_id + 'x-project-id': page.params.project_id } }); currentOffset += limit; @@ -117,7 +117,7 @@ } onMount(() => { - $page.data.table.then(() => { + page.data.table.then(() => { if (browser) { currentOffset = 0; moreConversationsFinished = false; @@ -146,7 +146,7 @@ } }; - let timestamps: Timestamp = { + let timestamps: Timestamp = $state({ today: null, yesterday: null, two_days: null, @@ -154,11 +154,11 @@ last_week: null, last_month: null, older: null - }; + }); let timestampKeys = Object.keys(timestamps) as Array; - $: { + $effect(() => { timestampKeys.forEach((key) => (timestamps[key] = null)); pastConversations.forEach((conversation, index) => { const timeDiff = Date.now() - new Date(conversation.updated_at).getTime(); @@ -190,7 +190,7 @@ timestamps.older = index; } }); - } + }); beforeNavigate(() => (isEditingTitle = null)); @@ -205,7 +205,7 @@ const debouncedSearchConv = () => {}; -Chat history +Chat history
- + {#snippet leading()} {#if isLoadingSearch} -
+
{:else} {/if} - + {/snippet}
-
+
+
No results found
{:else} @@ -271,14 +271,14 @@ on:osInitialized={(e) => { autoAnimateController = autoAnimate(e.detail[0].elements().viewport); }} - class="grow flex flex-col my-3 rounded-md overflow-auto os-dark" + class="os-dark my-3 flex grow flex-col overflow-auto rounded-md" > {#each !searchResults.length && !isNoResults ? pastConversations : searchResults as conversation, index (conversation.id)} {#if !searchResults.length && !isNoResults} {#each timestampKeys as time (time)} {#if timestamps[time] == index}
- + {timestampsDisplayName[time]}
@@ -287,40 +287,40 @@ {/if} {#if isEditingTitle === conversation.id}
- +
@@ -329,30 +329,29 @@ - + -
- +
+ {conversation.id}
- + + {#snippet child({ props })} + + {/snippet} + diff --git a/services/app/src/lib/components/tables/(sub)/FileColumnView.svelte b/services/app/src/lib/components/tables/(sub)/FileColumnView.svelte old mode 100644 new mode 100755 index 373c3b1..f62ad76 --- a/services/app/src/lib/components/tables/(sub)/FileColumnView.svelte +++ b/services/app/src/lib/components/tables/(sub)/FileColumnView.svelte @@ -1,6 +1,5 @@ {#if fileUri && isValidUri(fileUri)} - {@const fileType = fileColumnFiletypes.find(({ ext }) => fileUri.endsWith(ext))?.type} + {@const fileType = fileColumnFiletypes.find(({ ext }) => + fileUri.toLowerCase().endsWith(ext) + )?.type}
{#if fileUrl && isValidUri(fileUrl)?.protocol.startsWith('http') && fileType !== undefined} {#if fileType === 'image'} - + {:else if fileType === 'audio'} - + + {:else if fileType === 'document'} +
+
+ + {fileUri.split('/').pop()} +
+
{/if} {:else} -
-
+
+
- {fileUri.split('/').pop()} + + {fileUri.split('/').pop()} +
{/if}
+ {#if !readonly} + + {/if} + - {#if fileUrl && isValidUri(fileUrl)?.protocol.startsWith('http') && fileType === 'file'} + {#if fileUrl && isValidUri(fileUrl)?.protocol.startsWith('http') && fileType === 'image'} - - + + {#snippet child({ props })} + + {/snippet} -
+
{fileUri.split('/').pop()}
- +
- - - - - + + {#snippet child({ props })} + + {/snippet} + {#if !readonly} + + {#snippet child({ props })} + + {/snippet} + + {/if}
- {:else if fileUri} - {/if}
diff --git a/services/app/src/lib/components/tables/(sub)/FileSelect.svelte b/services/app/src/lib/components/tables/(sub)/FileSelect.svelte old mode 100644 new mode 100755 index fb271e5..a1f3387 --- a/services/app/src/lib/components/tables/(sub)/FileSelect.svelte +++ b/services/app/src/lib/components/tables/(sub)/FileSelect.svelte @@ -3,7 +3,7 @@ import axios from 'axios'; import debounce from 'lodash/debounce'; import toUpper from 'lodash/toUpper'; - import { page } from '$app/stores'; + import { page } from '$app/state'; import { fileColumnFiletypes } from '$lib/constants'; import logger from '$lib/logger'; import type { GenTableCol } from '$lib/types'; @@ -13,18 +13,29 @@ import CloseIcon from '$lib/icons/CloseIcon.svelte'; import LoadingSpinner from '$lib/icons/LoadingSpinner.svelte'; - export let tableType: 'action' | 'knowledge' | 'chat'; - export let controller: (string | AbortController) | (AbortController | undefined); - export let selectCb: (files: File[]) => void = handleSaveEditFile; - export let column: GenTableCol; - /** Edit cell function for tables */ - export let saveEditCell: - | ((cellToUpdate: { rowID: string; columnID: string }, editedValue: string) => Promise) - | undefined = undefined; - export let cellToUpdate: { rowID: string; columnID: string } | undefined = undefined; + interface Props { + tableType: 'action' | 'knowledge' | 'chat'; + controller: (string | AbortController) | (AbortController | undefined); + selectCb?: (files: File[]) => void; + column: GenTableCol; + /** Edit cell function for tables */ + saveEditCell?: + | ((cellToUpdate: { rowID: string; columnID: string }, editedValue: string) => Promise) + | undefined; + cellToUpdate?: { rowID: string; columnID: string } | undefined; + } + + let { + tableType, + controller = $bindable(), + selectCb = handleSaveEditFile, + column, + saveEditCell = undefined, + cellToUpdate = undefined + }: Props = $props(); - let container: HTMLDivElement; - let filesDragover = false; + let container: HTMLDivElement | undefined = $state(); + let filesDragover = $state(false); /** Validate before upload */ function handleSelectFiles(files: File[]) { @@ -76,10 +87,10 @@ formData.append('file', files[0]); try { - const uploadRes = await axios.post(`${PUBLIC_JAMAI_URL}/api/v1/files/upload`, formData, { + const uploadRes = await axios.post(`${PUBLIC_JAMAI_URL}/api/owl/files/upload`, formData, { headers: { 'Content-Type': 'multipart/form-data', - 'x-project-id': $page.params.project_id + 'x-project-id': page.params.project_id }, signal: controller.signal }); @@ -118,18 +129,20 @@ const handleDragLeave = () => (filesDragover = false); - +
{ + ondragover={(e) => { + e.preventDefault(); if (e.dataTransfer?.items) { if ([...e.dataTransfer.items].some((item) => item.kind === 'file')) { filesDragover = true; } } }} - on:dragleave={debounce(handleDragLeave, 50)} - on:drop|preventDefault={(e) => { + ondragleave={debounce(handleDragLeave, 50)} + ondrop={(e) => { + e.preventDefault(); filesDragover = false; if (e.dataTransfer?.items) { handleSelectFiles( @@ -152,17 +165,17 @@ handleSelectFiles([...(e.dataTransfer?.files ?? [])]); } }} - class="flex flex-col gap-1 px-2 py-2 h-full w-full" + class="flex h-full w-full flex-col gap-1 px-2 py-2" > @@ -196,9 +209,12 @@ .filter(({ type }) => column.dtype === type) .map(({ ext }) => ext) .join(',')} - on:change|preventDefault={(e) => handleSelectFiles([...(e.currentTarget.files ?? [])])} + onchange={(e) => { + e.preventDefault(); + handleSelectFiles([...(e.currentTarget.files ?? [])]); + }} multiple={false} - class="fixed max-h-[0] max-w-0 !p-0 !border-none overflow-hidden" + class="fixed max-h-[0] max-w-0 overflow-hidden !border-none !p-0" /> diff --git a/services/app/src/lib/components/tables/(sub)/FileThumbsFetch.svelte b/services/app/src/lib/components/tables/(sub)/FileThumbsFetch.svelte old mode 100644 new mode 100755 index 5714ea5..20b2055 --- a/services/app/src/lib/components/tables/(sub)/FileThumbsFetch.svelte +++ b/services/app/src/lib/components/tables/(sub)/FileThumbsFetch.svelte @@ -1,41 +1,47 @@ diff --git a/services/app/src/lib/components/tables/(sub)/NewRow.svelte b/services/app/src/lib/components/tables/(sub)/NewRow.svelte old mode 100644 new mode 100755 index ff6bb73..b74f3a8 --- a/services/app/src/lib/components/tables/(sub)/NewRow.svelte +++ b/services/app/src/lib/components/tables/(sub)/NewRow.svelte @@ -4,8 +4,8 @@ import axios from 'axios'; import toUpper from 'lodash/toUpper'; import { v4 as uuidv4 } from 'uuid'; - import { page } from '$app/stores'; - import { genTableRows, tableState } from '$lib/components/tables/tablesStore'; + import { page } from '$app/state'; + import { getTableState, getTableRowsState } from '$lib/components/tables/tablesState.svelte'; import { cn } from '$lib/utils'; import logger from '$lib/logger'; import type { GenTable, GenTableRow, GenTableStreamEvent } from '$lib/types'; @@ -17,26 +17,47 @@ import AddIcon from '$lib/icons/AddIcon.svelte'; import StarIcon from '$lib/icons/StarIcon.svelte'; - export let tableType: 'action' | 'knowledge' | 'chat'; - export let tableData: GenTable; - export let focusedCol: string | null; - export let refetchTable: () => Promise; - let className: string | undefined | null = undefined; - export { className as class }; - - let newRowForm: HTMLFormElement; - let maxInputHeight = 36; - let isAddingRow = false; - let uploadColumns: Record = {}; - let isLoadingAddRow = false; - let inputValues: Record = {}; - let isDeletingFile: { rowID: string; columnID: string; fileUri?: string } | null = null; + const tableState = getTableState(); + const tableRowsState = getTableRowsState(); + + interface Props { + tableType: 'action' | 'knowledge' | 'chat'; + tableData: GenTable; + focusedCol: string | null; + refetchTable: () => Promise; + class?: string | undefined | null; + } - $: tableData, isAddingRow, uploadColumns, resetMaxInputHeight(); + let { + tableType, + tableData, + focusedCol, + refetchTable, + class: className = undefined + }: Props = $props(); + + let newRowForm: HTMLFormElement | undefined = $state(); + let maxInputHeight = $state(36); + let isAddingRow = $state(false); + let uploadColumns: Record = $state({}); + let isLoadingAddRow = false; + let inputValues: Record = $state({}); + let isDeletingFile: { rowID: string; columnID: string; fileUri?: string } | null = $state(null); + + $effect(() => { + tableData; + isAddingRow; + uploadColumns; + resetMaxInputHeight(); + }); async function resetMaxInputHeight() { if (Object.entries(uploadColumns).some((val) => typeof val[1] === 'string')) { maxInputHeight = 150; - } else if (tableData.cols.find((col) => col.dtype === 'image' || col.dtype === 'audio')) { + } else if ( + tableData.cols.find( + (col) => col.dtype === 'image' || col.dtype === 'audio' || col.dtype === 'document' + ) + ) { maxInputHeight = 72; } else { maxInputHeight = 32; @@ -63,6 +84,7 @@ } async function handleAddRow(e: SubmitEvent & { currentTarget: EventTarget & HTMLFormElement }) { + e.preventDefault(); if (isLoadingAddRow) return; const formData = new FormData(e.currentTarget); const obj = Object.fromEntries( @@ -82,7 +104,7 @@ inputValues = {}; const clientRowID = uuidv4(); - genTableRows.addRow({ + tableRowsState.addRow({ ID: clientRowID, 'Updated at': new Date().toISOString(), ...(Object.fromEntries( @@ -90,26 +112,21 @@ ) as any) }); - console.log({ - [clientRowID]: tableData.cols - .filter((col) => col.gen_config && !Object.keys(data).includes(col.id)) - .map((col) => col.id) - }); tableState.addStreamingRows({ [clientRowID]: tableData.cols .filter((col) => col.gen_config && !Object.keys(data).includes(col.id)) .map((col) => col.id) }); - const response = await fetch(`${PUBLIC_JAMAI_URL}/api/v1/gen_tables/${tableType}/rows/add`, { + const response = await fetch(`${PUBLIC_JAMAI_URL}/api/owl/gen_tables/${tableType}/rows/add`, { method: 'POST', headers: { Accept: 'text/event-stream', 'Content-Type': 'application/json', - 'x-project-id': $page.params.project_id + 'x-project-id': page.params.project_id }, body: JSON.stringify({ - table_id: $page.params.table_id, + table_id: page.params.table_id, data: [data], stream: true }) @@ -129,100 +146,93 @@ } }); - genTableRows.deleteRow(clientRowID); + tableRowsState.deleteRow(clientRowID); tableState.delStreamingRows([clientRowID]); } else { const reader = response.body!.pipeThrough(new TextDecoderStream()).getReader(); - let isStreaming = true; - let lastMessage = ''; - let rowId = ''; + let rowID = ''; let addedRow = false; - while (isStreaming) { + // let references: Record> | null = null; + let buffer = ''; + // eslint-disable-next-line no-constant-condition + while (true) { try { const { value, done } = await reader.read(); if (done) break; - if (value.endsWith('\n\n')) { - const lines = (lastMessage + value) - .split('\n\n') - .filter((i) => i.trim()) - .flatMap((line) => line.split('\n')); //? Split by \n to handle collation - - lastMessage = ''; - - for (const line of lines) { - const sumValue = line.replace(/^data: /, '').replace(/data: \[DONE\]\s+$/, ''); - - if (sumValue.trim() == '[DONE]') break; - - let parsedValue; - try { - parsedValue = JSON.parse(sumValue) as GenTableStreamEvent; - } catch (err) { - console.error('Error parsing:', sumValue); - logger.error(toUpper(`${tableType}TBL_ROW_ADDSTREAMPARSE`), { - parsing: sumValue, - error: err - }); - continue; - } - - if (parsedValue.object === 'gen_table.completion.chunk') { - if (parsedValue.choices[0].finish_reason) { - switch (parsedValue.choices[0].finish_reason) { - case 'error': { - logger.error(toUpper(`${tableType}_ROW_ADDSTREAM`), parsedValue); - console.error('STREAMING_ERROR', parsedValue); - alert(`Error while streaming: ${parsedValue.choices[0].message.content}`); - break; + buffer += value; + const lines = buffer.split('\n'); //? Split by \n to handle collation + buffer = lines.pop() || ''; + + let parsedEvent: { data: GenTableStreamEvent } | undefined = undefined; + for (const line of lines) { + if (line === '') { + if (parsedEvent) { + if (parsedEvent.data.object === 'gen_table.completion.chunk') { + if (parsedEvent.data.choices[0].finish_reason) { + switch (parsedEvent.data.choices[0].finish_reason) { + case 'error': { + logger.error(toUpper(`${tableType}_ROW_ADDSTREAM`), parsedEvent.data); + console.error('STREAMING_ERROR', parsedEvent.data); + alert( + `Error while streaming: ${parsedEvent.data.choices[0].message.content}` + ); + break; + } + default: { + const streamingCols = + tableState.streamingRows[parsedEvent.data.row_id]?.filter( + (col) => col !== parsedEvent?.data.output_column_name + ) ?? []; + if (streamingCols.length === 0) { + tableState.delStreamingRows([parsedEvent.data.row_id]); + } else { + tableState.addStreamingRows({ + [parsedEvent.data.row_id]: streamingCols + }); + } + break; + } } - default: { - const streamingCols = $tableState.streamingRows[parsedValue.row_id].filter( - (col) => col !== parsedValue.output_column_name + } else { + rowID = parsedEvent.data.row_id; + + //* Add chunk to active row + if (!addedRow) { + tableRowsState.updateRow(clientRowID, { + ID: parsedEvent.data.row_id, + [parsedEvent.data.output_column_name]: { + value: parsedEvent.data.choices[0].message.content ?? '' + } + } as GenTableRow); + tableState.delStreamingRows([clientRowID]); + tableState.addStreamingRows({ + [parsedEvent.data.row_id]: tableData.cols + .filter((col) => col.gen_config && !Object.keys(data).includes(col.id)) + .map((col) => col.id) + }); + addedRow = true; + } else { + tableRowsState.stream( + parsedEvent.data.row_id, + parsedEvent.data.output_column_name, + parsedEvent.data.choices[0].message.content ?? '' ); - if (streamingCols.length === 0) { - tableState.delStreamingRows([parsedValue.row_id]); - } else { - tableState.addStreamingRows({ - [parsedValue.row_id]: streamingCols - }); - } - break; } } } else { - rowId = parsedValue.row_id; - - //* Add chunk to active row - if (!addedRow) { - genTableRows.updateRow(clientRowID, { - ID: parsedValue.row_id, - [parsedValue.output_column_name]: { - value: parsedValue.choices[0].message.content ?? '' - } - } as GenTableRow); - tableState.delStreamingRows([clientRowID]); - tableState.addStreamingRows({ - [parsedValue.row_id]: tableData.cols - .filter((col) => col.gen_config && !Object.keys(data).includes(col.id)) - .map((col) => col.id) - }); - addedRow = true; - } else { - genTableRows.stream( - parsedValue.row_id, - parsedValue.output_column_name, - parsedValue.choices[0].message.content ?? '' - ); - } + console.log('Unknown message:', parsedEvent.data); } } else { - console.log('Unknown message:', parsedValue); + console.warn('Unknown event object:', parsedEvent); } - } - } else { - lastMessage += value; + } else if (line.startsWith('data: ')) { + if (line.slice(6) === '[DONE]') break; + parsedEvent = { ...(parsedEvent ?? {}), data: JSON.parse(line.slice(6)) }; + } /* else if (line.startsWith('event: ')) { + parsedEvent = { ...(parsedEvent ?? {}), event: line.slice(7) }; + } */ } } catch (err) { logger.error(toUpper(`${tableType}TBL_ROW_ADDSTREAM`), err); @@ -231,7 +241,7 @@ } } - tableState.delStreamingRows([clientRowID, rowId]); + tableState.delStreamingRows([clientRowID, rowID]); refetchTable(); } } @@ -244,10 +254,10 @@ formData.append('file', files[0]); try { - const uploadRes = await axios.post(`${PUBLIC_JAMAI_URL}/api/v1/files/upload`, formData, { + const uploadRes = await axios.post(`${PUBLIC_JAMAI_URL}/api/owl/files/upload`, formData, { headers: { 'Content-Type': 'multipart/form-data', - 'x-project-id': $page.params.project_id + 'x-project-id': page.params.project_id }, signal: uploadController.signal }); @@ -264,11 +274,11 @@ ); return; } else { - const urlResponse = await fetch(`/api/v1/files/url/thumb`, { + const urlResponse = await fetch(`/api/owl/files/url/thumb`, { method: 'POST', headers: { 'Content-Type': 'application/json', - 'x-project-id': $page.params.project_id + 'x-project-id': page.params.project_id }, body: JSON.stringify({ uris: [uploadRes.data.uri] @@ -312,8 +322,8 @@ { - if (!newRowForm.contains(document.activeElement)) { + onclick={() => { + if (!newRowForm?.contains(document.activeElement)) { const formData = new FormData(newRowForm); const obj = Object.fromEntries( Array.from(formData.keys()).map((key) => [ @@ -334,40 +344,40 @@ }} /> - - + +
(isAddingRow = true)} - on:keydown={(event) => { + onclick={() => (isAddingRow = true)} + onkeydown={(event) => { if (event.key === 'Enter' && !event.shiftKey) { event.preventDefault(); - event.currentTarget.requestSubmit(); + if (isAddingRow) event.currentTarget.requestSubmit(); } }} - on:submit|preventDefault={handleAddRow} + onsubmit={handleAddRow} style="grid-template-columns: 45px {focusedCol === 'ID' ? '320px' : '120px'} {focusedCol === 'Updated at' ? '320px' - : '130px'} {$tableState.templateCols};" + : '130px'} {tableState.templateCols};" class={cn( - 'sticky top-[36px] z-20 grid place-items-start h-min max-h-[100px] sm:max-h-[150px] text-xs sm:text-sm text-[#667085] bg-[#FAFBFC] data-dark:bg-[#1E2024] group border-l border-l-transparent data-dark:border-l-transparent border-r-transparent data-dark:border-r-transparent border-b border-[#E4E7EC] data-dark:border-[#333]', + 'group sticky top-[36px] z-20 grid h-min max-h-[100px] place-items-start border-b border-l border-[#E4E7EC] border-l-transparent border-r-transparent bg-[#F2F4F7] text-xs text-[#667085] data-dark:border-[#333] data-dark:border-l-transparent data-dark:border-r-transparent data-dark:bg-[#1E2024] sm:max-h-[150px] sm:text-sm', className )} >
{#if isAddingRow} @@ -375,7 +385,7 @@ variant="ghost" type="button" title="Cancel" - on:click={(e) => { + onclick={(e) => { e.stopPropagation(); const formData = new FormData(newRowForm); @@ -399,9 +409,9 @@ uploadColumns = {}; inputValues = {}; }} - class="p-0 h-6 sm:h-7 rounded-full aspect-square" + class="aspect-square h-6 rounded-full p-0 sm:h-7" > - + {:else} @@ -410,9 +420,9 @@
New Row @@ -420,8 +430,8 @@ {#if isAddingRow} @@ -432,18 +442,20 @@ {#each tableData.cols as column} {#if column.id !== 'ID' && column.id !== 'Updated at'} {@const columnFile = uploadColumns[column.id]} - +
{#if isAddingRow} - {#if column.dtype === 'image' || column.dtype === 'audio'} + {#if column.dtype === 'image' || column.dtype === 'audio' || column.dtype === 'document'} {#if typeof columnFile !== 'string'} {/if} {/if} diff --git a/services/app/src/lib/components/tables/(sub)/PlaceholderNewCol.svelte b/services/app/src/lib/components/tables/(sub)/PlaceholderNewCol.svelte new file mode 100644 index 0000000..96897d3 --- /dev/null +++ b/services/app/src/lib/components/tables/(sub)/PlaceholderNewCol.svelte @@ -0,0 +1,324 @@ + + + { + if (e.key === 'Escape') { + tableState.addingCol = false; + } + }} +/> + +
+ + + + + { + if (e.key === 'Enter') { + handleAddColumn(); + } + }} + style="left: {colIDPaddingWidth + 32}px; width: {colIDInputWidth}px;" + class="pointer-events-auto absolute -top-[26px] h-[20px] rounded-[2px] border-0 bg-transparent text-sm outline outline-1 outline-[#4169e1] data-dark:outline-[#5b7ee5]" + /> + +
+
+ + + {#snippet children()} + + {colType || 'Select Column Type'} + + {/snippet} + + + {#each Object.keys(genTableColTypes) as colType} + + {colType} + + {/each} + + +
+ +
+ + + {#snippet children()} + + {genTableDTypes[dType] || 'Select Data Type'} + + {/snippet} + + + {#each genTableColDTypes[colType] as dType} + + {genTableDTypes[dType]} + + {/each} + + +
+
+ +
+ + +
+
+
+ + + + {colType === 'Input' ? 'Input' : 'Output'} + + + {dType} + + + + + + + + + +
diff --git a/services/app/src/lib/components/tables/(sub)/SelectKnowledgeTableDialog.svelte b/services/app/src/lib/components/tables/(sub)/SelectKnowledgeTableDialog.svelte old mode 100644 new mode 100755 index 8009d68..44a17f5 --- a/services/app/src/lib/components/tables/(sub)/SelectKnowledgeTableDialog.svelte +++ b/services/app/src/lib/components/tables/(sub)/SelectKnowledgeTableDialog.svelte @@ -1,9 +1,8 @@ - + Choose Knowledge Table(s) -
+
(isAddingTable = true)} - class="place-self-end lg:place-self-center flex-[0_0_auto] relative flex items-center justify-center gap-1.5 mr-1 sm:mr-0.5 px-2 sm:px-3 py-2 h-min w-min text-xs sm:text-sm aspect-square sm:aspect-auto" + onclick={() => (isAddingTable = true)} + class="relative mr-1 flex aspect-square h-min w-min flex-[0_0_auto] items-center justify-center gap-1.5 place-self-end px-2 py-2 text-xs sm:mr-0.5 sm:aspect-auto sm:px-3 sm:text-sm lg:place-self-center" > @@ -187,32 +191,32 @@
{#if isLoadingKTables} {#each Array(12) as _} {/each} {:else} {#each pastKnowledgeTables as knowledgeTable} - + + {#snippet child({ props })} + + {/snippet} +
diff --git a/services/app/src/lib/components/tables/(sub)/TablePagination.svelte b/services/app/src/lib/components/tables/(sub)/TablePagination.svelte old mode 100644 new mode 100755 index 291fc77..0324999 --- a/services/app/src/lib/components/tables/(sub)/TablePagination.svelte +++ b/services/app/src/lib/components/tables/(sub)/TablePagination.svelte @@ -1,7 +1,7 @@ + + + + {#snippet child({ props })} + + {/snippet} + + + + +
+
+ { + if (v === 'ID') { + page.url.searchParams.delete('sort_by'); + } else { + page.url.searchParams.set('sort_by', v); + } + goto(`?${page.url.searchParams}`, { + replaceState: true, + invalidate: [`${tableType}-table:slug`] + }); + }} + > + + {#snippet children()} + + {page.url.searchParams.get('sort_by') ?? 'ID'} + + {/snippet} + + + {#each tableData?.cols ?? [] as column} + {@const colType = !column.gen_config ? 'input' : 'output'} + + {#if !['ID', 'Updated at'].includes(column.id)} + + + {colType} + + + {column.dtype} + + + + + {/if} + + {column.id} + + {/each} + + +
+ +
+ { + if (v === '0') { + page.url.searchParams.delete('asc'); + } else { + page.url.searchParams.set('asc', '1'); + } + goto(`?${page.url.searchParams}`, { + replaceState: true, + invalidate: [`${tableType}-table:slug`] + }); + }} + > + + {#snippet children()} + + {page.url.searchParams.get('asc') === '1' ? 'Ascending' : 'Descending'} + + {/snippet} + + + {#each ['0', '1'] as sortDirection} + + {sortDirection === '1' ? 'Ascending' : 'Descending'} + + {/each} + + +
+
+ + +
+
+
diff --git a/services/app/src/lib/components/tables/(sub)/index.ts b/services/app/src/lib/components/tables/(sub)/index.ts old mode 100644 new mode 100755 index 13e5a7a..7999658 --- a/services/app/src/lib/components/tables/(sub)/index.ts +++ b/services/app/src/lib/components/tables/(sub)/index.ts @@ -8,6 +8,7 @@ import FileThumbsFetch from './FileThumbsFetch.svelte'; import NewRow from './NewRow.svelte'; import SelectKnowledgeTableDialog from './SelectKnowledgeTableDialog.svelte'; import TablePagination from './TablePagination.svelte'; +import TableSorter from './TableSorter.svelte'; export { ColumnDropdown, ColumnHeader, @@ -18,5 +19,6 @@ export { FileThumbsFetch, NewRow, SelectKnowledgeTableDialog, - TablePagination + TablePagination, + TableSorter }; diff --git a/services/app/src/lib/components/tables/(svg)/NoRowsGraphic.svelte b/services/app/src/lib/components/tables/(svg)/NoRowsGraphic.svelte old mode 100644 new mode 100755 index f3c1e9e..ca41a17 --- a/services/app/src/lib/components/tables/(svg)/NoRowsGraphic.svelte +++ b/services/app/src/lib/components/tables/(svg)/NoRowsGraphic.svelte @@ -1,6 +1,10 @@ import { PUBLIC_JAMAI_URL } from '$env/static/public'; import { onDestroy } from 'svelte'; - import { page } from '$app/stores'; - import { tableState, genTableRows } from '$lib/components/tables/tablesStore'; + import { page } from '$app/state'; + import { getTableState, getTableRowsState } from '$lib/components/tables/tablesState.svelte'; import { cn, isValidUri } from '$lib/utils'; import logger from '$lib/logger'; - import type { GenTable, GenTableRow, UserRead } from '$lib/types'; + import type { GenTable, GenTableRow, User } from '$lib/types'; import { ColumnHeader, @@ -21,29 +21,44 @@ import { toast, CustomToastDesc } from '$lib/components/ui/sonner'; import LoadingSpinner from '$lib/icons/LoadingSpinner.svelte'; - export let userData: UserRead | undefined; - export let tableData: GenTable | undefined; - export let tableError: { error: number; message?: any } | undefined; - export let readonly = false; - export let refetchTable: (hideColumnSettings?: boolean) => Promise; + const tableState = getTableState(); + const tableRowsState = getTableRowsState(); - let rowThumbs: { [rowID: string]: { [colID: string]: { value: string; url: string } } } = {}; - let isDeletingFile: { rowID: string; columnID: string; fileUri?: string } | null = null; + interface Props { + user: User | undefined; + tableData: GenTable | undefined; + tableError: { error: number; message?: any } | undefined; + readonly?: boolean; + refetchTable: (hideColumnSettings?: boolean) => Promise; + } + + let { + user, + tableData = $bindable(), + tableError = $bindable(), + readonly = false, + refetchTable + }: Props = $props(); + + let rowThumbs: { [rowID: string]: { [colID: string]: { value: string; url: string } } } = $state( + {} + ); + let isDeletingFile: { rowID: string; columnID: string; fileUri?: string } | null = $state(null); let uploadController: AbortController | undefined = undefined; //? Expanding ID and Updated at columns - let focusedCol: string | null = null; + let focusedCol: string | null = $state(null); async function handleSaveEdit( e: KeyboardEvent & { currentTarget: EventTarget & HTMLTextAreaElement; } ) { - if (!tableData || !$genTableRows) return; - if (!$tableState.editingCell) return; + if (!tableData || !tableRowsState.rows) return; + if (!tableState.editingCell) return; const editedValue = e.currentTarget.value; - const cellToUpdate = $tableState.editingCell; + const cellToUpdate = tableState.editingCell; await saveEditCell(cellToUpdate, editedValue); } @@ -52,25 +67,26 @@ cellToUpdate: { rowID: string; columnID: string }, editedValue: string ) { - if (!tableData || !$genTableRows) return; + if (!tableData || !tableRowsState.rows) return; //? Optimistic update - const originalValue = $genTableRows.find((row) => row.ID === cellToUpdate!.rowID)?.[ + const originalValue = tableRowsState.rows.find((row) => row.ID === cellToUpdate!.rowID)?.[ cellToUpdate.columnID ]; - genTableRows.setCell(cellToUpdate, editedValue); + tableRowsState.setCell(cellToUpdate, editedValue); - const response = await fetch(`${PUBLIC_JAMAI_URL}/api/v1/gen_tables/action/rows/update`, { - method: 'POST', + const response = await fetch(`${PUBLIC_JAMAI_URL}/api/owl/gen_tables/action/rows`, { + method: 'PATCH', headers: { 'Content-Type': 'application/json', - 'x-project-id': $page.params.project_id + 'x-project-id': page.params.project_id }, body: JSON.stringify({ table_id: tableData.id, - row_id: cellToUpdate.rowID, data: { - [cellToUpdate.columnID]: editedValue + [cellToUpdate.rowID]: { + [cellToUpdate.columnID]: editedValue + } } }) }); @@ -88,7 +104,7 @@ }); //? Revert back to original value - genTableRows.setCell(cellToUpdate, originalValue); + tableRowsState.setCell(cellToUpdate, originalValue?.value); } else { tableState.setEditingCell(null); refetchTable(); @@ -101,20 +117,24 @@ e: CustomEvent<{ event: MouseEvent; value: boolean }>, row: GenTableRow ) { - if (!tableData || !$genTableRows) return; + if (!tableData || !tableRowsState.rows) return; //? Select multiple rows with shift key - const rowIndex = $genTableRows.findIndex(({ ID }) => ID === row.ID); - if (e.detail.event.shiftKey && $tableState.selectedRows.length && shiftOrigin != null) { + const rowIndex = tableRowsState.rows.findIndex(({ ID }) => ID === row.ID); + if (e.detail.event.shiftKey && tableState.selectedRows.length && shiftOrigin != null) { if (shiftOrigin < rowIndex) { tableState.setSelectedRows([ - ...$tableState.selectedRows.filter((i) => !$genTableRows?.some(({ ID }) => ID === i)), - ...$genTableRows.slice(shiftOrigin, rowIndex + 1).map(({ ID }) => ID) + ...tableState.selectedRows.filter( + (i) => !tableRowsState.rows?.some(({ ID }) => ID === i) + ), + ...tableRowsState.rows.slice(shiftOrigin, rowIndex + 1).map(({ ID }) => ID) ]); } else if (shiftOrigin > rowIndex) { tableState.setSelectedRows([ - ...$tableState.selectedRows.filter((i) => !$genTableRows?.some(({ ID }) => ID === i)), - ...$genTableRows.slice(rowIndex, shiftOrigin + 1).map(({ ID }) => ID) + ...tableState.selectedRows.filter( + (i) => !tableRowsState.rows?.some(({ ID }) => ID === i) + ), + ...tableRowsState.rows.slice(rowIndex, shiftOrigin + 1).map(({ ID }) => ID) ]); } else { tableState.toggleRowSelection(row.ID); @@ -127,7 +147,7 @@ } function keyboardNavigate(e: KeyboardEvent) { - if (!tableData || !$genTableRows) return; + if (!tableData || !tableRowsState.rows) return; // const isCtrl = window.navigator.userAgent.indexOf('Mac') != -1 ? e.metaKey : e.ctrlKey; // const activeElement = document.activeElement as HTMLElement; // const isInputActive = activeElement.tagName == 'INPUT' || activeElement.tagName == 'TEXTAREA'; @@ -137,8 +157,8 @@ // if (Object.keys(streamingRows).length !== 0) return; // selectedRows = [ - // ...selectedRows.filter((i) => !$genTableRows?.some(({ ID }) => ID === i)), - // ...$genTableRows.map(({ ID }) => ID) + // ...selectedRows.filter((i) => !tableRowsState?.some(({ ID }) => ID === i)), + // ...tableRowsState.map(({ ID }) => ID) // ]; // } @@ -148,30 +168,30 @@ } onDestroy(() => { - $genTableRows = undefined; + tableRowsState.rows = undefined; tableState.reset(); }); { + onmousedown={(e) => { const editingCell = document.querySelector('[data-editing="true"]'); //@ts-ignore if (e.target && editingCell && !editingCell.contains(e.target)) { tableState.setEditingCell(null); } }} - on:keydown={keyboardNavigate} + onkeydown={keyboardNavigate} /> {#if tableData}
{ + onscroll={(e) => { //? Used to prevent elements showing through the padding between side nav and table header //FIXME: Use transform for performance const el = document.getElementById('checkbox-bg-obscure'); @@ -180,48 +200,48 @@ } }} role="grid" - style="grid-template-rows: 36px {$genTableRows - ? `repeat(${$genTableRows.length + (!readonly ? 1 : 0)}, min-content)` + style="grid-template-rows: 36px {tableRowsState.rows && !tableRowsState.loading + ? `repeat(${tableRowsState.rows.length + (!readonly ? 1 : 0)}, min-content)` : 'minmax(0, 1fr)'};" - class="grow relative grid px-2 overflow-auto" + class="relative grid grow overflow-auto px-2" >
{#if !readonly} { - if ($genTableRows) { - return tableState.selectAllRows($genTableRows); + if (tableRowsState.rows) { + return tableState.selectAllRows(tableRowsState.rows); } else return false; }} - checked={($genTableRows ?? []).every((row) => - $tableState.selectedRows.includes(row.ID) + checked={(tableRowsState.rows ?? []).every((row) => + tableState.selectedRows.includes(row.ID) )} - class="h-4 sm:h-[18px] w-4 sm:w-[18px] [&>svg]:h-3 sm:[&>svg]:h-3.5 [&>svg]:w-3 sm:[&>svg]:w-3.5 [&>svg]:translate-x-[1px]" + class="h-4 w-4 sm:h-[18px] sm:w-[18px] [&>svg]:h-3 [&>svg]:w-3 [&>svg]:translate-x-[1px] sm:[&>svg]:h-3.5 sm:[&>svg]:w-3.5" /> {/if}
@@ -229,7 +249,7 @@
- {#if $genTableRows} + {#if tableRowsState.rows && !tableRowsState.loading} {#if !readonly} @@ -246,134 +266,142 @@ - {#each $genTableRows as row (row.ID)} + {#each tableRowsState.rows as row (row.ID)}
- {#if $tableState.streamingRows[row.ID]} + {#if tableState.streamingRows[row.ID]}
{/if}
{#if !readonly} handleSelectRow(e, row)} - checked={!!$tableState.selectedRows.find((i) => i === row.ID)} - class="mt-[1px] h-4 sm:h-[18px] w-4 sm:w-[18px] [&>svg]:h-3 sm:[&>svg]:h-3.5 [&>svg]:w-3 sm:[&>svg]:w-3.5 [&>svg]:translate-x-[1px]" + checked={!!tableState.selectedRows.find((i) => i === row.ID)} + class="mt-[1px] h-4 w-4 sm:h-[18px] sm:w-[18px] [&>svg]:h-3 [&>svg]:w-3 [&>svg]:translate-x-[1px] sm:[&>svg]:h-3.5 sm:[&>svg]:w-3.5" /> {/if}
{#each tableData.cols as column} {@const editMode = - $tableState.editingCell && - $tableState.editingCell.rowID === row.ID && - $tableState.editingCell.columnID === column.id} + tableState.editingCell && + tableState.editingCell.rowID === row.ID && + tableState.editingCell.columnID === column.id} {@const isValidFileUri = isValidUri(row[column.id]?.value)} - +
(focusedCol = column.id)} - on:focusout={() => (focusedCol = null)} - on:mousedown={(e) => { + onfocusin={() => (focusedCol = column.id)} + onfocusout={() => (focusedCol = null)} + onmousedown={(e) => { if (column.id === 'ID' || column.id === 'Updated at') return; if ( - (column.dtype === 'image' || column.dtype === 'audio') && + (column.dtype === 'image' || + column.dtype === 'audio' || + column.dtype === 'document') && row[column.id]?.value && isValidFileUri ) return; if (uploadController) return; - if ($tableState.streamingRows[row.ID] || $tableState.editingCell) return; + if (tableState.streamingRows[row.ID] || tableState.editingCell) return; if (e.detail > 1) { e.preventDefault(); } }} - on:dblclick={() => { + ondblclick={() => { if (readonly) return; if (column.id === 'ID' || column.id === 'Updated at') return; if ( - (column.dtype === 'image' || column.dtype === 'audio') && + (column.dtype === 'image' || + column.dtype === 'audio' || + column.dtype === 'document') && row[column.id]?.value && isValidFileUri ) return; if (uploadController) return; - if (!$tableState.streamingRows[row.ID]) { + if (!tableState.streamingRows[row.ID]) { tableState.setEditingCell({ rowID: row.ID, columnID: column.id }); } }} - on:keydown={(e) => { + onkeydown={(e) => { if (readonly) return; if (column.id === 'ID' || column.id === 'Updated at') return; if ( - (column.dtype === 'image' || column.dtype === 'audio') && + (column.dtype === 'image' || + column.dtype === 'audio' || + column.dtype === 'document') && row[column.id]?.value && isValidFileUri ) return; if (uploadController) return; - if (!editMode && e.key == 'Enter' && !$tableState.streamingRows[row.ID]) { + if (!editMode && e.key == 'Enter' && !tableState.streamingRows[row.ID]) { tableState.setEditingCell({ rowID: row.ID, columnID: column.id }); } }} - style={$tableState.columnSettings.column?.id == column.id && - $tableState.columnSettings.isOpen + style={tableState.columnSettings.column?.id == column.id && + tableState.columnSettings.isOpen ? 'background-color: #30A8FF17;' : ''} class={cn( - 'flex flex-col justify-start gap-1 h-full max-h-[99px] sm:max-h-[149px] w-full break-words [&:not(:last-child)]:border-r border-[#E4E7EC] data-dark:border-[#333]', + 'flex h-full max-h-[99px] w-full flex-col justify-start gap-1 break-words border-[#E4E7EC] data-dark:border-[#333] sm:max-h-[149px] [&:not(:last-child)]:border-r', editMode - ? 'p-0 bg-black/5 data-dark:bg-white/5' - : 'p-2 overflow-auto whitespace-pre-line', - $tableState.streamingRows[row.ID] + ? 'bg-black/5 p-0 data-dark:bg-white/5' + : 'overflow-auto whitespace-pre-line p-2', + tableState.streamingRows[row.ID] ? 'bg-[#FDEFF4]' - : 'group-hover:bg-[#ECEDEE] data-dark:group-hover:bg-white/5' + : 'group-hover:bg-[#E7EBF1] data-dark:group-hover:bg-white/5' )} > - {#if $tableState.streamingRows[row.ID]?.includes(column.id) && !editMode && column.id !== 'ID' && column.id !== 'Updated at' && column.gen_config} + {#if tableState.streamingRows[row.ID]?.includes(column.id) && !editMode && column.id !== 'ID' && column.id !== 'Updated at' && column.gen_config} {/if} {#if editMode} - {#if column.dtype === 'image' || column.dtype === 'audio'} + {#if column.dtype === 'image' || column.dtype === 'audio' || column.dtype === 'document'} {:else} - + {/if} - {:else if column.dtype === 'image' || column.dtype === 'audio'} + {:else if column.dtype === 'image' || column.dtype === 'audio' || column.dtype === 'document'} {#if column.id === 'ID'} @@ -416,7 +445,7 @@ {:else if column.id === 'Updated at'} {new Date(row[column.id]).toISOString()} {:else} - {row[column.id]?.value === undefined ? '' : row[column.id]?.value} + {row[column.id]?.value ?? ''} {/if} {/if} @@ -432,23 +461,23 @@
{:else if tableError?.error == 404} - {#if tableError.message?.org_id && userData?.member_of.find((org) => org.organization_id === tableError.message?.org_id)} - {@const projectOrg = userData?.member_of.find( - (org) => org.organization_id === tableError.message?.org_id - )} + {#if tableError.message?.org_id && user?.org_memberships.find((org) => org.organization_id === tableError.message?.org_id)} + {@const projectOrg = user?.organizations.find((org) => org.id === tableError.message?.org_id)} {:else} -
-

Table not found

+
+

Table not found

{/if} {:else if tableError?.error} -
-

{tableError.error} Failed to load table

-

{JSON.stringify(tableError.message)}

+
+

{tableError.error} Failed to load table

+

+ {tableError.message.message ?? JSON.stringify(tableError.message)} +

{:else} -
+
{/if} diff --git a/services/app/src/lib/components/tables/ChatTable.svelte b/services/app/src/lib/components/tables/ChatTable.svelte old mode 100644 new mode 100755 index 506b03e..5e815f2 --- a/services/app/src/lib/components/tables/ChatTable.svelte +++ b/services/app/src/lib/components/tables/ChatTable.svelte @@ -1,11 +1,10 @@ { + onmousedown={(e) => { const editingCell = document.querySelector('[data-editing="true"]'); //@ts-ignore if (e.target && editingCell && !editingCell.contains(e.target)) { tableState.setEditingCell(null); } }} - on:keydown={keyboardNavigate} + onkeydown={keyboardNavigate} /> {#if tableData}
{ + onscroll={(e) => { //? Used to prevent elements showing through the padding between side nav and table header //FIXME: Use transform for performance const el = document.getElementById('checkbox-bg-obscure'); @@ -179,48 +198,48 @@ } }} role="grid" - style="grid-template-rows: 36px {$genTableRows - ? `repeat(${$genTableRows.length + (!readonly ? 1 : 0)}, min-content)` + style="grid-template-rows: 36px {tableRowsState.rows && !tableRowsState.loading + ? `repeat(${tableRowsState.rows.length + (!readonly ? 1 : 0)}, min-content)` : 'minmax(0, 1fr)'};" - class="grow relative grid px-2 overflow-auto" + class="relative grid grow overflow-auto px-2" >
{#if !readonly} { - if ($genTableRows) { - return tableState.selectAllRows($genTableRows); + if (tableRowsState.rows) { + return tableState.selectAllRows(tableRowsState.rows); } else return false; }} - checked={($genTableRows ?? []).every((row) => - $tableState.selectedRows.includes(row.ID) + checked={(tableRowsState.rows ?? []).every((row) => + tableState.selectedRows.includes(row.ID) )} - class="h-4 sm:h-[18px] w-4 sm:w-[18px] [&>svg]:h-3 sm:[&>svg]:h-3.5 [&>svg]:w-3 sm:[&>svg]:w-3.5 [&>svg]:translate-x-[1px]" + class="h-4 w-4 sm:h-[18px] sm:w-[18px] [&>svg]:h-3 [&>svg]:w-3 [&>svg]:translate-x-[1px] sm:[&>svg]:h-3.5 sm:[&>svg]:w-3.5" /> {/if}
@@ -228,7 +247,7 @@
- {#if $genTableRows} + {#if tableRowsState.rows && !tableRowsState.loading} {#if !readonly} @@ -245,134 +264,142 @@ - {#each $genTableRows as row} + {#each tableRowsState.rows as row}
- {#if $tableState.streamingRows[row.ID]} + {#if tableState.streamingRows[row.ID]}
{/if}
{#if !readonly} handleSelectRow(e, row)} - checked={!!$tableState.selectedRows.find((i) => i === row.ID)} - class="mt-[1px] h-4 sm:h-[18px] w-4 sm:w-[18px] [&>svg]:h-3 sm:[&>svg]:h-3.5 [&>svg]:w-3 sm:[&>svg]:w-3.5 [&>svg]:translate-x-[1px]" + checked={!!tableState.selectedRows.find((i) => i === row.ID)} + class="mt-[1px] h-4 w-4 sm:h-[18px] sm:w-[18px] [&>svg]:h-3 [&>svg]:w-3 [&>svg]:translate-x-[1px] sm:[&>svg]:h-3.5 sm:[&>svg]:w-3.5" /> {/if}
{#each tableData.cols as column} {@const editMode = - $tableState.editingCell && - $tableState.editingCell.rowID === row.ID && - $tableState.editingCell.columnID === column.id} + tableState.editingCell && + tableState.editingCell.rowID === row.ID && + tableState.editingCell.columnID === column.id} {@const isValidFileUri = isValidUri(row[column.id]?.value)} - +
(focusedCol = column.id)} - on:focusout={() => (focusedCol = null)} - on:mousedown={(e) => { + onfocusin={() => (focusedCol = column.id)} + onfocusout={() => (focusedCol = null)} + onmousedown={(e) => { if (column.id === 'ID' || column.id === 'Updated at') return; if ( - (column.dtype === 'image' || column.dtype === 'audio') && + (column.dtype === 'image' || + column.dtype === 'audio' || + column.dtype === 'document') && row[column.id]?.value && isValidFileUri ) return; if (uploadController) return; - if ($tableState.streamingRows[row.ID] || $tableState.editingCell) return; + if (tableState.streamingRows[row.ID] || tableState.editingCell) return; if (e.detail > 1) { e.preventDefault(); } }} - on:dblclick={() => { + ondblclick={() => { if (readonly) return; if (column.id === 'ID' || column.id === 'Updated at') return; if ( - (column.dtype === 'image' || column.dtype === 'audio') && + (column.dtype === 'image' || + column.dtype === 'audio' || + column.dtype === 'document') && row[column.id]?.value && isValidFileUri ) return; if (uploadController) return; - if (!$tableState.streamingRows[row.ID]) { + if (!tableState.streamingRows[row.ID]) { tableState.setEditingCell({ rowID: row.ID, columnID: column.id }); } }} - on:keydown={(e) => { + onkeydown={(e) => { if (readonly) return; if (column.id === 'ID' || column.id === 'Updated at') return; if ( - (column.dtype === 'image' || column.dtype === 'audio') && + (column.dtype === 'image' || + column.dtype === 'audio' || + column.dtype === 'document') && row[column.id]?.value && isValidFileUri ) return; if (uploadController) return; - if (!editMode && e.key == 'Enter' && !$tableState.streamingRows[row.ID]) { + if (!editMode && e.key == 'Enter' && !tableState.streamingRows[row.ID]) { tableState.setEditingCell({ rowID: row.ID, columnID: column.id }); } }} - style={$tableState.columnSettings.column?.id == column.id && - $tableState.columnSettings.isOpen + style={tableState.columnSettings.column?.id == column.id && + tableState.columnSettings.isOpen ? 'background-color: #30A8FF17;' : ''} class={cn( - 'flex flex-col justify-start gap-1 h-full max-h-[99px] sm:max-h-[149px] w-full break-words [&:not(:last-child)]:border-r border-[#E4E7EC] data-dark:border-[#333]', + 'flex h-full max-h-[99px] w-full flex-col justify-start gap-1 break-words border-[#E4E7EC] data-dark:border-[#333] sm:max-h-[149px] [&:not(:last-child)]:border-r', editMode - ? 'p-0 bg-black/5 data-dark:bg-white/5' - : 'p-2 overflow-auto whitespace-pre-line', - $tableState.streamingRows[row.ID] + ? 'bg-black/5 p-0 data-dark:bg-white/5' + : 'overflow-auto whitespace-pre-line p-2', + tableState.streamingRows[row.ID] ? 'bg-[#FDEFF4]' - : 'group-hover:bg-[#ECEDEE] data-dark:group-hover:bg-white/5' + : 'group-hover:bg-[#E7EBF1] data-dark:group-hover:bg-white/5' )} > - {#if $tableState.streamingRows[row.ID]?.includes(column.id) && !editMode && column.id !== 'ID' && column.id !== 'Updated at' && column.gen_config} + {#if tableState.streamingRows[row.ID]?.includes(column.id) && !editMode && column.id !== 'ID' && column.id !== 'Updated at' && column.gen_config} {/if} {#if editMode} - {#if column.dtype === 'image' || column.dtype === 'audio'} + {#if column.dtype === 'image' || column.dtype === 'audio' || column.dtype === 'document'} {:else} - + {/if} - {:else if column.dtype === 'image' || column.dtype === 'audio'} + {:else if column.dtype === 'image' || column.dtype === 'audio' || column.dtype === 'document'} {#if column.id === 'ID'} @@ -431,23 +459,23 @@
{:else if tableError?.error == 404} - {#if tableError.message?.org_id && userData?.member_of.find((org) => org.organization_id === tableError.message?.org_id)} - {@const projectOrg = userData?.member_of.find( - (org) => org.organization_id === tableError.message?.org_id - )} + {#if tableError.message?.org_id && user?.org_memberships.find((org) => org.organization_id === tableError.message?.org_id)} + {@const projectOrg = user?.organizations.find((org) => org.id === tableError.message?.org_id)} {:else} -
-

Table not found

+
+

Table not found

{/if} {:else if tableError?.error} -
-

{tableError.error} Failed to load table

-

{JSON.stringify(tableError.message)}

+
+

{tableError.error} Failed to load table

+

+ {tableError.message.message ?? JSON.stringify(tableError.message)} +

{:else} -
+
{/if} diff --git a/services/app/src/lib/components/tables/KnowledgeTable.svelte b/services/app/src/lib/components/tables/KnowledgeTable.svelte old mode 100644 new mode 100755 index 3291230..4fc58d8 --- a/services/app/src/lib/components/tables/KnowledgeTable.svelte +++ b/services/app/src/lib/components/tables/KnowledgeTable.svelte @@ -1,11 +1,11 @@ { + onmousedown={(e) => { const editingCell = document.querySelector('[data-editing="true"]'); //@ts-ignore if (e.target && editingCell && !editingCell.contains(e.target)) { tableState.setEditingCell(null); } }} - on:keydown={keyboardNavigate} + onkeydown={keyboardNavigate} /> {#if tableData}
{ + onscroll={(e) => { //? Used to prevent elements showing through the padding between side nav and table header //FIXME: Use transform for performance const el = document.getElementById('checkbox-bg-obscure'); @@ -183,50 +203,52 @@ } }} role="grid" - style={$genTableRows?.length !== 0 + style={tableRowsState.rows?.length !== 0 ? `grid-template-rows: 36px ${ - $genTableRows ? `repeat(${$genTableRows.length}, min-content)` : 'minmax(0, 1fr)' + tableRowsState.rows && !tableRowsState.loading + ? `repeat(${tableRowsState.rows.length}, min-content)` + : 'minmax(0, 1fr)' };` : undefined} - class="grow relative grid px-2 overflow-auto" + class="relative grid grow overflow-auto px-2" >
{#if !readonly} { - if ($genTableRows) { - return tableState.selectAllRows($genTableRows); + if (tableRowsState.rows) { + return tableState.selectAllRows(tableRowsState.rows); } else return false; }} - checked={($genTableRows ?? []).every((row) => - $tableState.selectedRows.includes(row.ID) + checked={(tableRowsState.rows ?? []).every((row) => + tableState.selectedRows.includes(row.ID) )} - class="h-4 sm:h-[18px] w-4 sm:w-[18px] [&>svg]:h-3 sm:[&>svg]:h-3.5 [&>svg]:w-3 sm:[&>svg]:w-3.5 [&>svg]:translate-x-[1px]" + class="h-4 w-4 sm:h-[18px] sm:w-[18px] [&>svg]:h-3 [&>svg]:w-3 [&>svg]:translate-x-[1px] sm:[&>svg]:h-3.5 sm:[&>svg]:w-3.5" /> {/if}
@@ -234,132 +256,140 @@
- {#if $genTableRows} - {#if $genTableRows.length > 0} - {#each $genTableRows as row} + {#if tableRowsState.rows && !tableRowsState.loading} + {#if tableRowsState.rows.length > 0} + {#each tableRowsState.rows as row}
- {#if $tableState.streamingRows[row.ID]} + {#if tableState.streamingRows[row.ID]}
{/if}
{#if !readonly} handleSelectRow(e, row)} - checked={!!$tableState.selectedRows.find((i) => i === row.ID)} - class="mt-[1px] h-4 sm:h-[18px] w-4 sm:w-[18px] [&>svg]:h-3 sm:[&>svg]:h-3.5 [&>svg]:w-3 sm:[&>svg]:w-3.5 [&>svg]:translate-x-[1px]" + checked={!!tableState.selectedRows.find((i) => i === row.ID)} + class="mt-[1px] h-4 w-4 sm:h-[18px] sm:w-[18px] [&>svg]:h-3 [&>svg]:w-3 [&>svg]:translate-x-[1px] sm:[&>svg]:h-3.5 sm:[&>svg]:w-3.5" /> {/if}
{#each tableData.cols as column} {@const editMode = - $tableState.editingCell && - $tableState.editingCell.rowID === row.ID && - $tableState.editingCell.columnID === column.id} + tableState.editingCell && + tableState.editingCell.rowID === row.ID && + tableState.editingCell.columnID === column.id} {@const isValidFileUri = isValidUri(row[column.id]?.value)} - +
(focusedCol = column.id)} - on:focusout={() => (focusedCol = null)} - on:mousedown={(e) => { + onfocusin={() => (focusedCol = column.id)} + onfocusout={() => (focusedCol = null)} + onmousedown={(e) => { if (column.id === 'ID' || column.id === 'Updated at') return; if ( - (column.dtype === 'image' || column.dtype === 'audio') && + (column.dtype === 'image' || + column.dtype === 'audio' || + column.dtype === 'document') && row[column.id]?.value && isValidFileUri ) return; if (uploadController) return; - if ($tableState.streamingRows[row.ID] || $tableState.editingCell) return; + if (tableState.streamingRows[row.ID] || tableState.editingCell) return; if (e.detail > 1) { e.preventDefault(); } }} - on:dblclick={() => { + ondblclick={() => { if (readonly) return; if (column.id === 'ID' || column.id === 'Updated at') return; if ( - (column.dtype === 'image' || column.dtype === 'audio') && + (column.dtype === 'image' || + column.dtype === 'audio' || + column.dtype === 'document') && row[column.id]?.value && isValidFileUri ) return; if (uploadController) return; - if (!$tableState.streamingRows[row.ID]) { + if (!tableState.streamingRows[row.ID]) { tableState.setEditingCell({ rowID: row.ID, columnID: column.id }); } }} - on:keydown={(e) => { + onkeydown={(e) => { if (readonly) return; if (column.id === 'ID' || column.id === 'Updated at') return; if ( - (column.dtype === 'image' || column.dtype === 'audio') && + (column.dtype === 'image' || + column.dtype === 'audio' || + column.dtype === 'document') && row[column.id]?.value && isValidFileUri ) return; if (uploadController) return; - if (!editMode && e.key == 'Enter' && !$tableState.streamingRows[row.ID]) { + if (!editMode && e.key == 'Enter' && !tableState.streamingRows[row.ID]) { tableState.setEditingCell({ rowID: row.ID, columnID: column.id }); } }} - style={$tableState.columnSettings.column?.id == column.id && - $tableState.columnSettings.isOpen + style={tableState.columnSettings.column?.id == column.id && + tableState.columnSettings.isOpen ? 'background-color: #30A8FF17;' : ''} class={cn( - 'flex flex-col justify-start gap-1 h-full max-h-[99px] sm:max-h-[149px] w-full break-words [&:not(:last-child)]:border-r border-[#E4E7EC] data-dark:border-[#333]', + 'flex h-full max-h-[99px] w-full flex-col justify-start gap-1 break-words border-[#E4E7EC] data-dark:border-[#333] sm:max-h-[149px] [&:not(:last-child)]:border-r', editMode - ? 'p-0 bg-black/5 data-dark:bg-white/5' - : 'p-2 overflow-auto whitespace-pre-line', - $tableState.streamingRows[row.ID] + ? 'bg-black/5 p-0 data-dark:bg-white/5' + : 'overflow-auto whitespace-pre-line p-2', + tableState.streamingRows[row.ID] ? 'bg-[#FDEFF4]' - : 'group-hover:bg-[#ECEDEE] data-dark:group-hover:bg-white/5' + : 'group-hover:bg-[#E7EBF1] data-dark:group-hover:bg-white/5' )} > - {#if $tableState.streamingRows[row.ID]?.includes(column.id) && !editMode && column.id !== 'ID' && column.id !== 'Updated at' && column.gen_config} + {#if tableState.streamingRows[row.ID]?.includes(column.id) && !editMode && column.id !== 'ID' && column.id !== 'Updated at' && column.gen_config} {/if} {#if editMode} - {#if column.dtype === 'image' || column.dtype === 'audio'} + {#if column.dtype === 'image' || column.dtype === 'audio' || column.dtype === 'document'} {:else} - + {/if} - {:else if column.dtype === 'image' || column.dtype === 'audio'} + {:else if column.dtype === 'image' || column.dtype === 'audio' || column.dtype === 'document'} {#if column.id === 'ID'} @@ -410,26 +441,26 @@ {/each}
{/each} - {:else if $genTableRows.length === 0} + {:else if tableRowsState.rows.length === 0}
Upload Document - + Select a document to start generating your table -
@@ -443,23 +474,23 @@
{:else if tableError?.error == 404} - {#if tableError.message?.org_id && userData?.member_of.find((org) => org.organization_id === tableError.message?.org_id)} - {@const projectOrg = userData?.member_of.find( - (org) => org.organization_id === tableError.message?.org_id - )} + {#if tableError.message?.org_id && user?.org_memberships.find((org) => org.organization_id === tableError.message?.org_id)} + {@const projectOrg = user?.organizations.find((org) => org.id === tableError.message?.org_id)} {:else} -
-

Table not found

+
+

Table not found

{/if} {:else if tableError?.error} -
-

{tableError.error} Failed to load table

-

{JSON.stringify(tableError.message)}

+
+

{tableError.error} Failed to load table

+

+ {tableError.message.message ?? JSON.stringify(tableError.message)} +

{:else} -
+
{/if} diff --git a/services/app/src/lib/components/tables/tablesState.svelte.ts b/services/app/src/lib/components/tables/tablesState.svelte.ts new file mode 100644 index 0000000..992bb02 --- /dev/null +++ b/services/app/src/lib/components/tables/tablesState.svelte.ts @@ -0,0 +1,280 @@ +import type { GenTable, GenTableCol, GenTableRow } from '$lib/types'; +import { serializer } from '$lib/utils'; +import { getContext, setContext } from 'svelte'; +import { persisted } from 'svelte-persisted-store'; +import { writable } from 'svelte/store'; + +export const pastActionTables = writable[]>([]); +export const pastKnowledgeTables = writable[]>([]); +export const pastChatAgents = writable[]>([]); +export const chatTableMode = persisted<'chat' | 'table'>('table_mode', 'table', { + serializer, + storage: 'session' +}); + +interface ITableState { + templateCols: string; + colSizes: Record; + resizingCol: { columnID: string; diffX: number } | null; + editingCell: { rowID: string; columnID: string } | null; + selectedRows: string[]; + streamingRows: Record; + columnSettings: { + isOpen: boolean; + column: GenTableCol | null; + }; + renamingCol: string | null; + deletingCol: string | null; + setTemplateCols: (columns: GenTableCol[]) => void; + setColSize: (colID: string, value: number) => void; + setResizingCol: (value: ITableState['resizingCol']) => void; + setEditingCell: (cell: ITableState['editingCell']) => void; + toggleRowSelection: (rowID: string) => void; + selectAllRows: (tableRows: GenTableRow[]) => void; + setSelectedRows: (rows: ITableState['selectedRows']) => void; + addStreamingRows: (rows: ITableState['streamingRows']) => void; + delStreamingRows: (rowIDs: string[]) => void; + setColumnSettings: (value: ITableState['columnSettings']) => void; + setRenamingCol: (value: string | null) => void; + setDeletingCol: (value: string | null) => void; + reset: () => void; +} + +export class TableState implements ITableState { + templateCols = $state(''); + colSizes = $state>({}); + resizingCol = $state<{ columnID: string; diffX: number } | null>(null); + editingCell = $state<{ rowID: string; columnID: string } | null>(null); + selectedRows = $state([]); + streamingRows = $state>({}); + columnSettings = $state<{ + isOpen: boolean; + column: GenTableCol | null; + }>({ + isOpen: false, + column: null + }); + addingCol = $state(false); + renamingCol = $state(null); + deletingCol = $state(null); + + constructor() { + this.templateCols = ''; + this.colSizes = {}; + this.resizingCol = null; + this.editingCell = null; + this.selectedRows = []; + this.streamingRows = {}; + this.columnSettings = { + isOpen: false, + column: null + }; + this.renamingCol = null; + this.deletingCol = null; + } + + setTemplateCols(columns: GenTableCol[]) { + this.templateCols = columns + .filter((col) => col.id !== 'ID' && col.id !== 'Updated at') + .map((col) => { + const colSize = this.colSizes[col.id]; + if (colSize) return `${colSize}px`; + else return 'minmax(320px, 1fr)'; + }) + .join(' '); + } + + setColSize(colID: string, value: number) { + // const obj = structuredClone(state); + this.colSizes[colID] = value; + } + + setResizingCol(value: TableState['resizingCol']) { + this.resizingCol = value; + } + + setEditingCell(cell: TableState['editingCell']) { + this.editingCell = cell; + } + + toggleRowSelection(rowID: string) { + if (this.selectedRows.includes(rowID)) { + this.selectedRows = this.selectedRows.filter((id) => id !== rowID); + } else { + this.selectedRows = [...this.selectedRows, rowID]; + } + } + + selectAllRows(tableRows: GenTableRow[]) { + if (tableRows.every((row) => this.selectedRows.includes(row.ID))) { + this.selectedRows = this.selectedRows.filter((i) => !tableRows?.some(({ ID }) => ID === i)); + } else { + this.selectedRows = [ + ...this.selectedRows.filter((i) => !tableRows?.some(({ ID }) => ID === i)), + ...tableRows.map(({ ID }) => ID) + ]; + } + } + + setSelectedRows(rows: TableState['selectedRows']) { + this.selectedRows = rows; + } + + addStreamingRows(rows: TableState['streamingRows']) { + this.streamingRows = { ...this.streamingRows, ...rows }; + } + + delStreamingRows(rowIDs: string[]) { + this.streamingRows = Object.fromEntries( + Object.entries(this.streamingRows).filter(([rowId]) => !rowIDs.includes(rowId)) + ); + } + + setColumnSettings(value: TableState['columnSettings']) { + this.columnSettings = $state.snapshot(value); + } + + setRenamingCol(value: string | null) { + this.renamingCol = value; + } + + setDeletingCol(value: string | null) { + this.deletingCol = value; + } + + reset() { + this.templateCols = ''; + this.colSizes = {}; + this.resizingCol = null; + this.editingCell = null; + this.selectedRows = []; + this.streamingRows = {}; + this.columnSettings = { + isOpen: false, + column: null + }; + this.renamingCol = null; + this.deletingCol = null; + } +} + +const tableStateContextKey = 'tableState'; +export function setTableState() { + return setContext(tableStateContextKey, new TableState()); +} + +export function getTableState() { + return getContext(tableStateContextKey); +} + +export class TableRowsState { + rows = $state(undefined); + loading = $state(false); + + setRows(rows: GenTableRow[] | undefined) { + this.rows = rows; + this.loading = false; + } + + /** Adds a row at the beginning of the array */ + addRow(row: GenTableRow) { + this.rows = [row, ...(this.rows ?? [])]; + } + + /** Removes a row */ + deleteRow(rowID: string) { + this.rows = this.rows?.filter((row) => row.ID !== rowID); + } + + /** Updates a row */ + updateRow(rowID: string, data: GenTableRow) { + this.rows = this.rows?.map((row) => { + if (row.ID === rowID) { + return { ...row, ...data }; + } + return row; + }); + } + + /** Set cell value */ + setCell({ rowID, columnID }: { rowID: string; columnID: string }, value: any) { + this.rows = this.rows?.map((row) => { + if (row.ID === rowID) { + if (columnID === 'ID' || columnID === 'Updated at') { + return { ...row, [columnID]: value }; + } else { + return { ...row, [columnID]: { value: value } }; + } + } + return row; + }); + } + + /** Streaming prep, clears outputs */ + clearOutputs(tableData: GenTable, rowIDs: string[], columnIDs?: string[]) { + this.rows = this.rows?.map((row) => { + if (rowIDs.includes(row.ID)) { + return { + ...row, + ...Object.fromEntries( + Object.entries(row).map(([key, value]) => { + if (key === 'ID' || key === 'Updated at' || (columnIDs && !columnIDs.includes(key))) { + return [key, value as string]; + } else { + return [ + key, + { + value: tableData.cols.find((col) => col.id == key)?.gen_config + ? '' + : (value as { value: any }).value + } + ]; + } + }) + ) + }; + } + return row; + }); + } + + /** Stream to cell */ + stream(rowID: string, colID: string, value: any) { + this.rows = this.rows?.map((row) => { + if (row.ID === rowID) { + return { + ...row, + [colID]: { + value: (row[colID]?.value ?? '') + value + } + }; + } + return row; + }); + } + + /** Revert to original value */ + revert( + originalValues: { + id: string; + value: GenTableRow; + }[] + ) { + this.rows = this.rows?.map((row) => { + const originalRow = originalValues.find((o) => o.id === row.ID); + if (originalRow) { + return originalRow.value; + } + return row; + }); + } +} + +const tableRowsStateContextKey = 'tableRowsState'; +export function setTableRowsState() { + return setContext(tableRowsStateContextKey, new TableRowsState()); +} + +export function getTableRowsState() { + return getContext(tableRowsStateContextKey); +} diff --git a/services/app/src/lib/components/tables/tablesStore.ts b/services/app/src/lib/components/tables/tablesStore.ts deleted file mode 100644 index 8365f15..0000000 --- a/services/app/src/lib/components/tables/tablesStore.ts +++ /dev/null @@ -1,232 +0,0 @@ -import { writable } from 'svelte/store'; -import { persisted } from 'svelte-persisted-store'; -import { serializer } from '$lib/utils'; -import type { GenTable, GenTableCol, GenTableRow } from '$lib/types'; - -export const tableState = createTableStore(); -export const genTableRows = createGenTableRows(); -export const pastActionTables = writable[]>([]); -export const pastKnowledgeTables = writable[]>([]); -export const pastChatAgents = writable[]>([]); -export const chatTableMode = persisted<'chat' | 'table'>('table_mode', 'table', { - serializer, - storage: 'session' -}); - -interface TableState { - templateCols: string; - colSizes: Record; - resizingCol: { columnID: string; diffX: number } | null; - editingCell: { rowID: string; columnID: string } | null; - selectedRows: string[]; - streamingRows: Record; - columnSettings: { - isOpen: boolean; - column: GenTableCol | null; - }; - renamingCol: string | null; - deletingCol: string | null; -} - -function createTableStore() { - const defaultValue = { - templateCols: '', - colSizes: {}, - resizingCol: null, - editingCell: null, - selectedRows: [], - streamingRows: {}, - columnSettings: { - isOpen: false, - column: null - }, - renamingCol: null, - deletingCol: null - } satisfies TableState; - const { subscribe, set, update } = writable(defaultValue); - - return { - subscribe, - set, - setTemplateCols: (columns: GenTableCol[]) => - update((state) => ({ - ...state, - templateCols: columns - .filter((col) => col.id !== 'ID' && col.id !== 'Updated at') - .map((col) => { - const colSize = state.colSizes[col.id]; - if (colSize) return `${colSize}px`; - else return 'minmax(320px, 1fr)'; - }) - .join(' ') - })), - setColSize: (colID: string, value: number) => - update((state) => { - const obj = structuredClone(state); - obj.colSizes[colID] = value; - return obj; - }), - setResizingCol: (value: TableState['resizingCol']) => - update((state) => ({ ...state, resizingCol: value })), - setEditingCell: (cell: TableState['editingCell']) => - update((state) => ({ ...state, editingCell: cell })), - toggleRowSelection: (rowID: string) => - update((state) => ({ - ...state, - selectedRows: state.selectedRows.includes(rowID) - ? state.selectedRows.filter((id) => id !== rowID) - : [...state.selectedRows, rowID] - })), - selectAllRows: (tableRows: GenTableRow[]) => - update((state) => ({ - ...state, - selectedRows: tableRows.every((row) => state.selectedRows.includes(row.ID)) - ? state.selectedRows.filter((i) => !tableRows?.some(({ ID }) => ID === i)) - : [ - ...state.selectedRows.filter((i) => !tableRows?.some(({ ID }) => ID === i)), - ...tableRows.map(({ ID }) => ID) - ] - })), - setSelectedRows: (rows: TableState['selectedRows']) => - update((state) => ({ ...state, selectedRows: rows })), - addStreamingRows: (rows: TableState['streamingRows']) => - update((state) => ({ ...state, streamingRows: { ...state.streamingRows, ...rows } })), - delStreamingRows: (rowIDs: string[]) => - update((state) => ({ - ...state, - streamingRows: Object.fromEntries( - Object.entries(state.streamingRows).filter(([rowId]) => !rowIDs.includes(rowId)) - ) - })), - setColumnSettings: (value: TableState['columnSettings']) => - update((state) => ({ ...state, columnSettings: value })), - setRenamingCol: (value: string | null) => update((state) => ({ ...state, renamingCol: value })), - setDeletingCol: (value: string | null) => update((state) => ({ ...state, deletingCol: value })), - reset: () => set(defaultValue) - }; -} - -function createGenTableRows() { - const { subscribe, set, update } = writable(undefined); - - return { - subscribe, - set, - /** Adds a row at the beginning of the array */ - addRow: (row: GenTableRow) => - update((rows) => { - if (rows) { - return [row, ...rows]; - } else { - return rows; - } - }), - /** Removes a row */ - deleteRow: (rowID: string) => - update((rows) => { - if (rows) { - return rows.filter((row) => row.ID !== rowID); - } else { - return rows; - } - }), - /** Updates a row */ - updateRow: (rowID: string, data: GenTableRow) => - update((rows) => - rows?.map((row) => { - if (row.ID === rowID) { - return { - ...row, - ...data - }; - } - return row; - }) - ), - /** Set cell value */ - setCell: ({ rowID, columnID }: { rowID: string; columnID: string }, value: any) => - update((rows) => - rows?.map((row) => { - if (row.ID === rowID) { - if (columnID === 'ID' || columnID === 'Updated at') { - return { - ...row, - [columnID]: value - }; - } else { - return { - ...row, - [columnID]: { - value: value - } - }; - } - } - return row; - }) - ), - /** Streaming prep, clears outputs */ - clearOutputs: (tableData: GenTable, rowIDs: string[], columnIDs?: string[]) => - update((rows) => - rows?.map((row) => { - if (rowIDs.includes(row.ID)) { - return { - ...row, - ...Object.fromEntries( - Object.entries(row).map(([key, value]) => { - if ( - key === 'ID' || - key === 'Updated at' || - (columnIDs && !columnIDs.includes(key)) - ) { - return [key, value as string]; - } else { - return [ - key, - { - value: tableData.cols.find((col) => col.id == key)?.gen_config - ? '' - : (value as { value: any }).value - } - ]; - } - }) - ) - }; - } - return row; - }) - ), - /** Stream to cell */ - stream: (rowID: string, colID: string, value: any) => - update((rows) => - rows?.map((row) => { - if (row.ID === rowID) { - return { - ...row, - [colID]: { - value: (row[colID]?.value ?? '') + value - } - }; - } - return row; - }) - ), - /** Revert to original value */ - revert: ( - originalValues: { - id: string; - value: GenTableRow; - }[] - ) => - update((rows) => - rows?.map((row) => { - const originalRow = originalValues.find((o) => o.id === row.ID); - if (originalRow) { - return originalRow.value; - } - return row; - }) - ) - }; -} diff --git a/services/app/src/lib/components/ui/alert/alert-description.svelte b/services/app/src/lib/components/ui/alert/alert-description.svelte new file mode 100644 index 0000000..ef74aa4 --- /dev/null +++ b/services/app/src/lib/components/ui/alert/alert-description.svelte @@ -0,0 +1,16 @@ + + +
+ {@render children?.()} +
diff --git a/services/app/src/lib/components/ui/alert/alert-title.svelte b/services/app/src/lib/components/ui/alert/alert-title.svelte new file mode 100644 index 0000000..12ec9fc --- /dev/null +++ b/services/app/src/lib/components/ui/alert/alert-title.svelte @@ -0,0 +1,25 @@ + + +
+ {@render children?.()} +
diff --git a/services/app/src/lib/components/ui/alert/alert.svelte b/services/app/src/lib/components/ui/alert/alert.svelte new file mode 100644 index 0000000..8b4528f --- /dev/null +++ b/services/app/src/lib/components/ui/alert/alert.svelte @@ -0,0 +1,39 @@ + + + + + diff --git a/services/app/src/lib/components/ui/alert/index.ts b/services/app/src/lib/components/ui/alert/index.ts new file mode 100644 index 0000000..97e21b4 --- /dev/null +++ b/services/app/src/lib/components/ui/alert/index.ts @@ -0,0 +1,14 @@ +import Root from "./alert.svelte"; +import Description from "./alert-description.svelte"; +import Title from "./alert-title.svelte"; +export { alertVariants, type AlertVariant } from "./alert.svelte"; + +export { + Root, + Description, + Title, + // + Root as Alert, + Description as AlertDescription, + Title as AlertTitle, +}; diff --git a/services/app/src/lib/components/ui/badge/badge.svelte b/services/app/src/lib/components/ui/badge/badge.svelte new file mode 100644 index 0000000..da26104 --- /dev/null +++ b/services/app/src/lib/components/ui/badge/badge.svelte @@ -0,0 +1,50 @@ + + + + + + {@render children?.()} + diff --git a/services/app/src/lib/components/ui/badge/index.ts b/services/app/src/lib/components/ui/badge/index.ts new file mode 100644 index 0000000..64e0aa9 --- /dev/null +++ b/services/app/src/lib/components/ui/badge/index.ts @@ -0,0 +1,2 @@ +export { default as Badge } from "./badge.svelte"; +export { badgeVariants, type BadgeVariant } from "./badge.svelte"; diff --git a/services/app/src/lib/components/ui/button/button.svelte b/services/app/src/lib/components/ui/button/button.svelte old mode 100644 new mode 100755 index 270d111..4db80bd --- a/services/app/src/lib/components/ui/button/button.svelte +++ b/services/app/src/lib/components/ui/button/button.svelte @@ -1,34 +1,101 @@ + + - - {#if loading} - - {/if} - - - +{#if href} + + {@render children?.()} + +{:else} + +{/if} diff --git a/services/app/src/lib/components/ui/button/index.ts b/services/app/src/lib/components/ui/button/index.ts old mode 100644 new mode 100755 index 48181d3..fb585d7 --- a/services/app/src/lib/components/ui/button/index.ts +++ b/services/app/src/lib/components/ui/button/index.ts @@ -1,53 +1,17 @@ -import Root from './button.svelte'; -import { tv, type VariantProps } from 'tailwind-variants'; -import type { Button as ButtonPrimitive } from 'bits-ui'; - -const buttonVariants = tv({ - base: 'inline-flex items-center justify-center text-sm font-medium whitespace-nowrap ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 rounded-full disabled:pointer-events-none disabled:opacity-50', - variants: { - variant: { - default: - 'text-[#FCFCFD] bg-[#BF416E] hover:bg-[#950048] focus-visible:bg-[#950048] active:bg-[#7A003B]', - destructive: 'bg-destructive text-destructive-foreground hover:bg-destructive/90', - outline: - 'text-[#BF416E] bg-transparent hover:bg-[#BF416E]/[0.025] focus-visible:bg-[#BF416E]/[0.025] active:bg-[#BF416E]/5 border border-[#BF416E]', - 'outline-neutral': - 'text-text bg-transparent hover:bg-[#F9FAFB] data-dark:hover:bg-white/[0.1] active:bg-[#F2F4F7] data-dark:bg-[#0D0E11] data-dark:hover:bg-white/[0.1] border border-[#DDD] data-dark:border-[#42464E]', - action: 'bg-[#F2F4F7] hover:bg-[#E4E7EC] text-black rounded-md', - warning: 'bg-warning hover:bg-warning/80 text-black', - ghost: 'hover:bg-[#F2F4F7] hover:text-accent-foreground', - link: 'text-primary underline-offset-4 hover:underline' - }, - size: { - default: 'h-10 px-4 py-2', - sm: 'h-9 rounded-md px-3', - lg: 'h-11 rounded-md px-8', - icon: 'h-10 w-10' - } - }, - defaultVariants: { - variant: 'default', - size: 'default' - } -}); - -type Variant = VariantProps['variant']; -type Size = VariantProps['size']; - -type Props = ButtonPrimitive.Props & { - variant?: Variant; - size?: Size; -}; - -type Events = ButtonPrimitive.Events; +import Root, { + type ButtonProps, + type ButtonSize, + type ButtonVariant, + buttonVariants, +} from "./button.svelte"; export { Root, - type Props, - type Events, + type ButtonProps as Props, // Root as Button, - type Props as ButtonProps, - type Events as ButtonEvents, - buttonVariants + buttonVariants, + type ButtonProps, + type ButtonSize, + type ButtonVariant, }; diff --git a/services/app/src/lib/components/ui/calendar/calendar-cell.svelte b/services/app/src/lib/components/ui/calendar/calendar-cell.svelte new file mode 100644 index 0000000..3c065a5 --- /dev/null +++ b/services/app/src/lib/components/ui/calendar/calendar-cell.svelte @@ -0,0 +1,19 @@ + + + diff --git a/services/app/src/lib/components/ui/calendar/calendar-day.svelte b/services/app/src/lib/components/ui/calendar/calendar-day.svelte new file mode 100644 index 0000000..d5e802a --- /dev/null +++ b/services/app/src/lib/components/ui/calendar/calendar-day.svelte @@ -0,0 +1,30 @@ + + + diff --git a/services/app/src/lib/components/ui/calendar/calendar-grid-body.svelte b/services/app/src/lib/components/ui/calendar/calendar-grid-body.svelte new file mode 100644 index 0000000..8cd86de --- /dev/null +++ b/services/app/src/lib/components/ui/calendar/calendar-grid-body.svelte @@ -0,0 +1,12 @@ + + + diff --git a/services/app/src/lib/components/ui/calendar/calendar-grid-head.svelte b/services/app/src/lib/components/ui/calendar/calendar-grid-head.svelte new file mode 100644 index 0000000..333edc4 --- /dev/null +++ b/services/app/src/lib/components/ui/calendar/calendar-grid-head.svelte @@ -0,0 +1,12 @@ + + + diff --git a/services/app/src/lib/components/ui/calendar/calendar-grid-row.svelte b/services/app/src/lib/components/ui/calendar/calendar-grid-row.svelte new file mode 100644 index 0000000..9032236 --- /dev/null +++ b/services/app/src/lib/components/ui/calendar/calendar-grid-row.svelte @@ -0,0 +1,12 @@ + + + diff --git a/services/app/src/lib/components/ui/calendar/calendar-grid.svelte b/services/app/src/lib/components/ui/calendar/calendar-grid.svelte new file mode 100644 index 0000000..1d7edb5 --- /dev/null +++ b/services/app/src/lib/components/ui/calendar/calendar-grid.svelte @@ -0,0 +1,16 @@ + + + diff --git a/services/app/src/lib/components/ui/calendar/calendar-head-cell.svelte b/services/app/src/lib/components/ui/calendar/calendar-head-cell.svelte new file mode 100644 index 0000000..dd5e55f --- /dev/null +++ b/services/app/src/lib/components/ui/calendar/calendar-head-cell.svelte @@ -0,0 +1,16 @@ + + + diff --git a/services/app/src/lib/components/ui/calendar/calendar-header.svelte b/services/app/src/lib/components/ui/calendar/calendar-header.svelte new file mode 100644 index 0000000..e64feae --- /dev/null +++ b/services/app/src/lib/components/ui/calendar/calendar-header.svelte @@ -0,0 +1,16 @@ + + + diff --git a/services/app/src/lib/components/ui/calendar/calendar-heading.svelte b/services/app/src/lib/components/ui/calendar/calendar-heading.svelte new file mode 100644 index 0000000..5d57a50 --- /dev/null +++ b/services/app/src/lib/components/ui/calendar/calendar-heading.svelte @@ -0,0 +1,12 @@ + + + diff --git a/services/app/src/lib/components/ui/calendar/calendar-months.svelte b/services/app/src/lib/components/ui/calendar/calendar-months.svelte new file mode 100644 index 0000000..4cd0ed7 --- /dev/null +++ b/services/app/src/lib/components/ui/calendar/calendar-months.svelte @@ -0,0 +1,20 @@ + + +
+ {@render children?.()} +
diff --git a/services/app/src/lib/components/ui/calendar/calendar-next-button.svelte b/services/app/src/lib/components/ui/calendar/calendar-next-button.svelte new file mode 100644 index 0000000..8581a43 --- /dev/null +++ b/services/app/src/lib/components/ui/calendar/calendar-next-button.svelte @@ -0,0 +1,28 @@ + + +{#snippet Fallback()} + +{/snippet} + + diff --git a/services/app/src/lib/components/ui/calendar/calendar-prev-button.svelte b/services/app/src/lib/components/ui/calendar/calendar-prev-button.svelte new file mode 100644 index 0000000..0ad629e --- /dev/null +++ b/services/app/src/lib/components/ui/calendar/calendar-prev-button.svelte @@ -0,0 +1,28 @@ + + +{#snippet Fallback()} + +{/snippet} + + diff --git a/services/app/src/lib/components/ui/calendar/calendar.svelte b/services/app/src/lib/components/ui/calendar/calendar.svelte new file mode 100644 index 0000000..e05c46e --- /dev/null +++ b/services/app/src/lib/components/ui/calendar/calendar.svelte @@ -0,0 +1,61 @@ + + + + + {#snippet children({ months, weekdays })} + + + + + + + {#each months as month (month)} + + + + {#each weekdays as weekday (weekday)} + + {weekday.slice(0, 2)} + + {/each} + + + + {#each month.weeks as weekDates (weekDates)} + + {#each weekDates as date (date)} + + + + {/each} + + {/each} + + + {/each} + + {/snippet} + diff --git a/services/app/src/lib/components/ui/calendar/index.ts b/services/app/src/lib/components/ui/calendar/index.ts new file mode 100644 index 0000000..ab257ab --- /dev/null +++ b/services/app/src/lib/components/ui/calendar/index.ts @@ -0,0 +1,30 @@ +import Root from "./calendar.svelte"; +import Cell from "./calendar-cell.svelte"; +import Day from "./calendar-day.svelte"; +import Grid from "./calendar-grid.svelte"; +import Header from "./calendar-header.svelte"; +import Months from "./calendar-months.svelte"; +import GridRow from "./calendar-grid-row.svelte"; +import Heading from "./calendar-heading.svelte"; +import GridBody from "./calendar-grid-body.svelte"; +import GridHead from "./calendar-grid-head.svelte"; +import HeadCell from "./calendar-head-cell.svelte"; +import NextButton from "./calendar-next-button.svelte"; +import PrevButton from "./calendar-prev-button.svelte"; + +export { + Day, + Cell, + Grid, + Header, + Months, + GridRow, + Heading, + GridBody, + GridHead, + HeadCell, + NextButton, + PrevButton, + // + Root as Calendar, +}; diff --git a/services/app/src/lib/components/ui/checkbox/checkbox.svelte b/services/app/src/lib/components/ui/checkbox/checkbox.svelte new file mode 100644 index 0000000..ca81843 --- /dev/null +++ b/services/app/src/lib/components/ui/checkbox/checkbox.svelte @@ -0,0 +1,35 @@ + + + + {#snippet children({ checked, indeterminate })} +
+ {#if indeterminate} + + {:else} + + {/if} +
+ {/snippet} +
diff --git a/services/app/src/lib/components/ui/checkbox/index.ts b/services/app/src/lib/components/ui/checkbox/index.ts new file mode 100644 index 0000000..6d92d94 --- /dev/null +++ b/services/app/src/lib/components/ui/checkbox/index.ts @@ -0,0 +1,6 @@ +import Root from "./checkbox.svelte"; +export { + Root, + // + Root as Checkbox, +}; diff --git a/services/app/src/lib/components/ui/dialog/dialog-actions.svelte b/services/app/src/lib/components/ui/dialog/dialog-actions.svelte old mode 100644 new mode 100755 index 9c5afb7..625fdfb --- a/services/app/src/lib/components/ui/dialog/dialog-actions.svelte +++ b/services/app/src/lib/components/ui/dialog/dialog-actions.svelte @@ -1,8 +1,13 @@
- + {@render children?.()}
diff --git a/services/app/src/lib/components/ui/dialog/dialog-content.svelte b/services/app/src/lib/components/ui/dialog/dialog-content.svelte old mode 100644 new mode 100755 index a440cb2..8ad83c4 --- a/services/app/src/lib/components/ui/dialog/dialog-content.svelte +++ b/services/app/src/lib/components/ui/dialog/dialog-content.svelte @@ -1,33 +1,41 @@ - + - + {@render children?.()} + diff --git a/services/app/src/lib/components/ui/dialog/dialog-description.svelte b/services/app/src/lib/components/ui/dialog/dialog-description.svelte old mode 100644 new mode 100755 index e1d796a..bc048e4 --- a/services/app/src/lib/components/ui/dialog/dialog-description.svelte +++ b/services/app/src/lib/components/ui/dialog/dialog-description.svelte @@ -2,15 +2,15 @@ import { Dialog as DialogPrimitive } from "bits-ui"; import { cn } from "$lib/utils.js"; - type $$Props = DialogPrimitive.DescriptionProps; - - let className: $$Props["class"] = undefined; - export { className as class }; + let { + ref = $bindable(null), + class: className, + ...restProps + }: DialogPrimitive.DescriptionProps = $props(); - - + bind:ref + class={cn("text-muted-foreground text-sm", className)} + {...restProps} +/> diff --git a/services/app/src/lib/components/ui/dialog/dialog-footer.svelte b/services/app/src/lib/components/ui/dialog/dialog-footer.svelte old mode 100644 new mode 100755 index 6f6e589..91ecaba --- a/services/app/src/lib/components/ui/dialog/dialog-footer.svelte +++ b/services/app/src/lib/components/ui/dialog/dialog-footer.svelte @@ -1,16 +1,20 @@
- + {@render children?.()}
diff --git a/services/app/src/lib/components/ui/dialog/dialog-header.svelte b/services/app/src/lib/components/ui/dialog/dialog-header.svelte old mode 100644 new mode 100755 index 33877fa..fab7e72 --- a/services/app/src/lib/components/ui/dialog/dialog-header.svelte +++ b/services/app/src/lib/components/ui/dialog/dialog-header.svelte @@ -1,32 +1,37 @@
- + {@render children?.()} + Close -
+
diff --git a/services/app/src/lib/components/ui/dialog/dialog-overlay.svelte b/services/app/src/lib/components/ui/dialog/dialog-overlay.svelte old mode 100644 new mode 100755 index ff264c0..9e40bf2 --- a/services/app/src/lib/components/ui/dialog/dialog-overlay.svelte +++ b/services/app/src/lib/components/ui/dialog/dialog-overlay.svelte @@ -1,21 +1,19 @@ diff --git a/services/app/src/lib/components/ui/dialog/dialog-portal.svelte b/services/app/src/lib/components/ui/dialog/dialog-portal.svelte old mode 100644 new mode 100755 index eb5d0a5..38b451f --- a/services/app/src/lib/components/ui/dialog/dialog-portal.svelte +++ b/services/app/src/lib/components/ui/dialog/dialog-portal.svelte @@ -1,8 +1,14 @@ - - + + {@render children?.()} diff --git a/services/app/src/lib/components/ui/dialog/dialog-root.svelte b/services/app/src/lib/components/ui/dialog/dialog-root.svelte deleted file mode 100644 index 656af2b..0000000 --- a/services/app/src/lib/components/ui/dialog/dialog-root.svelte +++ /dev/null @@ -1,22 +0,0 @@ - - - { - if ( - //@ts-ignore - document.getElementById('upload-tab-global')?.contains(e.target) || - //@ts-ignore - document.querySelector('[data-sonner-toaster]')?.contains(e.target) - ) - e.preventDefault(); - }} - bind:open - {...$$restProps} -> - - diff --git a/services/app/src/lib/components/ui/dialog/dialog-title.svelte b/services/app/src/lib/components/ui/dialog/dialog-title.svelte old mode 100644 new mode 100755 index 06574f3..9cf592c --- a/services/app/src/lib/components/ui/dialog/dialog-title.svelte +++ b/services/app/src/lib/components/ui/dialog/dialog-title.svelte @@ -2,15 +2,15 @@ import { Dialog as DialogPrimitive } from "bits-ui"; import { cn } from "$lib/utils.js"; - type $$Props = DialogPrimitive.TitleProps; - - let className: $$Props["class"] = undefined; - export { className as class }; + let { + ref = $bindable(null), + class: className, + ...restProps + }: DialogPrimitive.TitleProps = $props(); - - + {...restProps} +/> diff --git a/services/app/src/lib/components/ui/dialog/index.ts b/services/app/src/lib/components/ui/dialog/index.ts old mode 100644 new mode 100755 index e839343..fd0bb41 --- a/services/app/src/lib/components/ui/dialog/index.ts +++ b/services/app/src/lib/components/ui/dialog/index.ts @@ -1,16 +1,17 @@ -import { Dialog as DialogPrimitive } from 'bits-ui'; +import { Dialog as DialogPrimitive } from "bits-ui"; -const Trigger = DialogPrimitive.Trigger; +import Title from "./dialog-title.svelte"; +import Footer from "./dialog-footer.svelte"; +import Header from "./dialog-header.svelte"; +import Overlay from "./dialog-overlay.svelte"; +import Content from "./dialog-content.svelte"; +import Description from "./dialog-description.svelte"; +import Actions from "./dialog-actions.svelte"; -import Root from './dialog-root.svelte'; -import Title from './dialog-title.svelte'; -import Portal from './dialog-portal.svelte'; -import Footer from './dialog-footer.svelte'; -import Header from './dialog-header.svelte'; -import Overlay from './dialog-overlay.svelte'; -import Content from './dialog-content.svelte'; -import Description from './dialog-description.svelte'; -import Actions from './dialog-actions.svelte'; +const Root = DialogPrimitive.Root; +const Trigger = DialogPrimitive.Trigger; +const Close = DialogPrimitive.Close; +const Portal = DialogPrimitive.Portal; export { Root, @@ -22,6 +23,7 @@ export { Overlay, Content, Description, + Close, Actions, // Root as Dialog, @@ -33,5 +35,6 @@ export { Overlay as DialogOverlay, Content as DialogContent, Description as DialogDescription, + Close as DialogClose, Actions as DialogActions }; diff --git a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-checkbox-item.svelte b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-checkbox-item.svelte old mode 100644 new mode 100755 index cbca3c5..3f1575a --- a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-checkbox-item.svelte +++ b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-checkbox-item.svelte @@ -1,35 +1,40 @@ - - - - - - + {#snippet children({ checked, indeterminate })} + + {#if indeterminate} + + {:else} + + {/if} + + {@render childrenProp?.()} + {/snippet} diff --git a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-content.svelte b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-content.svelte old mode 100644 new mode 100755 index 2c05927..fdbaa47 --- a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-content.svelte +++ b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-content.svelte @@ -1,27 +1,26 @@ - - - + + + diff --git a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-group-heading.svelte b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-group-heading.svelte new file mode 100644 index 0000000..84d5cca --- /dev/null +++ b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-group-heading.svelte @@ -0,0 +1,19 @@ + + + diff --git a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-item.svelte b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-item.svelte old mode 100644 new mode 100755 index 9cc73fd..70a5236 --- a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-item.svelte +++ b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-item.svelte @@ -1,32 +1,23 @@ - - + {...restProps} +/> diff --git a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-label.svelte b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-label.svelte old mode 100644 new mode 100755 index 43f1527..9837d5a --- a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-label.svelte +++ b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-label.svelte @@ -1,19 +1,23 @@ - - - + {@render children?.()} +
diff --git a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-radio-group.svelte b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-radio-group.svelte old mode 100644 new mode 100755 index 1c74ae1..c04aa6a --- a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-radio-group.svelte +++ b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-radio-group.svelte @@ -3,9 +3,15 @@ type $$Props = DropdownMenuPrimitive.RadioGroupProps; - export let value: $$Props["value"] = undefined; + interface Props { + value?: $$Props["value"]; + children?: import('svelte').Snippet; + [key: string]: any + } + + let { value = $bindable(undefined), children, ...rest }: Props = $props(); - - + + {@render children?.()} diff --git a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-radio-item.svelte b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-radio-item.svelte old mode 100644 new mode 100755 index 4e6e3be..bcb960d --- a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-radio-item.svelte +++ b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-radio-item.svelte @@ -1,35 +1,30 @@ - - - - - - + {#snippet children({ checked })} + + {#if checked} + + {/if} + + {@render childrenProp?.({ checked })} + {/snippet} diff --git a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-separator.svelte b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-separator.svelte old mode 100644 new mode 100755 index 48d016a..32fac4b --- a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-separator.svelte +++ b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-separator.svelte @@ -1,14 +1,16 @@ diff --git a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-shortcut.svelte b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-shortcut.svelte old mode 100644 new mode 100755 index 880d9b4..053e2a2 --- a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-shortcut.svelte +++ b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-shortcut.svelte @@ -1,13 +1,20 @@ - - + + {@render children?.()} diff --git a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-sub-content.svelte b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-sub-content.svelte old mode 100644 new mode 100755 index ff20507..0bb6eea --- a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-sub-content.svelte +++ b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-sub-content.svelte @@ -1,30 +1,19 @@ - - + {...restProps} +/> diff --git a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-sub-trigger.svelte b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-sub-trigger.svelte old mode 100644 new mode 100755 index 942e577..be175ad --- a/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-sub-trigger.svelte +++ b/services/app/src/lib/components/ui/dropdown-menu/dropdown-menu-sub-trigger.svelte @@ -1,32 +1,28 @@ - - + {@render children?.()} + diff --git a/services/app/src/lib/components/ui/dropdown-menu/index.ts b/services/app/src/lib/components/ui/dropdown-menu/index.ts old mode 100644 new mode 100755 index c1749e9..40c4502 --- a/services/app/src/lib/components/ui/dropdown-menu/index.ts +++ b/services/app/src/lib/components/ui/dropdown-menu/index.ts @@ -1,48 +1,50 @@ import { DropdownMenu as DropdownMenuPrimitive } from "bits-ui"; +import CheckboxItem from "./dropdown-menu-checkbox-item.svelte"; +import Content from "./dropdown-menu-content.svelte"; +import GroupHeading from "./dropdown-menu-group-heading.svelte"; import Item from "./dropdown-menu-item.svelte"; import Label from "./dropdown-menu-label.svelte"; -import Content from "./dropdown-menu-content.svelte"; -import Shortcut from "./dropdown-menu-shortcut.svelte"; import RadioItem from "./dropdown-menu-radio-item.svelte"; import Separator from "./dropdown-menu-separator.svelte"; -import RadioGroup from "./dropdown-menu-radio-group.svelte"; +import Shortcut from "./dropdown-menu-shortcut.svelte"; import SubContent from "./dropdown-menu-sub-content.svelte"; import SubTrigger from "./dropdown-menu-sub-trigger.svelte"; -import CheckboxItem from "./dropdown-menu-checkbox-item.svelte"; const Sub = DropdownMenuPrimitive.Sub; const Root = DropdownMenuPrimitive.Root; const Trigger = DropdownMenuPrimitive.Trigger; const Group = DropdownMenuPrimitive.Group; +const RadioGroup = DropdownMenuPrimitive.RadioGroup; export { - Sub, - Root, - Item, - Label, - Group, - Trigger, - Content, - Shortcut, - Separator, - RadioItem, - SubContent, - SubTrigger, - RadioGroup, CheckboxItem, - // + Content, Root as DropdownMenu, - Sub as DropdownMenuSub, + CheckboxItem as DropdownMenuCheckboxItem, + Content as DropdownMenuContent, + Group as DropdownMenuGroup, + GroupHeading as DropdownMenuGroupHeading, Item as DropdownMenuItem, Label as DropdownMenuLabel, - Group as DropdownMenuGroup, - Content as DropdownMenuContent, - Trigger as DropdownMenuTrigger, - Shortcut as DropdownMenuShortcut, + RadioGroup as DropdownMenuRadioGroup, RadioItem as DropdownMenuRadioItem, Separator as DropdownMenuSeparator, - RadioGroup as DropdownMenuRadioGroup, + Shortcut as DropdownMenuShortcut, + Sub as DropdownMenuSub, SubContent as DropdownMenuSubContent, SubTrigger as DropdownMenuSubTrigger, - CheckboxItem as DropdownMenuCheckboxItem, + Trigger as DropdownMenuTrigger, + Group, + GroupHeading, + Item, + Label, + RadioGroup, + RadioItem, + Root, + Separator, + Shortcut, + Sub, + SubContent, + SubTrigger, + Trigger, }; diff --git a/services/app/src/lib/components/ui/input-otp/index.ts b/services/app/src/lib/components/ui/input-otp/index.ts new file mode 100644 index 0000000..e9ae273 --- /dev/null +++ b/services/app/src/lib/components/ui/input-otp/index.ts @@ -0,0 +1,15 @@ +import Root from "./input-otp.svelte"; +import Group from "./input-otp-group.svelte"; +import Slot from "./input-otp-slot.svelte"; +import Separator from "./input-otp-separator.svelte"; + +export { + Root, + Group, + Slot, + Separator, + Root as InputOTP, + Group as InputOTPGroup, + Slot as InputOTPSlot, + Separator as InputOTPSeparator, +}; diff --git a/services/app/src/lib/components/ui/input-otp/input-otp-group.svelte b/services/app/src/lib/components/ui/input-otp/input-otp-group.svelte new file mode 100644 index 0000000..7ef58a5 --- /dev/null +++ b/services/app/src/lib/components/ui/input-otp/input-otp-group.svelte @@ -0,0 +1,16 @@ + + +
+ {@render children?.()} +
diff --git a/services/app/src/lib/components/ui/input-otp/input-otp-separator.svelte b/services/app/src/lib/components/ui/input-otp/input-otp-separator.svelte new file mode 100644 index 0000000..8e99e58 --- /dev/null +++ b/services/app/src/lib/components/ui/input-otp/input-otp-separator.svelte @@ -0,0 +1,19 @@ + + +
+ {#if children} + {@render children?.()} + {:else} + + {/if} +
diff --git a/services/app/src/lib/components/ui/input-otp/input-otp-slot.svelte b/services/app/src/lib/components/ui/input-otp/input-otp-slot.svelte new file mode 100644 index 0000000..f5c6035 --- /dev/null +++ b/services/app/src/lib/components/ui/input-otp/input-otp-slot.svelte @@ -0,0 +1,30 @@ + + + + {cell.char} + {#if cell.hasFakeCaret} +
+ +
+ {/if} +
diff --git a/services/app/src/lib/components/ui/input-otp/input-otp.svelte b/services/app/src/lib/components/ui/input-otp/input-otp.svelte new file mode 100644 index 0000000..8b59b3f --- /dev/null +++ b/services/app/src/lib/components/ui/input-otp/input-otp.svelte @@ -0,0 +1,22 @@ + + + diff --git a/services/app/src/lib/components/ui/input/index.ts b/services/app/src/lib/components/ui/input/index.ts new file mode 100644 index 0000000..f47b6d3 --- /dev/null +++ b/services/app/src/lib/components/ui/input/index.ts @@ -0,0 +1,7 @@ +import Root from "./input.svelte"; + +export { + Root, + // + Root as Input, +}; diff --git a/services/app/src/lib/components/ui/input/input.svelte b/services/app/src/lib/components/ui/input/input.svelte new file mode 100644 index 0000000..dae6b9a --- /dev/null +++ b/services/app/src/lib/components/ui/input/input.svelte @@ -0,0 +1,46 @@ + + +{#if type === 'file'} + +{:else} + +{/if} diff --git a/services/app/src/lib/components/ui/label/index.ts b/services/app/src/lib/components/ui/label/index.ts old mode 100644 new mode 100755 diff --git a/services/app/src/lib/components/ui/label/label.svelte b/services/app/src/lib/components/ui/label/label.svelte old mode 100644 new mode 100755 index 2a7d479..f251d98 --- a/services/app/src/lib/components/ui/label/label.svelte +++ b/services/app/src/lib/components/ui/label/label.svelte @@ -1,21 +1,26 @@ - + {@render children?.()} + {#if required} + * + {/if} diff --git a/services/app/src/lib/components/ui/pagination/index.ts b/services/app/src/lib/components/ui/pagination/index.ts old mode 100644 new mode 100755 diff --git a/services/app/src/lib/components/ui/pagination/pagination-content.svelte b/services/app/src/lib/components/ui/pagination/pagination-content.svelte old mode 100644 new mode 100755 index 9279558..6ba3cd3 --- a/services/app/src/lib/components/ui/pagination/pagination-content.svelte +++ b/services/app/src/lib/components/ui/pagination/pagination-content.svelte @@ -1,13 +1,16 @@ -
    - +
      + {@render children?.()}
    diff --git a/services/app/src/lib/components/ui/pagination/pagination-ellipsis.svelte b/services/app/src/lib/components/ui/pagination/pagination-ellipsis.svelte old mode 100644 new mode 100755 index fe064c3..e2155e1 --- a/services/app/src/lib/components/ui/pagination/pagination-ellipsis.svelte +++ b/services/app/src/lib/components/ui/pagination/pagination-ellipsis.svelte @@ -1,19 +1,22 @@ diff --git a/services/app/src/lib/components/ui/pagination/pagination-item.svelte b/services/app/src/lib/components/ui/pagination/pagination-item.svelte old mode 100644 new mode 100755 index 009ad17..09c1076 --- a/services/app/src/lib/components/ui/pagination/pagination-item.svelte +++ b/services/app/src/lib/components/ui/pagination/pagination-item.svelte @@ -1,13 +1,14 @@ -
  • - +
  • + {@render children?.()}
  • diff --git a/services/app/src/lib/components/ui/pagination/pagination-link.svelte b/services/app/src/lib/components/ui/pagination/pagination-link.svelte old mode 100644 new mode 100755 index ebec229..4a41b47 --- a/services/app/src/lib/components/ui/pagination/pagination-link.svelte +++ b/services/app/src/lib/components/ui/pagination/pagination-link.svelte @@ -1,35 +1,36 @@ +{#snippet Fallback()} + {page.value} +{/snippet} + - {page.value} - + children={children || Fallback} + {...restProps} +/> diff --git a/services/app/src/lib/components/ui/pagination/pagination-next-button.svelte b/services/app/src/lib/components/ui/pagination/pagination-next-button.svelte old mode 100644 new mode 100755 index 02a43cf..84eee90 --- a/services/app/src/lib/components/ui/pagination/pagination-next-button.svelte +++ b/services/app/src/lib/components/ui/pagination/pagination-next-button.svelte @@ -1,10 +1,31 @@ - - - +{#snippet Fallback()} + Next + +{/snippet} + + diff --git a/services/app/src/lib/components/ui/pagination/pagination-prev-button.svelte b/services/app/src/lib/components/ui/pagination/pagination-prev-button.svelte old mode 100644 new mode 100755 index 23f0c04..2b1a991 --- a/services/app/src/lib/components/ui/pagination/pagination-prev-button.svelte +++ b/services/app/src/lib/components/ui/pagination/pagination-prev-button.svelte @@ -1,10 +1,31 @@ - - - +{#snippet Fallback()} + + Previous +{/snippet} + + diff --git a/services/app/src/lib/components/ui/pagination/pagination.svelte b/services/app/src/lib/components/ui/pagination/pagination.svelte old mode 100644 new mode 100755 index 1cbcce3..4cdc9b1 --- a/services/app/src/lib/components/ui/pagination/pagination.svelte +++ b/services/app/src/lib/components/ui/pagination/pagination.svelte @@ -3,31 +3,23 @@ import { cn } from "$lib/utils.js"; - type $$Props = PaginationPrimitive.Props; - type $$Events = PaginationPrimitive.Events; - - let className: $$Props["class"] = undefined; - export let count: $$Props["count"] = 0; - export let perPage: $$Props["perPage"] = 10; - export let page: $$Props["page"] = 1; - export let siblingCount: $$Props["siblingCount"] = 1; - export { className as class }; - - $: currentPage = page; + let { + ref = $bindable(null), + class: className, + count = 0, + perPage = 10, + page = $bindable(1), + siblingCount = 1, + ...restProps + }: PaginationPrimitive.RootProps = $props(); - - + {...restProps} +/> diff --git a/services/app/src/lib/components/ui/popover/index.ts b/services/app/src/lib/components/ui/popover/index.ts new file mode 100644 index 0000000..63aecf9 --- /dev/null +++ b/services/app/src/lib/components/ui/popover/index.ts @@ -0,0 +1,17 @@ +import { Popover as PopoverPrimitive } from "bits-ui"; +import Content from "./popover-content.svelte"; +const Root = PopoverPrimitive.Root; +const Trigger = PopoverPrimitive.Trigger; +const Close = PopoverPrimitive.Close; + +export { + Root, + Content, + Trigger, + Close, + // + Root as Popover, + Content as PopoverContent, + Trigger as PopoverTrigger, + Close as PopoverClose, +}; diff --git a/services/app/src/lib/components/ui/popover/popover-content.svelte b/services/app/src/lib/components/ui/popover/popover-content.svelte new file mode 100644 index 0000000..d2fbace --- /dev/null +++ b/services/app/src/lib/components/ui/popover/popover-content.svelte @@ -0,0 +1,28 @@ + + + + + diff --git a/services/app/src/lib/components/ui/range-calendar/index.ts b/services/app/src/lib/components/ui/range-calendar/index.ts new file mode 100644 index 0000000..d949b05 --- /dev/null +++ b/services/app/src/lib/components/ui/range-calendar/index.ts @@ -0,0 +1,32 @@ +import { RangeCalendar as RangeCalendarPrimitive } from "bits-ui"; +import Root from "./range-calendar.svelte"; +import Cell from "./range-calendar-cell.svelte"; +import Day from "./range-calendar-day.svelte"; +import Grid from "./range-calendar-grid.svelte"; +import Header from "./range-calendar-header.svelte"; +import Months from "./range-calendar-months.svelte"; +import GridRow from "./range-calendar-grid-row.svelte"; +import Heading from "./range-calendar-heading.svelte"; +import HeadCell from "./range-calendar-head-cell.svelte"; +import NextButton from "./range-calendar-next-button.svelte"; +import PrevButton from "./range-calendar-prev-button.svelte"; + +const GridHead = RangeCalendarPrimitive.GridHead; +const GridBody = RangeCalendarPrimitive.GridBody; + +export { + Day, + Cell, + Grid, + Header, + Months, + GridRow, + Heading, + GridBody, + GridHead, + HeadCell, + NextButton, + PrevButton, + // + Root as RangeCalendar, +}; diff --git a/services/app/src/lib/components/ui/range-calendar/range-calendar-cell.svelte b/services/app/src/lib/components/ui/range-calendar/range-calendar-cell.svelte new file mode 100644 index 0000000..596bd71 --- /dev/null +++ b/services/app/src/lib/components/ui/range-calendar/range-calendar-cell.svelte @@ -0,0 +1,19 @@ + + + diff --git a/services/app/src/lib/components/ui/range-calendar/range-calendar-day.svelte b/services/app/src/lib/components/ui/range-calendar/range-calendar-day.svelte new file mode 100644 index 0000000..09650e5 --- /dev/null +++ b/services/app/src/lib/components/ui/range-calendar/range-calendar-day.svelte @@ -0,0 +1,34 @@ + + + diff --git a/services/app/src/lib/components/ui/range-calendar/range-calendar-grid-row.svelte b/services/app/src/lib/components/ui/range-calendar/range-calendar-grid-row.svelte new file mode 100644 index 0000000..3286b2a --- /dev/null +++ b/services/app/src/lib/components/ui/range-calendar/range-calendar-grid-row.svelte @@ -0,0 +1,12 @@ + + + diff --git a/services/app/src/lib/components/ui/range-calendar/range-calendar-grid.svelte b/services/app/src/lib/components/ui/range-calendar/range-calendar-grid.svelte new file mode 100644 index 0000000..7379a71 --- /dev/null +++ b/services/app/src/lib/components/ui/range-calendar/range-calendar-grid.svelte @@ -0,0 +1,16 @@ + + + diff --git a/services/app/src/lib/components/ui/range-calendar/range-calendar-head-cell.svelte b/services/app/src/lib/components/ui/range-calendar/range-calendar-head-cell.svelte new file mode 100644 index 0000000..3c5b869 --- /dev/null +++ b/services/app/src/lib/components/ui/range-calendar/range-calendar-head-cell.svelte @@ -0,0 +1,16 @@ + + + diff --git a/services/app/src/lib/components/ui/range-calendar/range-calendar-header.svelte b/services/app/src/lib/components/ui/range-calendar/range-calendar-header.svelte new file mode 100644 index 0000000..be2bc82 --- /dev/null +++ b/services/app/src/lib/components/ui/range-calendar/range-calendar-header.svelte @@ -0,0 +1,16 @@ + + + diff --git a/services/app/src/lib/components/ui/range-calendar/range-calendar-heading.svelte b/services/app/src/lib/components/ui/range-calendar/range-calendar-heading.svelte new file mode 100644 index 0000000..a39e4e2 --- /dev/null +++ b/services/app/src/lib/components/ui/range-calendar/range-calendar-heading.svelte @@ -0,0 +1,16 @@ + + + diff --git a/services/app/src/lib/components/ui/range-calendar/range-calendar-months.svelte b/services/app/src/lib/components/ui/range-calendar/range-calendar-months.svelte new file mode 100644 index 0000000..4cd0ed7 --- /dev/null +++ b/services/app/src/lib/components/ui/range-calendar/range-calendar-months.svelte @@ -0,0 +1,20 @@ + + +
    + {@render children?.()} +
    diff --git a/services/app/src/lib/components/ui/range-calendar/range-calendar-next-button.svelte b/services/app/src/lib/components/ui/range-calendar/range-calendar-next-button.svelte new file mode 100644 index 0000000..c627c60 --- /dev/null +++ b/services/app/src/lib/components/ui/range-calendar/range-calendar-next-button.svelte @@ -0,0 +1,28 @@ + + +{#snippet Fallback()} + +{/snippet} + + diff --git a/services/app/src/lib/components/ui/range-calendar/range-calendar-prev-button.svelte b/services/app/src/lib/components/ui/range-calendar/range-calendar-prev-button.svelte new file mode 100644 index 0000000..b457412 --- /dev/null +++ b/services/app/src/lib/components/ui/range-calendar/range-calendar-prev-button.svelte @@ -0,0 +1,28 @@ + + +{#snippet Fallback()} + +{/snippet} + + diff --git a/services/app/src/lib/components/ui/range-calendar/range-calendar.svelte b/services/app/src/lib/components/ui/range-calendar/range-calendar.svelte new file mode 100644 index 0000000..f255351 --- /dev/null +++ b/services/app/src/lib/components/ui/range-calendar/range-calendar.svelte @@ -0,0 +1,57 @@ + + + + {#snippet children({ months, weekdays })} + + + + + + + {#each months as month (month)} + + + + {#each weekdays as weekday (weekday)} + + {weekday.slice(0, 2)} + + {/each} + + + + {#each month.weeks as weekDates (weekDates)} + + {#each weekDates as date (date)} + + + + {/each} + + {/each} + + + {/each} + + {/snippet} + diff --git a/services/app/src/lib/components/ui/select/index.ts b/services/app/src/lib/components/ui/select/index.ts old mode 100644 new mode 100755 index 327541c..f31b8ae --- a/services/app/src/lib/components/ui/select/index.ts +++ b/services/app/src/lib/components/ui/select/index.ts @@ -1,34 +1,34 @@ import { Select as SelectPrimitive } from "bits-ui"; -import Label from "./select-label.svelte"; +import GroupHeading from "./select-group-heading.svelte"; import Item from "./select-item.svelte"; import Content from "./select-content.svelte"; import Trigger from "./select-trigger.svelte"; import Separator from "./select-separator.svelte"; +import ScrollDownButton from "./select-scroll-down-button.svelte"; +import ScrollUpButton from "./select-scroll-up-button.svelte"; const Root = SelectPrimitive.Root; const Group = SelectPrimitive.Group; -const Input = SelectPrimitive.Input; -const Value = SelectPrimitive.Value; export { Root, Group, - Input, - Label, + GroupHeading, Item, - Value, Content, Trigger, Separator, + ScrollDownButton, + ScrollUpButton, // Root as Select, Group as SelectGroup, - Input as SelectInput, - Label as SelectLabel, + GroupHeading as SelectGroupHeading, Item as SelectItem, - Value as SelectValue, Content as SelectContent, Trigger as SelectTrigger, Separator as SelectSeparator, + ScrollDownButton as SelectScrollDownButton, + ScrollUpButton as SelectScrollUpButton, }; diff --git a/services/app/src/lib/components/ui/select/select-content.svelte b/services/app/src/lib/components/ui/select/select-content.svelte old mode 100644 new mode 100755 index f4148df..56916ef --- a/services/app/src/lib/components/ui/select/select-content.svelte +++ b/services/app/src/lib/components/ui/select/select-content.svelte @@ -1,39 +1,39 @@ - -
    - -
    -
    + + + + + {@render children?.()} + + + + diff --git a/services/app/src/lib/components/ui/select/select-group-heading.svelte b/services/app/src/lib/components/ui/select/select-group-heading.svelte new file mode 100644 index 0000000..7984bef --- /dev/null +++ b/services/app/src/lib/components/ui/select/select-group-heading.svelte @@ -0,0 +1,16 @@ + + + diff --git a/services/app/src/lib/components/ui/select/select-item.svelte b/services/app/src/lib/components/ui/select/select-item.svelte old mode 100644 new mode 100755 index 92c46ed..2acc257 --- a/services/app/src/lib/components/ui/select/select-item.svelte +++ b/services/app/src/lib/components/ui/select/select-item.svelte @@ -1,60 +1,37 @@ - {#if labelSelected} - {#if disabled} - - - + {#snippet children({ selected, highlighted })} + + {#if selected && !restProps.disabled} + + {/if} + + {#if childrenProp} + {@render childrenProp({ selected, highlighted })} {:else} - - - - - + {label || value} {/if} - {/if} - - {label ? label : value} - + {/snippet} diff --git a/services/app/src/lib/components/ui/select/select-label.svelte b/services/app/src/lib/components/ui/select/select-label.svelte deleted file mode 100644 index d966450..0000000 --- a/services/app/src/lib/components/ui/select/select-label.svelte +++ /dev/null @@ -1,16 +0,0 @@ - - - - - diff --git a/services/app/src/lib/components/ui/select/select-scroll-down-button.svelte b/services/app/src/lib/components/ui/select/select-scroll-down-button.svelte new file mode 100644 index 0000000..c17d5d1 --- /dev/null +++ b/services/app/src/lib/components/ui/select/select-scroll-down-button.svelte @@ -0,0 +1,19 @@ + + + + + diff --git a/services/app/src/lib/components/ui/select/select-scroll-up-button.svelte b/services/app/src/lib/components/ui/select/select-scroll-up-button.svelte new file mode 100644 index 0000000..8ba08c0 --- /dev/null +++ b/services/app/src/lib/components/ui/select/select-scroll-up-button.svelte @@ -0,0 +1,19 @@ + + + + + diff --git a/services/app/src/lib/components/ui/select/select-separator.svelte b/services/app/src/lib/components/ui/select/select-separator.svelte old mode 100644 new mode 100755 index bc518e6..38a3ab0 --- a/services/app/src/lib/components/ui/select/select-separator.svelte +++ b/services/app/src/lib/components/ui/select/select-separator.svelte @@ -1,11 +1,13 @@ - + diff --git a/services/app/src/lib/components/ui/select/select-trigger.svelte b/services/app/src/lib/components/ui/select/select-trigger.svelte old mode 100644 new mode 100755 index 8714e0b..64908f1 --- a/services/app/src/lib/components/ui/select/select-trigger.svelte +++ b/services/app/src/lib/components/ui/select/select-trigger.svelte @@ -1,23 +1,27 @@ span]:line-clamp-1', + 'flex h-10 w-full min-w-full items-center justify-between gap-8 rounded-lg border-transparent bg-[#F2F4F7] px-3 py-2 pl-3 pr-2 text-sm transition-colors placeholder:text-muted-foreground hover:bg-[#e1e2e6] disabled:cursor-not-allowed disabled:opacity-100 data-[placeholder]:italic data-[placeholder]:text-muted-foreground data-dark:bg-[#42464e] [&>span]:line-clamp-1', className )} - {...$$restProps} - let:builder - on:click - on:keydown + {...restProps} > - + {@render children?.()} + {#if showArrow} + + {/if} diff --git a/services/app/src/lib/components/ui/separator/index.ts b/services/app/src/lib/components/ui/separator/index.ts new file mode 100644 index 0000000..82442d2 --- /dev/null +++ b/services/app/src/lib/components/ui/separator/index.ts @@ -0,0 +1,7 @@ +import Root from "./separator.svelte"; + +export { + Root, + // + Root as Separator, +}; diff --git a/services/app/src/lib/components/ui/separator/separator.svelte b/services/app/src/lib/components/ui/separator/separator.svelte new file mode 100644 index 0000000..839494d --- /dev/null +++ b/services/app/src/lib/components/ui/separator/separator.svelte @@ -0,0 +1,22 @@ + + + diff --git a/services/app/src/lib/components/ui/skeleton/index.ts b/services/app/src/lib/components/ui/skeleton/index.ts old mode 100644 new mode 100755 diff --git a/services/app/src/lib/components/ui/skeleton/skeleton.svelte b/services/app/src/lib/components/ui/skeleton/skeleton.svelte old mode 100644 new mode 100755 index 5a6f269..4089b49 --- a/services/app/src/lib/components/ui/skeleton/skeleton.svelte +++ b/services/app/src/lib/components/ui/skeleton/skeleton.svelte @@ -1,15 +1,17 @@
    diff --git a/services/app/src/lib/components/ui/sonner/CustomToastDesc.svelte b/services/app/src/lib/components/ui/sonner/CustomToastDesc.svelte old mode 100644 new mode 100755 index 59249ea..c51a611 --- a/services/app/src/lib/components/ui/sonner/CustomToastDesc.svelte +++ b/services/app/src/lib/components/ui/sonner/CustomToastDesc.svelte @@ -3,28 +3,32 @@ import CheckDoneIcon from '$lib/icons/CheckDoneIcon.svelte'; import CopyIcon from '$lib/icons/CopyIcon.svelte'; - export let description: string; - export let requestID: string; + interface Props { + description: string; + requestID: string; + } - let requestIDCopied = false; - let requestIDCopiedTimeout: ReturnType; + let { description, requestID }: Props = $props(); + + let requestIDCopied = $state(false); + let requestIDCopiedTimeout: ReturnType | undefined = $state();
    -

    {description}

    +

    {description}

    {#if requestID}
    {requestID} - {#if !$page.data.hideBreadcrumbs} + {#if !page.data.hideBreadcrumbs} {/if} + + {#if !page.data.hideUserDetailsBtn} + + {/if}
    - + {@render children?.()}
    - + {#if page.data.rightDock} + + {/if}
@@ -174,7 +214,7 @@ {#if $showLoadingOverlay}
diff --git a/services/app/src/routes/(main)/BreadcrumbsBar.svelte b/services/app/src/routes/(main)/BreadcrumbsBar.svelte old mode 100644 new mode 100755 index 6bc9d4c..2a11e97 --- a/services/app/src/routes/(main)/BreadcrumbsBar.svelte +++ b/services/app/src/routes/(main)/BreadcrumbsBar.svelte @@ -1,182 +1,279 @@
- {#if PUBLIC_IS_LOCAL === 'false'} + {#if page.url.pathname.startsWith('/system')} +
+ + System +
+ {:else if page.url.pathname.startsWith('/chat')} +
+ + JamAI Chat +
+ {:else} - - + + {#snippet child({ props })} + + {/snippet} - {#each $page.data.userData?.member_of ?? [] as org} - { - if (org?.organization_id !== $activeOrganization?.organization_id) { - $activeOrganization = org; - await tick(); - if ($page.route.id?.includes('/project/[project_id]')) { - goto('/project'); - } else { - invalidate('layout:root'); +

Organization

+ + {#each (page.data.user as User)?.org_memberships ?? [] as orgMembership} + {@const org = (page.data.user as User)?.organizations.find( + (org) => org.id === orgMembership.organization_id + )} + {#if org} + { + if (org?.id !== $activeOrganization?.id) { + activeOrganization.setOrgCookie(org.id); + if (page.route.id?.includes('/project/[project_id]')) { + goto('/project'); + } else { + invalidate('layout:root'); + } } - } - }} - class="flex items-center gap-1 text-xs cursor-pointer {$activeOrganization?.organization_id === - org.organization_id - ? 'bg-[#F7F7F7]' - : ''} rounded-sm" - > - - {org.organization_name} + }} + class="flex cursor-pointer items-center gap-1 text-xs {$activeOrganization?.id === + org.id + ? '!bg-[#D0F7FB]' + : ''} rounded-sm" + > + + {org.name} - - + +
+ {/if} {/each}
- - - - New Organization - + + {#snippet child({ props })} + + {@render joinOrgIcon('h-4 w-4')} + {m['breadcrumbs.org_join_btn']()} + + {/snippet} + + + {#snippet child({ props })} +
+ + {m['breadcrumbs.org_create_btn']()} + + {/snippet} - {:else} -
- - Default Organization -
{/if} - / - {#if $page.route.id?.startsWith('/(main)/project')} + + {#if !page.url.pathname.startsWith('/chat') && $activeOrganization?.id === PUBLIC_ADMIN_ORGANIZATION_ID} + + {/if} + + {#if page.route.id?.startsWith('/(main)/project')} - - Projects + + {m['project.heading']()} - {:else if $page.url.pathname.startsWith('/organization')} + {:else if page.url.pathname.startsWith('/organization')}
- + Organization
- {:else if $page.url.pathname.startsWith('/home')} + {:else if page.url.pathname.startsWith('/analytics')}
- - Home + + Analytics
- {:else if $page.route.id?.endsWith('/template')} -
+ {:else if page.url.pathname.startsWith('/template')} + Discover -
- {:else if $page.route.id?.startsWith('/(main)/template/[template_id]')} + + {/if} + + {#if page.route.id?.startsWith('/(main)/(cloud)/template/[template_id]')} + / - - + + - {$page.data?.templateData?.data?.name ?? $page.params.template_id} + {page.data?.templateData?.data?.name ?? page.params.template_id} - {/if} - - {#if $page.route.id?.startsWith('/(main)/project/[project_id]')} + {:else if page.route.id?.startsWith('/(main)/project/[project_id]')} / - + {$activeProject?.name ?? - ($loadingProjectData.loading ? 'Loading...' : $page.params.project_id)} + ($loadingProjectData.loading ? 'Loading...' : page.params.project_id)} + {:else if page.route.id?.startsWith('/(main)/template/[template_id]')} + + + + + {page.data?.templateData?.data?.name ?? page.params.template_id} + {/if} - {#if $page.route.id?.endsWith('/action-table/[table_id]')} + {#if page.route.id?.endsWith('/action-table/[table_id]')} /
- {$page.params.table_id} + {page.params.table_id}
- {:else if $page.route.id?.endsWith('/knowledge-table/[table_id]')} + {:else if page.route.id?.endsWith('/knowledge-table/[table_id]')} /
- {$page.params.table_id} + {page.params.table_id}
- {:else if $page.route.id?.endsWith('/chat-table/[table_id]')} + {:else if page.route.id?.endsWith('/chat-table/[table_id]')} /
- {$page.params.table_id} + {page.params.table_id}
{/if}
+{#snippet joinOrgIcon(className = '')} + + + + + + + +{/snippet} + diff --git a/services/app/src/routes/(main)/chat/[project_id]/[conversation_id]/+page.ts b/services/app/src/routes/(main)/chat/[project_id]/[conversation_id]/+page.ts new file mode 100644 index 0000000..a4d609b --- /dev/null +++ b/services/app/src/routes/(main)/chat/[project_id]/[conversation_id]/+page.ts @@ -0,0 +1,5 @@ +export async function load() { + return { + hideUserDetailsBtn: true + }; +} diff --git a/services/app/src/routes/(main)/chat/chat.svelte.ts b/services/app/src/routes/(main)/chat/chat.svelte.ts new file mode 100644 index 0000000..87c65f1 --- /dev/null +++ b/services/app/src/routes/(main)/chat/chat.svelte.ts @@ -0,0 +1,989 @@ +import { browser } from '$app/environment'; +import { goto } from '$app/navigation'; +import { page } from '$app/state'; +import { env as publicEnv } from '$env/dynamic/public'; +import logger from '$lib/logger'; +import type { + ChatReferences, + ChatThreads, + Conversation, + GenTable, + GenTableStreamEvent +} from '$lib/types'; +import { waitForElement } from '$lib/utils'; +import { tick } from 'svelte'; +import { v4 as uuidv4 } from 'uuid'; + +import { CustomToastDesc, toast } from '$lib/components/ui/sonner'; +import { fileColumnFiletypes } from '$lib/constants'; +import axios from 'axios'; + +const { PUBLIC_JAMAI_URL } = publicEnv; + +export class ChatState { + // Agents + agents: Record = $state({}); + + // Conversations + fetchController: AbortController | null = null; + conversations: Conversation[] = $state([]); + loadingConvsError: { status: number; message: string } | null = $state(null); + isLoadingConvs = $state(true); + isLoadingMoreConvs = $state(false); + moreConvsFinished = false; + currentOffsetConvs = 0; + private limitConvs = 50; + searchQuery = $state(''); + isLoadingSearch = $state(false); + + // Actual chat + agent = $state< + (Omit & { agent_id: string }) | null + >(null); + conversation: Conversation | null = $state(null); + loadingConversation: any = $state(true); + messages: ChatThreads['threads'] = $state({}); + loadingMessages: any = $state(true); + chatWindow: HTMLDivElement | null = $state(null); + chatForm: HTMLFormElement | null = $state(null); + chat: HTMLTextAreaElement | null = $state(null); + chatMessage = $state(''); + editingContent: { + rowID: string; + columnID: string; + fileColumns: Record; + } | null = $state(null); + generationStatus: string[] | null = $state(null); + isLoadingMoreMessages = $state(false); + moreMessagesFinished = false; + currentOffsetMessages = 0; + private limitMessages = 10; + fileColumns = $derived( + (!page.params.conversation_id ? this.agent : this.conversation)?.cols.filter( + (col) => col.dtype === 'image' || col.dtype === 'audio' || col.dtype === 'document' + ) ?? [] + ); + uploadColumns: Record = $state({}); + loadedStreams: Record> = $state({}); + latestStreams: Record> = $state({}); + loadedReferences: Record> | null = null; + + private getConvController: AbortController | null = null; + private getMessagesController: AbortController | null = null; + + async getConversation() { + if (!page.params.project_id || !page.params.conversation_id) return; + + this.getConvController?.abort('Duplicate'); + this.getConvController = new AbortController(); + + try { + const response = await fetch( + `${PUBLIC_JAMAI_URL}/api/owl/conversations?${new URLSearchParams([ + ['conversation_id', page.params.conversation_id] + ])}`, + { + headers: { + 'x-project-id': page.params.project_id + }, + signal: this.getConvController.signal + } + ); + const responseBody = await response.json(); + + if (response.ok) { + this.conversation = responseBody; + this.loadingConversation = false; + } else { + this.loadingConversation = responseBody; + logger.error('CHAT_GET_CONV', responseBody); + toast.error('Failed to load conversation', { + id: responseBody.message || JSON.stringify(responseBody), + description: CustomToastDesc as any, + componentProps: { + description: responseBody.message || JSON.stringify(responseBody), + requestID: responseBody.request_id + } + }); + } + } catch (err) { + //* don't show abort errors in browser + if (err !== 'Duplicate') { + console.error(err); + } + } + } + + async getMessages(scroll = true) { + if (!page.params.project_id || !page.params.conversation_id) return; + + this.getMessagesController?.abort('Duplicate'); + this.getMessagesController = new AbortController(); + + try { + const searchParams = new URLSearchParams([ + ['conversation_id', page.params.conversation_id], + ['offset', this.currentOffsetMessages.toString()], + ['limit', this.limitMessages.toString()], + ['order_ascending', 'false'] + // ['organization_id', $activeOrganization.id] + ]); + + const response = await fetch( + `${PUBLIC_JAMAI_URL}/api/owl/conversations/threads?${searchParams}`, + { + headers: { + 'x-project-id': page.params.project_id + }, + signal: this.getMessagesController.signal + } + ); + const responseBody = await response.json(); + + this.currentOffsetMessages += this.limitMessages; + + if (response.ok) { + this.messages = responseBody.threads; + + this.moreMessagesFinished = true; + + //! Old paginated response + /* if (responseBody.items.length) { + if (this.chatWindow && !scroll) { + this.chatWindow.scrollTop += 1; + } + this.messages = [...responseBody.items.reverse(), ...this.messages]; + } else { + this.moreMessagesFinished = true; + } */ + this.loadingMessages = false; + } else { + this.loadingMessages = responseBody; + logger.error('CHAT_GET_MESSAGES', responseBody); + toast.error('Failed to load messages', { + id: responseBody.message || JSON.stringify(responseBody), + description: CustomToastDesc as any, + componentProps: { + description: responseBody.message || JSON.stringify(responseBody), + requestID: responseBody.request_id + } + }); + } + + if (scroll) { + await tick(); + if (this.messages.length) { + await waitForElement('[data-testid=chat-message]'); + } + await this.scrollChatToBottom(); + } + } catch (err) { + //* don't show abort errors in browser + if (err !== 'Duplicate') { + console.error(err); + } + } + } + + async sendMessage() { + if ( + this.generationStatus || + (!this.chatMessage.trim() && Object.values(chatState.uploadColumns).every((col) => !col.uri)) + ) + return; + + const cachedPrompt = this.chatMessage; + const cachedFiles = structuredClone($state.snapshot(this.uploadColumns)); + this.chatMessage = ''; + this.uploadColumns = {}; + + if (this.chat) this.chat.style.height = '3rem'; + + //? Get agent threads + if (!page.params.conversation_id) { + if (!this.agent) return; + const agentThreadRes = await fetch( + `${PUBLIC_JAMAI_URL}/api/owl/conversations/threads?${new URLSearchParams([ + ['conversation_id', this.agent.agent_id] + ])}`, + { + headers: { + 'x-project-id': page.params.project_id ?? page.url.searchParams.get('project_id') + } + } + ); + const agentThreadBody = await agentThreadRes.json(); + + if (agentThreadRes.ok) { + this.messages = Object.fromEntries( + Object.entries((agentThreadBody as ChatThreads).threads).map(([outCol, thread]) => [ + outCol, + { + ...thread, + thread: [ + ...thread.thread, + { + row_id: uuidv4(), + role: 'user', + content: [ + { + type: 'text' as const, + text: cachedPrompt + }, + ...Object.entries(cachedFiles).map(([uploadColumn, val]) => ({ + type: 'input_s3' as const, + uri: val.uri, + column_name: uploadColumn + })) + ], + name: null, + user_prompt: cachedPrompt, + references: null + } + ] + } + ]) + ); + } else { + logger.error('CHAT_CONV_GETAGENT', agentThreadBody); + toast.error('Failed to send message', { + id: agentThreadBody.message || JSON.stringify(agentThreadBody), + description: CustomToastDesc as any, + componentProps: { + description: agentThreadBody.message || JSON.stringify(agentThreadBody), + requestID: agentThreadBody.request_id + } + }); + return; + } + } else { + //? Add user message to the chat + this.messages = Object.fromEntries( + Object.entries(this.messages).map(([outCol, thread]) => [ + outCol, + { + ...thread, + thread: [ + ...thread.thread, + { + row_id: uuidv4(), + role: 'user', + content: [ + { + type: 'text' as const, + text: cachedPrompt + }, + ...Object.entries(cachedFiles).map(([uploadColumn, val]) => ({ + type: 'input_s3' as const, + uri: val.uri, + column_name: uploadColumn + })) + ], + name: null, + user_prompt: cachedPrompt, + references: null + } + ] + } + ]) + ); + } + + this.generationStatus = ['new']; + this.loadedStreams = { + new: Object.fromEntries( + (page.params.conversation_id ? this.conversation! : this.agent!).cols + .map((col) => + col.gen_config?.object === 'gen_config.llm' && col.gen_config.multi_turn + ? [[col.id, []]] + : [] + ) + .flat() + ) + }; + this.latestStreams = { + new: Object.fromEntries( + (page.params.conversation_id ? this.conversation! : this.agent!).cols + .map((col) => + col.gen_config?.object === 'gen_config.llm' && col.gen_config.multi_turn + ? [[col.id, '']] + : [] + ) + .flat() + ) + }; + + //? Show user message + await tick(); + if (this.chatWindow) this.chatWindow.scrollTop = this.chatWindow.scrollHeight; + + //? Send message to the server + const apiUrl = page.params.conversation_id + ? '/api/owl/conversations/messages' + : '/api/owl/conversations'; + const response = await fetch(`${PUBLIC_JAMAI_URL}${apiUrl}`, { + method: 'POST', + headers: { + Accept: 'text/event-stream', + 'Content-Type': 'application/json', + 'x-project-id': page.params.project_id ?? page.url.searchParams.get('project_id') + }, + body: JSON.stringify({ + data: { + User: cachedPrompt, + ...Object.fromEntries( + Object.entries(cachedFiles).map(([uploadColumn, val]) => [uploadColumn, val.uri]) + ) + }, + agent_id: page.params.conversation_id ? undefined : this.agent?.agent_id, + conversation_id: page.params.conversation_id || undefined + }) + }); + + if (response.status != 200) { + const responseBody = await response.json(); + logger.error(this.conversation ? 'CHAT_MESSAGE_ADD' : 'CHAT_CONV_CREATE', responseBody); + toast.error('Failed to add message', { + id: responseBody.message || JSON.stringify(responseBody), + description: CustomToastDesc as any, + componentProps: { + description: responseBody.message || JSON.stringify(responseBody), + requestID: responseBody.request_id + } + }); + this.messages = Object.fromEntries( + Object.entries(this.messages).map(([outCol, thread]) => [ + outCol, + { + ...thread, + thread: thread.thread.slice(0, -1) + } + ]) + ); + this.chatMessage = cachedPrompt; + this.uploadColumns = cachedFiles; + } else { + const { row_id } = await this.parseStream( + response.body!.pipeThrough(new TextDecoderStream()).getReader(), + true + ); + + this.loadedStreams = Object.fromEntries( + Object.entries(this.loadedStreams).map(([row, colStreams]) => [ + row, + Object.fromEntries( + Object.entries(colStreams).map(([col, streams]) => [ + col, + [...streams, this.latestStreams[row][col]] + ]) + ) + ]) + ); + + this.messages = Object.fromEntries( + Object.entries(this.messages).map(([outCol, thread]) => { + const loadedStreamCol = this.loadedStreams.new[outCol]; + const colReferences = this.loadedReferences?.new?.[outCol] ?? null; + const userPrompt = thread.thread.at(-1)!; + + return [ + outCol, + { + ...thread, + thread: [ + ...thread.thread.slice(0, -1), + { + ...userPrompt, + row_id + }, + { + row_id, + role: 'assistant', + content: [ + { + type: 'text', + text: loadedStreamCol.join('') + } + ], + name: null, + user_prompt: null, + references: colReferences + } + ] + } + ]; + }) + ); + + this.getMessages(); + if (apiUrl === '/api/owl/conversations') { + chatState.refetchConversations(); + } + } + + this.generationStatus = null; + this.loadedStreams = {}; + this.latestStreams = {}; + this.loadedReferences = {}; + } + + async handleSaveFile(files: File[], editing = false) { + const formData = new FormData(); + formData.append('file', files[0]); + + const nextAvailableCol = this.fileColumns.find( + (col) => + !(editing ? this.editingContent?.fileColumns ?? {} : this.uploadColumns)[col.id]?.uri && + fileColumnFiletypes + .filter(({ type }) => col.dtype === type) + .map(({ ext }) => ext) + .includes('.' + (files[0].name.split('.').pop() ?? '').toLowerCase()) + ); + if (!nextAvailableCol) + return alert('No more files of this type can be uploaded: all columns filled.'); + + if (editing) { + if (this.editingContent) { + this.editingContent.fileColumns[nextAvailableCol.id] = { + uri: 'loading', + url: '' + }; + } + } else { + this.uploadColumns[nextAvailableCol.id] = { + uri: 'loading', + url: '' + }; + } + + try { + const uploadRes = await axios.post(`${PUBLIC_JAMAI_URL}/api/owl/files/upload`, formData, { + headers: { + 'Content-Type': 'multipart/form-data', + 'x-project-id': page.url.searchParams.get('project_id') ?? page.params.project_id + } + }); + + if (uploadRes.status !== 200) { + logger.error('CHAT_FILE_UPLOAD', { + file: files[0].name, + response: uploadRes.data + }); + alert( + 'Failed to upload file: ' + + (uploadRes.data.message || JSON.stringify(uploadRes.data)) + + `\nRequest ID: ${uploadRes.data.request_id}` + ); + return; + } else { + const urlResponse = await fetch(`/api/owl/files/url/thumb`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'x-project-id': page.params.project_id + }, + body: JSON.stringify({ + uris: [uploadRes.data.uri] + }) + }); + const urlBody = await urlResponse.json(); + + if (urlResponse.ok) { + if (editing) { + if (this.editingContent) { + this.editingContent.fileColumns[nextAvailableCol.id] = { + uri: uploadRes.data.uri, + url: urlBody.urls[0] + }; + } + } else { + this.uploadColumns[nextAvailableCol.id] = { + uri: uploadRes.data.uri, + url: urlBody.urls[0] + }; + } + } else { + if (editing) { + if (this.editingContent) { + this.editingContent.fileColumns[nextAvailableCol.id] = { + uri: uploadRes.data.uri, + url: '' + }; + } + } else { + this.uploadColumns[nextAvailableCol.id] = { uri: uploadRes.data.uri, url: '' }; + } + toast.error('Failed to retrieve thumbnail', { + id: urlBody.message || JSON.stringify(urlBody), + description: CustomToastDesc as any, + componentProps: { + description: urlBody.message || JSON.stringify(urlBody), + requestID: urlBody.request_id + } + }); + } + } + } catch (err) { + if (!(err instanceof axios.CanceledError && err.code == 'ERR_CANCELED')) { + //@ts-expect-error AxiosError + logger.error('CHAT_FILE_UPLOAD', err?.response?.data); + alert( + 'Failed to upload file: ' + + //@ts-expect-error AxiosError + (err?.response?.data.message || JSON.stringify(err?.response?.data)) + + //@ts-expect-error AxiosError + `\nRequest ID: ${err?.response?.data?.request_id}` + ); + } + } + } + + async regenMessage(rowID: string) { + if (this.generationStatus) return; + + const cachedMessages = $state.snapshot(this.messages); + + this.messages = Object.fromEntries( + Object.entries(this.messages).map(([outCol, thread]) => { + return [ + outCol, + { + ...thread, + thread: thread.thread.map((v) => + v.row_id === rowID && v.role === 'assistant' + ? { ...v, row_id: rowID, content: '', references: null } + : v + ) + } + ]; + }) + ); + + const longestThreadCol = Object.keys(chatState.messages).reduce( + (a, b) => + Array.isArray(chatState.messages[b].thread) && + (!a || chatState.messages[b].thread.length > chatState.messages[a].thread.length) + ? b + : a, + '' + ); + const rowsToRegen = this.messages[longestThreadCol].thread + .filter((m) => m.role !== 'User') + .slice(this.messages[longestThreadCol].thread.findIndex((m) => m.row_id === rowID)) + .map((m) => m.row_id); + this.loadedStreams = Object.fromEntries( + rowsToRegen.map((row) => [ + row, + Object.fromEntries( + this.conversation!.cols.map((col) => + col.gen_config?.object === 'gen_config.llm' && col.gen_config.multi_turn + ? [[col.id, []]] + : [] + ).flat() + ) + ]) + ); + this.latestStreams = Object.fromEntries( + rowsToRegen.map((row) => [ + row, + Object.fromEntries( + this.conversation!.cols.map((col) => + col.gen_config?.object === 'gen_config.llm' && col.gen_config.multi_turn + ? [[col.id, '']] + : [] + ).flat() + ) + ]) + ); + + this.generationStatus = rowsToRegen; + + //? Show user message + await tick(); + + //? Send message to the server + const response = await fetch(`${PUBLIC_JAMAI_URL}/api/owl/conversations/messages/regen`, { + method: 'POST', + headers: { + Accept: 'text/event-stream', + 'Content-Type': 'application/json', + 'x-project-id': page.params.project_id + }, + body: JSON.stringify({ + conversation_id: page.params.conversation_id, + row_id: rowID + }) + }); + + if (response.status != 200) { + const responseBody = await response.json(); + logger.error('CHAT_MESSAGE_REGEN', responseBody); + toast.error('Failed to regen message response', { + id: responseBody.message || JSON.stringify(responseBody), + description: CustomToastDesc as any, + componentProps: { + description: responseBody.message || JSON.stringify(responseBody), + requestID: responseBody.request_id + } + }); + this.messages = cachedMessages; + } else { + await this.parseStream(response.body!.pipeThrough(new TextDecoderStream()).getReader()); + + this.loadedStreams = Object.fromEntries( + Object.entries(this.loadedStreams).map(([row, colStreams]) => [ + row, + Object.fromEntries( + Object.entries(colStreams).map(([col, streams]) => [ + col, + [...streams, this.latestStreams[row][col]] + ]) + ) + ]) + ); + + this.messages = Object.fromEntries( + Object.entries(this.messages).map(([outCol, thread]) => { + return [ + outCol, + { + ...thread, + thread: [ + ...thread.thread.map((v) => + v.row_id === rowID && v.role === 'assistant' + ? { + ...v, + content: this.loadedStreams[v.row_id]?.[outCol]?.join('') ?? v.content, + references: this.loadedReferences?.[v.row_id]?.[outCol] ?? v.references + } + : v + ) + ] + } + ]; + }) + ); + + this.getMessages(); + } + + this.generationStatus = null; + this.loadedStreams = {}; + this.latestStreams = {}; + this.loadedReferences = {}; + } + + async saveEditedContent(newContent: Record) { + if (!this.editingContent || this.generationStatus) return; + + // const editingMessage = this.messages.find((m) => m.ID === this.editingContent?.rowID)!; + const response = await fetch(`${PUBLIC_JAMAI_URL}/api/owl/conversations/messages`, { + method: 'PATCH', + headers: { + 'Content-Type': 'application/json', + 'x-project-id': page.params.project_id + }, + body: JSON.stringify({ + conversation_id: page.params.conversation_id, + row_id: this.editingContent.rowID, + data: newContent + }) + }); + const responseBody = await response.json(); + + if (response.ok) { + if (this.editingContent.columnID === 'User') { + this.messages = Object.fromEntries( + Object.entries(this.messages).map(([column, thread]) => [ + column, + { + ...thread, + thread: thread.thread.map((m) => + m.row_id === this.editingContent?.rowID && m.role === 'user' + ? { + ...m, + content: Object.entries(newContent).map(([col, val]) => + col === 'User' + ? { type: 'text', text: newContent.User } + : { type: 'input_s3', uri: val, column_name: col } + ), + user_prompt: newContent.User + } + : m + ) + } + ]) + ); + } else { + this.messages = { + [this.editingContent.columnID]: { + ...this.messages[this.editingContent.columnID], + thread: this.messages[this.editingContent.columnID].thread.map((v) => + v.row_id === this.editingContent?.rowID + ? { + ...v, + content: Object.entries(newContent).map(([col, val]) => + col === this.editingContent?.columnID + ? { type: 'text', text: newContent[this.editingContent.columnID] } + : { type: 'input_s3', uri: val, column_name: col } + ) + } + : v + ) + } + }; + } + // editingMessage[this.editingContent.columnID] = newContent; + this.editingContent = null; + } else { + logger.error('CHAT_MESSAGE_EDIT', responseBody); + toast.error('Failed to edit message', { + id: responseBody.message || JSON.stringify(responseBody), + description: CustomToastDesc as any, + componentProps: { + description: responseBody.message || JSON.stringify(responseBody), + requestID: responseBody.request_id + } + }); + } + + this.getMessages(); + } + + resetChat() { + this.conversation = null; + this.loadingConversation = true; + this.messages = {}; + this.loadingMessages = true; + this.currentOffsetMessages = 0; + this.moreMessagesFinished = false; + this.uploadColumns = {}; + } + + private async parseStream(reader: ReadableStreamDefaultReader, newMessage = false) { + let rowID = ''; + let buffer = ''; + let renderCount = 0; + // eslint-disable-next-line no-constant-condition + while (true) { + try { + const { value, done } = await reader.read(); + if (done) break; + + buffer += value; + const lines = buffer.split('\n'); //? Split by \n to handle collation + buffer = lines.pop() || ''; + + let parsedEvent: + | { event: 'metadata'; data: Conversation } + | { event: undefined; data: GenTableStreamEvent } + | undefined = undefined; + for (const line of lines) { + if (line === '') { + if (parsedEvent) { + if (parsedEvent.event) { + if (parsedEvent.event === 'metadata') { + if (!page.params.conversation_id) { + goto( + `/chat/${page.url.searchParams.get('project_id')}/${encodeURIComponent(parsedEvent.data.conversation_id)}` + ); + setTimeout(() => chatState.refetchConversations(), 5000); + } + + this.conversation = parsedEvent.data; + } + } else if (parsedEvent.data.object === 'gen_table.completion.chunk') { + if (parsedEvent.data.choices[0].finish_reason) { + switch (parsedEvent.data.choices[0].finish_reason) { + case 'error': { + logger.error('CHAT_MESSAGE_ADDSTREAM', parsedEvent.data); + console.error('STREAMING_ERROR', parsedEvent.data); + alert( + `Error while streaming: ${parsedEvent.data.choices[0].message.content}` + ); + break; + } + } + } else { + rowID = parsedEvent.data.row_id; + const streamDataRowID = newMessage ? 'new' : rowID; + + if (this.loadedStreams[streamDataRowID][parsedEvent.data.output_column_name]) { + if (renderCount++ >= 20) { + this.loadedStreams[streamDataRowID][parsedEvent.data.output_column_name] = [ + ...this.loadedStreams[streamDataRowID][parsedEvent.data.output_column_name], + this.latestStreams[streamDataRowID][parsedEvent.data.output_column_name] + + (parsedEvent.data.choices[0]?.message?.content ?? '') + ]; + this.latestStreams[streamDataRowID][parsedEvent.data.output_column_name] = ''; + } else { + this.latestStreams[streamDataRowID][parsedEvent.data.output_column_name] += + parsedEvent.data.choices[0]?.message?.content ?? ''; + } + } + + this.scrollChatToBottom(); + } + } else if (parsedEvent.data.object === 'gen_table.references') { + this.loadedReferences = { + ...(this.loadedReferences ?? {}), + [parsedEvent.data.row_id]: { + ...((this.loadedReferences ?? {})[parsedEvent.data.row_id] ?? {}), + [parsedEvent.data.output_column_name]: + parsedEvent.data as unknown as ChatReferences + } + }; + } else { + console.warn('Unknown event data:', parsedEvent.data); + } + } else { + console.warn('Unknown event object:', parsedEvent); + } + } else if (line.startsWith('data: ')) { + if (line.slice(6) === '[DONE]') break; + //@ts-expect-error missing type + parsedEvent = { ...(parsedEvent ?? {}), data: JSON.parse(line.slice(6)) }; + } else if (line.startsWith('event: ')) { + //@ts-expect-error missing type + parsedEvent = { ...(parsedEvent ?? {}), event: line.slice(7) }; + } + } + } catch (err) { + logger.error('CHAT_MESSAGE_ADDSTREAM', err); + console.error(err); + break; + } + } + + return { row_id: rowID }; + } + + async getConversations() { + if (!page.params.project_id && !page.url.searchParams.has('project_id')) return; + + this.fetchController?.abort('Duplicate'); + this.fetchController = new AbortController(); + + try { + // autoAnimateController?.disable(); + this.isLoadingMoreConvs = true; + + const searchParams = new URLSearchParams([ + ['offset', this.currentOffsetConvs.toString()], + ['limit', this.limitConvs.toString()], + ['order_by', 'updated_at'], + ['order_ascending', 'false'] + // ['organization_id', $activeOrganization.id] + ]); + + if (this.searchQuery.trim() !== '') { + searchParams.append('search_query', this.searchQuery.trim()); + } + + const response = await fetch( + `${PUBLIC_JAMAI_URL}/api/owl/conversations/list?${searchParams}`, + { + credentials: 'same-origin', + signal: this.fetchController.signal, + headers: { + 'x-project-id': page.params.project_id || page.url.searchParams.get('project_id')! + } + } + ); + this.currentOffsetConvs += this.limitConvs; + + if (response.status == 200) { + const moreProjects = await response.json(); + if (moreProjects.items.length) { + this.conversations = [...this.conversations, ...moreProjects.items]; + } else { + //* Finished loading oldest conversation + this.moreConvsFinished = true; + } + } else { + const responseBody = await response.json(); + console.error(responseBody); + toast.error('Failed to fetch conversations', { + id: responseBody?.message || JSON.stringify(responseBody), + description: CustomToastDesc as any, + componentProps: { + description: responseBody?.message || JSON.stringify(responseBody), + requestID: responseBody?.request_id + } + }); + this.loadingConvsError = { + status: response.status, + message: responseBody + }; + } + + this.isLoadingMoreConvs = false; + } catch (err) { + //* don't show abort errors in browser + if (err !== 'Duplicate') { + console.error(err); + } + } + } + + async editConversationTitle( + newTitle: string, + conversationID: string, + projectID: string, + successCb: () => void + ) { + const response = await fetch( + `${PUBLIC_JAMAI_URL}/api/owl/conversations/title?${new URLSearchParams([ + ['conversation_id', conversationID], + ['title', newTitle ?? ''] + ])}`, + { + method: 'PATCH', + headers: { + 'x-project-id': projectID + } + } + ); + const responseBody = await response.json(); + + if (response.ok) { + successCb(); + } else { + logger.error('CHAT_TITLE_EDIT', responseBody); + toast.error('Failed to edit conversation title', { + id: responseBody.message || JSON.stringify(responseBody), + description: CustomToastDesc as any, + componentProps: { + description: responseBody.message || JSON.stringify(responseBody), + requestID: responseBody.request_id + } + }); + } + } + + async refetchConversations() { + this.fetchController?.abort('Duplicate'); + this.conversations = []; + this.currentOffsetConvs = 0; + this.moreConvsFinished = false; + await tick(); + this.getConversations(); + this.isLoadingSearch = false; + } + + async scrollChatToBottom() { + if (!browser || !this.chatWindow) return; + + if ( + this.chatWindow.scrollHeight - this.chatWindow.clientHeight - this.chatWindow.scrollTop < + 100 || + !this.generationStatus + ) { + await tick(); + await tick(); + this.chatWindow.scrollTop = this.chatWindow.scrollHeight; + } + } +} + +export const chatState = new ChatState(); diff --git a/services/app/src/routes/(main)/organization/+layout.server.ts b/services/app/src/routes/(main)/organization/+layout.server.ts new file mode 100755 index 0000000..24ca8f2 --- /dev/null +++ b/services/app/src/routes/(main)/organization/+layout.server.ts @@ -0,0 +1,141 @@ +import { env } from '$env/dynamic/private'; +import logger from '$lib/logger'; +import { getPrices } from '$lib/server/nodeCache'; +import { error } from '@sveltejs/kit'; +import Stripe from 'stripe'; + +const stripe = new Stripe(env.OWL_STRIPE_API_KEY); + +export async function load({ cookies, depends, locals, parent }) { + depends('layout:settings'); + const data = await parent(); + const { user, organizationData } = data; + + const prices = await getPrices(locals.user?.id); + + if (!data.ossMode && !prices) { + throw error(500, 'Failed to get prices'); + } + + if (data.ossMode || !env.OWL_STRIPE_API_KEY || !locals.user) { + return { + prices, + billing_info: [ + { data: null, status: 401 }, + { data: null, status: 401 } + ], + payment_methods: { data: null, status: 401 } + }; + } + + if (!user || !organizationData || !organizationData.stripe_id) { + return { + prices, + billing_info: [ + { data: null, status: 401 }, + { data: null, status: 401 } + ], + payment_methods: { data: null, status: 401 } + }; + // throw error(500, 'Failed to get organization data'); + } + + //? check if user is in organization + const activeOrganizationId = cookies.get('activeOrganizationId'); + if (!user.org_memberships.find((org) => org.organization_id === activeOrganizationId)) { + throw error(403, 'Unauthorized'); + } + + const getSubscription = async () => { + try { + const subscription = await stripe.subscriptions.list({ + customer: organizationData.stripe_id!, + expand: ['data.latest_invoice.payment_intent.latest_charge'] + }); + + if (subscription.data.length > 0) { + return { data: subscription.data[0], status: 200 }; + } else { + return { data: null, status: 404, error: 'No subscription found' }; + } + } catch (err) { + if ((err as any).type === 'StripeInvalidRequestError' && (err as any).statusCode === 404) { + return { data: null, status: 404, error: 'No subscription found' }; + } else { + logger.error('SETTINGS_SUBSCRIPTION_GET', err); + return { data: null, status: 500, error: err }; + } + } + }; + + const getInvoice: () => Promise<{ + data: Stripe.Response | null; + status: number; + error?: any; + }> = async () => { + try { + const invoice = await stripe.invoices.retrieveUpcoming({ + customer: organizationData.stripe_id! + }); + return { data: invoice, status: 200 }; + } catch (err) { + if ((err as any).type === 'StripeInvalidRequestError' && (err as any).statusCode === 404) { + return { data: null, status: 404, error: 'No invoice found' }; + } else { + logger.error('SETTINGS_INVOICE_GET', err); + return { data: null, status: 500, error: err }; + } + } + }; + + const getBillingInfo = () => { + return [getSubscription(), getInvoice()] as const; + }; + + const getPaymentMethods = async (): Promise<{ + data: Stripe.PaymentMethod[] | null; + status: number; + error?: any; + }> => { + //? Get payment methods for customer + try { + const paymentMethods = await stripe.customers.listPaymentMethods(organizationData.stripe_id!); + return { data: paymentMethods.data, status: 200 }; + } catch (err) { + logger.error('SETTINGS_PAYMENTMETHODS_GET', err); + return { data: null, status: 500, error: err }; + } + }; + + const getCustomer = async (): Promise<{ + data: Stripe.Customer | null; + status: number; + error?: any; + }> => { + try { + const customer = await stripe.customers.retrieve(organizationData.stripe_id!); + if (customer.deleted) { + return { data: null, status: 404, error: 'Customer not found' }; + } + return { data: customer, status: 200 }; + } catch (err) { + logger.error('SETTINGS_CUSTOMER_GET', err); + return { data: null, status: 500, error: err }; + } + }; + + const isAdmin = + locals.user.org_memberships.find((org) => org.organization_id === activeOrganizationId) + ?.role === 'ADMIN'; + return { + prices, + billing_info: isAdmin + ? getBillingInfo() + : [ + { data: null, status: 403 }, + { data: null, status: 403 } + ], + payment_methods: isAdmin ? getPaymentMethods() : { data: null, status: 403 }, + customer: isAdmin ? await getCustomer() : { data: null, status: 403 } + }; +} diff --git a/services/app/src/routes/(main)/organization/+layout.svelte b/services/app/src/routes/(main)/organization/+layout.svelte new file mode 100755 index 0000000..e1c3c97 --- /dev/null +++ b/services/app/src/routes/(main)/organization/+layout.svelte @@ -0,0 +1,93 @@ + + + moveHighlighter(page.url.pathname)} /> + +
+
+
+ +

Organization

+
+ +
+ {#each links.filter((l) => !l.exclude) as { title, href }, index (href)} + + {title} + + {/each} + +
+
+
+ + {@render children?.()} +
diff --git a/services/app/src/routes/(main)/organization/+page.server.ts b/services/app/src/routes/(main)/organization/+page.server.ts new file mode 100755 index 0000000..b5b5d9b --- /dev/null +++ b/services/app/src/routes/(main)/organization/+page.server.ts @@ -0,0 +1,5 @@ +import { redirect } from '@sveltejs/kit'; + +export function load() { + throw redirect(302, '/organization/general'); +} diff --git a/services/app/src/routes/(main)/organization/general/+page.server.ts b/services/app/src/routes/(main)/organization/general/+page.server.ts new file mode 100755 index 0000000..dbc013d --- /dev/null +++ b/services/app/src/routes/(main)/organization/general/+page.server.ts @@ -0,0 +1,205 @@ +import { env } from '$env/dynamic/private'; +import logger, { APIError } from '$lib/logger.js'; +import type { User } from '$lib/types.js'; +import { fail } from '@sveltejs/kit'; + +const { OWL_SERVICE_KEY, OWL_URL } = env; + +const headers = { + Authorization: `Bearer ${OWL_SERVICE_KEY}` +}; + +export const actions = { + update: async ({ cookies, fetch, locals, request }) => { + const data = await request.formData(); + const organization_name = data.get('organization_name'); + const activeOrganizationId = cookies.get('activeOrganizationId'); + + if (typeof organization_name !== 'string' || organization_name.trim() === '') { + return fail(400, new APIError('Invalid organization name').getSerializable()); + } + + if (!activeOrganizationId) { + return fail(400, new APIError('No active organization').getSerializable()); + } + + //* Verify user perms + if (!locals.user) { + return fail(401, new APIError('Unauthorized').getSerializable()); + } + + const userApiRes = await fetch( + `${OWL_URL}/api/v2/users?${new URLSearchParams([['user_id', locals.user.id]])}`, + { + headers: { + ...headers, + 'x-user-id': locals.user.id + } + } + ); + const userApiBody = (await userApiRes.json()) as User; + if (userApiRes.ok) { + const targetOrg = userApiBody.org_memberships.find( + (org) => org.organization_id === activeOrganizationId + ); + if (!targetOrg || targetOrg.role !== 'ADMIN') { + return fail(403, new APIError('Forbidden').getSerializable()); + } + } else { + logger.error('ORG_UPDATE_USERGET', userApiBody); + return fail( + userApiRes.status, + new APIError('Failed to get user', userApiBody as any).getSerializable() + ); + } + + const updateOrgRes = await fetch( + `${OWL_URL}/api/v2/organizations?${new URLSearchParams([['organization_id', activeOrganizationId]])}`, + { + method: 'PATCH', + headers: { + ...headers, + 'x-user-id': locals.user.id, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + name: organization_name + }) + } + ); + + const updateOrgBody = await updateOrgRes.json(); + if (!updateOrgRes.ok) { + logger.error('ORG_UPDATE_UPDATE', updateOrgBody); + return fail( + updateOrgRes.status, + new APIError('Failed to update organization', updateOrgBody as any).getSerializable() + ); + } else { + return updateOrgBody; + } + }, + + leave: async ({ cookies, fetch, locals }) => { + const activeOrganizationId = cookies.get('activeOrganizationId'); + + if (!activeOrganizationId) { + return fail(400, new APIError('No active organization').getSerializable()); + } + + //* Verify user perms + if (!locals.user) { + return fail(401, new APIError('Unauthorized').getSerializable()); + } + + const userApiRes = await fetch( + `${OWL_URL}/api/v2/users?${new URLSearchParams([['user_id', locals.user.id]])}`, + { + headers: { + ...headers, + 'x-user-id': locals.user.id + } + } + ); + const userApiBody = (await userApiRes.json()) as User; + if (userApiRes.ok) { + const targetOrg = userApiBody.org_memberships.find( + (org) => org.organization_id === activeOrganizationId + ); + if (!targetOrg) { + return fail(403, new APIError('Forbidden').getSerializable()); + } + } else { + logger.error('ORG_LEAVE_USERGET', userApiBody); + return fail( + userApiRes.status, + new APIError('Failed to get user', userApiBody as any).getSerializable() + ); + } + + const leaveOrgRes = await fetch( + `${OWL_URL}/api/v2/organizations/members?${new URLSearchParams([ + ['user_id', locals.user.id], + ['organization_id', activeOrganizationId] + ])}`, + { + method: 'DELETE', + headers: { + ...headers, + 'x-user-id': locals.user.id + } + } + ); + + const leaveOrgBody = await leaveOrgRes.json(); + if (!leaveOrgRes.ok) { + logger.error('ORG_LEAVE_DELETE', leaveOrgBody); + return fail( + leaveOrgRes.status, + new APIError('Failed to leave organization', leaveOrgBody as any).getSerializable() + ); + } else { + return leaveOrgBody; + } + }, + + delete: async ({ cookies, fetch, locals }) => { + const activeOrganizationId = cookies.get('activeOrganizationId'); + + if (!activeOrganizationId) { + return fail(400, new APIError('No active organization').getSerializable()); + } + + //* Verify user perms + if (!locals.user) { + return fail(401, new APIError('Unauthorized').getSerializable()); + } + + const userApiRes = await fetch( + `${OWL_URL}/api/v2/users?${new URLSearchParams([['user_id', locals.user.id]])}`, + { + headers: { + ...headers, + 'x-user-id': locals.user.id + } + } + ); + const userApiBody = (await userApiRes.json()) as User; + if (userApiRes.ok) { + const targetOrg = userApiBody.org_memberships.find( + (org) => org.organization_id === activeOrganizationId + ); + if (!targetOrg || targetOrg.role !== 'ADMIN') { + return fail(403, new APIError('Forbidden').getSerializable()); + } + } else { + logger.error('ORG_DELETE_USERGET', userApiBody); + return fail( + userApiRes.status, + new APIError('Failed to get user', userApiBody as any).getSerializable() + ); + } + + const deleteOrgRes = await fetch( + `${OWL_URL}/api/v2/organizations?${new URLSearchParams([['organization_id', activeOrganizationId]])}`, + { + method: 'DELETE', + headers: { + ...headers, + 'x-user-id': locals.user.id + } + } + ); + + const deleteOrgBody = await deleteOrgRes.json(); + if (!deleteOrgRes.ok) { + logger.error('ORG_DELETE_DELETE', deleteOrgBody); + return fail( + deleteOrgRes.status, + new APIError('Failed to delete organization', deleteOrgBody as any).getSerializable() + ); + } else { + return deleteOrgBody; + } + } +}; diff --git a/services/app/src/routes/(main)/organization/general/+page.svelte b/services/app/src/routes/(main)/organization/general/+page.svelte new file mode 100755 index 0000000..20cd7de --- /dev/null +++ b/services/app/src/routes/(main)/organization/general/+page.svelte @@ -0,0 +1,354 @@ + + + + General - Organization + + +
+

YOUR ORGANIZATION

+ +
+
+
+

Organization Name

+ {$activeOrganization?.name ?? ''} +
+ + + + +
+ +
+

Organization ID

+
+ {$activeOrganization?.id ?? ''} + +
+
+
+ +
+

ORGANIZATION REMOVAL

+ +

+ Leaving this organization will remove you from it + + + , while deleting it will permanently remove all data associated with it. + {#snippet deniedMessage()} + . + {/snippet} + + +

+ +
+ + + + + +
+
+
+ + (editOrgName = $activeOrganization?.name ?? '')} +> + + Edit organization name + + + { + isLoadingEditOrgName = true; + + return async ({ result, update }) => { + if (result.type !== 'success') { + //@ts-ignore + const data = result.data; + toast.error('Error updating organization details', { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } + + await update({ reset: false }); + isLoadingEditOrgName = false; + isEditingOrgName = false; + }; + }} + method="POST" + action="?/update" + class="w-full grow overflow-auto" + > +
+ + + +
+ + + +
+ + {#snippet child({ props })} + + {/snippet} + + +
+
+
+
+ + + + + + Close + + + +
{ + isLoadingLeaveOrg = true; + + return async ({ result, update }) => { + if (result.type !== 'success') { + //@ts-ignore + const data = result.data; + toast.error('Error leaving organization', { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else { + return location.reload(); + } + + isLoadingLeaveOrg = false; + update({ reset: false, invalidateAll: false }); + }; + }} + onkeydown={(event) => event.key === 'Enter' && event.preventDefault()} + method="POST" + action="?/leave" + class="flex flex-col items-start gap-2 p-8 pb-10" + > + +

Are you sure?

+

+ Do you really want to leave organization + + `{$activeOrganization?.name}` + ? +

+ + + +
+ + {#snippet child({ props })} + + {/snippet} + + +
+
+
+
+ + { + if (!e) { + confirmOrgName = ''; + } + }} +> + + Delete organization + + +
{ + isLoadingDeleteOrg = true; + + return async ({ result, update }) => { + if (result.type !== 'success') { + //@ts-ignore + const data = result.data; + toast.error('Error deleting organization', { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else { + return location.reload(); + } + + isLoadingDeleteOrg = false; + update({ reset: false, invalidateAll: false }); + }; + }} + onkeydown={(event) => event.key === 'Enter' && event.preventDefault()} + method="POST" + action="?/delete" + class="w-full grow overflow-auto" + > +
+

+ Do you really want to delete organization + + `{$activeOrganization?.name}` + ? This process cannot be undone. +

+ +
+ + + +
+
+
+ + +
+ + {#snippet child({ props })} + + {/snippet} + + +
+
+
+
diff --git a/services/app/src/routes/(main)/organization/secrets/(components)/DeleteExtKeyDialog.svelte b/services/app/src/routes/(main)/organization/secrets/(components)/DeleteExtKeyDialog.svelte new file mode 100644 index 0000000..eb29cff --- /dev/null +++ b/services/app/src/routes/(main)/organization/secrets/(components)/DeleteExtKeyDialog.svelte @@ -0,0 +1,111 @@ + + + isDeletingExtKey.open, + (v) => (isDeletingExtKey = { ...isDeletingExtKey, open: v })} +> + + + + Close + + +
+ +

Are you sure?

+

+ Do you really want to delete API key + + `{PROVIDERS[isDeletingExtKey.value ?? ''] || isDeletingExtKey.value}` + ? This process cannot be undone. +

+
+ + +
{ + loadingDeleteExtKey = true; + + if (!formData.get('provider') || !organizationData) { + cancel(); + } else { + Object.keys(organizationData.external_keys).forEach((key) => { + if (key !== formData.get('provider')?.toString()) { + formData.append(key, organizationData.external_keys[key]); + } + }); + + formData.delete('provider'); + } + + return async ({ update, result }) => { + if (result.type === 'failure') { + const data = result.data as any; + toast.error('Error deleting external key', { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else if (result.type === 'success') { + isDeletingExtKey = { ...isDeletingExtKey, open: false }; + } + + loadingDeleteExtKey = false; + update({ reset: false }); + }; + }} + action="?/update-external-keys" + class="flex gap-2 overflow-x-auto overflow-y-hidden" + > + + + + {#snippet child({ props })} + + {/snippet} + + +
+
+
+
diff --git a/services/app/src/routes/(main)/organization/secrets/(components)/EditExtKeyDialog.svelte b/services/app/src/routes/(main)/organization/secrets/(components)/EditExtKeyDialog.svelte new file mode 100644 index 0000000..61646cf --- /dev/null +++ b/services/app/src/routes/(main)/organization/secrets/(components)/EditExtKeyDialog.svelte @@ -0,0 +1,180 @@ + + + isEditingExtKey.open, (v) => (isEditingExtKey = { ...isEditingExtKey, open: v })} +> + + + {isEditingExtKey.value ? 'Edit' : 'Add'} API Key + + +
{ + loadingEditExtKey = true; + + if (!formData.get('provider') || !organizationData) { + cancel(); + } else { + formData.set( + formData.get('provider')?.toString()!, + formData.get('external_key')?.toString() ?? '' + ); + + Object.keys(organizationData.external_keys).forEach((key) => { + if (key !== formData.get('provider')?.toString()) { + formData.append(key, organizationData.external_keys[key]); + } + }); + + formData.delete('provider'); + formData.delete('external_key'); + } + + return async ({ update, result }) => { + if (result.type === 'failure') { + const data = result.data as any; + toast.error('Error updating external keys', { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else if (result.type === 'success') { + isEditingExtKey = { ...isEditingExtKey, open: false }; + } + + loadingEditExtKey = false; + update({ reset: false }); + }; + }} + action="?/update-external-keys" + class="flex grow flex-col gap-3 overflow-auto py-3" + > +
+ + {#if customProvider} +
+ + +
+ {:else} +
+ { + selectedProvider = value; + }} + > + + {PROVIDERS[selectedProvider] || 'Select a Provider'} + + + + {#each Object.entries(PROVIDERS) as [key, value]} + {value} + {/each} + + + + +
+ {/if} +
+ +
+ + +
+
+ + +
+ + +
+
+
+
diff --git a/services/app/src/routes/(main)/organization/secrets/(components)/index.ts b/services/app/src/routes/(main)/organization/secrets/(components)/index.ts new file mode 100644 index 0000000..e1e10c8 --- /dev/null +++ b/services/app/src/routes/(main)/organization/secrets/(components)/index.ts @@ -0,0 +1,4 @@ +import DeleteExtKeyDialog from './DeleteExtKeyDialog.svelte'; +import EditExtKeyDialog from './EditExtKeyDialog.svelte'; + +export { DeleteExtKeyDialog, EditExtKeyDialog }; diff --git a/services/app/src/routes/(main)/organization/secrets/+page.server.ts b/services/app/src/routes/(main)/organization/secrets/+page.server.ts new file mode 100755 index 0000000..3a89945 --- /dev/null +++ b/services/app/src/routes/(main)/organization/secrets/+page.server.ts @@ -0,0 +1,55 @@ +import { env } from '$env/dynamic/private'; +import { APIError } from '$lib/logger.js'; +import { fail } from '@sveltejs/kit'; + +const { OWL_SERVICE_KEY, OWL_URL } = env; + +const headers = { + Authorization: `Bearer ${OWL_SERVICE_KEY}` +}; + +export const actions = { + 'update-external-keys': async ({ cookies, fetch, locals, request }) => { + const activeOrganizationId = cookies.get('activeOrganizationId'); + + const data = await request.formData(); + const externalKeys: Record = {}; + for (const [key, value] of data.entries()) { + externalKeys[key] = (value as string).trim(); + } + + if (!activeOrganizationId) { + return fail(400, new APIError('No active organization').getSerializable()); + } + + //* Verify user perms + if (!locals.user) { + return fail(401, new APIError('Unauthorized').getSerializable()); + } + + const updateExternalKeysRes = await fetch( + `${OWL_URL}/api/v2/organizations?${new URLSearchParams([['organization_id', activeOrganizationId]])}`, + { + method: 'PATCH', + headers: { + ...headers, + 'x-user-id': locals.user.id, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + external_keys: externalKeys + }) + } + ); + + const updateExternalKeysBody = await updateExternalKeysRes.json(); + if (!updateExternalKeysRes.ok) { + return fail( + updateExternalKeysRes.status, + new APIError('Failed to update external keys', updateExternalKeysBody).getSerializable() + ); + } else { + return updateExternalKeysBody; + } + } +}; diff --git a/services/app/src/routes/(main)/organization/secrets/+page.svelte b/services/app/src/routes/(main)/organization/secrets/+page.svelte new file mode 100755 index 0000000..37ab060 --- /dev/null +++ b/services/app/src/routes/(main)/organization/secrets/+page.svelte @@ -0,0 +1,138 @@ + + + + Secrets - Organization + + +
+
+

EXTERNAL API KEYS

+ + +
+
+
+
Provider
+
API Key
+
+
+
+ + {#if Object.keys(organizationData?.external_keys ?? {}).length > 0} + {@const extKeys = Object.keys(organizationData?.external_keys ?? {})} +
+ {#each extKeys as provider} +
+
+

+ {PROVIDERS[provider] ?? provider} +

+
+ +
+ +
+ +
+ + + +
+ +
+ {/each} +
+ {:else} +
+
+
+

No external keys have been added to this organization

+
+
+
+ {/if} +
+ +
+ +
+ + {#snippet deniedMessage()} +
+

You need to be an Admin to manage external keys in your organization

+
+ {/snippet} +
+
+
+ + + diff --git a/services/app/src/routes/(main)/organization/team/+page.server.ts b/services/app/src/routes/(main)/organization/team/+page.server.ts new file mode 100755 index 0000000..e229fc2 --- /dev/null +++ b/services/app/src/routes/(main)/organization/team/+page.server.ts @@ -0,0 +1,262 @@ +import { env } from '$env/dynamic/private'; +import { userRoles } from '$lib/constants.js'; +import logger, { APIError } from '$lib/logger.js'; +import { fail } from '@sveltejs/kit'; + +const { ORIGIN, OWL_SERVICE_KEY, OWL_URL, RESEND_API_KEY } = env; + +const headers = { + Authorization: `Bearer ${OWL_SERVICE_KEY}` +}; + +export const actions = { + invite: async ({ cookies, fetch, locals, request }) => { + const data = await request.formData(); + const user_email = data.get('user_email'); + const user_role = data.get('user_role'); + const valid_days = data.get('valid_days'); + const activeOrganizationId = cookies.get('activeOrganizationId'); + + if (typeof user_email !== 'string' || user_email.trim() === '') { + return fail(400, new APIError('Invalid user email').getSerializable()); + } + + if ( + typeof user_role !== 'string' || + user_role.trim() === '' || + !userRoles.includes(user_role as (typeof userRoles)[number]) + ) { + return fail(400, new APIError('Invalid user role').getSerializable()); + } + + if ( + typeof valid_days !== 'string' || + valid_days.trim() === '' || + isNaN(Number(valid_days)) || + Number(valid_days) <= 0 + ) { + return fail(400, new APIError('Invalid valid days').getSerializable()); + } + + if (!activeOrganizationId) { + return fail(400, new APIError('No active organization').getSerializable()); + } + + //* Verify user perms + if (!locals.user) { + return fail(401, new APIError('Unauthorized').getSerializable()); + } + + const getInviteToken = await fetch( + `${OWL_URL}/api/v2/organizations/invites?${new URLSearchParams({ + user_email: user_email.trim(), + organization_id: activeOrganizationId, + role: user_role, + valid_days + })}`, + { + method: 'POST', + headers: { + ...headers, + 'x-user-id': locals.user.id + } + } + ); + const inviteToken = await getInviteToken.json(); + + if (!getInviteToken.ok) { + if (![403].includes(getInviteToken.status)) { + logger.error('ORGTEAM_INVITE_TOKEN', inviteToken); + } + return fail( + getInviteToken.status, + new APIError('Failed to get invite token', inviteToken as any).getSerializable() + ); + } + + if (RESEND_API_KEY) { + const sendEmailRes = await fetch('https://api.resend.com/emails', { + method: 'POST', + headers: { + Authorization: `Bearer ${RESEND_API_KEY}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + from: 'JamAI Base ', + to: user_email, + subject: 'You have been invited to join an organization on JamAI Base', + html: getInviteEmailBody(locals.user.email, inviteToken.id) + }) + }); + + if (!sendEmailRes.ok) { + logger.error('ORGTEAM_INVITE_EMAIL', await sendEmailRes.json()); + return fail(sendEmailRes.status, new APIError('Failed to send email').getSerializable()); + } + } + + return RESEND_API_KEY ? { ok: true } : inviteToken; + }, + + update: async ({ cookies, fetch, locals, request }) => { + const data = await request.formData(); + const user_id = data.get('user_id'); + const user_role = data.get('user_role'); + const activeOrganizationId = cookies.get('activeOrganizationId'); + + if (typeof user_id !== 'string' || user_id.trim() === '') { + return fail(400, new APIError('Invalid user ID').getSerializable()); + } + if ( + typeof user_role !== 'string' || + user_role.trim() === '' || + !userRoles.includes(user_role as (typeof userRoles)[number]) + ) { + return fail(400, new APIError('Invalid user role').getSerializable()); + } + + if (!activeOrganizationId) { + return fail(400, new APIError('No active organization').getSerializable()); + } + + //* Verify user perms + if (!locals.user) { + return fail(401, new APIError('Unauthorized').getSerializable()); + } + + const updateRoleRes = await fetch( + `${OWL_URL}/api/v2/organizations/members/role?${new URLSearchParams([ + ['user_id', user_id], + ['organization_id', activeOrganizationId], + ['role', user_role] + ])}`, + { + method: 'PATCH', + headers: { + ...headers, + 'x-user-id': locals.user.id + } + } + ); + + const updateRoleBody = await updateRoleRes.json(); + if (!updateRoleRes.ok) { + logger.error('ORGTEAM_UPDATE_ROLE', updateRoleBody); + return fail( + updateRoleRes.status, + new APIError('Failed to update role', updateRoleBody as any).getSerializable() + ); + } else { + return updateRoleBody; + } + }, + + remove: async ({ cookies, fetch, locals, request }) => { + const data = await request.formData(); + const user_id = data.get('user_id'); + const activeOrganizationId = cookies.get('activeOrganizationId'); + + if (typeof user_id !== 'string' || user_id.trim() === '') { + return fail(400, new APIError('Invalid user ID').getSerializable()); + } + + if (!activeOrganizationId) { + return fail(400, new APIError('No active organization').getSerializable()); + } + + //* Verify user perms + if (!locals.user) { + return fail(401, new APIError('Unauthorized').getSerializable()); + } + + const leaveOrgRes = await fetch( + `${OWL_URL}/api/v2/organizations/members?${new URLSearchParams([ + ['user_id', user_id], + ['organization_id', activeOrganizationId] + ])}`, + { + method: 'DELETE', + headers: { + ...headers, + 'x-user-id': locals.user.id + } + } + ); + + const leaveOrgBody = await leaveOrgRes.json(); + if (!leaveOrgRes.ok) { + logger.error('ORGTEAM_REMOVE_REMOVE', leaveOrgBody); + return fail( + leaveOrgRes.status, + new APIError('Failed to remove user', leaveOrgBody as any).getSerializable() + ); + } else { + return leaveOrgBody; + } + } +}; + +// eslint-disable-next-line @typescript-eslint/no-unused-vars +const getInviteEmailBody = (inviterEmail: string, inviteToken: string) => ` + + + + +
+ + + + +
+
+

+ JamAI Logo +

+ +

${inviterEmail} has invited you to join their organization on JamAI Base

+ +

You have been invited to join an organization on JamAI Base. Click the link below to accept the invitation:

+ +

Join JamAI Base

+ +

This link will expire in 7 days.

+ +
+ Thanks! +
+ + JamAI Base + +

+
+

+ If you did not make this request, you can ignore this mail. +

+
+
+
+ +`; diff --git a/services/app/src/routes/(main)/organization/team/+page.svelte b/services/app/src/routes/(main)/organization/team/+page.svelte new file mode 100755 index 0000000..fa73e6d --- /dev/null +++ b/services/app/src/routes/(main)/organization/team/+page.svelte @@ -0,0 +1,340 @@ + + + + Team - Organization + + +
+
+
+

ORGANIZATION MEMBERS

+ + + + +
+ +
+
+
+
No.
+
Name
+
Email
+
Role
+
Created at
+
+
+
+ +
+ {#each organizationMembers.data ?? [] as user, index} + {@const userOrgCreatedAt = new Date(user.created_at).toLocaleString(undefined, { + day: '2-digit', + month: 'short', + year: isThisYear(new Date(user.created_at)) ? undefined : 'numeric', + hour: 'numeric', + minute: '2-digit', + second: '2-digit' + })} +
+
+ {index + 1} +
+
+ + {user.user.name} + + + {user.user_id} + +
+
+ + {user.user.email ?? ''} + +
+
+ + {lowerCase(user.role)} + +
+
+ + {userOrgCreatedAt} + +
+
+ + + + + +
+ +
+
+ {/each} +
+
+
+
+ + + + !!editingUser, () => (editingUser = null)}> + + Edit user role + + +
{ + isLoadingEdit = true; + + return async ({ result, update }) => { + if (result.type !== 'success') { + //@ts-ignore + const data = result.data; + toast.error('Error updating user role', { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else { + editingUser = null; + } + + isLoadingEdit = false; + update({ reset: false }); + }; + }} + method="POST" + action="?/update" + class="w-full grow overflow-auto" + > +
+ + +
+ + + + + + {#snippet children()} + + {selectedUserRole} + + {/snippet} + + + {#each userRoles as roleType} + + {roleType} + + {/each} + + +
+
+
+ + +
+ + {#snippet child({ props })} + + {/snippet} + + +
+
+
+
+ + !!deletingUser, + (v) => { + if (!v) { + deletingUser = null; + } + }} +> + + + + Close + + +
+ +

Are you sure?

+

+ Do you really want to remove user + + `{deletingUser?.user_id}` + ? +

+
+ + +
{ + isLoadingDelete = true; + + return async ({ result, update }) => { + if (result.type === 'failure') { + const data = result.data as any; + toast.error('Error removing user', { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else if (result.type === 'success') { + deletingUser = null; + } + + isLoadingDelete = false; + update({ reset: false }); + }; + }} + method="POST" + action="?/remove" + class="flex gap-2 overflow-x-auto overflow-y-hidden" + > + + + {#snippet child({ props })} + + {/snippet} + + +
+
+
+
diff --git a/services/app/src/routes/(main)/organization/team/+page.ts b/services/app/src/routes/(main)/organization/team/+page.ts new file mode 100644 index 0000000..44229e0 --- /dev/null +++ b/services/app/src/routes/(main)/organization/team/+page.ts @@ -0,0 +1,34 @@ +import { env } from '$env/dynamic/public'; +import logger from '$lib/logger'; +import type { OrgMemberRead } from '$lib/types'; + +const { PUBLIC_JAMAI_URL } = env; + +export const load = async ({ fetch, parent }) => { + const data = await parent(); + + const getOrgMembers = async () => { + const activeOrganizationId = data.organizationData?.id; + if (!activeOrganizationId) { + return { error: 400, message: 'No active organization' }; + } + + const orgMembersRes = await fetch( + `${PUBLIC_JAMAI_URL}/api/owl/organizations/members/list?${new URLSearchParams([['organization_id', activeOrganizationId]])}` + ); + const orgMembersBody = await orgMembersRes.json(); + + if (!orgMembersRes.ok) { + logger.error('ORGTEAM_MEMBERS_ERROR', orgMembersBody); + return { error: orgMembersRes.status, message: orgMembersBody }; + } else { + return { + data: orgMembersBody.items as OrgMemberRead[] + }; + } + }; + return { + ...data, + organizationMembers: await getOrgMembers() + }; +}; diff --git a/services/app/src/routes/(main)/organization/team/OrgInviteDialog.svelte b/services/app/src/routes/(main)/organization/team/OrgInviteDialog.svelte new file mode 100755 index 0000000..268135a --- /dev/null +++ b/services/app/src/routes/(main)/organization/team/OrgInviteDialog.svelte @@ -0,0 +1,188 @@ + + + + + Invite user + + +
{ + isLoadingInvite = true; + + return async ({ result, update }) => { + if (result.type === 'failure') { + const data = result.data as any; + toast.error('Error inviting user to organization', { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else if (result.type === 'success') { + await update({ reset: false }); + email = ''; + isInvitingUser = false; + if ((result.data as any).id) { + showCodeDialog = (result.data as any).id; + } else { + toast.success('Invite email sent!', { id: 'invite-sent' }); + } + } + + isLoadingInvite = false; + }; + }} + onkeydown={(event) => event.key === 'Enter' && event.preventDefault()} + method="POST" + action="?/invite" + class="flex grow flex-col gap-3 overflow-auto py-3" + > +
+ + +
+ +
+ + + + + {selectedUserRoleInvite} + + + + {#each userRoles as roleType} + + {roleType} + + {/each} + + +
+ +
+ + + + + {inviteValidity} days + + + + {#each ['1', '2', '3', '4', '5', '6', '7'] as daysValid} + + {daysValid} days + + {/each} + + +
+
+ + +
+ + {#snippet child({ props })} + + {/snippet} + + +
+
+
+
+ + !!showCodeDialog, () => (showCodeDialog = null)}> + + Invite code + +
+
+

Invitation Code:

+
+ {showCodeDialog} + +
+

Share this code with the user you want to invite.

+
+
+ + +
+ + {#snippet child({ props })} + + {/snippet} + +
+
+
+
diff --git a/services/app/src/routes/(main)/organization/usage/+page.svelte b/services/app/src/routes/(main)/organization/usage/+page.svelte new file mode 100755 index 0000000..b827aee --- /dev/null +++ b/services/app/src/routes/(main)/organization/usage/+page.svelte @@ -0,0 +1,77 @@ + + + + Usage - Organization + + +
+
+

QUOTAS

+ +
+ {#if organizationData} + {#each Object.keys(organizationData.quotas) as key} + {@const productQuota = + organizationData.price_plan?.products[key as keyof PriceRes['products']]} + {#if productQuota} +
+ + {productQuota.name} + + + {parseQuotas(organizationData.quotas[key].usage)} + {productQuota.unit} + + + + +
+ + Balance:
+ + {parseQuotas( + organizationData.quotas[key].quota - organizationData.quotas[key].usage + )} + {productQuota.unit} + +
+ + Quota:
+ + {organizationData.quotas[key].quota} + {productQuota.unit} + +
+
+
+ {/if} + {/each} + {/if} +
+
+
diff --git a/services/app/src/routes/(main)/project/+layout.svelte b/services/app/src/routes/(main)/project/+layout.svelte old mode 100644 new mode 100755 index 88642b2..28383b9 --- a/services/app/src/routes/(main)/project/+layout.svelte +++ b/services/app/src/routes/(main)/project/+layout.svelte @@ -1,56 +1,61 @@ - +{@render children?.()} diff --git a/services/app/src/routes/(main)/project/+layout.ts b/services/app/src/routes/(main)/project/+layout.ts old mode 100644 new mode 100755 diff --git a/services/app/src/routes/(main)/project/+page.svelte b/services/app/src/routes/(main)/project/+page.svelte old mode 100644 new mode 100755 index bd9ff03..eb36728 --- a/services/app/src/routes/(main)/project/+page.svelte +++ b/services/app/src/routes/(main)/project/+page.svelte @@ -1,9 +1,10 @@ - Projects + {m['project.heading']()} -
-
+
+
- -

Projects

-
- -
- { - //@ts-expect-error Generic type - debouncedSearchProjects(e.target?.value ?? ''); - }} - bind:value={searchQuery} - type="search" - placeholder="Search Project" - class="pl-8 h-9 w-[16rem] placeholder:not-italic placeholder:text-[#98A2B3] bg-[#F2F4F7] rounded-full" - > - - {#if isLoadingSearch} -
- -
- {:else} - - {/if} -
-
+

{m['project.heading']()}

-
-
+
+
+ +
+ +
+ + Join Project +
+ + + +
+ +
+ + Browse Templates +
-
-
-

All Projects

+
+
+

{m['project.subheading']()}

+ +
+
+ + {#snippet leading()} + {#if isLoadingSearch} +
+ +
+ {:else} + + {/if} + {/snippet} +
+
- + +
{#if !loadingProjectsError}
{#if isLoadingProjects} {#each Array(8) as _} {/each} {:else} {#each orgProjects ?? [] as project (project.id)} ($activeProject = project)} - href="/project/{project.id}" + onclick={() => ($activeProject = project)} + href="/project/{encodeURIComponent(project.id)}" title={project.id} - class="flex flex-col bg-white data-dark:bg-[#42464E] border border-[#E5E5E5] data-dark:border-[#333] rounded-lg hover:-translate-y-0.5 hover:shadow-float transition-[transform,box-shadow]" + class="flex flex-col rounded-lg border border-[#E5E5E5] bg-white transition-[transform,box-shadow] hover:-translate-y-0.5 hover:shadow-float data-dark:border-[#333] data-dark:bg-[#42464E]" > -
+
- - + + - + {project.name}
- - + + {#snippet child({ props })} + + {/snippet} - + (isEditingProjectName = project)} + onclick={() => (isEditingProjectName = project)} class="text-[#344054] data-[highlighted]:text-[#344054]" > - - Rename project + + {m['project.settings_rename']()} - - handleExportProject(project.id)} - class="text-[#344054] data-[highlighted]:text-[#344054]" - > - - Export project - + + {#snippet children({ handleExportProject })} + handleExportProject(project.id)} + class="text-[#344054] data-[highlighted]:text-[#344054]" + > + + {m['project.settings_export']()} + + {/snippet} (isDeletingProject = project.id)} + onclick={() => (isDeletingProject = project.id)} class="text-destructive data-[highlighted]:text-destructive" > - - Delete project + + {m['project.settings_delete']()} @@ -426,16 +406,16 @@
- Last updated - - {new Date(project.updated_at).toLocaleString(undefined, { + {m['project.updated_at']()} + + {new Date(project.updated_at).toLocaleString(getLocale(), { month: 'long', day: 'numeric', year: 'numeric' @@ -448,18 +428,18 @@ {/if} {#if isLoadingMoreProjects} -
+
{/if}
{:else} -
+
{loadingProjectsError.status}

{JSON.stringify(loadingProjectsError.message)}

diff --git a/services/app/src/routes/(main)/project/ExportProjectButton.svelte b/services/app/src/routes/(main)/project/ExportProjectButton.svelte old mode 100644 new mode 100755 index ff295a6..6c2b1ac --- a/services/app/src/routes/(main)/project/ExportProjectButton.svelte +++ b/services/app/src/routes/(main)/project/ExportProjectButton.svelte @@ -1,19 +1,34 @@ - +{@render children?.({ handleExportProject })} diff --git a/services/app/src/routes/(main)/project/ProjectDialogs.svelte b/services/app/src/routes/(main)/project/ProjectDialogs.svelte old mode 100644 new mode 100755 index 939ec5d..c855d2e --- a/services/app/src/routes/(main)/project/ProjectDialogs.svelte +++ b/services/app/src/routes/(main)/project/ProjectDialogs.svelte @@ -1,92 +1,113 @@ - { - if (!e) { - isEditingProjectName = null; - } - }} -> + !!isEditingProjectName, () => (isEditingProjectName = null)}> - Edit project name + {m['project.edit.heading']()} - +
-
- +
+
- +
- - - + + {#snippet child({ props })} + + {/snippet} +
- + isAddingProject, + (v) => { + isAddingProject = v; + page.url.searchParams.delete('new'); + history.replaceState(history.state, '', page.url); + }} +> - New project + {m['project.create.heading']()} - +
-
- Project name* +
+ - +
- +
- - - + + {#snippet child({ props })} + + {/snippet} +
@@ -238,10 +270,9 @@ !!isDeletingProject, (v) => (isDeletingProject = null)} onOpenChange={(e) => { if (!e) { - isDeletingProject = null; confirmProjectName = ''; } }} @@ -251,32 +282,37 @@ data-testid="delete-project-dialog" class="max-h-[90vh] w-[clamp(0px,35rem,100%)]" > - Delete project + {m['project.delete.heading']()} - +
event.key === 'Enter' && event.preventDefault()} - on:submit={handleDeleteProject} - class="grow w-full overflow-auto" + id="deleteProjectForm" + onkeydown={(event) => event.key === 'Enter' && event.preventDefault()} + onsubmit={handleDeleteProject} + class="w-full grow overflow-auto" > -
-

- Do you really want to delete project - - `{targetProject?.name ?? isDeletingProject}` - ? This process cannot be undone. +

+

+ {@html m['project.delete.text_content']({ + project_name: escapeHtmlText(targetProject?.name ?? isDeletingProject ?? '') + })}

-
- - Enter project {targetProject?.name ? 'name' : 'ID'} to confirm +
+ + {m['project.delete.text_confirm']({ + confirm_text: targetProject?.name ? 'name' : 'ID' + })}
@@ -284,20 +320,21 @@
- - - + + {#snippet child({ props })} + + {/snippet} +
diff --git a/services/app/src/routes/(main)/project/[project_id]/(components)/ActionsDropdown.svelte b/services/app/src/routes/(main)/project/[project_id]/(components)/ActionsDropdown.svelte old mode 100644 new mode 100755 index 71605be..7160c32 --- a/services/app/src/routes/(main)/project/[project_id]/(components)/ActionsDropdown.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/(components)/ActionsDropdown.svelte @@ -1,14 +1,14 @@ - - + + {#snippet child({ props })} + + {/snippet} - - - Order by - Created - - -
- { - const query = new URLSearchParams($page.url.searchParams.toString()); - query.set('asc', '1'); - goto(`?${query.toString()}`, { replaceState: true }); - }} - class="z-10 transition-colors ease-in-out rounded-full px-4 py-1 w-full text-center {$page.url.searchParams.get( - 'asc' - ) === '1' - ? 'text-[#667085]' - : 'text-[#98A2B3]'} cursor-pointer" - > - Ascending - - - { - const query = new URLSearchParams($page.url.searchParams.toString()); - query.delete('asc'); - goto(`?${query.toString()}`, { replaceState: true }); - }} - class="z-10 transition-colors ease-in-out rounded-full px-4 py-1 w-full text-center {$page.url.searchParams.get( - 'asc' - ) !== '1' - ? 'text-[#667085]' - : 'text-[#98A2B3]'} cursor-pointer" - > - Descending - -
-
- - - - {#if tableType !== 'chat' || !tableData?.parent_id} - - (isAddingColumn = { type: 'input', showDialog: true })}> - - - Add - input -
- column -
-
- (isAddingColumn = { type: 'output', showDialog: true })}> - - - Add - output -
- column -
-
-
- - - {/if} - - - + + Import rows - - + + Export rows (.csv) - - - - Export table - + + {#snippet children({ handleExportTable })} + + + Export table + + {/snippet} @@ -317,10 +248,10 @@ (isDeletingTable = $page.params.table_id)} - class="text-[#D92D20] hover:!text-[#D92D20] data-[highlighted]:text-[#D92D20] hover:!bg-[#FEF3F2] data-[highlighted]:bg-[#FEF3F2]" + onclick={() => (isDeletingTable = page.params.table_id)} + class="text-[#D92D20] hover:!bg-[#FEF3F2] hover:!text-[#D92D20] data-[highlighted]:bg-[#FEF3F2] data-[highlighted]:text-[#D92D20]" > - + Delete table @@ -333,7 +264,7 @@ bind:isDeletingTable deletedCb={(success) => { if (success) { - goto(`/project/${$page.params.project_id}/${tableType}-table`); + goto(`/project/${page.params.project_id}/${tableType}-table`); } }} /> diff --git a/services/app/src/routes/(main)/project/[project_id]/(components)/ExportTableButton.svelte b/services/app/src/routes/(main)/project/[project_id]/(components)/ExportTableButton.svelte old mode 100644 new mode 100755 index 604ef67..a24d3a1 --- a/services/app/src/routes/(main)/project/[project_id]/(components)/ExportTableButton.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/(components)/ExportTableButton.svelte @@ -1,32 +1,42 @@ - +{@render children?.({ handleExportTable })} diff --git a/services/app/src/routes/(main)/project/[project_id]/(components)/GenerateButton.svelte b/services/app/src/routes/(main)/project/[project_id]/(components)/GenerateButton.svelte old mode 100644 new mode 100755 index 978dc54..9bd4260 --- a/services/app/src/routes/(main)/project/[project_id]/(components)/GenerateButton.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/(components)/GenerateButton.svelte @@ -1,28 +1,40 @@ diff --git a/services/app/src/routes/(main)/project/[project_id]/(components)/index.ts b/services/app/src/routes/(main)/project/[project_id]/(components)/index.ts old mode 100644 new mode 100755 diff --git a/services/app/src/routes/(main)/project/[project_id]/(dialogs)/AddColumnDialog.svelte b/services/app/src/routes/(main)/project/[project_id]/(dialogs)/AddColumnDialog.svelte old mode 100644 new mode 100755 index 4042e3d..dc64f9b --- a/services/app/src/routes/(main)/project/[project_id]/(dialogs)/AddColumnDialog.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/(dialogs)/AddColumnDialog.svelte @@ -1,9 +1,7 @@ { - if (!e) { - isAddingColumn = { ...isAddingColumn, showDialog: false }; - } - }} + bind:open={() => isAddingColumn.showDialog, + (v) => (isAddingColumn = { ...isAddingColumn, showDialog: v })} > New {isAddingColumn.type} column -
- +
+
-
- Data type* - - { - if (v) { - selectedDatatype = v.value; - } - }} - > - - + {/snippet} - + {#each Object.keys(genTableDTypes).filter((dtype) => (isAddingColumn.type === 'output' || !dtype.endsWith('_code')) && (isAddingColumn.type === 'input' || dtype.startsWith('str') || dtype === 'file_code')) as dType} {genTableDTypes[dType]} @@ -200,40 +194,32 @@ {#if isAddingColumn.type == 'output'} {#if !selectedDatatype.endsWith('_code')} -
- Models +
+ { - selectedModel = model; - const modelDetails = $modelsAvailable.find((val) => val.id == model); if (modelDetails && parseInt(maxTokens) > modelDetails.context_length) { maxTokens = modelDetails.context_length.toString(); } }} - buttonText={($modelsAvailable.find((model) => model.id == selectedModel)?.name ?? - selectedModel) || - 'Select model'} - class="bg-[#F2F4F7] data-dark:bg-[#42464e] hover:bg-[#e1e2e6] border-transparent" + class="border-transparent bg-[#F2F4F7] hover:bg-[#e1e2e6] data-dark:bg-[#42464e]" />
-
-
- +
+
+ { + onchange={(e) => { const value = parseFloat(e.currentTarget.value); if (isNaN(value)) { @@ -246,22 +232,20 @@ temperature = value.toFixed(2); } }} - class="px-3 py-2 text-sm bg-[#F2F4F7] data-dark:bg-[#42464e] rounded-md border border-transparent placeholder:text-muted-foreground focus-visible:outline-none focus-visible:border-[#4169e1] data-dark:focus-visible:border-[#5b7ee5] disabled:cursor-not-allowed disabled:opacity-50 transition-colors" + class="rounded-md border border-transparent bg-[#F2F4F7] px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus-visible:border-[#d5607c] focus-visible:shadow-[0_0_0_1px_#FFD8DF] focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 data-dark:bg-[#42464e] data-dark:focus-visible:border-[#5b7ee5]" />
-
- +
+ { + onchange={(e) => { const value = parseInt(e.currentTarget.value); const model = $modelsAvailable.find((model) => model.id == selectedModel); @@ -275,7 +259,7 @@ maxTokens = value.toString(); } }} - class="px-3 py-2 text-sm bg-[#F2F4F7] data-dark:bg-[#42464e] rounded-md border border-transparent placeholder:text-muted-foreground focus-visible:outline-none focus-visible:border-[#4169e1] data-dark:focus-visible:border-[#5b7ee5] disabled:cursor-not-allowed disabled:opacity-50 transition-colors" + class="rounded-md border border-transparent bg-[#F2F4F7] px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus-visible:border-[#d5607c] focus-visible:shadow-[0_0_0_1px_#FFD8DF] focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 data-dark:bg-[#42464e] data-dark:focus-visible:border-[#5b7ee5]" />
-
- +
+ { + onchange={(e) => { const value = parseFloat(e.currentTarget.value); if (isNaN(value)) { @@ -309,22 +291,17 @@ topP = value.toFixed(3); } }} - class="px-3 py-2 text-sm bg-[#F2F4F7] data-dark:bg-[#42464e] rounded-md border border-transparent placeholder:text-muted-foreground focus-visible:outline-none focus-visible:border-[#4169e1] data-dark:focus-visible:border-[#5b7ee5] disabled:cursor-not-allowed disabled:opacity-50 transition-colors" + class="rounded-md border border-transparent bg-[#F2F4F7] px-3 py-2 text-sm transition-colors placeholder:text-muted-foreground focus-visible:border-[#d5607c] focus-visible:shadow-[0_0_0_1px_#FFD8DF] focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 data-dark:bg-[#42464e] data-dark:focus-visible:border-[#5b7ee5]" />
-
- +
+
-
- +
+ -
- Columns: +
+ Columns: {#each usableColumns as column}
-
-
- - - + + {#snippet child({ props })} + + {/snippet} + diff --git a/services/app/src/routes/(main)/project/[project_id]/(dialogs)/ColumnMatchDialog.svelte b/services/app/src/routes/(main)/project/[project_id]/(dialogs)/ColumnMatchDialog.svelte old mode 100644 new mode 100755 index 391bb2f..fe10386 --- a/services/app/src/routes/(main)/project/[project_id]/(dialogs)/ColumnMatchDialog.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/(dialogs)/ColumnMatchDialog.svelte @@ -1,7 +1,6 @@ !!isMatchingImportCols, () => (isMatchingImportCols = null)} onOpenChange={(e) => { if (!e) { - isMatchingImportCols = null; match = []; } }} @@ -87,46 +91,50 @@ Drag columns to match -
+
- Source file + {#snippet leading()} + Source file + {/snippet} - - - + {#snippet listItem({ + item: col, + itemIndex: index, + dragStart, + dragMove, + dragOver, + dragEnd, + draggingItem: draggingColumn, + draggingItemIndex: draggingColumnIndex + })} + +
dragStart(e, col, index, !!col.name)} - on:drag={dragMove} - on:dragover|preventDefault={(e) => dragOver(e, index)} - on:dragend={dragEnd} - on:touchstart={(e) => dragStart(e, col, index, !!col.name)} - on:touchmove={dragMove} - on:touchend={dragEnd} + onclick={(e) => e.stopPropagation()} + ondragstart={(e) => dragStart(e, col, index, !!col.name)} + ondrag={dragMove} + ondragover={(e) => { + e.preventDefault(); + dragOver(e, index); + }} + ondragend={dragEnd} + ontouchstart={(e) => dragStart(e, col, index, !!col.name)} + ontouchmove={dragMove} + ontouchend={dragEnd} draggable={!!col.name} - class="flex items-center gap-2 px-2 h-[40px] bg-white data-dark:bg-[#42464E] {col.name - ? 'border cursor-grab hover:shadow-float' + class="flex h-[40px] items-center gap-2 bg-white px-2 data-dark:bg-[#42464E] {col.name + ? 'cursor-grab border hover:shadow-float' : ''} border-[#E4E7EC] data-dark:border-[#333] {draggingColumn?.id === col.id ? 'opacity-0' : include.includes(col.id) && index + 1 <= filterTableCols.length ? draggingColumnIndex === null && col.name ? 'hover:shadow-float' : '' - : 'opacity-60'} transition-shadow rounded touch-none" + : 'opacity-60'} touch-none rounded transition-shadow" >
-
+ {/snippet} - - {#if dragMouseCoords && draggingColumn} - + {#snippet draggedItem({ dragMouseCoords, draggingItem: draggingColumn })} + + {#if dragMouseCoords && draggingColumn}
{draggingColumn.name}
-
- {/if} -
+ {/if} + + {/snippet}
  {#each match as col, index}
- Table + Table {#each filterTableCols as col} {@const colType = !col.gen_config ? 'input' : 'output'}
- + {colType} {col.dtype} {col.id} @@ -238,23 +243,23 @@ type="submit" loading={isLoadingImport} disabled={isLoadingImport} - class="hidden relative grow px-6 rounded-full" + class="relative hidden grow px-6" />
- - - + + {#snippet child({ props })} + + {/snippet} + diff --git a/services/app/src/routes/(main)/project/[project_id]/(dialogs)/DeleteDialogs.svelte b/services/app/src/routes/(main)/project/[project_id]/(dialogs)/DeleteDialogs.svelte old mode 100644 new mode 100755 index 45186ea..ba2072d --- a/services/app/src/routes/(main)/project/[project_id]/(dialogs)/DeleteDialogs.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/(dialogs)/DeleteDialogs.svelte @@ -1,9 +1,8 @@ - { - if (!e) { - tableState.setDeletingCol(null); - } - }} -> + !!tableState.deletingCol, () => tableState.setDeletingCol(null)}> - Close - +
-

Are you sure?

-

+

Are you sure?

+

Do you really want to drop column - - `{$tableState.deletingCol}` + + `{tableState.deletingCol}` ? This process cannot be undone.

- - - + + {#snippet child({ props })} + + {/snippet} + @@ -144,31 +146,24 @@ - { - if (!e) { - isDeletingRow = null; - } - }} -> + !!isDeletingRow, () => (isDeletingRow = null)}> - Close - +
-

Are you sure?

-

+

Are you sure?

+

Do you really want to delete these row(s)? This process cannot be undone.

@@ -189,17 +184,17 @@
- - - + + {#snippet child({ props })} + + {/snippet} + diff --git a/services/app/src/routes/(main)/project/[project_id]/(dialogs)/DeleteTableDialog.svelte b/services/app/src/routes/(main)/project/[project_id]/(dialogs)/DeleteTableDialog.svelte old mode 100644 new mode 100755 index 1f3f6f5..2b3f8e3 --- a/services/app/src/routes/(main)/project/[project_id]/(dialogs)/DeleteTableDialog.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/(dialogs)/DeleteTableDialog.svelte @@ -2,14 +2,13 @@ import { PUBLIC_JAMAI_URL } from '$env/static/public'; import toUpper from 'lodash/toUpper'; import { goto } from '$app/navigation'; - import { page } from '$app/stores'; - import { Dialog as DialogPrimitive } from 'bits-ui'; + import { page } from '$app/state'; import logger from '$lib/logger'; import { pastActionTables, pastKnowledgeTables, pastChatAgents - } from '$lib/components/tables/tablesStore'; + } from '$lib/components/tables/tablesState.svelte'; import { toast, CustomToastDesc } from '$lib/components/ui/sonner'; import { Button } from '$lib/components/ui/button'; @@ -17,23 +16,28 @@ import DialogCloseIcon from '$lib/icons/DialogCloseIcon.svelte'; import CloseIcon from '$lib/icons/CloseIcon.svelte'; - export let tableType: 'action' | 'knowledge' | 'chat'; - export let isDeletingTable: string | null; - export let deletedCb: ((success: boolean, deletedTableID?: string) => any) | undefined = - undefined; + interface Props { + tableType: 'action' | 'knowledge' | 'chat'; + isDeletingTable: string | null; + deletedCb?: ((success: boolean, deletedTableID?: string) => any) | undefined; + } + + let { tableType, isDeletingTable = $bindable(), deletedCb = undefined }: Props = $props(); - let isLoading = false; + let isLoading = $state(false); async function handleDeleteTable() { if (isLoading || !isDeletingTable) return; isLoading = true; const response = await fetch( - `${PUBLIC_JAMAI_URL}/api/v1/gen_tables/${tableType}/${isDeletingTable}`, + `${PUBLIC_JAMAI_URL}/api/owl/gen_tables/${tableType}?${new URLSearchParams([ + ['table_id', isDeletingTable] + ])}`, { method: 'DELETE', headers: { - 'x-project-id': $page.params.project_id + 'x-project-id': page.params.project_id } } ); @@ -56,9 +60,9 @@ switch (tableType) { case 'action': { $pastActionTables = $pastActionTables.filter((t) => t.id !== isDeletingTable); - if ($page.params.table_id === isDeletingTable) { + if (page.params.table_id === isDeletingTable) { goto( - `/project/${$page.params.project_id}/action-table/${$pastActionTables[0]?.id || ''}` + `/project/${page.params.project_id}/action-table/${$pastActionTables[0]?.id || ''}` ); } break; @@ -67,9 +71,9 @@ $pastKnowledgeTables = $pastKnowledgeTables.filter( (table) => table.id !== isDeletingTable ); - if ($page.params.table_id === isDeletingTable) { + if (page.params.table_id === isDeletingTable) { goto( - `/project/${$page.params.project_id}/knowledge-table/${$pastKnowledgeTables[0]?.id || ''}` + `/project/${page.params.project_id}/knowledge-table/${$pastKnowledgeTables[0]?.id || ''}` ); } break; @@ -89,33 +93,26 @@ } - { - if (!e) { - isDeletingTable = null; - } - }} -> + !!isDeletingTable, () => (isDeletingTable = null)}> - Close - +
-

Are you sure?

-

+

Are you sure?

+

Do you really want to delete table - + `{isDeletingTable}` ? This process cannot be undone.

@@ -123,17 +120,17 @@
- - - + + {#snippet child({ props })} + + {/snippet} + diff --git a/services/app/src/routes/(main)/project/[project_id]/(dialogs)/ImportTableDialog.svelte b/services/app/src/routes/(main)/project/[project_id]/(dialogs)/ImportTableDialog.svelte old mode 100644 new mode 100755 index 6e42763..6f2fe55 --- a/services/app/src/routes/(main)/project/[project_id]/(dialogs)/ImportTableDialog.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/(dialogs)/ImportTableDialog.svelte @@ -2,26 +2,28 @@ import { PUBLIC_JAMAI_URL } from '$env/static/public'; import toUpper from 'lodash/toUpper'; import axios, { CanceledError } from 'axios'; - import { Dialog as DialogPrimitive } from 'bits-ui'; - import { page } from '$app/stores'; import logger from '$lib/logger'; import { toast } from 'svelte-sonner'; import InputText from '$lib/components/InputText.svelte'; + import { Label } from '$lib/components/ui/label'; import { Button } from '$lib/components/ui/button'; import * as Dialog from '$lib/components/ui/dialog'; import DocumentFilledIcon from '$lib/icons/DocumentFilledIcon.svelte'; - import { tick } from 'svelte'; - export let isImportingTable: File | null; - export let tableType: 'action' | 'knowledge' | 'chat'; - export let refetchTables: () => Promise; + interface Props { + isImportingTable: File | null; + tableType: 'action' | 'knowledge' | 'chat'; + refetchTables: () => Promise; + } + + let { isImportingTable = $bindable(), tableType, refetchTables }: Props = $props(); - let form: HTMLFormElement; - let isLoading = false; - let uploadProgress: number | null = null; + let isLoading = $state(false); + let uploadProgress: number | null = $state(null); async function handleImportTable(e: SubmitEvent & { currentTarget: HTMLFormElement }) { + e.preventDefault(); if (!isImportingTable) return; const tableId = new FormData(e.currentTarget).get('table_id') as string; @@ -39,12 +41,11 @@ isLoading = true; try { const uploadRes = await axios.post( - `${PUBLIC_JAMAI_URL}/api/v1/gen_tables/${tableType}/import`, + `${PUBLIC_JAMAI_URL}/api/owl/gen_tables/${tableType}/import`, formData, { headers: { - 'Content-Type': 'multipart/form-data', - 'x-project-id': $page.params.project_id + 'Content-Type': 'multipart/form-data' }, onUploadProgress: (progressEvent) => { if (!progressEvent.total) return; @@ -93,30 +94,27 @@ } - { - if (!e) { - isImportingTable = null; - } - }} -> - + !!isImportingTable, () => (isImportingTable = null)}> + Import table
-
- Table ID* +
+ {#if isImportingTable} -
+
-

+

{isImportingTable.name}

@@ -147,17 +145,17 @@ {#if uploadProgress}
{:else} @@ -165,31 +163,21 @@
{/if} - - -
- - - + + {#snippet child({ props })} + + {/snippet} + diff --git a/services/app/src/routes/(main)/project/[project_id]/(dialogs)/RenameTableDialog.svelte b/services/app/src/routes/(main)/project/[project_id]/(dialogs)/RenameTableDialog.svelte old mode 100644 new mode 100755 index 76dd002..01808b4 --- a/services/app/src/routes/(main)/project/[project_id]/(dialogs)/RenameTableDialog.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/(dialogs)/RenameTableDialog.svelte @@ -2,42 +2,50 @@ import { PUBLIC_JAMAI_URL } from '$env/static/public'; import toUpper from 'lodash/toUpper'; import { goto } from '$app/navigation'; - import { page } from '$app/stores'; - import { Dialog as DialogPrimitive } from 'bits-ui'; + import { page } from '$app/state'; import { pastActionTables, pastChatAgents, pastKnowledgeTables - } from '$lib/components/tables/tablesStore'; + } from '$lib/components/tables/tablesState.svelte'; import logger from '$lib/logger'; import InputText from '$lib/components/InputText.svelte'; import { toast, CustomToastDesc } from '$lib/components/ui/sonner'; + import { Label } from '$lib/components/ui/label'; import { Button } from '$lib/components/ui/button'; import * as Dialog from '$lib/components/ui/dialog'; - export let tableType: 'action' | 'knowledge' | 'chat'; - export let isEditingTableID: string | null; - export let editedCb: ((success: boolean, tableID?: string) => any) | undefined = undefined; + interface Props { + tableType: 'action' | 'knowledge' | 'chat'; + isEditingTableID: string | null; + editedCb?: ((success: boolean, tableID?: string) => any) | undefined; + } + + let { tableType, isEditingTableID = $bindable(), editedCb = undefined }: Props = $props(); - let form: HTMLFormElement; - let isLoadingSaveEdit = false; + let isLoadingSaveEdit = $state(false); async function handleSaveTableID( e: SubmitEvent & { currentTarget: EventTarget & HTMLFormElement } ) { - const editedTableID = e.currentTarget.getElementsByTagName('input')[0].value.trim(); + e.preventDefault(); + if (!isEditingTableID) return; + const editedTableID = e.currentTarget.getElementsByTagName('input')[0].value.trim(); if (isEditingTableID === editedTableID) return; isLoadingSaveEdit = true; const response = await fetch( - `${PUBLIC_JAMAI_URL}/api/v1/gen_tables/${tableType}/rename/${isEditingTableID}/${editedTableID}`, + `${PUBLIC_JAMAI_URL}/api/owl/gen_tables/${tableType}/rename?${new URLSearchParams([ + ['table_id_src', isEditingTableID], + ['table_id_dst', editedTableID] + ])}`, { method: 'POST', headers: { - 'x-project-id': $page.params.project_id + 'x-project-id': page.params.project_id } } ); @@ -85,8 +93,8 @@ $pastActionTables = $pastActionTables; } - if ($page.params.table_id === isEditingTableID) { - goto(`/project/${$page.params.project_id}/action-table/${editedTableID}`); + if (page.params.table_id === isEditingTableID) { + goto(`/project/${page.params.project_id}/action-table/${editedTableID}`); } break; } @@ -101,8 +109,8 @@ $pastKnowledgeTables = $pastKnowledgeTables; } - if ($page.params.table_id === isEditingTableID) { - goto(`/project/${$page.params.project_id}/knowledge-table/${editedTableID}`); + if (page.params.table_id === isEditingTableID) { + goto(`/project/${page.params.project_id}/knowledge-table/${editedTableID}`); } break; } @@ -117,8 +125,8 @@ $pastChatAgents = $pastChatAgents; } - if ($page.params.table_id === isEditingTableID) { - goto(`/project/${$page.params.project_id}/chat-table/${editedTableID}`); + if (page.params.table_id === isEditingTableID) { + goto(`/project/${page.params.project_id}/chat-table/${editedTableID}`); } break; } @@ -135,45 +143,32 @@ } - { - if (!e) { - isEditingTableID = null; - } - }} -> + !!isEditingTableID, () => (isEditingTableID = null)}> Edit table ID - -
-
- Table ID* + + +
+ - +
- - -
- - - + + {#snippet child({ props })} + + {/snippet} + diff --git a/services/app/src/routes/(main)/project/[project_id]/(dialogs)/index.ts b/services/app/src/routes/(main)/project/[project_id]/(dialogs)/index.ts old mode 100644 new mode 100755 diff --git a/services/app/src/routes/(main)/project/[project_id]/+layout.server.ts b/services/app/src/routes/(main)/project/[project_id]/+layout.server.ts new file mode 100644 index 0000000..ed1f0ab --- /dev/null +++ b/services/app/src/routes/(main)/project/[project_id]/+layout.server.ts @@ -0,0 +1,44 @@ +import { env } from '$env/dynamic/private'; +import logger from '$lib/logger.js'; +import type { ProjectMemberRead } from '$lib/types.js'; + +const { OWL_URL, OWL_SERVICE_KEY /* RESEND_API_KEY */ } = env; + +const headers = { + Authorization: `Bearer ${OWL_SERVICE_KEY}` +}; + +export async function load({ cookies, locals }) { + //TODO: Paginate this + const getProjectMembers = async () => { + const activeProjectId = cookies.get('activeProjectId'); + + if (!activeProjectId) { + return { error: 400, message: 'No active project' }; + } + + const response = await fetch( + `${OWL_URL}/api/v2/projects/members/list?${new URLSearchParams([ + ['project_id', activeProjectId] + ])}`, + { + headers: { + ...headers, + 'x-user-id': locals.user?.id ?? '' + } + } + ); + const responseBody = await response.json(); + + if (!response.ok) { + logger.error('ORGMEMBER_LIST_ERROR', responseBody, locals.user?.id); + return { error: response.status, message: responseBody }; + } + + return { data: responseBody.items as ProjectMemberRead[] }; + }; + + return { + projectMembers: getProjectMembers() + }; +} diff --git a/services/app/src/routes/(main)/project/[project_id]/+layout.svelte b/services/app/src/routes/(main)/project/[project_id]/+layout.svelte old mode 100644 new mode 100755 index 9c37acb..2709ed7 --- a/services/app/src/routes/(main)/project/[project_id]/+layout.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/+layout.svelte @@ -1,7 +1,7 @@ -
-
-
- - - +
+
+
+ -

+

{$activeProject?.name ?? ($loadingProjectData.loading ? 'Loading...' - : $loadingProjectData.error ?? $page.params.project_id)} + : $loadingProjectData.error ?? page.params.project_id)}

- - + + {#snippet child({ props })} + + {/snippet} - + (isEditingProjectName = $activeProject)} + onclick={() => (isEditingProjectName = $activeProject)} class="text-[#344054] data-[highlighted]:text-[#344054]" > - + Rename project { - navigator.clipboard.writeText($page.params.project_id ?? ''); + onclick={() => { + navigator.clipboard.writeText(page.params.project_id ?? ''); toast.success('Project ID copied to clipboard', { id: 'project-id-copied' }); }} + class="text-[#344054] data-[highlighted]:text-[#344054]" > (isDeletingProject = $page.params.project_id)} + onclick={() => (isDeletingProject = page.params.project_id)} class="text-destructive data-[highlighted]:text-destructive" > - + Delete project - - + + + {/snippet}
- - Action Table - - - - Knowledge Table - - - - Chat Table - + {#each tabItems as { title, href, route }} + + {title} + + {/each}
@@ -165,7 +185,7 @@ />
- + {@render children?.()}
{#if $activeProject} diff --git a/services/app/src/routes/(main)/project/[project_id]/+page.ts b/services/app/src/routes/(main)/project/[project_id]/+page.ts old mode 100644 new mode 100755 diff --git a/services/app/src/routes/(main)/project/[project_id]/action-table/(dialogs)/AddTableDialog.svelte b/services/app/src/routes/(main)/project/[project_id]/action-table/(dialogs)/AddTableDialog.svelte old mode 100644 new mode 100755 index 8567d04..19940a0 --- a/services/app/src/routes/(main)/project/[project_id]/action-table/(dialogs)/AddTableDialog.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/action-table/(dialogs)/AddTableDialog.svelte @@ -1,19 +1,25 @@ - + New action table -
-
- - Table ID* - +
+
+
-
-

Columns

+
+

Table Columns

- {#if columns.length === 0} -

No columns added

- {:else} +
-
- - - - Column ID* - - - - Data Type* - - - - Output* - - - -
- - - + + + + + + + + +
+ + + {#snippet listItem({ + item: column, + itemIndex: index, + dragStart, + dragMove, + dragOver, + dragEnd, + draggingItem: draggingColumn + })} +
  • (columns.length === index + 1 ? null : dragOver(e, index))} + style="grid-template-columns: 30px minmax(0, 1.2fr) repeat(2, minmax(0, 1fr)) auto;" + class="grid gap-2 {draggingColumn?.drag_id == column.drag_id ? 'opacity-0' : ''}" > -
  • dragOver(e, index)} - style="grid-template-columns: 30px repeat(2, minmax(0, 1fr)) 50px 40px;" - class="grid gap-2 {draggingColumn?.drag_id == column.drag_id ? 'opacity-0' : ''}" + + +
    + { + if (columns.length === index + 1) { + if (columns[index].col_type === null) columns[index].col_type = 'Input'; + columns[index].dtype = + 'str' /* genTableColDTypes[column.col_type ?? 'Input'][0] */; + columns = [...columns, newColDefault()]; + } + }} + placeholder="New column" + class="h-[38px] border border-[#E4E7EC] {columns.length === index + 1 + ? 'bg-[#F9FAFB]' + : 'bg-white data-dark:bg-[#42464e]'}" + /> +
    + +
    + { + columns[index].col_type = v as keyof typeof genTableColTypes; + columns[index].gen_config = genTableColTypes[columns[index].col_type]; + if ( + !genTableColDTypes[column.col_type ?? 'Input'].includes( + columns[index].dtype + ) + ) { + columns[index].dtype = + 'str' /* genTableColDTypes[column.col_type ?? 'Input'][0] */; + } + + if (columns.length === index + 1) { + columns = [...columns, newColDefault()]; + } + }} > - - - -
    - -
    - -
    - { - if (v) { - columns[index].dtype = v.value; - if (columns[index].gen_config) { - columns[index].gen_config = v.value.endsWith('_code') - ? CODE_GEN_CONFIG_DEFAULT - : LLM_GEN_CONFIG_DEFAULT; - } - } - }} + - - - - - {#each Object.keys(genTableDTypes).filter((dtype) => (column.gen_config || !dtype.endsWith('_code')) && (!column.gen_config || dtype.startsWith('str') || dtype === 'file_code')) as dType} - - {genTableDTypes[dType]} - - {/each} - - -
    - -
    - { - if (e.detail.value) { - columns[index].gen_config = columns[index].dtype.endsWith('_code') - ? CODE_GEN_CONFIG_DEFAULT - : LLM_GEN_CONFIG_DEFAULT; - - if (!['str', 'image'].includes(columns[index].dtype)) { - columns[index].dtype = 'str'; - } - } else { - columns[index].gen_config = null; - - if (columns[index].dtype.endsWith('_code')) { - columns[index].dtype = 'str'; - } - } - }} - checked={!!column.gen_config} - class="h-5 w-5 [&>svg]:translate-x-[1px]" - /> -
    - -
    + +
    + { + columns[index].dtype = v; + + if (columns.length === index + 1) { + columns[index].col_type = (Object.entries(genTableColDTypes).find( + ([colType, dTypes]) => dTypes.includes(v) + )?.[0] ?? 'Input') as keyof typeof genTableColDTypes; + columns = [...columns, newColDefault()]; + } + }} > - - -
  • - - - - {#if dragMouseCoords && draggingColumn} - -
  • - - -
    - -
    - -
    - -
    - -
    - -
    + {genTableDTypes[dType]} + + {/each} + + +
  • + + + + {/snippet} + {#snippet draggedItem({ dragMouseCoords, draggingItem: draggingColumn })} + + {#if dragMouseCoords && draggingColumn} +
  • + + +
    + +
    + +
    -
  • -
    +
    + +
    + +
    + + + {/if} - - -
    - {/if} + + {/snippet} + +
    -
    +
    - +
    - - - + + {#snippet child({ props })} + + {/snippet} + diff --git a/services/app/src/routes/(main)/project/[project_id]/action-table/+page.svelte b/services/app/src/routes/(main)/project/[project_id]/action-table/+page.svelte old mode 100644 new mode 100755 index c0b8218..0685259 --- a/services/app/src/routes/(main)/project/[project_id]/action-table/+page.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/action-table/+page.svelte @@ -3,9 +3,9 @@ import { onMount } from 'svelte'; import debounce from 'lodash/debounce'; import Trash_2 from 'lucide-svelte/icons/trash-2'; - import { page } from '$app/stores'; + import { page } from '$app/state'; import { aTableSort as sortOptions } from '$globalStore'; - import { pastActionTables } from '$lib/components/tables/tablesStore'; + import { pastActionTables } from '$lib/components/tables/tablesState.svelte'; import logger from '$lib/logger'; import AddTableDialog from './(dialogs)/AddTableDialog.svelte'; @@ -27,16 +27,17 @@ import SortAlphabetIcon from '$lib/icons/SortAlphabetIcon.svelte'; import ImportIcon from '$lib/icons/ImportIcon.svelte'; import ExportIcon from '$lib/icons/ExportIcon.svelte'; + import InputText from '$lib/components/InputText.svelte'; + import SearchIcon from '$lib/icons/SearchIcon.svelte'; - export let data; - $: ({ userData } = data); - - let windowWidth: number; + let { data } = $props(); + let { user } = $derived(data); let fetchController: AbortController | null = null; - let loadingATablesError: { status: number; message: string; org_id: string } | null = null; - let isLoadingATables = true; - let isLoadingMoreATables = false; + let loadingATablesError: { status: number; message: string; org_id: string } | null = + $state(null); + let isLoadingATables = $state(true); + let isLoadingMoreATables = $state(false); let moreATablesFinished = false; //FIXME: Bandaid fix for infinite loop caused by loading circle let currentOffset = 0; const limit = 50; @@ -45,14 +46,14 @@ { id: 'updated_at', title: 'Date modified', Icon: SortByIcon } ]; - let searchQuery = ''; + let searchQuery = $state(''); let searchController: AbortController | null = null; - let isLoadingSearch = false; + let isLoadingSearch = $state(false); - let isAddingTable = false; - let isEditingTableID: string | null = null; - let isDeletingTable: string | null = null; - let isImportingTable: File | null = null; + let isAddingTable = $state(false); + let isEditingTableID: string | null = $state(null); + let isDeletingTable: string | null = $state(null); + let isImportingTable: File | null = $state(null); onMount(() => { getActionTables(); @@ -75,7 +76,7 @@ offset: currentOffset.toString(), limit: limit.toString(), order_by: $sortOptions.orderBy, - order_descending: $sortOptions.order === 'asc' ? 'false' : 'true', + order_ascending: $sortOptions.order === 'asc' ? 'true' : 'false', search_query: searchQuery.trim() } as Record; @@ -84,13 +85,10 @@ } const response = await fetch( - `${PUBLIC_JAMAI_URL}/api/v1/gen_tables/action?` + new URLSearchParams(searchParams), + `${PUBLIC_JAMAI_URL}/api/owl/gen_tables/action/list?` + new URLSearchParams(searchParams), { credentials: 'same-origin', - signal: fetchController.signal, - headers: { - 'x-project-id': $page.params.project_id - } + signal: fetchController.signal } ); currentOffset += limit; @@ -157,17 +155,14 @@ try { const response = await fetch( - `${PUBLIC_JAMAI_URL}/api/v1/gen_tables/action?${new URLSearchParams({ + `${PUBLIC_JAMAI_URL}/api/owl/gen_tables/action/list?${new URLSearchParams({ limit: limit.toString(), order_by: $sortOptions.orderBy, - order_descending: $sortOptions.order === 'asc' ? 'false' : 'true', + order_ascending: $sortOptions.order === 'asc' ? 'true' : 'false', search_query: q })}`, { - signal: searchController.signal, - headers: { - 'x-project-id': $page.params.project_id - } + signal: searchController.signal } ); currentOffset = limit; @@ -241,129 +236,127 @@ Action Table - - {#if !loadingATablesError} -
    +
    -
    - - - -
    - - + + +
    + + + +
    {#if isLoadingATables} {#each Array(12) as _} {/each} {:else} {#each $pastActionTables as actionTable (actionTable.id)} -
    +
    - - + + {actionTable.id}
    - - + + {#snippet child({ props })} + + {/snippet} - + (isEditingTableID = actionTable.id)} + onclick={() => (isEditingTableID = actionTable.id)} class="text-[#344054] data-[highlighted]:text-[#344054]" > - + Rename table - - - - Export table - + + {#snippet children({ handleExportTable })} + + + Export table + + {/snippet} (isDeletingTable = actionTable.id)} + onclick={() => (isDeletingTable = actionTable.id)} class="text-destructive data-[highlighted]:text-destructive" > - + Delete table @@ -378,7 +371,7 @@ day: 'numeric', year: 'numeric' })} - class="font-medium text-xs text-[#98A2B3] data-dark:text-[#C9C9C9] line-clamp-1" + class="line-clamp-1 text-xs font-medium text-[#98A2B3] data-dark:text-[#C9C9C9]" > Last updated @@ -394,25 +387,23 @@ {/each} {#if isLoadingMoreATables} -
    +
    {/if} {/if}
    -{:else if loadingATablesError.status === 404 && loadingATablesError.org_id && userData?.member_of.find((org) => org.organization_id === loadingATablesError?.org_id)} - {@const projectOrg = userData?.member_of.find( - (org) => org.organization_id === loadingATablesError?.org_id - )} +{:else if loadingATablesError.status === 404 && loadingATablesError.org_id && user?.org_memberships.find((org) => org.organization_id === loadingATablesError?.org_id)} + {@const projectOrg = user?.organizations.find((org) => org.id === loadingATablesError?.org_id)} {:else} -
    +
    {loadingATablesError.status}

    {loadingATablesError.message}

    diff --git a/services/app/src/routes/(main)/project/[project_id]/action-table/+page.ts b/services/app/src/routes/(main)/project/[project_id]/action-table/+page.ts old mode 100644 new mode 100755 diff --git a/services/app/src/routes/(main)/project/[project_id]/action-table/[table_id]/+page.ts b/services/app/src/routes/(main)/project/[project_id]/action-table/[table_id]/+page.ts old mode 100644 new mode 100755 index 6120f4c..d6ab9e2 --- a/services/app/src/routes/(main)/project/[project_id]/action-table/[table_id]/+page.ts +++ b/services/app/src/routes/(main)/project/[project_id]/action-table/[table_id]/+page.ts @@ -1,13 +1,14 @@ import { PUBLIC_JAMAI_URL } from '$env/static/public'; -import { error } from '@sveltejs/kit'; -import logger from '$lib/logger.js'; import { actionRowsPerPage } from '$lib/constants.js'; +import logger from '$lib/logger.js'; import type { GenTable, GenTableRow } from '$lib/types.js'; +import { error } from '@sveltejs/kit'; export const load = async ({ depends, fetch, params, parent, url }) => { depends('action-table:slug'); await parent(); const page = parseInt(url.searchParams.get('page') ?? '1'); + const orderBy = url.searchParams.get('sort_by'); const orderAsc = parseInt(url.searchParams.get('asc') ?? '0'); if (!params.table_id) { @@ -16,11 +17,7 @@ export const load = async ({ depends, fetch, params, parent, url }) => { const getTable = async () => { const tableDataRes = await fetch( - `${PUBLIC_JAMAI_URL}/api/v1/gen_tables/action/${params.table_id}?` + - new URLSearchParams({ - offset: '0', - limit: '1' - }), + `${PUBLIC_JAMAI_URL}/api/owl/gen_tables/action?${new URLSearchParams([['table_id', params.table_id]])}`, { headers: { 'x-project-id': params.project_id @@ -30,7 +27,7 @@ export const load = async ({ depends, fetch, params, parent, url }) => { const tableDataBody = await tableDataRes.json(); if (!tableDataRes.ok) { - if (tableDataRes.status !== 404 && tableDataRes.status !== 422) { + if (![403, 404, 422].includes(tableDataRes.status)) { logger.error('ACTIONTBL_TBL_GET', tableDataBody); } return { error: tableDataRes.status, message: tableDataBody }; @@ -42,13 +39,22 @@ export const load = async ({ depends, fetch, params, parent, url }) => { }; const getRows = async () => { + const q = url.searchParams.get('q'); + + const searchParams = new URLSearchParams([ + ['table_id', params.table_id], + ['offset', ((page - 1) * actionRowsPerPage).toString()], + ['limit', actionRowsPerPage.toString()], + ['order_by', orderBy ?? 'ID'], + ['order_ascending', orderAsc === 1 ? 'true' : 'false'] + ]); + + if (q) { + searchParams.set('search_query', q); + } + const tableRowsRes = await fetch( - `${PUBLIC_JAMAI_URL}/api/v1/gen_tables/action/${params.table_id}/rows?` + - new URLSearchParams({ - offset: ((page - 1) * actionRowsPerPage).toString(), - limit: actionRowsPerPage.toString(), - order_descending: orderAsc === 1 ? 'false' : 'true' - }), + `${PUBLIC_JAMAI_URL}/api/owl/gen_tables/action/rows/list?${searchParams}`, { headers: { 'x-project-id': params.project_id @@ -58,7 +64,7 @@ export const load = async ({ depends, fetch, params, parent, url }) => { const tableRowsBody = await tableRowsRes.json(); if (!tableRowsRes.ok) { - if (tableRowsRes.status !== 404 && tableRowsRes.status !== 422) { + if (![403, 404, 422].includes(tableRowsRes.status)) { logger.error('ACTIONTBL_TBL_GETROWS', tableRowsBody); } return { error: tableRowsRes.status, message: tableRowsBody }; diff --git a/services/app/src/routes/(main)/project/[project_id]/action-table/[table_id]/+page@project.svelte b/services/app/src/routes/(main)/project/[project_id]/action-table/[table_id]/+page@project.svelte old mode 100644 new mode 100755 index 6fa7ae4..0b961d7 --- a/services/app/src/routes/(main)/project/[project_id]/action-table/[table_id]/+page@project.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/action-table/[table_id]/+page@project.svelte @@ -1,17 +1,18 @@ - {$page.params.table_id} - Action Table + {page.params.table_id} - Action Table
    - - - - + + + - {#await table} - {$page.params.table_id} + {#await data.table} + {page.params.table_id} {:then { data }} {data ? data.id : tableError?.error === 404 ? 'Not found' : 'Failed to load'} {/await}
    -
    - {#if tableLoaded || (tableData && $genTableRows)} - - +
    + {#if tableLoaded || (tableData && tableRowsState.rows)}
    button>div]:bg-[#E4E7EC]'} transition-[opacity,grid-template-columns]" + : 'opacity-80 [&>button>div]:bg-[#E4E7EC] [&_*]:!text-[#98A2B3] [&_button]:bg-[#E4E7EC]'} transition-[opacity,grid-template-columns]" > @@ -233,17 +232,56 @@ Get Code --> - - {:else} - - - + {/if}
    - +
    + {#if tableLoaded || (tableData && tableRowsState.rows)} +
    + + + +
    + +
    + +
    + + + {:else} + + + + {/if} +
    + + {#if !tableError} diff --git a/services/app/src/routes/(main)/project/[project_id]/chat-table/(dialogs)/AddAgentDialog.svelte b/services/app/src/routes/(main)/project/[project_id]/chat-table/(dialogs)/AddAgentDialog.svelte old mode 100644 new mode 100755 index caad262..d5be18b --- a/services/app/src/routes/(main)/project/[project_id]/chat-table/(dialogs)/AddAgentDialog.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/chat-table/(dialogs)/AddAgentDialog.svelte @@ -1,9 +1,8 @@ - + New agent -
    -
    - +
    +
    + - +
    -
    - Models* +
    + model.id == selectedModel)?.name ?? - selectedModel) || - 'Select model'} class="{!selectedModel ? 'italic text-muted-foreground' - : ''} bg-[#F2F4F7] data-dark:bg-[#42464e] hover:bg-[#e1e2e6] border-transparent" + : ''} border-transparent bg-[#F2F4F7] hover:bg-[#e1e2e6] data-dark:bg-[#42464e]" />
    - - - + + {#snippet child({ props })} + + {/snippet} + diff --git a/services/app/src/routes/(main)/project/[project_id]/chat-table/(dialogs)/AddConversationDialog.svelte b/services/app/src/routes/(main)/project/[project_id]/chat-table/(dialogs)/AddConversationDialog.svelte old mode 100644 new mode 100755 index 75ce94f..fa5c3aa --- a/services/app/src/routes/(main)/project/[project_id]/chat-table/(dialogs)/AddConversationDialog.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/chat-table/(dialogs)/AddConversationDialog.svelte @@ -1,30 +1,41 @@ Chat Table - - -
    +
    {#if !loadingAgentsError}
    -

    Agents

    +

    Agents

    scrollHandler(e, 'agent'), 300)} - class="grow flex flex-col gap-1 overflow-auto" + onscroll={debounce((e) => scrollHandler(e, 'agent'), 300)} + class="flex grow flex-col gap-1 overflow-auto" > {#if isLoadingCAgents} {#each Array(6) as _} {/each} {:else} {#each $pastChatAgents as chatTable (chatTable.id)} {/each} {#if isLoadingMoreCAgents} -
    +
    {/if} @@ -474,23 +469,23 @@
    scrollHandler(e, 'filtered'), 300)} - class="flex flex-col gap-1 pl-6 @2xl:pl-0.5 supports-[not(container-type:inline-size)]:lg:pl-0.5 pr-6 py-2.5 @2xl:py-4 supports-[not(container-type:inline-size)]:lg:py-4" + onscroll={debounce((e) => scrollHandler(e, 'filtered'), 300)} + class="flex flex-col gap-1 py-2.5 pl-6 pr-6 @2xl:py-4 @2xl:pl-0.5 supports-[not(container-type:inline-size)]:lg:py-4 supports-[not(container-type:inline-size)]:lg:pl-0.5" > {#if filterByAgent} -
    +
    -
    - - - -
    - - + + +
    + + + +
    scrollHandler(e, 'filtered'), 300)} + onscroll={debounce((e) => scrollHandler(e, 'filtered'), 300)} style="grid-auto-rows: 120px;" - class="grow grid grid-cols-[repeat(auto-fill,_minmax(300px,_1fr))] grid-flow-row gap-3 pt-1 px-1 h-1 overflow-auto [scrollbar-gutter:stable]" + class="grid h-1 grow grid-flow-row grid-cols-[repeat(auto-fill,_minmax(300px,_1fr))] gap-3 overflow-auto px-1 pt-1 [scrollbar-gutter:stable]" > {#if isLoadingFilteredConv} -
    +
    {:else} {#each filteredConversations as chatTable}
    -
    +
    - - + + {chatTable.id}
    - - + + {#snippet child({ props })} + + {/snippet} - + (isEditingTableID = chatTable.id)} + onclick={() => (isEditingTableID = chatTable.id)} class="text-[#344054] data-[highlighted]:text-[#344054]" > - + Rename table - - - - Export table - + + {#snippet children({ handleExportTable })} + + + Export table + + {/snippet} (isDeletingTable = chatTable.id)} + onclick={() => (isDeletingTable = chatTable.id)} class="text-destructive data-[highlighted]:text-destructive" > - + Delete table @@ -704,7 +697,7 @@ day: 'numeric', year: 'numeric' })} - class="font-medium text-xs text-[#98A2B3] data-dark:text-[#C9C9C9] line-clamp-1" + class="line-clamp-1 text-xs font-medium text-[#98A2B3] data-dark:text-[#C9C9C9]" > Last updated @@ -722,7 +715,7 @@ style="background-color: {mappedColors ? mappedColors.bg : '#E3F2FD'}; color: {mappedColors ? mappedColors.text : '#0295FF'};" - class="w-min px-1 py-0.5 text-xs font-medium whitespace-nowrap rounded-[0.1875rem] select-none" + class="w-min select-none whitespace-nowrap rounded-[0.1875rem] px-1 py-0.5 text-xs font-medium" > {chatTable.parent_id} @@ -732,31 +725,29 @@ {/each} {#if isLoadingMoreFilteredConv} -
    +
    {/if} {/if}
    {:else} - + Select agent to filter conversations {/if}
    - {:else if loadingAgentsError.status === 404 && loadingAgentsError.org_id && userData?.member_of.find((org) => org.organization_id === loadingAgentsError?.org_id)} - {@const projectOrg = userData?.member_of.find( - (org) => org.organization_id === loadingAgentsError?.org_id - )} + {:else if loadingAgentsError.status === 404 && loadingAgentsError.org_id && user?.org_memberships?.find((org) => org.organization_id === loadingAgentsError?.org_id)} + {@const projectOrg = user?.organizations.find((org) => org.id === loadingAgentsError?.org_id)} {:else} -
    +
    {loadingAgentsError.status}

    {loadingAgentsError.message}

    diff --git a/services/app/src/routes/(main)/project/[project_id]/chat-table/+page.ts b/services/app/src/routes/(main)/project/[project_id]/chat-table/+page.ts old mode 100644 new mode 100755 diff --git a/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/+page.ts b/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/+page.ts old mode 100644 new mode 100755 index 41e13c9..e4ab3b3 --- a/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/+page.ts +++ b/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/+page.ts @@ -1,13 +1,14 @@ import { PUBLIC_JAMAI_URL } from '$env/static/public'; -import { error } from '@sveltejs/kit'; -import logger from '$lib/logger.js'; import { chatRowsPerPage } from '$lib/constants.js'; +import logger from '$lib/logger.js'; import type { GenTable, GenTableRow } from '$lib/types.js'; +import { error } from '@sveltejs/kit'; export const load = async ({ depends, fetch, params, parent, url }) => { depends('chat-table:slug'); await parent(); const page = parseInt(url.searchParams.get('page') ?? '1'); + const orderBy = url.searchParams.get('sort_by'); const orderAsc = parseInt(url.searchParams.get('asc') ?? '0'); if (!params.table_id) { @@ -16,11 +17,8 @@ export const load = async ({ depends, fetch, params, parent, url }) => { const getTable = async () => { const tableDataRes = await fetch( - `${PUBLIC_JAMAI_URL}/api/v1/gen_tables/chat/${params.table_id}?` + - new URLSearchParams({ - offset: '0', - limit: '1' - }), + `${PUBLIC_JAMAI_URL}/api/owl/gen_tables/chat?` + + new URLSearchParams([['table_id', params.table_id]]), { headers: { 'x-project-id': params.project_id @@ -30,7 +28,7 @@ export const load = async ({ depends, fetch, params, parent, url }) => { const tableDataBody = await tableDataRes.json(); if (!tableDataRes.ok) { - if (tableDataRes.status !== 404 && tableDataRes.status !== 422) { + if (![403, 404, 422].includes(tableDataRes.status)) { logger.error('CHATTBL_TBL_GET', tableDataBody); } return { error: tableDataRes.status, message: tableDataBody }; @@ -42,13 +40,22 @@ export const load = async ({ depends, fetch, params, parent, url }) => { }; const getRows = async () => { + const q = url.searchParams.get('q'); + + const searchParams = new URLSearchParams([ + ['table_id', params.table_id], + ['offset', ((page - 1) * chatRowsPerPage).toString()], + ['limit', chatRowsPerPage.toString()], + ['order_by', orderBy ?? 'ID'], + ['order_ascending', orderAsc === 1 ? 'true' : 'false'] + ]); + + if (q) { + searchParams.set('search_query', q); + } + const tableRowsRes = await fetch( - `${PUBLIC_JAMAI_URL}/api/v1/gen_tables/chat/${params.table_id}/rows?` + - new URLSearchParams({ - offset: ((page - 1) * chatRowsPerPage).toString(), - limit: chatRowsPerPage.toString(), - order_descending: orderAsc === 1 ? 'false' : 'true' - }), + `${PUBLIC_JAMAI_URL}/api/owl/gen_tables/chat/rows/list?${searchParams}`, { headers: { 'x-project-id': params.project_id @@ -58,7 +65,7 @@ export const load = async ({ depends, fetch, params, parent, url }) => { const tableRowsBody = await tableRowsRes.json(); if (!tableRowsRes.ok) { - if (tableRowsRes.status !== 404 && tableRowsRes.status !== 422) { + if (![403, 404, 422].includes(tableRowsRes.status)) { logger.error('CHATTBL_TBL_GETROWS', tableRowsBody); } return { error: tableRowsRes.status, message: tableRowsBody }; diff --git a/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/+page@project.svelte b/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/+page@project.svelte old mode 100644 new mode 100755 index 5ad4b95..a5e7980 --- a/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/+page@project.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/+page@project.svelte @@ -1,17 +1,22 @@ - {$page.params.table_id} - Chat Table + {page.params.table_id} - Chat Table
    -
    - - - - +
    + + - {#await table} - {$page.params.table_id} + {#await data.table} + {page.params.table_id} {:then { data }} {data ? data.id : tableError?.error === 404 ? 'Not found' : 'Failed to load'} {/await} +
    -
    - {#if tableLoaded || (tableData && $genTableRows)} -
    - {#if $chatTableMode != 'chat'} - + {#if tableLoaded || (tableData && tableRowsState.rows)} +
    + {#if $chatTableMode != 'chat'} +
    + - {/if} -
    - -
    - {#if $chatTableMode != 'chat'} -
    - - - -
    - {:else} + - {/if} - - { - // Prevent toggling if streaming - if (generationStatus || Object.keys($tableState.streamingRows).length) { - return false; - } else return true; - }} - checked={$chatTableMode == 'table'} - on:checkedChange={(e) => { - if (e.detail.value) { - $chatTableMode = 'table'; - } else { - $chatTableMode = 'chat'; - } - }} - /> - - -
    - {:else} - -
    - - - -
    - {/if} -
    +
    + {:else} + + {/if} + + { + // Prevent toggling if streaming + if (generationStatus || Object.keys(tableState.streamingRows).length) { + return false; + } else return true; + }} + checked={$chatTableMode == 'table'} + on:checkedChange={(e) => { + if (e.detail.value) { + $chatTableMode = 'table'; + } else { + $chatTableMode = 'chat'; + } + }} + /> +
    + {:else} +
    + + +
    + {/if}
    + {#if $chatTableMode == 'table'} +
    + {#if tableLoaded || (tableData && tableRowsState.rows)} +
    + + + +
    + +
    + +
    + + + {:else} + + + + {/if} +
    + {/if} + {#if $chatTableMode == 'chat'} - + {:else} - + {#if !tableError} @@ -388,9 +389,7 @@ filterByAgent={tableData?.parent_id ?? ''} refetchTables={async (tableID) => { threadLoaded = false; - await goto( - `${$page.url.pathname.substring(0, $page.url.pathname.lastIndexOf('/'))}/${tableID}` - ); + await goto(`${page.url.pathname.substring(0, page.url.pathname.lastIndexOf('/'))}/${tableID}`); }} /> diff --git a/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/ChatMode.svelte b/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/ChatMode.svelte old mode 100644 new mode 100755 index 40e20bb..39fea75 --- a/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/ChatMode.svelte +++ b/services/app/src/routes/(main)/project/[project_id]/chat-table/[table_id]/ChatMode.svelte @@ -1,18 +1,30 @@ (isResizing = false)} - on:click={handleCustomBtnClick} - on:keydown={(e) => { + onmousemove={handleResize} + onmouseup={() => (isResizing = false)} + onclick={handleCustomBtnClick} + onkeydown={(e) => { if ( //@ts-ignore e.target.tagName !== 'INPUT' && @@ -326,241 +496,258 @@ bind:this={chatWindow} data-testid="chat-window" id="chat-window" - class="@container/chat relative grow flex flex-col gap-4 pt-6 pb-16 overflow-x-hidden overflow-y-auto [scrollbar-gutter:stable]" + class="relative flex grow flex-col gap-4 overflow-y-auto overflow-x-hidden pt-6 [scrollbar-gutter:stable]" > {#if threadLoaded} - {@const displayedLoadedStreams = Object.keys(loadedStreams).filter((colID) => { - // Filter out columns to display - const col = tableData?.cols?.find((col) => col.id === colID); - return ( - col?.gen_config?.object === 'gen_config.llm' && - col.gen_config.multi_turn && - /\${User}/g.test(col.gen_config.prompt ?? '') - ); - })} - - {#each thread as threadItem, index} - {@const nonErrorMessage = threadItem.find((v) => 'content' in v && v.role === 'user')} - {@const messages = nonErrorMessage ? [nonErrorMessage] : threadItem} - {@const messagesWithContent = messages.filter( - (v) => ('content' in v && v.content?.trim()) || 'error' in v - )} + {@const multiturnCols = Object.keys(tableThread)} + {@const longestThreadColLen = tableThread[longestThreadCol]?.thread?.length ?? 0} + {#each Array(longestThreadColLen).fill('') as _, index}
    - {#each messages as message} - {@const { column_id } = message} - {#if 'content' in message} - {@const { content, role } = message} - {#if content?.trim()} - {#if role == 'assistant'} -
    -
    - + +
    + +
    +
    + {/if} + {/each} +
    + {/if} +

    - {content} + {threadItem.user_prompt}

    -
    + {/if}
    - {/if} - {/if} - {:else if messages.every((v) => ('content' in v && v.role === 'assistant') || 'error' in v)} -
    -
    - +
    - + {:else if threadItem.role === 'assistant'} + {@const isEditingCell = false}
    1 + ? 'min-w-full @5xl/chat:min-w-[50%] supports-[not(container-type:inline-size)]:xl:min-w-[50%]' + : '', + multiturnCols.length == 1 + ? '@5xl/chat:pr-[20%] supports-[not(container-type:inline-size)]:xl:pr-[20%]' + : 'last:pr-3 @2xl/chat:last:pr-6 @4xl/chat:last:pr-20 @5xl/chat:last:pr-0 supports-[not(container-type:inline-size)]:last:pr-6 supports-[not(container-type:inline-size)]:lg:last:pr-20 supports-[not(container-type:inline-size)]:xl:last:pr-0' + )} > -
    - - {message.error} - + {#if threadItem.row_id !== generationStatus} +
    +
    + + {column} + +
    + +
    + +
    +
    +
    1} + class="group relative max-w-full scroll-my-2 self-start rounded-xl bg-[#F2F4F7] p-4 text-text data-dark:bg-[#5B7EE5]" > -

    - {message.message.message || JSON.stringify(message)} -

    + {#if isEditingCell} + + {:else if typeof threadItem.content === 'string'} +

    + {#if showRawTexts} + {threadItem.content} + {:else} + {@const rawHtml = converter.makeHtml(threadItem.content)} + {@html rawHtml} + {/if} +

    + {:else} + {@const textContent = threadItem.content + .filter((c) => c.type === 'text') + .map((c) => c.text) + .join('')} + +

    + {#if showRawTexts} + {textContent} + {:else} + {@const rawHtml = converter.makeHtml(textContent)} + {@html rawHtml} + {/if} +

    + {/if}
    -
    + + {#if isEditingCell} + + {/if} + {:else} + {@render generatingMessages(column)} + {/if}
    -
    - {:else} - empty block + {/if} {/if} {/each}
    {/each} -
    - {#each displayedLoadedStreams as key} - {@const loadedStream = loadedStreams[key]} - {@const latestStream = latestStreams[key] ?? ''} -
    -
    - - output - - str - - -
    -
    - -
    -
    - - - {key} - -
    - -
    -

    - {@html converter.makeHtml(loadedStream.join(''))} - {latestStream} - - {#if loadedStream.length === 0 && latestStream === ''} - - {/if} -

    -
    -
    - {/each} -
    + {#if generationStatus === 'new'} + {@render generatingMessages()} + {/if} {:else} -
    +
    {/if} @@ -568,52 +755,52 @@
    -
    - + +
    @@ -632,6 +819,83 @@
    --> + + + +{#snippet generatingMessages(columnID?: string)} + {#if columnID} + {#each displayedLoadedStreams as key} + {#if key === columnID} + {@const loadedStream = loadedStreams[key]} + {@const latestStream = latestStreams[key] ?? ''} +
    + + {key} + +
    + +
    1} + class="group relative max-w-full scroll-my-2 self-start rounded-xl bg-[#F2F4F7] p-4 text-text data-dark:bg-[#5B7EE5]" + > +

    + {@html converter.makeHtml(loadedStream.join(''))} + {latestStream} + + {#if loadedStream.length === 0 && latestStream === ''} + + {/if} +

    +
    + {/if} + {/each} + {:else} +
    + {#each displayedLoadedStreams as key} + {@const loadedStream = loadedStreams[key]} + {@const latestStream = latestStreams[key] ?? ''} +
    1 + ? 'min-w-full @5xl/chat:min-w-[50%] supports-[not(container-type:inline-size)]:xl:min-w-[50%]' + : '', + displayedLoadedStreams.length == 1 + ? '@5xl/chat:pr-[20%] supports-[not(container-type:inline-size)]:xl:pr-[20%]' + : '' + )} + > +
    + + {key} + +
    + +
    1} + class="group relative max-w-full scroll-my-2 self-start rounded-xl bg-[#F2F4F7] p-4 text-text data-dark:bg-[#5B7EE5]" + > +

    + {@html converter.makeHtml(loadedStream.join(''))} + {latestStream} + + {#if loadedStream.length === 0 && latestStream === ''} + + {/if} +

    +
    +
    + {/each} +
    + {/if} +{/snippet} + + + +
    + + + + +
    +
    +

    + JamAI Logo +

    + +

    ${inviterEmail} has invited you to join their project on JamAI Base

    + +

    You have been invited to join a project on JamAI Base. Click the link below to accept the invitation:

    + +

    Join JamAI Base

    + +

    This link will expire in 7 days.

    + +
    + Thanks! +
    + + JamAI Base + +

    +
    +

    + If you did not make this request, you can ignore this mail. +

    +
    +
    +
    + +`; diff --git a/services/app/src/routes/(main)/project/[project_id]/members/+page.svelte b/services/app/src/routes/(main)/project/[project_id]/members/+page.svelte new file mode 100644 index 0000000..895c7bb --- /dev/null +++ b/services/app/src/routes/(main)/project/[project_id]/members/+page.svelte @@ -0,0 +1,187 @@ + + + + Project Members + + +
    +
    +
    +
    + +
    + +
    + + + + +
    + +
    +
    + + + + Name + Member + Role + + + + + {#await data.projectMembers} + {#each Array(6) as _} + + + + + + {/each} + {:then projectMembers} + {#if projectMembers.data} + {@const filteredMembers = filterMembers(projectMembers.data)} + {#if filteredMembers.length > 0} + {#each filteredMembers as member} + + +
    +
    +
    + {member.user.name?.charAt(0).toUpperCase() || '?'} +
    +
    +
    +
    {member.user.name}
    +
    {member.user.email}
    +
    +
    +
    + + {formatDistanceToNow(new Date(member.created_at), { addSuffix: true })} + + + + • + {member.role} + + + + + + + + (isEditingUser = member)} + > + Edit role + + (isRemovingUser = { open: true, value: member })} + > + Remove member + + + + +
    + {/each} + {:else} +
    +
    + No members found +
    +
    + {/if} + {:else} +
    +
    + Error fetching members + + {projectMembers?.message.message || JSON.stringify(projectMembers?.message)} + +
    +
    + {/if} + {/await} +
    +
    +
    +
    +
    + + + + diff --git a/services/app/src/routes/(main)/project/[project_id]/members/+page.ts b/services/app/src/routes/(main)/project/[project_id]/members/+page.ts new file mode 100644 index 0000000..f1162a7 --- /dev/null +++ b/services/app/src/routes/(main)/project/[project_id]/members/+page.ts @@ -0,0 +1,35 @@ +import { env } from '$env/dynamic/public'; +import logger from '$lib/logger'; +import type { OrgMemberRead } from '$lib/types'; + +const { PUBLIC_JAMAI_URL } = env; + +export const load = async ({ fetch, parent, data }) => { + const parentData = await parent(); + + const getOrgMembers = async () => { + const activeOrganizationId = parentData.organizationData?.id; + if (!activeOrganizationId) { + return { error: 400, message: 'No active organization' }; + } + + const orgMembersRes = await fetch( + `${PUBLIC_JAMAI_URL}/api/owl/organizations/members/list?${new URLSearchParams([['organization_id', activeOrganizationId]])}` + ); + const orgMembersBody = await orgMembersRes.json(); + + if (!orgMembersRes.ok) { + logger.error('PROJTEAM_ORGMEMBERS_ERROR', orgMembersBody); + return { error: orgMembersRes.status, message: orgMembersBody }; + } else { + return { + data: orgMembersBody.items as OrgMemberRead[] + }; + } + }; + + return { + ...data, + organizationMembers: await getOrgMembers() + }; +}; diff --git a/services/app/src/routes/(main)/project/[project_id]/overview/+page.svelte b/services/app/src/routes/(main)/project/[project_id]/overview/+page.svelte new file mode 100644 index 0000000..a971f17 --- /dev/null +++ b/services/app/src/routes/(main)/project/[project_id]/overview/+page.svelte @@ -0,0 +1,43 @@ + + + + Overview + + +
    +
    +
    Placeholder picture
    + +

    {$activeProject?.name}

    + +
    + {#each $activeProject?.tags ?? [] as tag} + {tag} + {/each} +
    + + + +
    +
    + + + {#await data.projectMembers} + 0 + {:then projectMembers} + {projectMembers.data?.length} + {/await} + + members +
    +
    +
    + +
    Description
    +
    diff --git a/services/app/src/routes/(main)/settings/+layout.svelte b/services/app/src/routes/(main)/settings/+layout.svelte old mode 100644 new mode 100755 index 5ed2f8b..94d0780 --- a/services/app/src/routes/(main)/settings/+layout.svelte +++ b/services/app/src/routes/(main)/settings/+layout.svelte @@ -1,10 +1,15 @@ - moveHighlighter($page.url.pathname)} /> + moveHighlighter(page.url.pathname)} />
    -

    Account Settings

    +

    Account Settings

    -
    +
    -
    diff --git a/services/app/src/routes/(main)/settings/+layout.ts b/services/app/src/routes/(main)/settings/+layout.ts old mode 100644 new mode 100755 diff --git a/services/app/src/routes/(main)/settings/+page.ts b/services/app/src/routes/(main)/settings/+page.ts old mode 100644 new mode 100755 index 3538a56..2482054 --- a/services/app/src/routes/(main)/settings/+page.ts +++ b/services/app/src/routes/(main)/settings/+page.ts @@ -1,10 +1,5 @@ -import { PUBLIC_IS_LOCAL } from '$env/static/public'; import { redirect } from '@sveltejs/kit'; export function load() { - if (PUBLIC_IS_LOCAL === 'false') { - return redirect(302, '/settings/account'); - } else { - throw redirect(302, '/'); - } + return redirect(302, '/settings/account'); } diff --git a/services/app/src/routes/(main)/settings/account/(components)/ChangePasswordDialog.svelte b/services/app/src/routes/(main)/settings/account/(components)/ChangePasswordDialog.svelte new file mode 100644 index 0000000..774b51a --- /dev/null +++ b/services/app/src/routes/(main)/settings/account/(components)/ChangePasswordDialog.svelte @@ -0,0 +1,135 @@ + + + + + Change password + +
    { + loadingChangePW = true; + + return async ({ update, result }) => { + if (result.type === 'failure') { + const data = result.data as any; + toast.error(data.error, { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } + + loadingChangePW = false; + isChangingPW = false; + signOut(); + await update(); + }; + }} + method="POST" + action="?/change-password" + class="grow overflow-auto" + > +
    +
    + + + +
    + +
    + + + +
    + +
    + + + +
    +
    +
    + + +
    + + +
    +
    +
    +
    diff --git a/services/app/src/routes/(main)/settings/account/(components)/CreatePATDialog.svelte b/services/app/src/routes/(main)/settings/account/(components)/CreatePATDialog.svelte new file mode 100644 index 0000000..1e8234d --- /dev/null +++ b/services/app/src/routes/(main)/settings/account/(components)/CreatePATDialog.svelte @@ -0,0 +1,165 @@ + + + + + Create PAT + + +
    { + isLoadingCreatePAT = true; + + return async ({ result, update }) => { + if (result.type === 'failure') { + const data = result.data as any; + toast.error('Error creating PAT', { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else if (result.type === 'success') { + isCreatingPAT = false; + } + + isLoadingCreatePAT = false; + update({ reset: false }); + }; + }} + method="POST" + action="?/create-pat" + class="h-full w-full grow overflow-auto" + > +
    + + + +
    + +
    + + + + + {#snippet child({ props })} + + {/snippet} + + + + + + + +
    + +
    + + + + +
    + {#if selectedProject} + {selectedProjectData?.name ?? ''}  â€“  + + {selectedProjectOrg?.name ?? selectedProjectData?.organization_id} + + {:else} + Optional + {/if} +
    +
    + + {#each user?.projects ?? [] as project} + {@const projectOrg = (user?.organizations ?? []).find( + (o) => project.organization_id === o.id + )} + + {project.name}  â€“  + + {projectOrg?.name ?? project.organization_id} + + + {/each} + +
    +
    +
    + + +
    + + {#snippet child({ props })} + + {/snippet} + + +
    +
    +
    +
    diff --git a/services/app/src/routes/(main)/settings/account/(components)/DeleteAccountDialog.svelte b/services/app/src/routes/(main)/settings/account/(components)/DeleteAccountDialog.svelte new file mode 100644 index 0000000..57a7188 --- /dev/null +++ b/services/app/src/routes/(main)/settings/account/(components)/DeleteAccountDialog.svelte @@ -0,0 +1,101 @@ + + + { + if (!e) { + confirmEmail = ''; + } + }} +> + + Delete account + + +
    { + isLoadingDeleteAccount = true; + + return async ({ result, update }) => { + if (result.type !== 'success') { + //@ts-ignore + const data = result.data; + toast.error('Error deleting account', { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else { + return location.reload(); + } + + isLoadingDeleteAccount = false; + update({ reset: false, invalidateAll: false }); + }; + }} + onkeydown={(event) => event.key === 'Enter' && event.preventDefault()} + method="POST" + action="?/delete-account" + class="w-full grow overflow-auto" + > +
    +

    + Do you really want to delete your account + + `{user?.email}` + ? This process cannot be undone. +

    + +
    + + + +
    +
    +
    + + +
    + + {#snippet child({ props })} + + {/snippet} + + +
    +
    +
    +
    diff --git a/services/app/src/routes/(main)/settings/account/(components)/DeletePATDialog.svelte b/services/app/src/routes/(main)/settings/account/(components)/DeletePATDialog.svelte new file mode 100644 index 0000000..506dafd --- /dev/null +++ b/services/app/src/routes/(main)/settings/account/(components)/DeletePATDialog.svelte @@ -0,0 +1,89 @@ + + + !!isDeletingPAT, () => (isDeletingPAT = null)}> + + + + Close + + +
    + +

    Are you sure?

    +

    + Do you really want to remove PAT + + `{isDeletingPAT}` + ? +

    +
    + + +
    { + isLoadingDeletePAT = true; + + return async ({ result, update }) => { + if (result.type !== 'success') { + //@ts-ignore + const data = result.data; + toast.error('Error deleting PAT', { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else { + isDeletingPAT = null; + } + + isLoadingDeletePAT = false; + update({ reset: false }); + }; + }} + method="POST" + action="?/delete-pat" + class="flex gap-2 overflow-x-auto overflow-y-hidden" + > + + + {#snippet child({ props })} + + {/snippet} + + +
    +
    +
    +
    diff --git a/services/app/src/routes/(main)/settings/account/(components)/index.ts b/services/app/src/routes/(main)/settings/account/(components)/index.ts new file mode 100644 index 0000000..451c358 --- /dev/null +++ b/services/app/src/routes/(main)/settings/account/(components)/index.ts @@ -0,0 +1,6 @@ +import ChangePasswordDialog from './ChangePasswordDialog.svelte'; +import CreatePatDialog from './CreatePATDialog.svelte'; +import DeleteAccountDialog from './DeleteAccountDialog.svelte'; +import DeletePatDialog from './DeletePATDialog.svelte'; + +export { ChangePasswordDialog, CreatePatDialog, DeleteAccountDialog, DeletePatDialog }; diff --git a/services/app/src/routes/(main)/settings/account/+page.server.ts b/services/app/src/routes/(main)/settings/account/+page.server.ts new file mode 100755 index 0000000..288ce66 --- /dev/null +++ b/services/app/src/routes/(main)/settings/account/+page.server.ts @@ -0,0 +1,236 @@ +import { env } from '$env/dynamic/private'; +import logger, { APIError } from '$lib/logger.js'; +import type { PATRead } from '$lib/types.js'; +import { fail, redirect } from '@sveltejs/kit'; +import { ManagementClient } from 'auth0'; + +const { + AUTH0_CLIENT_ID, + AUTH0_ISSUER_BASE_URL, + AUTH0_MGMTAPI_CLIENT_ID, + AUTH0_MGMTAPI_CLIENT_SECRET, + OWL_SERVICE_KEY, + OWL_URL +} = env; + +const headers = { + Authorization: `Bearer ${OWL_SERVICE_KEY}` +}; + +const management = new ManagementClient({ + domain: AUTH0_ISSUER_BASE_URL?.replace('https://', ''), + clientId: AUTH0_MGMTAPI_CLIENT_ID, + clientSecret: AUTH0_MGMTAPI_CLIENT_SECRET +}); + +export async function load({ locals }) { + //TODO: Infinite scroll this + const getPats = async () => { + const patListRes = await fetch(`${OWL_URL}/api/v2/pats/list`, { + headers: { + ...headers, + 'x-user-id': locals.user?.id ?? '' + } + }); + const patListBody = await patListRes.json(); + + if (!patListRes.ok) { + logger.error('PAT_LIST_ERROR', patListBody); + return { error: patListRes.status, message: patListBody }; + } else { + return { + data: patListBody.items as PATRead[] + }; + } + }; + + return { + pats: await getPats() + }; +} + +export const actions = { + 'change-password': async ({ locals, request }) => { + //* Verify user perms + if (!locals.user) { + return fail(401, new APIError('Unauthorized').getSerializable()); + } + + if (locals.auth0Mode) { + try { + const pwChangeRes = await management.tickets.changePassword({ + user_id: locals.user.sub, + client_id: AUTH0_CLIENT_ID, + ttl_sec: 0, + mark_email_as_verified: false, + includeEmailInRedirect: true + }); + if (pwChangeRes.status !== 200 && pwChangeRes.status !== 201) { + return fail( + pwChangeRes.status, + new APIError('Failed to change password', pwChangeRes as any).getSerializable() + ); + } else { + throw redirect(303, pwChangeRes.data.ticket); + } + } catch (err) { + //@ts-expect-error library throws error for redirects??? + if (err?.status === 303) { + //@ts-expect-error see above + throw redirect(303, err.location); + } else { + logger.error('PASSWORD_CHANGE_CHANGE', err); + return fail(500, new APIError('Failed to change password', err as any).getSerializable()); + } + } + } else { + try { + const data = await request.formData(); + const password = data.get('password'); + const new_password = data.get('new_password'); + + if ( + !password || + typeof password !== 'string' || + !new_password || + typeof new_password !== 'string' + ) { + return fail(400, new APIError('Invalid form data').getSerializable()); + } + + const response = await fetch(`${OWL_URL}/api/v2/auth/login/password`, { + method: 'PATCH', + headers: { + ...headers, + 'Content-Type': 'application/json', + 'x-user-id': locals.user.id + }, + body: JSON.stringify({ + email: locals.user.email, + password, + new_password + }) + }); + + const responseData = await response.json(); + + if (!response.ok) { + logger.error('PASSWORD_CHANGE_ERROR', responseData, locals.user.email || locals.user.id); + return fail( + response.status, + new APIError('Failed to change password', responseData).getSerializable() + ); + } + + return responseData; + } catch (err) { + logger.error('PASSWORD_CHANGE_ERROR', err); + return fail(500, new APIError('Failed to change password', err as any).getSerializable()); + } + } + }, + + 'create-pat': async ({ locals, request }) => { + //* Verify user perms + if (!locals.user) { + return fail(401, new APIError('Unauthorized').getSerializable()); + } + + const data = await request.formData(); + const patName = data.get('pat_name'); + const patExpiry = data.get('pat_expiry'); + const patProject = data.get('pat_project'); + + if ( + typeof patName !== 'string' || + typeof patExpiry !== 'string' || + typeof patProject !== 'string' + ) { + return fail(400, new APIError('Invalid form data').getSerializable()); + } + + const patCreateRes = await fetch(`${OWL_URL}/api/v2/pats`, { + method: 'POST', + headers: { + ...headers, + 'x-user-id': locals.user?.id, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + name: patName, + expiry: patExpiry || null, + project_id: patProject || null + }) + }); + const patCreateBody = await patCreateRes.json(); + + if (patCreateRes.ok) { + return patCreateBody; + } else { + return fail( + patCreateRes.status, + new APIError('Failed to create PAT', patCreateBody as any).getSerializable() + ); + } + }, + + 'delete-pat': async ({ locals, request }) => { + const data = await request.formData(); + const key = data.get('key'); + + if (typeof key !== 'string' || key.trim() === '') { + return fail(400, new APIError('Invalid PAT').getSerializable()); + } + + //* Verify user perms + if (!locals.user) { + return fail(401, new APIError('Unauthorized').getSerializable()); + } + + const patDeleteRes = await fetch( + `${OWL_URL}/api/v2/pats?${new URLSearchParams([['pat_id', key]])}`, + { + method: 'DELETE', + headers: { + ...headers, + 'x-user-id': locals.user?.id + } + } + ); + const patDeleteBody = await patDeleteRes.json(); + + if (patDeleteRes.ok) { + return patDeleteBody; + } else { + return fail( + patDeleteRes.status, + new APIError('Failed to delete PAT', patDeleteBody as any).getSerializable() + ); + } + }, + + 'delete-account': async ({ locals }) => { + if (!locals.user) { + return fail(401, new APIError('Unauthorized').getSerializable()); + } + + const deleteUserRes = await fetch(`${OWL_URL}/api/v2/users`, { + method: 'DELETE', + headers: { + ...headers, + 'x-user-id': locals.user.id + } + }); + + const deleteUserBody = await deleteUserRes.json(); + if (!deleteUserRes.ok) { + logger.error('USER_DELETE_DELETE', deleteUserBody); + return fail( + deleteUserRes.status, + new APIError('Failed to delete account', deleteUserBody as any).getSerializable() + ); + } else { + return deleteUserBody; + } + } +}; diff --git a/services/app/src/routes/(main)/settings/account/+page.svelte b/services/app/src/routes/(main)/settings/account/+page.svelte new file mode 100755 index 0000000..ab08bcd --- /dev/null +++ b/services/app/src/routes/(main)/settings/account/+page.svelte @@ -0,0 +1,264 @@ + + + + Account - Settings + + +
    +

    ACCOUNT

    + +
    +
    + {#if user?.picture_url} + User Avatar + {:else} + + {(user?.name ?? 'Default User').charAt(0)} + + {/if} +
    + +
    + + {user?.name} + +
    +
    + +
    +
    +

    User ID

    + {user?.id ?? ''} +
    + +
    +

    Email

    + {user?.email ?? ''} +
    +
    + + + + {#if data.auth0Mode} + +
    { + isLoadingChangePassword = true; + + return async ({ result, update }) => { + if (result.type !== 'redirect') { + //@ts-ignore + const data = result.data; + toast.error('Error getting password reset link', { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } + + isLoadingChangePassword = false; + update({ reset: false, invalidateAll: false }); + }; + }} + onkeydown={(event) => event.key === 'Enter' && event.preventDefault()} + method="POST" + action="?/change-password" + class="mb-8 flex w-full flex-col gap-3" + > +

    PASSWORD

    + + +
    + {:else} +
    +

    PASSWORD

    + + +
    + {/if} + +
    +

    PERSONAL ACCESS TOKEN

    + +
    +
    +
    +
    Name
    +
    Key
    +
    Project
    +
    Expiry
    +
    +
    + + {#if (pats.data ?? []).length > 0} +
    + {#each pats.data ?? [] as apiKey} +
    +
    +

    + {apiKey.name} +

    +
    + +
    + + + + + +
    + +
    +

    + {apiKey.project_id || '-'} +

    +
    + +
    + {apiKey.expiry ? new Date(apiKey.expiry).toLocaleString() : '-'} +
    + + +
    + {/each} +
    + {:else} +
    +
    +
    +

    No PATs have been created for this user

    +
    +
    +
    + {/if} +
    + +
    + +
    +
    + +
    +

    ACCOUNT REMOVAL

    + +

    + Delete your account permanently. +

    + +
    + +
    +
    + + +
    + + + + + diff --git a/services/app/src/routes/(main)/settings/theme/page.svelte b/services/app/src/routes/(main)/settings/theme/page.svelte old mode 100644 new mode 100755 index 8cbee89..7e5b350 --- a/services/app/src/routes/(main)/settings/theme/page.svelte +++ b/services/app/src/routes/(main)/settings/theme/page.svelte @@ -18,7 +18,7 @@
    + {/snippet} + + +
    + + + diff --git a/services/app/src/routes/(main)/system/models/(components)/AddModelConfigDialog.svelte b/services/app/src/routes/(main)/system/models/(components)/AddModelConfigDialog.svelte new file mode 100644 index 0000000..576a3ce --- /dev/null +++ b/services/app/src/routes/(main)/system/models/(components)/AddModelConfigDialog.svelte @@ -0,0 +1,498 @@ + + + + + + Add Model Config + + +
    +
    +

    SUGGESTED MODELS

    +
    + {#if data.modelPresets.data} +
    + {#each data.modelPresets.data.filter((p) => p.deployments[0].routing_id) as config} +
    +
    +
    {config.name}
    +

    {config.id}

    +
    + config.id === selectedSuggestedConfig?.id, + (v) => { + currentStep = 0; + if (v) { + selectedSuggestedConfig = config; + modelType = config.type; + selectedCapabilities = config.capabilities; + modelIcon = (config?.meta?.icon as string) || undefined; + } else { + selectedSuggestedConfig = null; + modelType = undefined; + selectedCapabilities = []; + modelIcon = undefined; + } + }} + class="border-gray-400 data-[state=checked]:bg-[#1B748A]" + /> +
    + {/each} +
    + {/if} +
    +
    +
    { + if (!modelType) { + toast.error('Model type is required'); + cancel(); + return; + } + + if (selectedCapabilities.length === 0) { + toast.error('At least one capability is required'); + cancel(); + return; + } + + loading = true; + if (modelIcon) { + formData.set('icon', modelIcon); + } + if (baseTier) { + formData.set('base_tier_id', baseTier); + } + formData.set('type', modelType); + formData.set('capabilities', JSON.stringify(selectedCapabilities)); + + const languages = formData.get('languages'); + if (languages) { + formData.set('languages', JSON.stringify(stringToArray(languages.toString()))); + } + + return async ({ update, result }) => { + //@ts-ignore + const data = result.data; + if (result.type === 'failure') { + toast.error(data.error, { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else if (result.type === 'success') { + open = false; + toast.success('Model config added successfully', { + id: 'add-model-config-success' + }); + selectedCapabilities = []; + } + + loading = false; + await update({ invalidateAll: false }); + await invalidate('system:models'); + }; + }} + action="?/add-model-config" + class="flex w-3/5 flex-col space-y-6 overflow-y-auto p-1 pr-0" + > +
    +
    + {#each steps as _, i} +
    +
    + {i + 1} +
    + {#if i < steps.length - 1} +
    + {/if} +
    + {/each} +
    +
    + {#each steps as step, i} +
    + {step} +
    + {/each} +
    +
    + +
    + + + + + +
    +
    +
    + + +
    + + {#snippet child({ props })} + + {/snippet} + + {#if currentStep > 0} + + {/if} + {#if currentStep < steps.length - 1} + + {:else} + + {/if} +
    +
    +
    +
    diff --git a/services/app/src/routes/(main)/system/models/(components)/DeleteDeploymentDialog.svelte b/services/app/src/routes/(main)/system/models/(components)/DeleteDeploymentDialog.svelte new file mode 100644 index 0000000..4a205bb --- /dev/null +++ b/services/app/src/routes/(main)/system/models/(components)/DeleteDeploymentDialog.svelte @@ -0,0 +1,96 @@ + + + + + + + Close + + +
    + +

    Are you sure?

    +

    + Do you really want to delete deployment + + `{deployment?.name || deployment?.id}` + ? This process cannot be undone. +

    +
    + + +
    { + loading = true; + + return async ({ update, result }) => { + //@ts-ignore + const data = result.data; + if (result.type === 'failure') { + toast.error(data.error, { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else if (result.type === 'success') { + open = false; + toast.success('Model deployment deleted successfully', { + id: 'delete-deployment-success' + }); + + update({ invalidateAll: false }); + invalidate('system:models'); + invalidate('system:modelsslug'); + } + + loading = false; + }; + }} + action="/system/models?/delete-deployment" + class="flex gap-2 overflow-x-auto overflow-y-hidden" + > + + + + {#snippet child({ props })} + + {/snippet} + + +
    +
    +
    +
    diff --git a/services/app/src/routes/(main)/system/models/(components)/DeleteModelConfigDialog.svelte b/services/app/src/routes/(main)/system/models/(components)/DeleteModelConfigDialog.svelte new file mode 100644 index 0000000..c705bde --- /dev/null +++ b/services/app/src/routes/(main)/system/models/(components)/DeleteModelConfigDialog.svelte @@ -0,0 +1,96 @@ + + + open.open, (v) => (open = { ...open, open: v })}> + + + + Close + + +
    + +

    Are you sure?

    +

    + Do you really want to delete model + + `{open.value?.name || open.value?.id}` + ? This process cannot be undone. +

    +
    + + +
    { + loading = true; + + return async ({ update, result }) => { + //@ts-ignore + const data = result.data; + if (result.type === 'failure') { + toast.error(data.error, { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else if (result.type === 'success') { + open = { ...open, open: false }; + toast.success('Model config deleted successfully', { + id: 'delete-model-config-success' + }); + } + + await update({ invalidateAll: false }); + if (page.params.model_id) { + await goto('/system/models'); + } + invalidate('system:models'); + loading = false; + }; + }} + action="/system/models?/delete-model-config" + class="flex gap-2 overflow-x-auto overflow-y-hidden" + > + + + + {#snippet child({ props })} + + {/snippet} + + +
    +
    +
    +
    diff --git a/services/app/src/routes/(main)/system/models/(components)/DeploymentDetails.svelte b/services/app/src/routes/(main)/system/models/(components)/DeploymentDetails.svelte new file mode 100644 index 0000000..25c0c68 --- /dev/null +++ b/services/app/src/routes/(main)/system/models/(components)/DeploymentDetails.svelte @@ -0,0 +1,208 @@ + + +
    { + loading = true; + formData.set('provider', selectedProvider); + + return async ({ update, result }) => { + if (result.type === 'failure') { + const data = result.data as any; + toast.error(data.error, { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else if (result.type === 'success') { + isEditing = false; + toast.success('Model deployment updated successfully', { + id: 'edit-deployment-success' + }); + } + + loading = false; + update({ invalidateAll: false }); + invalidate('system:models'); + invalidate('system:modelsslug'); + deployment.refetch(); + }; + }} + action="/system/models?/edit-deployment" + class="grow overflow-y-scroll py-1" +> +
    +
    + + {#if isEditing} + + {:else} +
    {deployment.data?.model_id}
    + {/if} +
    + +
    + + {#if isEditing} + + {:else} +
    {deployment.data?.name}
    + {/if} +
    + + + + +
    + + {#if isEditing} + {#await (page.data as LayoutData).providers} + + {:then providers} + + + {PROVIDERS[selectedProvider] || selectedProvider || 'Select a provider'} + + + + {#if providers.data} + {#each providers.data.filter((p) => p !== '') as provider} + + {PROVIDERS[provider] || provider} + + {/each} + {/if} + + + {/await} + {:else} +
    + {PROVIDERS[deployment.data?.provider ?? ''] || deployment.data?.provider || '-'} +
    + {/if} +
    + +
    + + {#if isEditing} + + {:else} +
    {deployment.data?.routing_id || '-'}
    + {/if} +
    + +
    + + {#if isEditing} + + {:else} +
    {deployment.data?.api_base || '-'}
    + {/if} +
    + +
    + + {#if isEditing} + + {:else} +
    {deployment.data?.weight || 1}
    + {/if} +
    +
    +
    + +{#if !isEditing} +
    + +
    +{:else} +
    + + +
    +{/if} diff --git a/services/app/src/routes/(main)/system/models/(components)/DeploymentManagement.svelte b/services/app/src/routes/(main)/system/models/(components)/DeploymentManagement.svelte new file mode 100644 index 0000000..1c680f7 --- /dev/null +++ b/services/app/src/routes/(main)/system/models/(components)/DeploymentManagement.svelte @@ -0,0 +1,131 @@ + + +
    +

    + Cloud Deployment Management +

    + +
    +
    + + + + Endpoint Name + Model + Model ID + API Base + Actions + + + + {#await data.deployments} + {#each Array(4) as _} + + + + + + {/each} + {:then deployments} + {#if deployments.data} + {#if deployments.data.length > 0} + {#each deployments.data as deployment} + + {deployment.name} + {deployment.model?.name} + {deployment.model_id} + + {deployment.api_base} + +
    + + +
    +
    +
    + {/each} + + + {:else} + + +
    +
    + No cloud deployments found + Deploy a model to see it listed here +
    +
    +
    +
    + {/if} + {:else} +
    +
    + Error loading deployments + + {deployments?.error.message || JSON.stringify(deployments?.error)} + +
    +
    + {/if} + {/await} +
    +
    +
    +
    + + {#if selectedEditDeploymentId} + + {/if} + +
    diff --git a/services/app/src/routes/(main)/system/models/(components)/ManageDeploymentDialog.svelte b/services/app/src/routes/(main)/system/models/(components)/ManageDeploymentDialog.svelte new file mode 100644 index 0000000..19dd575 --- /dev/null +++ b/services/app/src/routes/(main)/system/models/(components)/ManageDeploymentDialog.svelte @@ -0,0 +1,108 @@ + + + + + +
    +

    Model Deployment

    + + {#await deployment} +

    Loading...

    + {:then deployment} + {deployment.data?.name} + {/await} +
    +
    +
    + +
    +
    + +
    + + + {#await deployment} +
    + +
    + {:then deployment} + {#if deployment.data} +
    + {#if activeTab === 'details'} + + {/if} +
    + {:else} +
    +
    + Error loading deployment +

    + {deployment?.error?.message || JSON.stringify(deployment.error)} +

    +
    +
    + {/if} + {/await} +
    +
    +
    diff --git a/services/app/src/routes/(main)/system/models/(components)/ModelCatalogue.svelte b/services/app/src/routes/(main)/system/models/(components)/ModelCatalogue.svelte new file mode 100644 index 0000000..13988cb --- /dev/null +++ b/services/app/src/routes/(main)/system/models/(components)/ModelCatalogue.svelte @@ -0,0 +1,344 @@ + + +
    +

    Model Catalogue

    +
    +
    +
    + +
    + +
    + +
    + +
    +
    + {#each ['all', ...Object.keys(MODEL_TYPES)] as modelType} + + {/each} +
    + + {#await data.modelConfigs then modelConfigs} + {#if modelConfigs.data} + {@const filteredConfigs = filterModelConfigs( + modelConfigs.data, + searchQuery, + $modelConfigSort.filter + )} + + {#snippet children({ pages, currentPage })} + + + + {#snippet child({ props })} + + {/snippet} + + + {#each pages as page (page.key)} + {#if page.type === 'ellipsis'} + + + + {:else} + {@const pageFontSize = + 99 % page.value === 99 + ? 999 % page.value === 999 + ? 'text-[0.6rem]' + : 'text-xs' + : 'text-sm'} + + + {#snippet child({ props })} + + {/snippet} + + + {/if} + {/each} + + + {#snippet child({ props })} + + {/snippet} + + + + {/snippet} + + {/if} + {/await} +
    + +
    + {#await data.modelConfigs} + {#each Array(6) as _} + + {/each} + {:then modelConfigs} + {#if modelConfigs.data} + {@const { filteredConfigs, paginatedConfigs } = getPaginatedModelConfigs( + modelConfigs.data, + searchQuery, + $modelConfigSort.filter, + currentPage, + itemsPerPage + )} + + {#if filteredConfigs.length === 0} +
    +

    + {$modelConfigSort.filter === 'all' + ? 'No model config found.' + : `No ${MODEL_TYPES[$modelConfigSort.filter] ?? $modelConfigSort.filter} models found in the catalogue.`} +

    +
    + {:else} + {#each paginatedConfigs as modelConfig (modelConfig.id)} + + {/each} + {/if} + {:else} +
    +

    + {modelConfigs?.error.message || JSON.stringify(modelConfigs?.error)} +

    +
    + {/if} + {/await} +
    +
    + + + + + diff --git a/services/app/src/routes/(main)/system/models/(components)/ModelConfigCard.svelte b/services/app/src/routes/(main)/system/models/(components)/ModelConfigCard.svelte new file mode 100644 index 0000000..937331f --- /dev/null +++ b/services/app/src/routes/(main)/system/models/(components)/ModelConfigCard.svelte @@ -0,0 +1,117 @@ + + + + +
    { + goto(`/system/models/${encodeURIComponent(modelConfig.id)}`, { state: { page: currentPage } }); + }} + oncontextmenu={(e) => { + e.preventDefault(); + menuOpen = !menuOpen; + }} + class="flex h-full cursor-pointer flex-col justify-start space-y-2 overflow-auto rounded-xl border border-[#E4E7EC] bg-white p-4 transition-[transform,box-shadow] hover:-translate-y-0.5 hover:shadow-float" + class:!bg-[#FFF8EA]={!modelConfig.deployments.length} +> +
    + +
    +
    +

    + {modelConfig.name} +

    +

    + {modelConfig.id} + +

    +
    +
    + {modelConfig.type} +
    + {#if modelConfig.deployments.length} +
    + {modelConfig.deployments.length} + {modelConfig.deployments.length === 1 ? 'deployment' : 'deployments'} +
    + {:else} +
    + No Deployment +
    + {/if} +
    +
    + +
    + + + + + + + goto(`/system/models/${encodeURIComponent(modelConfig.id)}?edit=true`)} + class="text-[#344054] data-[highlighted]:text-[#344054]" + > + + Edit + + (deleteOpen = { open: true, value: modelConfig })} + class="text-destructive data-[highlighted]:text-destructive" + > + + Delete + + + +
    +
    diff --git a/services/app/src/routes/(main)/system/models/(components)/index.ts b/services/app/src/routes/(main)/system/models/(components)/index.ts new file mode 100755 index 0000000..bf3a223 --- /dev/null +++ b/services/app/src/routes/(main)/system/models/(components)/index.ts @@ -0,0 +1,19 @@ +import AddDeploymentDialog from './AddDeploymentDialog.svelte'; +import AddModelConfigDialog from './AddModelConfigDialog.svelte'; +import DeleteDeploymentDialog from './DeleteDeploymentDialog.svelte'; +import DeleteModelConfigDialog from './DeleteModelConfigDialog.svelte'; +import DeploymentManagement from './DeploymentManagement.svelte'; +import ManageDeploymentDialog from './ManageDeploymentDialog.svelte'; +import ModelCatalogue from './ModelCatalogue.svelte'; +import ModelConfigCard from './ModelConfigCard.svelte'; + +export { + AddDeploymentDialog, + AddModelConfigDialog, + DeleteDeploymentDialog, + DeleteModelConfigDialog, + DeploymentManagement, + ManageDeploymentDialog, + ModelCatalogue, + ModelConfigCard +}; diff --git a/services/app/src/routes/(main)/system/models/+layout.server.ts b/services/app/src/routes/(main)/system/models/+layout.server.ts new file mode 100644 index 0000000..07faa0e --- /dev/null +++ b/services/app/src/routes/(main)/system/models/+layout.server.ts @@ -0,0 +1,32 @@ +import { env } from '$env/dynamic/private'; +import logger from '$lib/logger.js'; + +const { OWL_URL, OWL_SERVICE_KEY } = env; + +const headers = { + Authorization: `Bearer ${OWL_SERVICE_KEY}` +}; + +export async function load({ locals }) { + const getProviders = async () => { + const response = await fetch(`${OWL_URL}/api/v2/models/deployments/providers/cloud`, { + headers: { + ...headers, + 'x-user-id': locals.user?.id || '' + } + }); + + const responseBody = await response.json(); + + if (!response.ok) { + logger.error('PROVIDERS_GET_ERROR', responseBody, locals.user?.id); + return { error: response.status, message: responseBody }; + } + + return { data: responseBody as string[] }; + }; + + return { + providers: getProviders() + }; +} diff --git a/services/app/src/routes/(main)/system/models/+layout.ts b/services/app/src/routes/(main)/system/models/+layout.ts new file mode 100644 index 0000000..247304c --- /dev/null +++ b/services/app/src/routes/(main)/system/models/+layout.ts @@ -0,0 +1,77 @@ +import { PUBLIC_JAMAI_URL } from '$env/static/public'; +import { activeOrganization } from '$globalStore'; +import logger from '$lib/logger.js'; +import type { ModelConfig, ModelDeployment } from '$lib/types.js'; +import { get } from 'svelte/store'; + +export const ssr = false; + +export async function load({ data, depends, fetch }) { + depends('system:models'); + + //TODO: Maybe paginate this + const getModelConfigs = async () => { + const activeOrg = get(activeOrganization); + + const limit = 1000; + const offset = 0; + const response = await fetch( + `${PUBLIC_JAMAI_URL}/api/owl/models/configs/list?${new URLSearchParams({ + organization_id: activeOrg?.id ?? '', + offset: offset.toString(), + limit: limit.toString() + })}` + ); + const responseBody = await response.json(); + + if (!response.ok) { + logger.error('MODELCONFIGS_GET_ERROR', responseBody); + return { data: null, error: responseBody as any, status: response.status }; + } + + return { data: responseBody.items as ModelConfig[] }; + }; + + const getDeployments = async () => { + const limit = 1000; + const offset = 0; + const response = await fetch( + `/api/owl/models/deployments/list?${new URLSearchParams([ + ['offset', offset.toString()], + ['limit', limit.toString()] + ])}` + ); + const responseBody = await response.json(); + + if (!response.ok) { + logger.error('DEPLOYMENTS_GET_ERROR', data); + return { data: null, error: responseBody as any, status: response.status }; + } + + return { data: responseBody.items as ModelDeployment[] }; + }; + + const getModelPresets = async () => { + const response = await fetch( + 'https://raw.githubusercontent.com/EmbeddedLLM/JamAIBase/refs/heads/main/services/api/src/owl/configs/preset_models.json', + { + method: 'GET' + } + ); + + if (!response.ok) { + const error = await response.text(); + logger.error('MODELPRESETS_GET_ERROR', error); + return { error: response.status, message: 'Failed to fetch model presets' }; + } + + return { data: (await response.json()) as ModelConfig[] }; + }; + + return { + ...data, + modelConfigs: getModelConfigs(), + deployments: getDeployments(), + modelPresets: await getModelPresets() + }; +} diff --git a/services/app/src/routes/(main)/system/models/+page.server.ts b/services/app/src/routes/(main)/system/models/+page.server.ts new file mode 100644 index 0000000..87e4bb5 --- /dev/null +++ b/services/app/src/routes/(main)/system/models/+page.server.ts @@ -0,0 +1,279 @@ +import { env } from '$env/dynamic/private'; +import { PUBLIC_ADMIN_ORGANIZATION_ID } from '$env/static/public'; +import logger, { APIError } from '$lib/logger.js'; +import { error, fail } from '@sveltejs/kit'; + +const { OWL_URL, OWL_SERVICE_KEY } = env; + +const headers = { + Authorization: `Bearer ${OWL_SERVICE_KEY}` +}; + +export async function load({ cookies, locals }) { + if ( + cookies.get('activeOrganizationId') !== PUBLIC_ADMIN_ORGANIZATION_ID || + !locals.user?.org_memberships.find( + (org) => org.organization_id === PUBLIC_ADMIN_ORGANIZATION_ID + ) + ) { + throw error(404, 'Not found'); + } +} + +export const actions = { + 'add-model-config': async function ({ locals, request }) { + if (!locals.user) { + return error(401, 'Unauthorized'); + } + + const formData = await request.formData(); + + const data: Record = {}; + + try { + for (const [key, value] of formData.entries()) { + // Skip empty values + if (value === '' || value === null || value === undefined) continue; + + // Handle value conversion + const numValue = Number(value); + if (!isNaN(numValue)) { + data[key] = numValue; + } else { + data[key] = value; + } + } + + data['capabilities'] = JSON.parse(formData.get('capabilities') as string); + if (data['languages']) { + data['languages'] = JSON.parse(formData.get('languages') as string); + } + if (data.provisioned_to !== undefined && data.provisioned_to !== null) { + data.provisioned_to = String(data.provisioned_to); + } + + if (data['icon']) { + if (!data['meta']) { + data['meta'] = {}; + } + data['meta']['icon'] = data['icon']; + delete data['icon']; + } + + const response = await fetch(`${OWL_URL}/api/v2/models/configs`, { + method: 'POST', + headers: { + ...headers, + 'x-user-id': locals.user.id || '', + 'Content-Type': 'application/json' + }, + body: JSON.stringify(data) + }); + + const responseData = await response.json(); + + if (!response.ok) { + logger.error('MODELCONFIG_ADD_ERROR', responseData, locals.user.id); + return fail( + response.status, + new APIError('Failed to create model config', responseData).getSerializable() + ); + } + + return responseData; + } catch (error) { + logger.error('MODELCONFIG_ADD_ERROR', error, locals.user.id); + return fail( + 500, + new APIError('Failed to create model config', error as any).getSerializable() + ); + } + }, + 'delete-model-config': async function ({ locals, request }) { + if (!locals.user) { + return error(401, 'Unauthorized'); + } + + const formData = await request.formData(); + + try { + const model_id = formData.get('model_id')?.toString(); + + if (!model_id || typeof model_id !== 'string' || model_id.trim() === '') { + return fail(400, new APIError('Model ID (type string) is required').getSerializable()); + } + + const response = await fetch( + `${OWL_URL}/api/v2/models/configs?${new URLSearchParams([['model_id', model_id]])}`, + { + method: 'DELETE', + headers: { + ...headers, + 'x-user-id': locals.user.id || '' + } + } + ); + + const responseData = await response.json(); + + if (!response.ok) { + logger.error('MODELCONFIG_DELETE_ERROR', responseData, locals.user.id); + return fail( + response.status, + new APIError('Failed to delete model config', responseData).getSerializable() + ); + } + + return responseData; + } catch (error: any) { + logger.error('MODELCONFIG_DELETE_ERROR', error, locals.user.id); + return fail( + 500, + new APIError('Failed to delete model config', error as any).getSerializable() + ); + } + }, + + 'add-deployment': async function ({ locals, request }) { + if (!locals.user) { + return error(401, 'Unauthorized'); + } + + const formData = await request.formData(); + + const data: Record = {}; + + try { + for (const [key, value] of formData.entries()) { + // Skip empty values + if (value === '' || value === null || value === undefined) continue; + + // Handle value conversion + const numValue = Number(value); + if (!isNaN(numValue)) { + data[key] = numValue; + } else { + data[key] = value; + } + } + + const response = await fetch(`${OWL_URL}/api/v2/models/deployments/cloud`, { + method: 'POST', + headers: { + ...headers, + 'x-user-id': locals.user.id || '', + 'Content-Type': 'application/json' + }, + body: JSON.stringify(data) + }); + + const responseData = await response.json(); + + if (!response.ok) { + logger.error('DEPLOYMENT_ADD_ERROR', responseData, locals.user.id); + return fail( + response.status, + new APIError('Failed to create deployment', responseData as any).getSerializable() + ); + } + + return responseData; + } catch (error) { + logger.error('DEPLOYMENT_ADD_ERROR', error, locals.user.id); + return fail(500, new APIError('Failed to create deployment', error as any).getSerializable()); + } + }, + 'edit-deployment': async function ({ locals, request }) { + if (!locals.user) { + return error(401, 'Unauthorized'); + } + + const formData = await request.formData(); + + const data: Record = {}; + + try { + for (const [key, value] of formData.entries()) { + // Skip empty values + if (value === '' || value === null || value === undefined) continue; + + // Handle value conversion + const numValue = Number(value); + if (!isNaN(numValue)) { + data[key] = numValue; + } else { + data[key] = value; + } + } + + const response = await fetch( + `${OWL_URL}/api/v2/models/deployments?${new URLSearchParams([['deployment_id', data.id]])}`, + { + method: 'PATCH', + headers: { + ...headers, + 'x-user-id': locals.user.id || '', + 'Content-Type': 'application/json' + }, + body: JSON.stringify(data) + } + ); + + const responseData = await response.json(); + + if (!response.ok) { + logger.error('DEPLOYMENT_EDIT_ERROR', responseData, locals.user.id); + return fail( + response.status, + new APIError('Failed to edit deployment', responseData).getSerializable() + ); + } + + return responseData; + } catch (error) { + logger.error('DEPLOYMENT_EDIT_ERROR', error, locals.user.id); + return fail(500, new APIError('Failed to edit deployment', error as any).getSerializable()); + } + }, + 'delete-deployment': async function ({ locals, request }) { + if (!locals.user) { + return error(401, 'Unauthorized'); + } + + const formData = await request.formData(); + + try { + const deployment_id = formData.get('deployment_id')?.toString(); + + if (!deployment_id || typeof deployment_id !== 'string' || deployment_id.trim() === '') { + return fail(400, new APIError('Deployment ID (type string) is required').getSerializable()); + } + + const response = await fetch( + `${OWL_URL}/api/v2/models/deployments?${new URLSearchParams([['deployment_id', deployment_id]])}`, + { + method: 'DELETE', + headers: { + ...headers, + 'x-user-id': locals.user.id || '' + } + } + ); + + const responseData = await response.json(); + + if (!response.ok) { + logger.error('DEPLOYMENT_DELETE_ERROR', responseData, locals.user.id); + return fail( + response.status, + new APIError('Failed to delete deployment', responseData).getSerializable() + ); + } + + return responseData; + } catch (error) { + logger.error('DEPLOYMENT_DELETE_ERROR', error, locals.user.id); + return fail(500, new APIError('Failed to delete deployment', error as any).getSerializable()); + } + } +}; diff --git a/services/app/src/routes/(main)/system/models/+page.svelte b/services/app/src/routes/(main)/system/models/+page.svelte new file mode 100644 index 0000000..a4531c2 --- /dev/null +++ b/services/app/src/routes/(main)/system/models/+page.svelte @@ -0,0 +1,56 @@ + + + + Model Setup + + +
    + + + +
    + + + + { + if (!v) { + page.url.searchParams.delete('onboarding'); + goto(`?${page.url.searchParams}`, { invalidate: [] }); + } + }} +> + +

    Start deploying your first model

    + +
    +
    diff --git a/services/app/src/routes/(main)/system/models/[model_id]/(components)/CloudDeployments.svelte b/services/app/src/routes/(main)/system/models/[model_id]/(components)/CloudDeployments.svelte new file mode 100644 index 0000000..4be96ab --- /dev/null +++ b/services/app/src/routes/(main)/system/models/[model_id]/(components)/CloudDeployments.svelte @@ -0,0 +1,106 @@ + + + + + + Endpoint Name + API Base + Provider + Routing ID + Actions + + + + + + + + + {#if model.deployments?.length} + {@const cloudDeployments = model.deployments.flat()} + {#if cloudDeployments.length > 0} + {#each cloudDeployments as deployment} + + {deployment.name} + {deployment.api_base} + + {PROVIDERS[deployment.provider] || deployment.provider} + + {deployment.routing_id} + +
    + + +
    +
    +
    + {/each} + {:else} + + + No cloud deployments found + + {/if} + {/if} +
    +
    + + + +{#if selectedEditDeploymentId} + +{/if} + diff --git a/services/app/src/routes/(main)/system/models/[model_id]/(components)/ModelDetails.svelte b/services/app/src/routes/(main)/system/models/[model_id]/(components)/ModelDetails.svelte new file mode 100644 index 0000000..6b7f3d4 --- /dev/null +++ b/services/app/src/routes/(main)/system/models/[model_id]/(components)/ModelDetails.svelte @@ -0,0 +1,513 @@ + + +
    +
    { + if (!modelType) { + toast.error('Model type is required', { + id: 'model-type-required' + }); + cancel(); + return; + } + + if (selectedCapabilities.length === 0) { + toast.error('At least one capability is required', { + id: 'capability-required' + }); + cancel(); + return; + } + + loading = true; + if (modelIcon) { + formData.set('icon', modelIcon); + } + + formData.set('type', modelType); + formData.set('capabilities', JSON.stringify(selectedCapabilities)); + + const languages = formData.get('languages'); + if (languages) { + formData.set('languages', JSON.stringify(stringToArray(languages.toString()))); + } + + const allowed_orgs = formData.get('allowed_orgs'); + if (allowed_orgs) { + formData.set('allowed_orgs', JSON.stringify(stringToArray(allowed_orgs.toString()))); + } + + const blocked_orgs = formData.get('blocked_orgs'); + if (blocked_orgs) { + formData.set('blocked_orgs', JSON.stringify(stringToArray(blocked_orgs.toString()))); + } + + return async ({ update, result }) => { + //@ts-ignore + const data = result.data; + if (result.type === 'failure') { + toast.error(data.error, { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else if (result.type === 'success') { + toast.success('Model config updated successfully', { + id: 'edit-model-config-success' + }); + isEditing = false; + + page.url.searchParams.delete('edit'); + goto( + `/system/models/${encodeURIComponent(formData.get('id')?.toString()!)}?${page.url.searchParams}`, + { + invalidate: ['system:models', 'system:modelsslug'], + replaceState: true + } + ); + } + + update({ invalidateAll: false }); + loading = false; + }; + }} + action="?/edit-model-config" + class="px-5 py-5" + class:space-y-5={!isEditing} + > +
    +
    +

    Basic Information

    +
    + {#if isEditing} +
    + + +
    + +
    + +
    +
    + +
    + +
    + +
    +
    + +
    + +
    + + + {#if modelIcon} +
    + + {modelIcon} +
    + {:else} + Select model icon + {/if} +
    + + {#each Object.keys(modelLogos) as icon} + + + {icon} + + {/each} + +
    +
    +
    + +
    + +
    + + + {MODEL_TYPES[modelType ?? ''] || 'Select model type'} + + + {#each Object.keys(MODEL_TYPES) as type} + {MODEL_TYPES[type]} + {/each} + + +
    +
    + +
    + +
    + {#each MODEL_CAPABILITIES as capability} + + {/each} +
    +
    + +
    + +
    + +
    +
    + +
    + +
    + +
    +
    +
    + {:else} +
    +
    +

    Model Name

    +

    {model.name}

    +
    +
    +

    Model ID

    +

    {model.id}

    +
    +
    +

    Model Type

    +
    +

    {model.type}

    +
    +
    +
    +

    Capabilities

    +

    + {model.capabilities.map((cap) => capitalize(cap)).join(', ')} +

    +
    +
    +

    Priority

    +

    {model.priority}

    +
    +
    +

    Owned By

    +

    {model.owned_by}

    +
    +
    + {/if} +
    + +
    +
    +

    Model Specification

    +
    + {#if isEditing} +
    +
    + +
    + +
    +
    + {#if modelType === MODEL_TYPES.embed.toLowerCase()} +
    + +
    + +
    +
    +
    + +
    + +
    +
    +
    + +
    + +
    +
    + {/if} +
    + +
    + +
    +
    +
    + {:else} +
    +
    +

    Context Length

    +

    {model.context_length}

    +
    +
    +

    Languages

    +

    {model.languages.join(', ')}

    +
    +
    + {/if} +
    + +
    +
    +

    Cost Configuration

    +
    + {#if isEditing} +
    + {#if modelType === MODEL_TYPES.llm.toLowerCase()} +
    + +
    + +
    +
    +
    + +
    + +
    +
    + {/if} + {#if modelType === MODEL_TYPES.embed.toLowerCase()} +
    + +
    + +
    +
    + {/if} + {#if modelType === MODEL_TYPES.rerank.toLowerCase()} +
    + +
    + +
    +
    + {/if} +
    + {:else} +
    +
    +

    Cost in USD per million input tokens

    +

    {model.llm_input_cost_per_mtoken}

    +
    +
    +

    Cost in USD per million output tokens

    +

    {model.llm_output_cost_per_mtoken}

    +
    +
    + {/if} +
    + +
    +
    +

    Access Control

    +
    + {#if isEditing} +
    +
    + +
    + +
    +
    +
    + +
    + +
    +
    +
    + {:else} +
    +
    +

    Allowed Orgs

    +

    {model.allowed_orgs.join(', ')}

    +
    +
    +

    Blocked Orgs

    +

    {model.blocked_orgs.join(', ')}

    +
    +
    +

    Is Private

    +
    + +
    +
    +
    + {/if} +
    +
    + {#if !isEditing} +
    + + + +
    + {:else} +
    + + +
    + {/if} +
    + + diff --git a/services/app/src/routes/(main)/system/models/[model_id]/(components)/index.ts b/services/app/src/routes/(main)/system/models/[model_id]/(components)/index.ts new file mode 100644 index 0000000..36f0313 --- /dev/null +++ b/services/app/src/routes/(main)/system/models/[model_id]/(components)/index.ts @@ -0,0 +1,4 @@ +import CloudDeployments from './CloudDeployments.svelte'; +import ModelDetails from './ModelDetails.svelte'; + +export { CloudDeployments, ModelDetails }; diff --git a/services/app/src/routes/(main)/system/models/[model_id]/+page.server.ts b/services/app/src/routes/(main)/system/models/[model_id]/+page.server.ts new file mode 100644 index 0000000..dd38a52 --- /dev/null +++ b/services/app/src/routes/(main)/system/models/[model_id]/+page.server.ts @@ -0,0 +1,125 @@ +import { env } from '$env/dynamic/private'; +import { PUBLIC_ADMIN_ORGANIZATION_ID } from '$env/static/public'; +import logger, { APIError } from '$lib/logger.js'; +import type { ModelConfig } from '$lib/types.js'; +import { error, fail } from '@sveltejs/kit'; + +const { OWL_URL, OWL_SERVICE_KEY } = env; + +const headers = { + Authorization: `Bearer ${OWL_SERVICE_KEY}` +}; + +export async function load({ cookies, locals, depends, params }) { + depends('system:modelsslug'); + + if (!locals.user) { + return error(401, 'Unauthorized'); + } + + if ( + cookies.get('activeOrganizationId') !== PUBLIC_ADMIN_ORGANIZATION_ID || + !locals.user?.org_memberships.find( + (org) => org.organization_id === PUBLIC_ADMIN_ORGANIZATION_ID + ) + ) { + throw error(404, 'Not found'); + } + + const getModelConfig = async () => { + const response = await fetch( + `${OWL_URL}/api/v2/models/configs?${new URLSearchParams([['model_id', params.model_id]])}`, + { + headers: { + ...headers, + 'x-user-id': locals.user!.id || '' + } + } + ); + + const data = await response.json(); + + if (!response.ok) { + logger.error('MODELCONFIG_GET_ERROR', data, locals.user!.id); + return { data: null, status: response.status, error: data as any }; + } + + return { data: data as ModelConfig, status: response.status }; + }; + + return { + modelConfig: getModelConfig() + }; +} + +export const actions = { + 'edit-model-config': async function ({ locals, request }) { + if (!locals.user) { + return error(401, 'Unauthorized'); + } + + const formData = await request.formData(); + + const data: Record = {}; + + try { + for (const [key, value] of formData.entries()) { + // Skip empty values + if (value === '' || value === null || value === undefined) continue; + + // Handle value conversion + const numValue = Number(value); + if (!isNaN(numValue)) { + data[key] = numValue; + } else { + data[key] = value; + } + } + + data['capabilities'] = JSON.parse((formData.get('capabilities') as string) || '[]'); + + data['languages'] = JSON.parse((formData.get('languages') as string) || '[]'); + + data['allowed_orgs'] = JSON.parse((formData.get('allowed_orgs') as string) || '[]'); + data['blocked_orgs'] = JSON.parse((formData.get('blocked_orgs') as string) || '[]'); + + data['owned_by'] = (formData.get('owned_by') as string) || ''; + + if (data['icon']) { + if (!data['meta']) { + data['meta'] = {}; + } + data['meta']['icon'] = data['icon']; + delete data['icon']; + } + + const response = await fetch( + `${OWL_URL}/api/v2/models/configs?${new URLSearchParams([['model_id', data.model_id]])}`, + { + method: 'PATCH', + headers: { + ...headers, + 'x-user-id': locals.user.id || '', + 'Content-Type': 'application/json' + }, + body: JSON.stringify(data) + } + ); + + const responseData = await response.json(); + + if (!response.ok) { + logger.error('MODELCONFIG_EDIT_ERROR', responseData, locals.user.id); + return fail( + response.status, + new APIError('Failed to edit model config', responseData).getSerializable() + ); + } + + return responseData; + } catch (error) { + logger.error('MODELCONFIG_EDIT_ERROR', error, locals.user.id); + return fail(500, new APIError('Failed to edit model config', error as any).getSerializable()); + } + } +}; diff --git a/services/app/src/routes/(main)/system/models/[model_id]/+page.svelte b/services/app/src/routes/(main)/system/models/[model_id]/+page.svelte new file mode 100644 index 0000000..eab37ef --- /dev/null +++ b/services/app/src/routes/(main)/system/models/[model_id]/+page.svelte @@ -0,0 +1,97 @@ + + + + Model Setup - {page.params.model_id} + + +
    + {#await data.modelConfig} +
    +
    +
    +
    +
    +
    + {:then modelConfig} + {#if modelConfig.data} +
    + +
    + + +

    + {modelConfig.data.name} + + {MODEL_TYPES[modelConfig.data.type] ?? modelConfig.data.type} + +

    +
    + + +
    +
    + + + + +
    + +
    + {#if activeTab === 'details'} + + {:else if activeTab === 'cloud'} + + + {/if} +
    +
    +
    + {/if} + {/await} +
    diff --git a/services/app/src/routes/+error.svelte b/services/app/src/routes/+error.svelte old mode 100644 new mode 100755 index a632a37..04a061f --- a/services/app/src/routes/+error.svelte +++ b/services/app/src/routes/+error.svelte @@ -1,20 +1,20 @@ - {$page.error.message} + {page.error.message}
    - {$page.status} + {page.status}
    -

    {$page.error.message}

    +

    {page.error.message}

    diff --git a/services/app/src/routes/+layout.server.ts b/services/app/src/routes/+layout.server.ts old mode 100644 new mode 100755 index 0ec5ec8..01df0eb --- a/services/app/src/routes/+layout.server.ts +++ b/services/app/src/routes/+layout.server.ts @@ -1,23 +1,28 @@ -import { PUBLIC_IS_LOCAL, PUBLIC_IS_SPA } from '$env/static/public'; -import { JAMAI_URL, JAMAI_SERVICE_KEY } from '$env/static/private'; -import { error, redirect } from '@sveltejs/kit'; -import { getPrices } from '$lib/server/nodeCache.js'; +import { env } from '$env/dynamic/private'; +import { PUBLIC_IS_SPA } from '$env/static/public'; import logger from '$lib/logger.js'; -import type { OrganizationReadRes, PriceRes, UserRead } from '$lib/types.js'; +import type { OrganizationReadRes } from '$lib/types.js'; +import { redirect } from '@sveltejs/kit'; import type { LayoutServerLoadEvent } from './$types.js'; interface Data { - prices: PriceRes | undefined; user: App.Locals['user']; - userData?: UserRead; dockOpen: boolean; rightDockOpen: boolean; activeOrganizationId?: string; organizationData?: OrganizationReadRes; + OWL_STRIPE_PUBLISHABLE_KEY: string; } +const { + OWL_SERVICE_KEY, + OWL_URL, + OWL_STRIPE_PUBLISHABLE_KEY_LIVE, + OWL_STRIPE_PUBLISHABLE_KEY_TEST +} = env; + const headers = { - Authorization: `Bearer ${JAMAI_SERVICE_KEY}` + Authorization: `Bearer ${OWL_SERVICE_KEY}` }; export const prerender = PUBLIC_IS_SPA !== 'true' ? false : 'auto'; @@ -44,8 +49,6 @@ export const load: (event: LayoutServerLoadEvent) => Promise = async ({ }); } - const prices = await getPrices(); - const showDock = cookies.get('dockOpen') === 'true'; const showRightDock = cookies.get('rightDockOpen') === 'true'; @@ -56,206 +59,111 @@ export const load: (event: LayoutServerLoadEvent) => Promise = async ({ cookies.set('rightDockOpen', 'false', { path: '/', httpOnly: false }); } - if (PUBLIC_IS_LOCAL === 'false') { + if (!url.pathname.startsWith('/login') && !url.pathname.startsWith('/register')) { if (!locals.user!.email_verified && !url.pathname.startsWith('/verify-email')) { throw redirect( 302, - `/verify-email${url.searchParams.size > 0 ? `?${url.searchParams.toString()}` : ''}` + `/verify-email${url.searchParams.size > 0 ? `?${url.searchParams}` : ''}` ); } - if (locals.user!.email_verified) { + if (locals.user?.email_verified) { let activeOrganizationId = cookies.get('activeOrganizationId'); - const userApiRes = await fetch( - `${JAMAI_URL}/api/admin/backend/v1/users/${locals.user!.sub}`, - { - headers - } - ); - if (userApiRes.status === 404) { - const userUpsertRes = await fetch(`${JAMAI_URL}/api/admin/backend/v1/users`, { - method: 'POST', - headers: { - ...headers, - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ - id: locals.user!.sub, - name: - locals.user!.email === locals.user!.name ? locals.user!.nickname : locals.user!.name, - description: '', - email: locals.user!.email - }) - }); - const userUpsertBody = (await userUpsertRes.json()) as UserRead; - - if (userUpsertRes.ok) { - //? Redirect to create org if no orgs - if ( - userUpsertBody.member_of.length === 0 && - !url.pathname.startsWith('/new-organization') && - !url.pathname.startsWith('/accept-invite') - ) { - throw redirect(302, '/new-organization'); - } - - //? Set org ID if not set or if it's not in the list of orgs - if ( - !activeOrganizationId || - !userUpsertBody.member_of.find((org) => org.organization_id === activeOrganizationId) - ) { - cookies.set('activeOrganizationId', userUpsertBody.member_of[0].organization_id, { - path: '/', - sameSite: 'strict', - maxAge: 604800, - httpOnly: false, - secure: false - }); - - activeOrganizationId = cookies.get('activeOrganizationId'); - } - - const orgData = await getOrganizationData(activeOrganizationId!); - const userRoleInOrg = orgData?.members?.find( - (user) => user.user_id === locals.user?.sub - )?.role; - - //* Obfuscate external keys if not admin - if (orgData && userRoleInOrg !== 'admin') { - if (orgData.external_keys) { - orgData.external_keys = Object.fromEntries( - Object.entries(orgData.external_keys).map(([key, value]) => [ - key, - value.trim() === '' ? '' : '********' - ]) - ); - } - delete orgData.members; - - //* Obfuscate credit - orgData.credit = orgData.credit > 0 ? 1 : 0; - orgData.credit_grant = orgData.credit_grant > 0 ? 1 : 0; - - //* Remove JamAI api keys if not member - if (userRoleInOrg !== 'member') { - delete orgData.api_keys; - } - } - - return { - prices, - user: locals.user, - userData: userUpsertBody, - dockOpen: cookies.get('dockOpen') === 'true', - rightDockOpen: cookies.get('rightDockOpen') === 'true', - activeOrganizationId, - organizationData: orgData - }; - } else { - logger.error('APP_USER_UPSERT', userUpsertBody); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - throw error(userUpsertRes.status, userUpsertBody as any); - } - } else if (userApiRes.ok) { - const userApiBody = (await userApiRes.json()) as UserRead; - - //? Redirect to create org if no orgs - if ( - userApiBody.member_of.length === 0 && - !url.pathname.startsWith('/new-organization') && - !url.pathname.startsWith('/accept-invite') - ) { - throw redirect(302, '/new-organization'); - } - - //? Set org ID if not set or if it's not in the list of orgs - if ( - !activeOrganizationId || - !userApiBody.member_of.find((org) => org.organization_id === activeOrganizationId) - ) { - cookies.set('activeOrganizationId', userApiBody.member_of[0]?.organization_id, { - path: '/', - sameSite: 'strict', - maxAge: 604800, - httpOnly: false, - secure: false - }); - - activeOrganizationId = cookies.get('activeOrganizationId'); - } - - const orgData = await getOrganizationData(activeOrganizationId!); - const userRoleInOrg = orgData?.members?.find( - (user) => user.user_id === locals.user?.sub - )?.role; + //? Redirect to create org if no orgs + if ( + locals.user?.org_memberships.length === 0 && + !url.pathname.startsWith('/new-organization') && + !url.pathname.startsWith('/join-organization') + ) { + throw redirect(302, '/new-organization'); + } - //* Obfuscate external keys if not admin - if (orgData && userRoleInOrg !== 'admin') { - if (orgData.external_keys) { - orgData.external_keys = Object.fromEntries( - Object.entries(orgData.external_keys).map(([key, value]) => [ - key, - value.trim() === '' ? '' : '********' - ]) - ); - } - delete orgData.members; + //? Set org ID if not set or if it's not in the list of orgs + if ( + locals.user?.org_memberships.length !== 0 && + (!activeOrganizationId || + !locals.user?.org_memberships.find((org) => org.organization_id === activeOrganizationId)) + ) { + cookies.set('activeOrganizationId', locals.user!.org_memberships[0].organization_id!, { + path: '/', + sameSite: 'strict', + maxAge: 604800, + httpOnly: false, + secure: false + }); - //* Obfuscate credit - orgData.credit = orgData.credit > 0 ? 1 : 0; - orgData.credit_grant = orgData.credit_grant > 0 ? 1 : 0; + activeOrganizationId = cookies.get('activeOrganizationId'); + } - //* Remove JamAI api keys if not member - if (userRoleInOrg !== 'member') { - delete orgData.api_keys; - } + const orgData = await getOrganizationData(activeOrganizationId!); + const userRoleInOrg = locals.user?.org_memberships.find( + (org) => org.organization_id === activeOrganizationId + )?.role; + + //* Obfuscate external keys if not admin + if (orgData && userRoleInOrg !== 'ADMIN') { + if (orgData.external_keys) { + orgData.external_keys = Object.fromEntries( + Object.entries(orgData.external_keys).map(([key, value]) => [ + key, + value.trim() === '' ? '' : '********' + ]) + ); } - return { - prices, - user: locals.user, - userData: userApiBody, - dockOpen: cookies.get('dockOpen') === 'true', - rightDockOpen: cookies.get('rightDockOpen') === 'true', - activeOrganizationId, - organizationData: orgData - }; - } else { - logger.error('APP_USER_GET', await userApiRes.json()); - //FIXME: Throw error if user API fails, maybe? - return { - prices, - user: locals.user, - dockOpen: cookies.get('dockOpen') === 'true', - rightDockOpen: cookies.get('rightDockOpen') === 'true' - }; - throw error(userApiRes.status, await userApiRes.json()); + //* Obfuscate credit + orgData.credit = orgData.credit > 0 ? 1 : 0; + orgData.credit_grant = orgData.credit_grant > 0 ? 1 : 0; } + + return { + user: locals.user, + dockOpen: cookies.get('dockOpen') === 'true', + rightDockOpen: cookies.get('rightDockOpen') === 'true', + activeOrganizationId, + organizationData: orgData, + ossMode: locals.ossMode, + auth0Mode: locals.auth0Mode, + OWL_STRIPE_PUBLISHABLE_KEY: + OWL_STRIPE_PUBLISHABLE_KEY_LIVE || OWL_STRIPE_PUBLISHABLE_KEY_TEST || '' + }; } else { return { - prices, user: locals.user, dockOpen: cookies.get('rightDockOpen') === 'true', - rightDockOpen: cookies.get('rightDockOpen') === 'true' + rightDockOpen: cookies.get('rightDockOpen') === 'true', + ossMode: locals.ossMode, + auth0Mode: locals.auth0Mode, + OWL_STRIPE_PUBLISHABLE_KEY: + OWL_STRIPE_PUBLISHABLE_KEY_LIVE || OWL_STRIPE_PUBLISHABLE_KEY_TEST || '' }; } } else { return { - prices, user: locals.user, activeOrganizationId: 'default', dockOpen: cookies.get('rightDockOpen') === 'true', - rightDockOpen: cookies.get('rightDockOpen') === 'true' + rightDockOpen: cookies.get('rightDockOpen') === 'true', + ossMode: locals.ossMode, + auth0Mode: locals.auth0Mode, + OWL_STRIPE_PUBLISHABLE_KEY: + OWL_STRIPE_PUBLISHABLE_KEY_LIVE || OWL_STRIPE_PUBLISHABLE_KEY_TEST || '' }; } // eslint-disable-next-line @typescript-eslint/no-explicit-any async function getOrganizationData(orgId: string): Promise { if (!orgId) return undefined; - const orgInfoRes = await fetch(`${JAMAI_URL}/api/admin/backend/v1/organizations/${orgId}`, { - headers - }); + const orgInfoRes = await fetch( + `${OWL_URL}/api/v2/organizations?${new URLSearchParams([['organization_id', orgId]])}`, + { + headers: { + ...headers, + 'x-user-id': locals.user?.id ?? '' + } + } + ); const orgInfoBody = (await orgInfoRes.json()) as OrganizationReadRes; if (!orgInfoRes.ok) { diff --git a/services/app/src/routes/+layout.svelte b/services/app/src/routes/+layout.svelte old mode 100644 new mode 100755 index 63551c1..3e0f28a --- a/services/app/src/routes/+layout.svelte +++ b/services/app/src/routes/+layout.svelte @@ -8,49 +8,66 @@ import '@fontsource-variable/roboto-flex'; import { showDock, showRightDock, preferredTheme, activeOrganization } from '$globalStore'; - import { Toaster } from '$lib/components/ui/sonner'; + import * as Tooltip from '$lib/components/ui/tooltip'; + import { CustomToastDesc, toast, Toaster } from '$lib/components/ui/sonner'; let timeout: NodeJS.Timeout; NProgress.configure({ showSpinner: false }); - beforeNavigate(() => (timeout = setTimeout(() => NProgress.start(), 250))); - afterNavigate(() => { - clearTimeout(timeout); - NProgress.done(); - }); + // beforeNavigate(() => (timeout = setTimeout(() => NProgress.start(), 250))); + // afterNavigate(() => { + // clearTimeout(timeout); + // NProgress.done(); + // }); - export let data; - $: ({ dockOpen, rightDockOpen, userData, activeOrganizationId } = data); + let { data, children } = $props(); + let { dockOpen, rightDockOpen, user, activeOrganizationId } = $derived(data); //* Initialize showDock using cookie store - $: $showDock = dockOpen; - $: $showRightDock = rightDockOpen; - - $: if (browser) { - document.cookie = `dockOpen=${$showDock}; path=/; sameSite=Lax`; - document.cookie = `rightDockOpen=${$showRightDock}; path=/; sameSite=Lax`; - } - - $: if (browser) { - if ($preferredTheme == 'SYSTEM') { - document.documentElement.setAttribute( - 'data-theme', - window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light' - ); - } else { - document.documentElement.setAttribute( - 'data-theme', - $preferredTheme == 'LIGHT' ? 'light' : 'dark' - ); + // svelte-ignore state_referenced_locally (mimic run function) + $showDock = dockOpen; + $effect.pre(() => { + $showDock = dockOpen; + }); + // svelte-ignore state_referenced_locally (mimic run function) + $showRightDock = rightDockOpen; + $effect.pre(() => { + $showRightDock = rightDockOpen; + }); + + $effect(() => { + if (browser) { + document.cookie = `dockOpen=${$showDock}; path=/; sameSite=Lax`; + document.cookie = `rightDockOpen=${$showRightDock}; path=/; sameSite=Lax`; } - } + }); - $: if (activeOrganizationId) { - $activeOrganization = - userData?.member_of?.find((org) => org.organization_id === activeOrganizationId) ?? null; - } - $: if (browser && $activeOrganization) { - document.cookie = `activeOrganizationId=${$activeOrganization?.organization_id}; path=/; max-age=604800; samesite=strict`; - } + $effect(() => { + if (browser) { + if ($preferredTheme == 'SYSTEM') { + document.documentElement.setAttribute( + 'data-theme', + window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light' + ); + } else { + document.documentElement.setAttribute( + 'data-theme', + $preferredTheme == 'LIGHT' ? 'light' : 'dark' + ); + } + } + }); + + $effect(() => { + if (activeOrganizationId) { + $activeOrganization = + user?.organizations?.find((org) => org.id === activeOrganizationId) ?? null; + } + }); + // $effect(() => { + // if (browser && $activeOrganization) { + // document.cookie = `activeOrganizationId=${$activeOrganization?.id}; path=/; max-age=604800; samesite=strict`; + // } + // }); onMount(() => { //* Reflect changes to user preference for immediately @@ -85,6 +102,53 @@ // document.startViewTransition(switchTheme); // } // } + + // function switchTheme(e: KeyboardEvent) { + // const target = e.target as HTMLElement; + // if ( + // target.tagName == 'INPUT' || + // target.tagName == 'TEXTAREA' || + // target.getAttribute('contenteditable') == 'true' + // ) + // return; + + // if (!e.ctrlKey && !e.shiftKey && !e.metaKey) { + // switch (e.key) { + // case 'e': + // toast.error('Test', { + // duration: Number.POSITIVE_INFINITY, + // description: CustomToastDesc as any, + // componentProps: { + // description: 'Error desc here', + // requestID: 'Request ID here' + // } + // }); + // break; + // case 's': + // toast.success('Test', { + // duration: Number.POSITIVE_INFINITY, + // description: CustomToastDesc as any, + // componentProps: { + // description: 'Error desc here', + // requestID: 'Request ID here' + // } + // }); + // break; + // case 'i': + // toast.info('Test', { + // duration: Number.POSITIVE_INFINITY, + // description: CustomToastDesc as any, + // componentProps: { + // description: 'Error desc here', + // requestID: 'Request ID here' + // } + // }); + // break; + // default: + // break; + // } + // } + // } @@ -113,4 +177,6 @@ - + + {@render children?.()} + diff --git a/services/app/src/routes/+page.ts b/services/app/src/routes/+page.ts old mode 100644 new mode 100755 diff --git a/services/app/src/routes/_layout.ts b/services/app/src/routes/_layout.ts old mode 100644 new mode 100755 index 5f4a9d3..d8eff40 --- a/services/app/src/routes/_layout.ts +++ b/services/app/src/routes/_layout.ts @@ -1,10 +1,10 @@ -import { PUBLIC_IS_LOCAL, PUBLIC_IS_SPA } from '$env/static/public'; +import { PUBLIC_IS_SPA } from '$env/static/public'; //@ts-expect-error missing types export async function load({ parent }) { - await parent(); + const data = await parent(); - if (PUBLIC_IS_LOCAL !== 'false' && PUBLIC_IS_SPA === 'true') { + if (data.ossMode && PUBLIC_IS_SPA === 'true') { return { activeOrganizationId: 'default', dockOpen: true, diff --git a/services/app/src/routes/api/admin/org/v1/projects/+server.ts b/services/app/src/routes/api/admin/org/v1/projects/+server.ts deleted file mode 100644 index c75e6f3..0000000 --- a/services/app/src/routes/api/admin/org/v1/projects/+server.ts +++ /dev/null @@ -1,191 +0,0 @@ -import { PUBLIC_IS_LOCAL } from '$env/static/public'; -import { JAMAI_URL, JAMAI_SERVICE_KEY } from '$env/static/private'; -import { json } from '@sveltejs/kit'; -import { projectIDPattern } from '$lib/constants.js'; -import logger, { APIError } from '$lib/logger.js'; -import type { UserRead } from '$lib/types.js'; - -const headers = { - Authorization: `Bearer ${JAMAI_SERVICE_KEY}` -}; - -export const GET = async ({ cookies, locals, url }) => { - const activeOrganizationId = cookies.get('activeOrganizationId'); - - if (PUBLIC_IS_LOCAL === 'false') { - if (!activeOrganizationId) { - return json(new APIError('No active organization'), { status: 400 }); - } - - //* Verify user perms - if (!locals.user) { - return json(new APIError('Unauthorized'), { status: 401 }); - } - - const userApiRes = await fetch(`${JAMAI_URL}/api/admin/backend/v1/users/${locals.user.sub}`, { - headers - }); - const userApiBody = (await userApiRes.json()) as UserRead; - if (userApiRes.ok) { - const targetOrg = userApiBody.member_of.find( - (org) => org.organization_id === activeOrganizationId - ); - if (!targetOrg) { - return json(new APIError('Forbidden'), { status: 403 }); - } - } else { - logger.error('PROJECT_LIST_GETUSER', userApiBody); - return json(new APIError('Failed to get user info', userApiBody as any), { - status: userApiRes.status - }); - } - } - - const searchParams = new URLSearchParams({ organization_id: activeOrganizationId ?? '' }); - url.searchParams.forEach((value, key) => { - if (key === 'organization_id' && PUBLIC_IS_LOCAL === 'false') return; - searchParams.set(key, value); - }); - - const projectsListRes = await fetch(`${JAMAI_URL}/api/admin/org/v1/projects?${searchParams}`, { - headers - }); - const projectsListBody = await projectsListRes.json(); - - if (!projectsListRes.ok) { - logger.error('PROJECT_LIST_LIST', projectsListBody); - return json(new APIError('Failed to get projects list', projectsListBody), { - status: projectsListRes.status - }); - } else { - return json(projectsListBody); - } -}; - -export const POST = async ({ cookies, fetch, locals, request }) => { - const activeOrganizationId = cookies.get('activeOrganizationId'); - - const { name: project_name } = await request.json(); - if (!project_name || typeof project_name !== 'string' || project_name.trim() === '') { - return json(new APIError('Invalid project name'), { status: 400 }); - } - - if (!projectIDPattern.test(project_name)) { - return json( - new APIError( - 'Project name must contain only alphanumeric characters and underscores/hyphens/spaces/periods, and start and end with alphanumeric characters, between 2 and 100 characters.' - ), - { status: 400 } - ); - } - - if (PUBLIC_IS_LOCAL === 'false') { - //* Verify user perms - if (!locals.user) { - return json(new APIError('Unauthorized'), { status: 401 }); - } - - const userApiRes = await fetch(`${JAMAI_URL}/api/admin/backend/v1/users/${locals.user.sub}`, { - headers - }); - const userApiBody = (await userApiRes.json()) as UserRead; - if (userApiRes.ok) { - const targetOrg = userApiBody.member_of.find( - (org) => org.organization_id === activeOrganizationId - ); - if (!targetOrg || (targetOrg.role !== 'admin' && targetOrg.role !== 'member')) { - return json(new APIError('Forbidden'), { status: 403 }); - } - } else { - logger.error('PROJECT_CREATE_GETUSER', userApiBody); - return json(new APIError('Failed to get user info', userApiBody as any), { - status: userApiRes.status - }); - } - } - - const createProjectRes = await fetch(`${JAMAI_URL}/api/admin/org/v1/projects`, { - method: 'POST', - headers: { - ...headers, - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ - name: project_name, - organization_id: activeOrganizationId - }) - }); - - const createProjectBody = await createProjectRes.json(); - if (!createProjectRes.ok) { - logger.error('PROJECT_CREATE_CREATE', createProjectBody); - return json(new APIError('Failed to create project', createProjectBody), { - status: createProjectRes.status - }); - } else { - return json(createProjectBody); - } -}; - -export const PATCH = async ({ locals, request }) => { - const { id: projectId, name: project_name } = await request.json(); - if (!project_name || typeof project_name !== 'string' || project_name.trim() === '') { - return json(new APIError('Invalid project name'), { status: 400 }); - } - - if (PUBLIC_IS_LOCAL === 'false') { - if (!locals.user) { - return json(new APIError('Unauthorized'), { status: 401 }); - } - - const projectApiRes = await fetch(`${JAMAI_URL}/api/admin/org/v1/projects/${projectId}`, { - headers - }); - const projectApiBody = await projectApiRes.json(); - - if (!projectApiRes.ok) { - logger.error('PROJECT_PATCH_GETPROJ', projectApiBody); - } - - const userApiRes = await fetch(`${JAMAI_URL}/api/admin/backend/v1/users/${locals.user.sub}`, { - headers - }); - const userApiBody = (await userApiRes.json()) as UserRead; - - if (userApiRes.ok) { - const targetOrg = userApiBody.member_of.find( - (org) => org.organization_id === projectApiBody.organization_id - ); - if (!targetOrg || targetOrg.role !== 'admin') { - return json(new APIError('Forbidden'), { status: 403 }); - } - } else { - logger.error('PROJECT_PATCH_GETUSER', userApiBody); - return json(new APIError('Failed to get user info', userApiBody as any), { - status: userApiRes.status - }); - } - } - - const patchProjectRes = await fetch(`${JAMAI_URL}/api/admin/org/v1/projects`, { - method: 'PATCH', - headers: { - ...headers, - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ - id: projectId, - name: project_name - }) - }); - - const patchProjectBody = await patchProjectRes.json(); - if (!patchProjectRes.ok) { - logger.error('PROJECT_PATCH_PATCH', patchProjectBody); - return json(new APIError('Failed to update project', patchProjectBody as any), { - status: patchProjectRes.status - }); - } else { - return json({ ok: true }); - } -}; diff --git a/services/app/src/routes/api/admin/org/v1/projects/[project_id]/+server.ts b/services/app/src/routes/api/admin/org/v1/projects/[project_id]/+server.ts deleted file mode 100644 index dfbe582..0000000 --- a/services/app/src/routes/api/admin/org/v1/projects/[project_id]/+server.ts +++ /dev/null @@ -1,93 +0,0 @@ -import { PUBLIC_IS_LOCAL } from '$env/static/public'; -import { JAMAI_URL, JAMAI_SERVICE_KEY } from '$env/static/private'; -import { json } from '@sveltejs/kit'; -import logger, { APIError } from '$lib/logger.js'; -import type { Project, UserRead } from '$lib/types.js'; - -const headers = { - Authorization: `Bearer ${JAMAI_SERVICE_KEY}` -}; - -export const GET = async ({ locals, params }) => { - if (PUBLIC_IS_LOCAL === 'false') { - if (!locals.user) { - return json(new APIError('Unauthorized'), { status: 401 }); - } - } - - const projectRes = await fetch(`${JAMAI_URL}/api/admin/org/v1/projects/${params.project_id}`, { - headers - }); - const projectBody = await projectRes.json(); - - if (projectRes.ok) { - if ( - PUBLIC_IS_LOCAL !== 'false' || - (projectBody as Project).organization.members?.find( - (user) => user.user_id === locals.user?.sub - ) - ) { - return json(projectBody); - } else { - return json(new APIError('Project not found'), { status: 404 }); - } - } else { - return json(new APIError('Failed to get project', projectBody as any), { - status: projectRes.status - }); - } -}; - -export const DELETE = async ({ locals, params }) => { - const projectId = params.project_id; - - if (PUBLIC_IS_LOCAL === 'false') { - //* Verify user perms - if (!locals.user) { - return json(new APIError('Unauthorized'), { status: 401 }); - } - - const projectApiRes = await fetch(`${JAMAI_URL}/api/admin/org/v1/projects/${projectId}`, { - headers - }); - const projectApiBody = await projectApiRes.json(); - - if (!projectApiRes.ok) { - logger.error('PROJECT_DELETE_GETPROJ', projectApiBody); - } - - const userApiRes = await fetch(`${JAMAI_URL}/api/admin/backend/v1/users/${locals.user.sub}`, { - headers - }); - const userApiBody = (await userApiRes.json()) as UserRead; - - if (userApiRes.ok) { - const targetOrg = userApiBody.member_of.find( - (org) => org.organization_id === projectApiBody.organization_id - ); - if (!targetOrg || targetOrg.role !== 'admin') { - return json(new APIError('Forbidden'), { status: 403 }); - } - } else { - logger.error('PROJECT_DELETE_GETUSER', userApiBody); - return json(new APIError('Failed to get user info', userApiBody as any), { - status: userApiRes.status - }); - } - } - - const deleteProjectRes = await fetch(`${JAMAI_URL}/api/admin/org/v1/projects/${projectId}`, { - method: 'DELETE', - headers - }); - - const deleteProjectBody = await deleteProjectRes.json(); - if (!deleteProjectRes.ok) { - logger.error('PROJECT_DELETE_DELETE', deleteProjectBody); - return json(new APIError('Failed to delete project', deleteProjectBody as any), { - status: deleteProjectRes.status - }); - } else { - return json({ ok: true }); - } -}; diff --git a/services/app/src/routes/api/admin/org/v1/projects/[project_id]/export/+server.ts b/services/app/src/routes/api/admin/org/v1/projects/[project_id]/export/+server.ts deleted file mode 100644 index 4c0671f..0000000 --- a/services/app/src/routes/api/admin/org/v1/projects/[project_id]/export/+server.ts +++ /dev/null @@ -1,65 +0,0 @@ -import { PUBLIC_IS_LOCAL } from '$env/static/public'; -import { JAMAI_URL, JAMAI_SERVICE_KEY } from '$env/static/private'; -import { json } from '@sveltejs/kit'; -import logger, { APIError } from '$lib/logger.js'; -import type { UserRead } from '$lib/types.js'; - -const headers = { - Authorization: `Bearer ${JAMAI_SERVICE_KEY}` -}; - -export const GET = async ({ locals, params }) => { - const projectId = params.project_id; - - if (PUBLIC_IS_LOCAL === 'false') { - //* Verify user perms - if (!locals.user) { - return json(new APIError('Unauthorized'), { status: 401 }); - } - - const projectApiRes = await fetch(`${JAMAI_URL}/api/admin/org/v1/projects/${projectId}`, { - headers - }); - const projectApiBody = await projectApiRes.json(); - - if (!projectApiRes.ok) { - logger.error('PROJECT_EXPORT_GETPROJ', projectApiBody); - } - - const userApiRes = await fetch(`${JAMAI_URL}/api/admin/backend/v1/users/${locals.user.sub}`, { - headers - }); - const userApiBody = (await userApiRes.json()) as UserRead; - - if (userApiRes.ok) { - const targetOrg = userApiBody.member_of.find( - (org) => org.organization_id === projectApiBody.organization_id - ); - if (!targetOrg) { - return json(new APIError('Forbidden'), { status: 403 }); - } - } else { - logger.error('PROJECT_EXPORT_GETUSER', userApiBody); - return json(new APIError('Failed to get user info', userApiBody as any), { - status: userApiRes.status - }); - } - } - - const exportProjectRes = await fetch( - `${JAMAI_URL}/api/admin/org/v1/projects/${projectId}/export`, - { - headers - } - ); - - if (!exportProjectRes.ok) { - const exportProjectBody = await exportProjectRes.json(); - logger.error('PROJECT_EXPORT_EXPORT', exportProjectBody); - return json(new APIError('Failed to export project', exportProjectBody as any), { - status: exportProjectRes.status - }); - } else { - return exportProjectRes; - } -}; diff --git a/services/app/src/routes/api/admin/org/v1/projects/import/[organization_id]/+server.ts b/services/app/src/routes/api/admin/org/v1/projects/import/[organization_id]/+server.ts deleted file mode 100644 index cb15352..0000000 --- a/services/app/src/routes/api/admin/org/v1/projects/import/[organization_id]/+server.ts +++ /dev/null @@ -1,66 +0,0 @@ -import { PUBLIC_IS_LOCAL } from '$env/static/public'; -import { JAMAI_URL, JAMAI_SERVICE_KEY } from '$env/static/private'; -import { json } from '@sveltejs/kit'; -import axios from 'axios'; -import logger, { APIError } from '$lib/logger.js'; -import type { UserRead } from '$lib/types.js'; - -const headers = { - Authorization: `Bearer ${JAMAI_SERVICE_KEY}` -}; - -export const POST = async ({ locals, params, request }) => { - const organizationId = params.organization_id; - - if (PUBLIC_IS_LOCAL === 'false') { - //* Verify user perms - if (!locals.user) { - return json(new APIError('Unauthorized'), { status: 401 }); - } - - const userApiRes = await fetch(`${JAMAI_URL}/api/admin/backend/v1/users/${locals.user.sub}`, { - headers - }); - const userApiBody = (await userApiRes.json()) as UserRead; - - if (userApiRes.ok) { - const targetOrg = userApiBody.member_of.find((org) => org.organization_id === organizationId); - if (!targetOrg || targetOrg.role === 'guest') { - return json(new APIError('Forbidden'), { status: 403 }); - } - } else { - logger.error('PROJECT_IMPORT_GETUSER', userApiBody); - return json(new APIError('Failed to get user info', userApiBody as any), { - status: userApiRes.status - }); - } - } - - try { - const importProjectRes = await axios.post( - `${JAMAI_URL}/api/admin/org/v1/projects/import/${organizationId}`, - await request.formData(), - { - headers: { - ...headers, - 'Content-Type': 'multipart/form-data' - } - } - ); - if (importProjectRes.status != 200) { - logger.error('PROJECT_IMPORT_IMPORT', importProjectRes.data); - return json(new APIError('Failed to import project', importProjectRes.data as any), { - status: importProjectRes.status - }); - } else { - return new Response(importProjectRes.data); - } - } catch (err) { - //@ts-expect-error AxiosError - logger.error('PROJECT_IMPORT_IMPORT', err?.response?.data); - //@ts-expect-error AxiosError - return json(new APIError('Failed to import project', err?.response?.data), { - status: 500 - }); - } -}; diff --git a/services/app/src/routes/api/log/+server.ts b/services/app/src/routes/api/log/+server.ts old mode 100644 new mode 100755 index 07c20bc..5fe2960 --- a/services/app/src/routes/api/log/+server.ts +++ b/services/app/src/routes/api/log/+server.ts @@ -1,6 +1,6 @@ +import { enumerateObj } from '$lib/utils.js'; import { json } from '@sveltejs/kit'; import { z } from 'zod'; -import { enumerateObj } from '$lib/utils.js'; export const POST = async ({ request, locals }) => { const body = await request.json().catch(() => null); @@ -27,31 +27,31 @@ export const POST = async ({ request, locals }) => { switch (type) { case 'error': console.error( - `Logged from client (${locals.user?.sub ?? 'Unknown'}): ${event}\n`, + `Logged from client (${locals.user?.id ?? 'Unknown'}): ${event}\n`, stringMessage ); break; case 'warn': console.warn( - `Logged from client (${locals.user?.sub ?? 'Unknown'}): ${event}\n`, + `Logged from client (${locals.user?.id ?? 'Unknown'}): ${event}\n`, stringMessage ); break; case 'info': console.info( - `Logged from client (${locals.user?.sub ?? 'Unknown'}): ${event}\n`, + `Logged from client (${locals.user?.id ?? 'Unknown'}): ${event}\n`, stringMessage ); break; case 'log': console.log( - `Logged from client (${locals.user?.sub ?? 'Unknown'}): ${event}\n`, + `Logged from client (${locals.user?.id ?? 'Unknown'}): ${event}\n`, stringMessage ); break; default: console.log( - `Logged from client (${locals.user?.sub ?? 'Unknown'}): ${event}\n`, + `Logged from client (${locals.user?.id ?? 'Unknown'}): ${event}\n`, stringMessage ); break; diff --git a/services/app/src/routes/api/v2/projects/export/+server.ts b/services/app/src/routes/api/v2/projects/export/+server.ts new file mode 100755 index 0000000..3bada5e --- /dev/null +++ b/services/app/src/routes/api/v2/projects/export/+server.ts @@ -0,0 +1,35 @@ +import { env } from '$env/dynamic/private'; +import logger, { APIError } from '$lib/logger.js'; +import { json } from '@sveltejs/kit'; + +const { OWL_SERVICE_KEY, OWL_URL } = env; + +const headers = { + Authorization: `Bearer ${OWL_SERVICE_KEY}` +}; + +export const GET = async ({ locals, url }) => { + const projectId = url.searchParams.get('project_id'); + + //* Verify user perms + if (!locals.user) { + return json(new APIError('Unauthorized'), { status: 401 }); + } + + const exportProjectRes = await fetch( + `${OWL_URL}/api/v2/projects/export?${new URLSearchParams([['project_id', projectId ?? '']])}`, + { + headers + } + ); + + if (!exportProjectRes.ok) { + const exportProjectBody = await exportProjectRes.json(); + logger.error('PROJECT_EXPORT_EXPORT', exportProjectBody); + return json(new APIError('Failed to export project', exportProjectBody as any), { + status: exportProjectRes.status + }); + } else { + return exportProjectRes; + } +}; diff --git a/services/app/src/routes/api/v2/projects/import/parquet/+server.ts b/services/app/src/routes/api/v2/projects/import/parquet/+server.ts new file mode 100755 index 0000000..2a496d2 --- /dev/null +++ b/services/app/src/routes/api/v2/projects/import/parquet/+server.ts @@ -0,0 +1,45 @@ +import { env } from '$env/dynamic/private'; +import logger, { APIError } from '$lib/logger.js'; +import { json } from '@sveltejs/kit'; +import axios from 'axios'; + +const { OWL_SERVICE_KEY, OWL_URL } = env; + +const headers = { + Authorization: `Bearer ${OWL_SERVICE_KEY}` +}; + +export const POST = async ({ locals, request }) => { + //* Verify user perms + if (!locals.user) { + return json(new APIError('Unauthorized'), { status: 401 }); + } + + try { + const importProjectRes = await axios.post( + `${OWL_URL}/api/v2/projects/import/parquet`, + await request.formData(), + { + headers: { + ...headers, + 'Content-Type': 'multipart/form-data' + } + } + ); + if (importProjectRes.status != 200) { + logger.error('PROJECT_IMPORT_IMPORT', importProjectRes.data); + return json(new APIError('Failed to import project', importProjectRes.data as any), { + status: importProjectRes.status + }); + } else { + return new Response(importProjectRes.data); + } + } catch (err) { + //@ts-expect-error AxiosError + logger.error('PROJECT_IMPORT_IMPORT', err?.response?.data); + //@ts-expect-error AxiosError + return json(new APIError('Failed to import project', err?.response?.data), { + status: 500 + }); + } +}; diff --git a/services/app/src/routes/join-organization/+page.server.ts b/services/app/src/routes/join-organization/+page.server.ts new file mode 100755 index 0000000..125a132 --- /dev/null +++ b/services/app/src/routes/join-organization/+page.server.ts @@ -0,0 +1,80 @@ +import { env } from '$env/dynamic/private'; +import logger from '$lib/logger.js'; +import { error, redirect } from '@sveltejs/kit'; + +const { OWL_SERVICE_KEY, OWL_URL } = env; + +const headers = { + Authorization: `Bearer ${OWL_SERVICE_KEY}` +}; + +export const load = async ({ locals, url, parent }) => { + await parent(); + const token = url.searchParams.get('token'); + + if (token) { + if (!locals.user) { + throw error(401, 'Unauthorized'); + } + + const inviteUserRes = await fetch( + `${OWL_URL}/api/v2/organizations/members?${new URLSearchParams([ + ['user_id', locals.user.id], + ['invite_code', token] + ])}`, + { + method: 'POST', + headers: { + ...headers, + 'x-user-id': locals.user?.id ?? '' + } + } + ); + + const inviteUserBody = await inviteUserRes.json(); + if (!inviteUserRes.ok) { + if (inviteUserRes.status !== 404) { + logger.error('INVITEORG_TOKEN_ERROR', inviteUserBody); + } + throw error(inviteUserRes.status, inviteUserBody.message || JSON.stringify(inviteUserBody)); + } else { + throw redirect(302, '/'); + } + } +}; + +export const actions = { + /** Form actions method of invite */ + // default: async function ({ locals, request }) { + // if (!locals.user) { + // return error(401, 'Unauthorized'); + // } + // const formdata = await request.formData(); + // const code = formdata.get('code'); + // if (!code || typeof code !== 'string' || code.trim() === '') { + // return fail(400, new APIError('Code (type string) is required').getSerializable()); + // } + // const response = await fetch( + // `${OWL_URL}/api/v2/organizations/members?${new URLSearchParams([ + // ['user_id', locals.user?.id ?? ''], + // ['invite_code', code] + // ])}`, + // { + // method: 'POST', + // headers: { + // ...headers, + // // 'x-user-id': locals.user.id || '', + // 'Content-Type': 'application/json' + // } + // } + // ); + // const responseBody = await response.json(); + // if (response.ok) { + // return responseBody?.organization; + // } + // return fail( + // response.status, + // new APIError('Failed to join organization', responseBody).getSerializable() + // ); + // } +}; diff --git a/services/app/src/routes/join-organization/+page.svelte b/services/app/src/routes/join-organization/+page.svelte new file mode 100644 index 0000000..85bd8db --- /dev/null +++ b/services/app/src/routes/join-organization/+page.svelte @@ -0,0 +1,112 @@ + + + + Join Organization + + +
    + + +
    +
    +

    Join Organization

    +

    + Enter the code provided by your organization administrator +

    +
    + +
    { + loading = true; + return async ({ update, result }) => { + //@ts-ignore + const data = result.data; + errorMessage = data?.err_message?.message || ''; + if (result.type === 'failure') { + toast.error(data.error, { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else if (result.type === 'success') { + activeOrganization.setOrgCookie(data?.data?.id ?? $activeOrganization?.id); + await goto('/'); + } + loading = false; + await update(); + }; + }} + class="mt-10 space-y-6" + > +
    + + {#snippet children({ cells })} + + {#each cells as cell (cell)} + + {/each} + + {/snippet} + + + {#if errorMessage} + + {errorMessage} + + {/if} +
    +
    + +
    +

    Don't have a code? Contact your organization administrator for Invitation.

    +
    + +
    + + +
    +
    +
    diff --git a/services/app/src/routes/join-project/+page.server.ts b/services/app/src/routes/join-project/+page.server.ts new file mode 100755 index 0000000..318a266 --- /dev/null +++ b/services/app/src/routes/join-project/+page.server.ts @@ -0,0 +1,80 @@ +import { env } from '$env/dynamic/private'; +import logger from '$lib/logger.js'; +import { error, redirect } from '@sveltejs/kit'; + +const { OWL_SERVICE_KEY, OWL_URL } = env; + +const headers = { + Authorization: `Bearer ${OWL_SERVICE_KEY}` +}; + +export const load = async ({ locals, url, parent }) => { + await parent(); + const token = url.searchParams.get('token'); + + if (token) { + if (!locals.user) { + throw error(401, 'Unauthorized'); + } + + const inviteUserRes = await fetch( + `${OWL_URL}/api/v2/projects/members?${new URLSearchParams([ + ['user_id', locals.user.id], + ['invite_code', token] + ])}`, + { + method: 'POST', + headers: { + ...headers, + 'x-user-id': locals.user?.id ?? '' + } + } + ); + + const inviteUserBody = await inviteUserRes.json(); + if (!inviteUserRes.ok) { + if (inviteUserRes.status !== 404) { + logger.error('INVITEPROJ_TOKEN_ERROR', inviteUserBody); + } + throw error(inviteUserRes.status, inviteUserBody.message || JSON.stringify(inviteUserBody)); + } else { + throw redirect(302, '/'); + } + } +}; + +export const actions = { + /** Form actions method of invite */ + // default: async function ({ locals, request }) { + // if (!locals.uer) { + // return error(401, 'Unauthorized'); + // } + // const formdata = await request.formData(); + // const code = formdata.get('code'); + // if (!code || typeof code !== 'string' || code.trim() === '') { + // return fail(400, new APIError('Code (type string) is required').getSerializable()); + // } + // const response = await fetch( + // `${OWL_URL}/api/v2/organizations/members?${new URLSearchParams([ + // ['user_id', locals.user?.id ?? ''], + // ['invite_code', code] + // ])}`, + // { + // method: 'POST', + // headers: { + // ...headers, + // // 'x-user-id': locals.user.id || '', + // 'Content-Type': 'application/json' + // } + // } + // ); + // const responseBody = await response.json(); + // if (response.ok) { + // return responseBody?.organization; + // } + // return fail( + // response.status, + // new APIError('Failed to join organization', responseBody).getSerializable() + // ); + // } +}; diff --git a/services/app/src/routes/join-project/+page.svelte b/services/app/src/routes/join-project/+page.svelte new file mode 100644 index 0000000..490cb91 --- /dev/null +++ b/services/app/src/routes/join-project/+page.svelte @@ -0,0 +1,110 @@ + + + + Join Organization + + +
    + + +
    +
    +

    Join Project

    +

    + Enter the code provided by your project administrator +

    +
    + +
    { + loading = true; + return async ({ update, result }) => { + //@ts-ignore + const data = result.data; + errorMessage = data?.err_message?.message || ''; + if (result.type === 'failure') { + toast.error(data.error, { + id: data?.err_message?.message || JSON.stringify(data), + description: CustomToastDesc as any, + componentProps: { + description: data?.err_message?.message || JSON.stringify(data), + requestID: data?.err_message?.request_id ?? '' + } + }); + } else if (result.type === 'success') { + activeOrganization.setOrgCookie(data?.data?.id ?? $activeOrganization?.id); + await goto('/'); + } + loading = false; + await update(); + }; + }} + class="mt-10 space-y-6" + > +
    + + {#snippet children({ cells })} + + {#each cells as cell (cell)} + + {/each} + + {/snippet} + + + {#if errorMessage} + + {errorMessage} + + {/if} +
    +
    + +
    +

    Don't have a code? Contact your project administrator for Invitation.

    +
    + +
    + + +
    +
    +
    diff --git a/services/app/src/routes/login/+page.svelte b/services/app/src/routes/login/+page.svelte new file mode 100644 index 0000000..aa0ecd8 --- /dev/null +++ b/services/app/src/routes/login/+page.svelte @@ -0,0 +1,138 @@ + + + + Log in | JamAI Base + + +
    +
    +
    + +

    Welcome back

    +
    + +
    +
    +
    + + +
    + +
    + + + +
    +
    + + {#if error} + + {error} + + {/if} + + +
    + +
    + Don't have an account? + +
    +
    +
    diff --git a/services/app/src/routes/login/auth-errors.ts b/services/app/src/routes/login/auth-errors.ts new file mode 100644 index 0000000..f8adaf0 --- /dev/null +++ b/services/app/src/routes/login/auth-errors.ts @@ -0,0 +1,22 @@ +export const AUTH_ERROR_MESSAGES = { + invalid_credentials: 'Invalid email or password', + default: 'An error occurred during authentication', + user_exists: 'An account with this email already exists', + user_not_found: 'No user found with the email provided', + weak_password: 'Password should be at least 8 characters long', + email_verification: 'Please verify your email address', + account_disabled: 'Your account has been disabled', + rate_limit: 'Too many attempts. Please try again later', + invalid_token: 'Your session has expired. Please sign in again', + server_error: 'Server error. Please try again later' +}; + +// Type to ensure the code is one of the keys in AUTH_ERROR_MESSAGES +export type TCode = keyof typeof AUTH_ERROR_MESSAGES; + +export const getAuthErrorMessage = (code?: string | null): string => { + if (code && code in AUTH_ERROR_MESSAGES) { + return AUTH_ERROR_MESSAGES[code as TCode]; + } + return AUTH_ERROR_MESSAGES.default; +}; diff --git a/services/app/src/routes/new-organization/+page.server.ts b/services/app/src/routes/new-organization/+page.server.ts new file mode 100755 index 0000000..bd0a407 --- /dev/null +++ b/services/app/src/routes/new-organization/+page.server.ts @@ -0,0 +1,7 @@ +import { redirect } from '@sveltejs/kit'; + +export async function load({ locals }) { + if (!locals.user) { + throw redirect(302, '/'); + } +} diff --git a/services/app/src/routes/new-organization/+page.svelte b/services/app/src/routes/new-organization/+page.svelte new file mode 100755 index 0000000..a4bba32 --- /dev/null +++ b/services/app/src/routes/new-organization/+page.svelte @@ -0,0 +1,190 @@ + + + + Create new organization + + +
    + + +
    + {#if (user?.org_memberships ?? []).length > 0} + Create a new organization + {:else} + Welcome,
    let's get you ready! + {/if} +
    + +
    + +
    +
    + + + + + + Your organization's display name. You can change this later. + +
    + + +
    + +
    + +
    +
    + + +
    + {#if (user?.org_memberships ?? []).length > 0} + + {:else} + + {/if} + + +
    +
    +
    +
    +
    + + diff --git a/services/app/src/routes/register/+page.svelte b/services/app/src/routes/register/+page.svelte new file mode 100644 index 0000000..5aa20ad --- /dev/null +++ b/services/app/src/routes/register/+page.svelte @@ -0,0 +1,177 @@ + + + + Sign Up | JamAI Base + + +
    +
    +
    + +

    Create account

    +
    + +
    +
    +
    + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + +
    +
    + + {#if error} + + {error} + + {/if} + + +
    + +
    + Already have an account? + +
    +
    +
    diff --git a/services/app/src/routes/register/auth-errors.ts b/services/app/src/routes/register/auth-errors.ts new file mode 100644 index 0000000..697b0d4 --- /dev/null +++ b/services/app/src/routes/register/auth-errors.ts @@ -0,0 +1,22 @@ +export const AUTH_ERROR_MESSAGES = { + invalid_credentials: 'Invalid username or password', + default: 'An error occurred during authentication', + user_exists: 'An account with this username already exists', + user_not_found: 'No user found with the username provided', + weak_password: 'Password should be at least 8 characters long', + email_verification: 'Please verify your email address', + account_disabled: 'Your account has been disabled', + rate_limit: 'Too many attempts. Please try again later', + invalid_token: 'Your session has expired. Please sign in again', + server_error: 'Server error. Please try again later' +}; + +// Type to ensure the code is one of the keys in AUTH_ERROR_MESSAGES +export type TCode = keyof typeof AUTH_ERROR_MESSAGES; + +export const getAuthErrorMessage = (code?: string | null): string => { + if (code && code in AUTH_ERROR_MESSAGES) { + return AUTH_ERROR_MESSAGES[code as TCode]; + } + return AUTH_ERROR_MESSAGES.default; +}; diff --git a/services/app/src/routes/verify-email/+page.server.ts b/services/app/src/routes/verify-email/+page.server.ts new file mode 100755 index 0000000..30b79e6 --- /dev/null +++ b/services/app/src/routes/verify-email/+page.server.ts @@ -0,0 +1,301 @@ +import { env } from '$env/dynamic/private'; +import { emailCodeCooldownSecs } from '$lib/constants.js'; +import logger, { APIError } from '$lib/logger.js'; +import { error, fail, redirect } from '@sveltejs/kit'; +import { ManagementClient } from 'auth0'; + +const { + AUTH0_CLIENT_ID, + AUTH0_ISSUER_BASE_URL, + AUTH0_MGMTAPI_CLIENT_ID, + AUTH0_MGMTAPI_CLIENT_SECRET, + ORIGIN, + OWL_SERVICE_KEY, + OWL_URL, + RESEND_API_KEY +} = env; + +const management = new ManagementClient({ + domain: AUTH0_ISSUER_BASE_URL?.replace('https://', '') ?? '', + clientId: AUTH0_MGMTAPI_CLIENT_ID ?? '', + clientSecret: AUTH0_MGMTAPI_CLIENT_SECRET ?? '' +}); + +const headers = { + Authorization: `Bearer ${OWL_SERVICE_KEY}` +}; + +export async function load({ locals, url }) { + if (!locals.user || locals.user.email_verified) { + throw redirect(302, '/'); + } + + const token = url.searchParams.get('token'); + + if (token) { + const verifyUserRes = await fetch( + `${OWL_URL}/api/v2/users/verify/email?${new URLSearchParams([['verification_code', token]])}`, + { + method: 'POST', + headers: { + ...headers, + 'x-user-id': locals.user.id ?? '' + } + } + ); + + const verifyUserBody = await verifyUserRes.json(); + if (!verifyUserRes.ok) { + if (verifyUserRes.status !== 404) { + logger.error('VERIFYEMAIL_LOAD_TOKEN', verifyUserBody); + } + throw error(verifyUserRes.status, verifyUserBody.message || JSON.stringify(verifyUserBody)); + } else { + throw redirect(302, '/'); + } + } + + const listCodesRes = await fetch( + `${OWL_URL}/api/v2/users/verify/email/code/list?${new URLSearchParams([ + ['limit', '1'], + ['search_query', locals.user.email], + ['search_columns', 'user_email'] + ])}`, + { + headers: { + ...headers, + 'x-user-id': '0' + } + } + ); + const listCodesBody = await listCodesRes.json(); + + if ( + listCodesRes.ok && + (!listCodesBody.items[0] || + new Date(listCodesBody.items[0]?.expiry).getTime() < new Date().getTime()) + ) { + const sendCodeRes = await fetch( + `${OWL_URL}/api/v2/users/verify/email/code?${new URLSearchParams([ + ['user_email', locals.user.email], + ['valid_days', '1'] + ])}`, + { + method: 'POST', + headers: { + ...headers, + 'x-user-id': '0' + } + } + ); + const sendCodeBody = await sendCodeRes.json(); + + if (!sendCodeRes.ok) { + logger.error('VERIFYEMAIL_LOAD_GETCODE', sendCodeBody); + } else { + const sendEmailRes = await fetch('https://api.resend.com/emails', { + method: 'POST', + headers: { + Authorization: `Bearer ${RESEND_API_KEY}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + from: 'JamAI Base ', + to: locals.user.email, + subject: 'Verify your JamAI Base email address', + html: getVerificationEmailBody(sendCodeBody.id) + }) + }); + + if (!sendEmailRes.ok) { + logger.error('VERIFYEMAIL_LOAD_SENDCODE', await sendEmailRes.json()); + } + } + } +} + +export const actions = { + 'resend-verification-email': async ({ locals }) => { + //* Verify user perms + if (!locals.user) { + return fail(401, new APIError('Unauthorized').getSerializable()); + } + + if (locals.auth0Mode) { + try { + const resendEmailRes = await management.jobs.verifyEmail({ + user_id: locals.user.sub!, + client_id: AUTH0_CLIENT_ID + }); + if (resendEmailRes.status !== 200 && resendEmailRes.status !== 201) { + logger.error('VERIFY_RESEND_EMAIL', resendEmailRes.data); + return fail( + resendEmailRes.status, + new APIError( + 'Failed to resend verification email', + resendEmailRes.data as any + ).getSerializable() + ); + } else { + return resendEmailRes.data; + } + } catch (err) { + logger.error('VERIFY_RESEND_EMAILERR', err); + return fail( + 500, + new APIError('Failed to resend verification email', err as any).getSerializable() + ); + } + } else { + try { + //? Check if resend cooldown is up + const response = await fetch( + `${OWL_URL}/api/v2/users/verify/email/code/list?${new URLSearchParams([ + ['limit', '1'], + ['search_query', locals.user!.email], + ['search_columns', 'user_email'] + ])}`, + { + headers: { + ...headers, + 'x-user-id': '0' + } + } + ); + const responseBody = await response.json(); + + if (response.ok) { + if ( + new Date().getTime() - new Date(responseBody.items[0]?.created_at).getTime() > + emailCodeCooldownSecs * 1000 + ) { + const sendCodeRes = await fetch( + `${OWL_URL}/api/v2/users/verify/email/code?${new URLSearchParams([ + ['user_email', locals.user.email], + ['valid_days', '1'] + ])}`, + { + method: 'POST', + headers: { + ...headers, + 'x-user-id': '0' + } + } + ); + const sendCodeBody = await sendCodeRes.json(); + + if (!sendCodeRes.ok) { + logger.error('VERIFYEMAIL_RESEND_GETCODE', sendCodeBody); + } else { + const sendEmailRes = await fetch('https://api.resend.com/emails', { + method: 'POST', + headers: { + Authorization: `Bearer ${RESEND_API_KEY}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + from: 'JamAI Base ', + to: locals.user.email, + subject: 'Verify your JamAI Base email address', + html: getVerificationEmailBody(sendCodeBody.id) + }) + }); + + if (!sendEmailRes.ok) { + logger.error('VERIFYEMAIL_RESEND_SENDCODE', await sendEmailRes.json()); + } + } + } else { + return fail( + 403, + new APIError( + 'Too many resend verification email requests, please wait.' + ).getSerializable() + ); + } + } else { + logger.error('VERIFY_RESEND_LISTCODE', responseBody); + return fail( + 500, + new APIError('Failed to resend verification email', responseBody).getSerializable() + ); + } + } catch (err) { + logger.error('VERIFY_RESEND_EMAILERR', err); + return fail( + 500, + new APIError('Failed to resend verification email', err as any).getSerializable() + ); + } + } + } +}; + +const getVerificationEmailBody = (verificationToken: string) => ` + + + + +
    + + + + +
    +
    +

    + JamAI Logo +

    +

    Verify your email address

    +

    Welcome to JamAI Base! To complete your account setup, please verify your email address by clicking the link below:

    + + + + +
    + + Verify Email Address + +
    +

    If the button above doesn't work, you can copy and paste this link into your browser:

    +

    ${ORIGIN}/verify-email?token=${verificationToken}

    +

    This verification link will expire in 24 hours.

    +
    + Thanks! +
    + JamAI Base +

    +
    +

    + If you did not make this request, you can ignore this mail. +

    +
    +
    +
    + +`; diff --git a/services/app/src/routes/verify-email/+page.svelte b/services/app/src/routes/verify-email/+page.svelte new file mode 100755 index 0000000..de5f18c --- /dev/null +++ b/services/app/src/routes/verify-email/+page.svelte @@ -0,0 +1,89 @@ + + + + Verify your email + + +
    +
    +
    +
    + {#if user?.picture_url} + User Avatar + {:else} + + {((page.data.user as User).name ?? 'Default User').charAt(0)} + + {/if} +
    + {user?.email} +
    + +

    Verify your email

    +

    + We've sent you an email with a link to verify your email address. Please check your inbox and + click the link to continue. +

    + +
    +
    { + if (emailResent) { + cancel(); + } else { + isLoading = true; + } + + return async ({ result, update }) => { + if (result.type !== 'success') { + toast.error('Error resending verification email', { + //@ts-ignore + id: result.data?.err_message?.message || JSON.stringify(result.data), + description: CustomToastDesc as any, + componentProps: { + //@ts-ignore + description: result.data?.err_message?.message || JSON.stringify(result.data) + } + }); + } else { + emailResent = true; + } + + isLoading = false; + update({ reset: result.type === 'success', invalidateAll: false }); + }; + }} + method="POST" + action="?/resend-verification-email" + > + +
    + + +
    + +

    + Verification email sent, please check your inbox. +

    +
    +
    diff --git a/services/app/src/showdown-theme.css b/services/app/src/showdown-theme.css old mode 100644 new mode 100755 diff --git a/services/app/static/favicon.ico b/services/app/static/favicon.ico old mode 100644 new mode 100755 diff --git a/services/app/static/favicon.png b/services/app/static/favicon.png old mode 100644 new mode 100755 diff --git a/services/app/static/jamai-onboarding-bg.svg b/services/app/static/jamai-onboarding-bg.svg old mode 100644 new mode 100755 diff --git a/services/app/static/logo.png b/services/app/static/logo.png old mode 100644 new mode 100755 diff --git a/services/app/svelte.config.js b/services/app/svelte.config.js old mode 100644 new mode 100755 diff --git a/services/app/tailwind.config.js b/services/app/tailwind.config.js old mode 100644 new mode 100755 index 07dd773..d1718ff --- a/services/app/tailwind.config.js +++ b/services/app/tailwind.config.js @@ -93,10 +93,25 @@ const config = { '100%': { opacity: '0' } + }, + 'accordion-down': { + from: { height: '0' }, + to: { height: 'var(--bits-accordion-content-height)' } + }, + 'accordion-up': { + from: { height: 'var(--bits-accordion-content-height)' }, + to: { height: '0' } + }, + 'caret-blink': { + '0%,70%,100%': { opacity: '1' }, + '20%,50%': { opacity: '0' } } }, animation: { - blink: 'blink 1060ms steps(1) infinite' + blink: 'blink 1060ms steps(1) infinite', + 'accordion-down': 'accordion-down 0.2s ease-out', + 'accordion-up': 'accordion-up 0.2s ease-out', + 'caret-blink': 'caret-blink 1.25s ease-out infinite' } } }, diff --git a/services/app/tests/auth.setup.ts b/services/app/tests/auth.setup.ts old mode 100644 new mode 100755 index d3babec..b7da7c5 --- a/services/app/tests/auth.setup.ts +++ b/services/app/tests/auth.setup.ts @@ -1,11 +1,12 @@ +import { test as setup } from '@playwright/test'; import 'dotenv/config'; import { existsSync } from 'fs'; -import { test as setup } from '@playwright/test'; +const ossMode = !process.env.OWL_SERVICE_KEY; const authFile = 'playwright/.auth/user.json'; setup('authenticate', async ({ browser, page }) => { - if (process.env.PUBLIC_IS_LOCAL === 'false') { + if (!ossMode) { if (existsSync(authFile)) { await page.close(); const context = await browser.newContext({ storageState: authFile }); @@ -16,13 +17,11 @@ setup('authenticate', async ({ browser, page }) => { const isCredentialsValid = !/.*\/(login)/.test(page.url()); if (!isCredentialsValid) { - await page.getByLabel('Email address').fill(process.env.TEST_ACC_EMAIL!); - await page.getByLabel('Password').fill(process.env.TEST_ACC_PW!); - await page.getByRole('button', { name: 'Continue', exact: true }).click(); + await page.getByPlaceholder('Username').fill(process.env.TEST_USER_USERNAME!); + await page.getByPlaceholder('Password').fill(process.env.TEST_USER_PASSWORD!); + await page.getByRole('button', { name: 'Login', exact: true }).click(); } - await page.goto('/'); - await page.waitForURL(/.*\/(project|new-organization)/); await page.context().storageState({ path: authFile }); diff --git a/services/app/tests/fixtures/sample-csv.csv b/services/app/tests/fixtures/sample-csv.csv old mode 100644 new mode 100755 diff --git a/services/app/tests/fixtures/sample-data.json b/services/app/tests/fixtures/sample-data.json new file mode 100644 index 0000000..2522cd8 --- /dev/null +++ b/services/app/tests/fixtures/sample-data.json @@ -0,0 +1,122 @@ +{ + "admin": { + "email": "admin@example.com", + "name": "Sue Dough", + "username": "admin", + "password": "password" + }, + + "test": { + "email": "test@example.com", + "name": "Deepak Gurr", + "username": "test-user", + "password": "password" + }, + + "chat_model": { + "meta": { + "icon": "openai" + }, + "id": "openai/gpt-4o", + "name": "OpenAI GPT-4o", + "type": "llm", + "context_length": 1047576, + "capabilities": ["chat", "image"], + "languages": ["en", "fr"], + "llm_input_cost_per_mtoken": 2.5, + "llm_output_cost_per_mtoken": 10.0 + }, + "chat_model_deployment": { + "model_id": "openai/gpt-4o", + "name": "OpenAI GPT-4o", + "provider": "openai", + "routing_id": "openai/gpt-4o", + "api_base": "" + }, + + "embedding_model": { + "meta": { + "icon": "openai" + }, + "id": "openai/text-embedding-3-small-1536", + "name": "OpenAI Text Embedding 3 Small (1536-dim)", + "type": "embed", + "context_length": 8192, + "capabilities": ["embed"], + "languages": ["en", "fr"], + "embedding_size": 1536, + "embedding_cost_per_mtoken": 0.022 + }, + "embedding_model_deployment": { + "model_id": "openai/text-embedding-3-small-1536", + "name": "OpenAI Text Embedding 3 Small (1536-dim)", + "provider": "openai", + "routing_id": "openai/text-embedding-3-small", + "api_base": "" + }, + + "pro_plan": { + "name": "Pro", + "stripe_price_id_live": "a", + "stripe_price_id_test": "a", + "flat_cost": 0, + "credit_grant": 10000, + "max_users": 20, + "products": { + "llm_tokens": { + "name": "LLM Tokens", + "included": { + "unit_cost": 0, + "up_to": 100000 + }, + "tiers": [], + "unit": "tokens" + }, + "embedding_tokens": { + "name": "Embedding Tokens", + "included": { + "unit_cost": 0, + "up_to": 100000 + }, + "tiers": [], + "unit": "Tokens" + }, + "reranker_searches": { + "name": "Reranker Searches", + "included": { + "unit_cost": 0, + "up_to": 100000 + }, + "tiers": [], + "unit": "Searches" + }, + "db_storage": { + "name": "DB Storage", + "included": { + "unit_cost": 0, + "up_to": 100000 + }, + "tiers": [], + "unit": "bytes" + }, + "file_storage": { + "name": "File Storage", + "included": { + "unit_cost": 0, + "up_to": 100000 + }, + "tiers": [], + "unit": "bytes" + }, + "egress": { + "name": "Egress", + "included": { + "unit_cost": 0, + "up_to": 100000 + }, + "tiers": [], + "unit": "bytes" + } + } + } +} diff --git a/services/app/tests/fixtures/sample-doc.txt b/services/app/tests/fixtures/sample-doc.txt old mode 100644 new mode 100755 diff --git a/services/app/tests/fixtures/sample-img.jpg b/services/app/tests/fixtures/sample-img.jpg old mode 100644 new mode 100755 diff --git a/services/app/tests/main.setup.ts b/services/app/tests/main.setup.ts old mode 100644 new mode 100755 index 66d58ea..a9c1329 --- a/services/app/tests/main.setup.ts +++ b/services/app/tests/main.setup.ts @@ -1,36 +1,288 @@ -import 'dotenv/config'; +import type { GenTableCol, ModelConfig } from '$lib/types'; import { test as setup } from '@playwright/test'; -import type { AvailableModel, GenTableCol } from '$lib/types'; +import 'dotenv/config'; +import { readFileSync } from 'fs'; +import Stripe from 'stripe'; -const { JAMAI_URL, JAMAI_SERVICE_KEY, TEST_ACC_USERID } = process.env; +const { OWL_URL, OWL_SERVICE_KEY, OWL_STRIPE_API_KEY } = process.env; +const stripe = new Stripe(OWL_STRIPE_API_KEY!); -setup('create org and tables', async () => { - const headers = { - Authorization: `Bearer ${JAMAI_SERVICE_KEY}` - }; +const testDataFile = 'tests/fixtures/sample-data.json'; +const headers = { + Authorization: `Bearer ${OWL_SERVICE_KEY}` +}; + +//TODO: Clean slate tests with teardown +setup.skip('create users', async () => { + const users = JSON.parse(readFileSync(testDataFile, 'utf-8')); + // const getUserRes = await fetch( + // `${OWL_URL}/api/v2/users?${new URLSearchParams([['user_id', '0']])}`, + // { + // headers + // } + // ); + // const getUserBody = await getUserRes.json(); + + // if (!getUserRes.ok) { + // if (getUserRes.status !== 404) throw { code: 'get_user', ...getUserBody }; + + // } + const createAdminRes = await fetch(`${OWL_URL}/api/v2/auth/register/password`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(users.admin) + }); + const createAdminBody = await createAdminRes.json(); + + if (!createAdminRes.ok) throw { code: 'create_admin_user', ...createAdminBody }; + + const createTestUserRes = await fetch(`${OWL_URL}/api/v2/auth/register/password`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(users.test) + }); + const createTestUserBody = await createTestUserRes.json(); + + if (!createTestUserRes.ok) throw { code: 'create_test_user', ...createTestUserBody }; + + process.env.TEST_ADMIN_ID = createAdminBody.id; + process.env.TEST_USER_ID = createTestUserBody.id; + + // Verify accounts + const verifyAdminRes = await fetch(`${OWL_URL}/api/v2/users`, { + method: 'POST', + headers: { + ...headers, + 'x-user-id': process.env.TEST_ADMIN_ID!, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + id: process.env.TEST_ADMIN_ID!, + email_verified: true + }) + }); + if (!verifyAdminRes.ok) throw { code: 'verify_admin_user', ...(await verifyAdminRes.json()) }; + + const verifyTestUserRes = await fetch(`${OWL_URL}/api/v2/users`, { + method: 'POST', + headers: { + ...headers, + 'x-user-id': process.env.TEST_ADMIN_ID!, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + id: process.env.TEST_USER_ID!, + email_verified: true + }) + }); + if (!verifyTestUserRes.ok) + throw { code: 'verify_test_user', ...(await verifyTestUserRes.json()) }; +}); + +setup.skip('add model config and deployment', async () => { + // const modelPresetsRes = await fetch( + // 'https://raw.githubusercontent.com/EmbeddedLLM/JamAIBase/refs/heads/main/services/api/src/owl/configs/preset_models.json', + // { + // method: 'GET' + // } + // ); + + // if (!modelPresetsRes.ok) { + // const error = await modelPresetsRes.text(); + // throw { code: '', status: modelPresetsRes.status, message: error }; + // } + + // const modelPresetsBody = (await modelPresetsRes.json()) as ModelConfig[]; + + const models = JSON.parse(readFileSync(testDataFile, 'utf-8')); + + const createModelConfigs = await Promise.allSettled([ + createModelConfig(models.chat_table), + createModelConfig(models.embedding_model) + ]); + + if (createModelConfigs.some((val) => val.status === 'rejected')) { + const rejected = await Promise.all( + createModelConfigs.flatMap((val, index) => + val.status === 'rejected' ? { index, ...val.reason } : [] + ) + ); + throw { code: 'create_model_configs', rejected }; + } + + const createModelDeployments = await Promise.allSettled([ + createModelDeployment(models.chat_model_deployment), + createModelDeployment(models.embedding_model_deployment) + ]); + + if (createModelDeployments.some((val) => val.status === 'rejected')) { + const rejected = await Promise.all( + createModelDeployments.flatMap((val, index) => + val.status === 'rejected' ? { index, ...val.reason } : [] + ) + ); + throw { code: 'create_model_deployments', rejected }; + } + + async function createModelConfig(body: any) { + const createModelConfigRes = await fetch(`${OWL_URL}/api/v2/models/configs`, { + method: 'POST', + headers: { + ...headers, + 'x-user-id': process.env.TEST_ADMIN_ID!, + 'Content-Type': 'application/json' + }, + body: JSON.stringify(body) + }); + if (!createModelConfigRes.ok) throw await createModelConfigRes.json(); + } - const createOrgRes = await fetch(`${JAMAI_URL}/api/admin/backend/v1/organizations`, { + async function createModelDeployment(body: any) { + const createModelDeploymentRes = await fetch(`${OWL_URL}/api/v2/models/deployment/cloud`, { + method: 'POST', + headers: { + ...headers, + 'x-user-id': process.env.TEST_ADMIN_ID!, + 'Content-Type': 'application/json' + }, + body: JSON.stringify(body) + }); + if (!createModelDeploymentRes.ok) throw await createModelDeploymentRes.json(); + } +}); + +setup.skip('create price plans', async () => { + const prices = JSON.parse(readFileSync(testDataFile, 'utf-8')); + + const createPlans = await Promise.allSettled([createPricePlan(prices.pro_plan)]); + + if (createPlans.some((val) => val.status === 'rejected')) { + const rejected = await Promise.all( + createPlans.flatMap((val, index) => + val.status === 'rejected' ? { index, ...val.reason } : [] + ) + ); + throw { code: 'create_price_plan', rejected }; + } + + process.env.TEST_PRO_PLAN_ID = + createPlans[0].status === 'fulfilled' ? createPlans[0].value.id : null; + + async function createPricePlan(body: any) { + const createPricePlanRes = await fetch(`${OWL_URL}/api/v2/prices/plans`, { + method: 'POST', + headers: { + ...headers, + 'x-user-id': process.env.TEST_ADMIN_ID!, + 'Content-Type': 'application/json' + }, + body: JSON.stringify(body) + }); + const createPricePlanBody = await createPricePlanRes.json(); + + if (!createPricePlanRes.ok) throw createPricePlanBody; + return createPricePlanBody; + } +}); + +setup.skip('create admin org', async () => { + const createOrgRes = await fetch(`${OWL_URL}/api/v2/organizations`, { + method: 'POST', + headers: { + ...headers, + 'x-user-id': process.env.TEST_ADMIN_ID!, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + name: 'admin-org', + currency: 'USD' + }) + }); + const createOrgBody = await createOrgRes.json(); + + if (!createOrgRes.ok) throw { code: 'create_org', ...createOrgBody }; + + process.env.TEST_ADMIN_ORGID = createOrgBody.id; +}); + +setup('create org and tables', async () => { + const createOrgRes = await fetch(`${OWL_URL}/api/v2/organizations`, { method: 'POST', headers: { ...headers, + 'x-user-id': process.env.TEST_USER_ID!, 'Content-Type': 'application/json' }, body: JSON.stringify({ - creator_user_id: TEST_ACC_USERID, name: 'test-org', - tier: 'team' + currency: 'USD' }) }); const createOrgBody = await createOrgRes.json(); - console.log(createOrgBody); + if (!createOrgRes.ok) throw { code: 'create_org', ...createOrgBody }; const organizationId = createOrgBody.id; - const createProjectRes = await fetch(`${JAMAI_URL}/api/admin/org/v1/projects`, { + //stripe add payment method and subscribe plan + const paymentMethod = await stripe.paymentMethods.create({ + type: 'card', + card: { + token: 'tok_visa' + } + }); + await stripe.paymentMethods.attach(paymentMethod.id, { + customer: createOrgBody.stripe_id + }); + await stripe.customers.update(createOrgBody.stripe_id, { + invoice_settings: { + default_payment_method: paymentMethod.id + } + }); + + const changeOrgPlanRes = await fetch( + `${OWL_URL}/api/v2/organizations/plan?${new URLSearchParams([ + ['organization_id', organizationId], + ['price_plan_id', process.env.TEST_TEAM_PLAN_ID!] + ])}`, + { + method: 'PATCH', + headers: { + ...headers, + 'x-user-id': process.env.TEST_USER_ID! + } + } + ); + const changeOrgPlanBody = await changeOrgPlanRes.json(); + + if (!changeOrgPlanRes.ok) throw { code: 'change_org_plan', ...changeOrgPlanBody }; + + // await stripe.paymentIntents.confirm(changeOrgPlanBody.payment_intent_id); + + // Add credits + const addCreditsRes = await fetch(`${OWL_URL}/api/v2/organizations`, { + method: 'PATCH', + headers: { + ...headers, + 'x-user-id': process.env.TEST_USER_ID!, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + id: organizationId, + credit: 10000 + }) + }); + if (!addCreditsRes.ok) throw { code: 'add_org_credits', ...(await addCreditsRes.json()) }; + + const createProjectRes = await fetch(`${OWL_URL}/api/v2/projects`, { method: 'POST', headers: { ...headers, + 'x-user-id': process.env.TEST_USER_ID!, 'Content-Type': 'application/json' }, body: JSON.stringify({ @@ -43,6 +295,8 @@ setup('create org and tables', async () => { const projectId = createProjectBody.id; + const models = JSON.parse(readFileSync(testDataFile, 'utf-8')); + const createTestTables = await Promise.allSettled([ createTable('action', 'test-action-table', [ { @@ -59,7 +313,7 @@ setup('create org and tables', async () => { index: true, gen_config: { object: 'gen_config.llm', - model: 'anthropic/claude-3-haiku-20240307', + model: models.chat_model.id, multi_turn: false } } @@ -67,7 +321,7 @@ setup('create org and tables', async () => { createTable('action', 'test-action-table-file', [ { id: 'Input', - dtype: 'file', + dtype: 'image', vlen: 0, index: true, gen_config: null @@ -79,7 +333,7 @@ setup('create org and tables', async () => { index: true, gen_config: { object: 'gen_config.llm', - model: 'openai/gpt-4o', + model: models.chat_model.id, multi_turn: false } } @@ -101,7 +355,7 @@ setup('create org and tables', async () => { index: true, gen_config: { object: 'gen_config.llm', - model: 'anthropic/claude-3-haiku-20240307', + model: models.chat_model.id, multi_turn: true } } @@ -125,7 +379,7 @@ setup('create org and tables', async () => { if (createConvs.some((val) => val.status === 'rejected')) { const rejected = await Promise.all( createConvs.flatMap((val, index) => - val.status === 'rejected' ? { index, ...val.reason } : [] + val.status === 'rejected' ? { index, ...(val.reason ? val.reason : { reason: val }) } : [] ) ); throw { code: 'create_test_convs', rejected }; @@ -136,36 +390,39 @@ setup('create org and tables', async () => { tableName: string, cols: GenTableCol[] ) { + await new Promise((r) => setTimeout(r, Math.floor(Math.random() * 3000))); + let embeddingModel; if (tableType === 'knowledge') { const modelsRes = await fetch( - `${JAMAI_URL}/api/v1/models?${new URLSearchParams({ - capabilities: 'embed' - })}`, + `${OWL_URL}/api/v2/organizations/models/catalogue?${new URLSearchParams([['organization_id', organizationId]])}`, { headers: { ...headers, - 'x-project-id': projectId + 'x-user-id': process.env.TEST_USER_ID! } } ); const modelsBody = await modelsRes.json(); - if (!modelsRes.ok) throw { code: 'list_models', ...modelsBody }; + if (!modelsRes.ok) throw modelsBody; - embeddingModel = (modelsBody.data as AvailableModel[])[0].id; + embeddingModel = (modelsBody.items as ModelConfig[]).find((m) => + m.capabilities.includes('embed') + )?.id; } - const response = await fetch(`${JAMAI_URL}/api/v1/gen_tables/${tableType}`, { + const response = await fetch(`${OWL_URL}/api/v2/gen_tables/${tableType}`, { method: 'POST', headers: { ...headers, 'Content-Type': 'application/json', + 'x-user-id': process.env.TEST_USER_ID!, 'x-project-id': projectId }, body: JSON.stringify({ id: tableName, - version: '0.3.0', + version: '0.5.0', cols, embedding_model: tableType === 'knowledge' ? embeddingModel : undefined }) @@ -181,14 +438,16 @@ setup('create org and tables', async () => { async function createConv(parent: string, name: string) { const response = await fetch( - `${JAMAI_URL}/api/v1/gen_tables/chat/duplicate/${parent}?${new URLSearchParams({ - create_as_child: 'true', - table_id_dst: name - })}`, + `${OWL_URL}/api/v2/gen_tables/chat/duplicate?${new URLSearchParams([ + ['table_id_src', parent], + ['table_id_dst', name], + ['create_as_child', 'true'] + ])}`, { method: 'POST', headers: { ...headers, + 'x-user-id': process.env.TEST_USER_ID!, 'x-project-id': projectId } } diff --git a/services/app/tests/main.teardown.ts b/services/app/tests/main.teardown.ts old mode 100644 new mode 100755 index db2b085..aa98415 --- a/services/app/tests/main.teardown.ts +++ b/services/app/tests/main.teardown.ts @@ -1,39 +1,46 @@ +import type { User } from '$lib/types'; +import { test as teardown } from '@playwright/test'; import 'dotenv/config'; import fs from 'fs'; -import { test as teardown } from '@playwright/test'; -import type { UserRead } from '$lib/types'; -const { JAMAI_URL, JAMAI_SERVICE_KEY, TEST_ACC_USERID } = process.env; +const { OWL_URL, OWL_SERVICE_KEY, TEST_USER_ID } = process.env; teardown('delete setup', async () => { const headers = { - Authorization: `Bearer ${JAMAI_SERVICE_KEY}` + Authorization: `Bearer ${OWL_SERVICE_KEY}` }; - const userInfoRes = await fetch(`${JAMAI_URL}/api/admin/backend/v1/users/${TEST_ACC_USERID}`, { - headers - }); + const userInfoRes = await fetch( + `${OWL_URL}/api/v2/users?${new URLSearchParams([['user_id', TEST_USER_ID!]])}`, + { + headers: { + ...headers, + 'x-user-id': process.env.TEST_USER_ID! + } + } + ); const userInfoBody = await userInfoRes.json(); - if (!userInfoRes.ok) throw new Error(userInfoBody); + if (!userInfoRes.ok) throw JSON.stringify(userInfoBody); - const testOrgs = (userInfoBody as UserRead).member_of.filter((org) => - /test-org/.test(org.organization_name) - ); + const testOrgs = (userInfoBody as User).organizations.filter((org) => /test-org/.test(org.name)); if (testOrgs.length === 0) { console.warn('Playwright test organization not found, skipping delete step'); } else { for (const testOrg of testOrgs) { const deleteOrgRes = await fetch( - `${JAMAI_URL}/api/admin/backend/v1/organizations/${testOrg?.organization_id}`, + `${OWL_URL}/api/v2/organizations?${new URLSearchParams([['organization_id', testOrg.id]])}`, { method: 'DELETE', - headers + headers: { + ...headers, + 'x-user-id': process.env.TEST_USER_ID! + } } ); if (!deleteOrgRes.ok) { const deleteOrgBody = await deleteOrgRes.json(); - throw new Error(deleteOrgBody); + throw JSON.stringify(deleteOrgBody); } } } diff --git a/services/app/tests/pages/layout.page.ts b/services/app/tests/pages/layout.page.ts old mode 100644 new mode 100755 index 8ab7b1f..1dcbb17 --- a/services/app/tests/pages/layout.page.ts +++ b/services/app/tests/pages/layout.page.ts @@ -1,5 +1,7 @@ -import 'dotenv/config'; import { expect, type Locator, type Page } from '@playwright/test'; +import 'dotenv/config'; + +const ossMode = !process.env.OWL_SERVICE_KEY; /** Layout with breadcrumbs */ export class LayoutPage { @@ -12,7 +14,7 @@ export class LayoutPage { } async switchOrganization(organizationName: string) { - if (process.env.PUBLIC_IS_LOCAL === 'false') { + if (!ossMode) { const orgSelector = this.page.getByTestId('org-selector'); await expect(async () => { await this.selectOrgBtn.click(); diff --git a/services/app/tests/pages/project.page.ts b/services/app/tests/pages/project.page.ts old mode 100644 new mode 100755 index be32c0c..9ec6c65 --- a/services/app/tests/pages/project.page.ts +++ b/services/app/tests/pages/project.page.ts @@ -1,14 +1,16 @@ -import 'dotenv/config'; import { expect, type Page } from '@playwright/test'; +import 'dotenv/config'; import { LayoutPage } from './layout.page'; +const ossMode = !process.env.OWL_SERVICE_KEY; + export class ProjectPage extends LayoutPage { constructor(page: Page) { super(page); } async goto() { - if (process.env.PUBLIC_IS_LOCAL === 'false') { + if (!ossMode) { await this.page.goto('/'); await this.page.waitForURL(/.*\/project/); } else { @@ -18,7 +20,7 @@ export class ProjectPage extends LayoutPage { } async gotoProject(projectName: string) { - if (process.env.PUBLIC_IS_LOCAL === 'false') { + if (!ossMode) { await this.page .locator('a', { has: this.page.getByText(projectName, { exact: true }) }) .click(); diff --git a/services/app/tests/pages/table.page.ts b/services/app/tests/pages/table.page.ts old mode 100644 new mode 100755 index a768373..89ffaf2 --- a/services/app/tests/pages/table.page.ts +++ b/services/app/tests/pages/table.page.ts @@ -1,5 +1,5 @@ -import 'dotenv/config'; import { expect, type Locator, type Page } from '@playwright/test'; +import 'dotenv/config'; import { LayoutPage } from './layout.page'; /** Only to be instantiated in a project */ @@ -190,7 +190,7 @@ export class TablePage extends LayoutPage { } /** Add column */ - async addColumn(type: 'input' | 'output', datatype: 'str' | 'file' = 'str') { + async addColumn(type: 'input' | 'output', datatype: 'str' | 'image' = 'str') { await this.actionsBtn.click(); await this.page .getByTestId('table-actions-dropdown') @@ -205,9 +205,9 @@ export class TablePage extends LayoutPage { await newColDialog.getByLabel('Column ID').fill(`transient-${type}-column`); await newColDialog.getByTestId('datatype-select-btn').click(); if (type === 'input') { - await newColDialog - .getByTestId('datatype-select-btn') - .locator('div[role="option"]', { hasText: datatype === 'str' ? 'Text' : 'File' }) + await this.page + .getByTestId('datatype-select-list') + .locator('div[role="option"]', { hasText: datatype === 'str' ? 'Text' : 'Image' }) .click(); } if (type === 'output') { diff --git a/services/app/tests/pages/tableList.page.ts b/services/app/tests/pages/tableList.page.ts old mode 100644 new mode 100755 diff --git a/services/app/tests/tableList.spec.ts b/services/app/tests/tableList.spec.ts old mode 100644 new mode 100755 index f9c1a31..62b3d32 --- a/services/app/tests/tableList.spec.ts +++ b/services/app/tests/tableList.spec.ts @@ -76,7 +76,7 @@ test.describe('Knowledge Table', () => { await modal.waitFor({ state: 'visible' }); await modal.locator('input[name="table_id"]').fill('transient-test-knowledge-table'); await modal.getByTestId('model-select-btn').click(); - await modal.getByTestId('model-select-btn').locator('div[role="option"]').first().click(); + await page.getByTestId('model-select-list').locator('div[role="option"]').first().click(); await modal.locator('button:has-text("Create"):visible').click(); await modal.waitFor({ state: 'hidden' }); @@ -128,7 +128,11 @@ test.describe('Chat Table', () => { await modal.waitFor({ state: 'visible' }); await modal.locator('input[name="agent-id"]').fill('transient-test-chat-agent'); await modal.locator('button[title="Select model"]').click(); - await modal.locator('div[role="option"]:visible').first().click(); + await page + .getByTestId('model-select-list') + .locator('div[role="option"]:visible') + .first() + .click(); await modal.locator('button:has-text("Add"):visible').click(); await modal.waitFor({ state: 'hidden' }); @@ -169,7 +173,7 @@ test.describe('Chat Table', () => { await modal.waitFor({ state: 'visible' }); await modal.locator('input[name="conversation-id"]').fill('transient-test-chat-conv'); await modal.locator('button[title="Select Chat Agent"]').click(); - await modal + await page .locator('div[role="option"]:visible', { has: page.getByText('transient-test-agent', { exact: true }) }) diff --git a/services/app/tests/tables/actionTable.spec.ts b/services/app/tests/tables/actionTable.spec.ts old mode 100644 new mode 100755 index 514391d..b7eedc1 --- a/services/app/tests/tables/actionTable.spec.ts +++ b/services/app/tests/tables/actionTable.spec.ts @@ -1,13 +1,13 @@ -import 'dotenv/config'; -import { expect, test as base } from '@playwright/test'; import { faker } from '@faker-js/faker'; +import { test as base, expect } from '@playwright/test'; +import 'dotenv/config'; import { ProjectPage } from '../pages/project.page'; -import { TableListPage } from '../pages/tableList.page'; import { TablePage } from '../pages/table.page'; +import { TableListPage } from '../pages/tableList.page'; -const { JAMAI_URL, JAMAI_SERVICE_KEY } = process.env; +const { OWL_URL, OWL_SERVICE_KEY } = process.env; const headers = { - Authorization: `Bearer ${JAMAI_SERVICE_KEY}` + Authorization: `Bearer ${OWL_SERVICE_KEY}` }; const test = base.extend<{ tablePage: TablePage; fileTablePage: TablePage }>({ @@ -125,7 +125,7 @@ test.describe('Action Table Page Basic', () => { const tableType = 'action'; const tableName = 'test-action-table'; - const deleteTableRes = await fetch(`${JAMAI_URL}/api/v1/gen_tables/${tableType}/${tableName}`, { + const deleteTableRes = await fetch(`${OWL_URL}/api/v2/gen_tables/${tableType}/${tableName}`, { method: 'DELETE', headers: { ...headers, @@ -137,7 +137,7 @@ test.describe('Action Table Page Basic', () => { throw await deleteTableRes.json(); } - const createTableRes = await fetch(`${JAMAI_URL}/api/v1/gen_tables/${tableType}`, { + const createTableRes = await fetch(`${OWL_URL}/api/v2/gen_tables/${tableType}`, { method: 'POST', headers: { ...headers, @@ -146,7 +146,7 @@ test.describe('Action Table Page Basic', () => { }, body: JSON.stringify({ id: tableName, - version: '0.3.0', + version: '0.5.0', cols: [ { id: 'Input', @@ -162,7 +162,7 @@ test.describe('Action Table Page Basic', () => { index: true, gen_config: { object: 'gen_config.llm', - model: 'anthropic/claude-3-haiku-20240307', + model: 'openai/gpt-4o-mini', multi_turn: false } } @@ -252,7 +252,7 @@ test.describe('Action Table Page with File Col', () => { const tableType = 'action'; const tableName = 'test-action-table-file'; - const deleteTableRes = await fetch(`${JAMAI_URL}/api/v1/gen_tables/${tableType}/${tableName}`, { + const deleteTableRes = await fetch(`${OWL_URL}/api/v2/gen_tables/${tableType}/${tableName}`, { method: 'DELETE', headers: { ...headers, @@ -264,7 +264,7 @@ test.describe('Action Table Page with File Col', () => { throw await deleteTableRes.json(); } - const createTableRes = await fetch(`${JAMAI_URL}/api/v1/gen_tables/${tableType}`, { + const createTableRes = await fetch(`${OWL_URL}/api/v2/gen_tables/${tableType}`, { method: 'POST', headers: { ...headers, @@ -273,7 +273,7 @@ test.describe('Action Table Page with File Col', () => { }, body: JSON.stringify({ id: tableName, - version: '0.3.0', + version: '0.5.0', cols: [ { id: 'Input', diff --git a/services/app/tests/tables/chatTable.spec.ts b/services/app/tests/tables/chatTable.spec.ts old mode 100644 new mode 100755 index 6a10a8d..bef945e --- a/services/app/tests/tables/chatTable.spec.ts +++ b/services/app/tests/tables/chatTable.spec.ts @@ -1,13 +1,13 @@ -import 'dotenv/config'; -import { expect, test as base } from '@playwright/test'; import { faker } from '@faker-js/faker'; +import { test as base, expect } from '@playwright/test'; +import 'dotenv/config'; import { ProjectPage } from '../pages/project.page'; -import { TableListPage } from '../pages/tableList.page'; import { TablePage } from '../pages/table.page'; +import { TableListPage } from '../pages/tableList.page'; -const { JAMAI_URL, JAMAI_SERVICE_KEY } = process.env; +const { OWL_URL, OWL_SERVICE_KEY } = process.env; const headers = { - Authorization: `Bearer ${JAMAI_SERVICE_KEY}` + Authorization: `Bearer ${OWL_SERVICE_KEY}` }; const test = base.extend<{ tablePage: TablePage; agentTablePage: TablePage }>({ @@ -164,7 +164,7 @@ test.describe('Chat Table Page Basic', () => { const tableName = 'test-chat-conv'; const tableParent = 'temp-chat-agent'; - const deleteTableRes = await fetch(`${JAMAI_URL}/api/v1/gen_tables/${tableType}/${tableName}`, { + const deleteTableRes = await fetch(`${OWL_URL}/api/v2/gen_tables/${tableType}/${tableName}`, { method: 'DELETE', headers: { ...headers, @@ -177,7 +177,7 @@ test.describe('Chat Table Page Basic', () => { } //* Temp chat agent in case original has been changed - const createTempAgentRes = await fetch(`${JAMAI_URL}/api/v1/gen_tables/${tableType}`, { + const createTempAgentRes = await fetch(`${OWL_URL}/api/v2/gen_tables/${tableType}`, { method: 'POST', headers: { ...headers, @@ -186,7 +186,7 @@ test.describe('Chat Table Page Basic', () => { }, body: JSON.stringify({ id: tableParent, - version: '0.3.0', + version: '0.5.0', cols: [ { id: 'User', @@ -202,7 +202,7 @@ test.describe('Chat Table Page Basic', () => { index: true, gen_config: { object: 'gen_config.llm', - model: 'anthropic/claude-3-haiku-20240307', + model: 'openai/gpt-4o-mini', multi_turn: true } } @@ -216,9 +216,10 @@ test.describe('Chat Table Page Basic', () => { //* Duplicate agent to chat conv const dupeTableRes = await fetch( - `${JAMAI_URL}/api/v1/gen_tables/chat/duplicate/${tableParent}?${new URLSearchParams({ - create_as_child: 'true', - table_id_dst: tableName + `${OWL_URL}/api/v2/gen_tables/chat/duplicate?${new URLSearchParams({ + table_id_src: tableParent, + table_id_dst: tableName, + create_as_child: 'true' })}`, { method: 'POST', @@ -234,7 +235,7 @@ test.describe('Chat Table Page Basic', () => { } const deleteTempAgentRes = await fetch( - `${JAMAI_URL}/api/v1/gen_tables/${tableType}/${tableParent}`, + `${OWL_URL}/api/v2/gen_tables/${tableType}/${tableParent}`, { method: 'DELETE', headers: { @@ -260,7 +261,7 @@ test.describe('Chat Table Page with File Col', () => { test.describe('Column create, rename, reorder, delete', () => { test('can add new input column', async ({ agentTablePage }) => { - await agentTablePage.addColumn('input', 'file'); + await agentTablePage.addColumn('input', 'image'); }); test('can add new output column', async ({ agentTablePage }) => { @@ -326,7 +327,7 @@ test.describe('Chat Table Page with File Col', () => { const tableType = 'chat'; const tableName = 'test-chat-agent'; - const deleteTableRes = await fetch(`${JAMAI_URL}/api/v1/gen_tables/${tableType}/${tableName}`, { + const deleteTableRes = await fetch(`${OWL_URL}/api/v2/gen_tables/${tableType}/${tableName}`, { method: 'DELETE', headers: { ...headers, @@ -338,7 +339,7 @@ test.describe('Chat Table Page with File Col', () => { throw await deleteTableRes.json(); } - const createTableRes = await fetch(`${JAMAI_URL}/api/v1/gen_tables/${tableType}`, { + const createTableRes = await fetch(`${OWL_URL}/api/v2/gen_tables/${tableType}`, { method: 'POST', headers: { ...headers, @@ -347,7 +348,7 @@ test.describe('Chat Table Page with File Col', () => { }, body: JSON.stringify({ id: tableName, - version: '0.3.0', + version: '0.5.0', cols: [ { id: 'User', @@ -363,7 +364,7 @@ test.describe('Chat Table Page with File Col', () => { index: true, gen_config: { object: 'gen_config.llm', - model: 'anthropic/claude-3-haiku-20240307', + model: 'openai/gpt-4o-mini', multi_turn: true } } diff --git a/services/app/tests/tables/knowledgeTable.spec.ts b/services/app/tests/tables/knowledgeTable.spec.ts old mode 100644 new mode 100755 index 8dbcfca..0b4b970 --- a/services/app/tests/tables/knowledgeTable.spec.ts +++ b/services/app/tests/tables/knowledgeTable.spec.ts @@ -1,12 +1,12 @@ +import { test as base, expect } from '@playwright/test'; import 'dotenv/config'; -import { expect, test as base } from '@playwright/test'; import { ProjectPage } from '../pages/project.page'; -import { TableListPage } from '../pages/tableList.page'; import { TablePage } from '../pages/table.page'; +import { TableListPage } from '../pages/tableList.page'; -const { JAMAI_URL, JAMAI_SERVICE_KEY } = process.env; +const { OWL_URL, OWL_SERVICE_KEY } = process.env; const headers = { - Authorization: `Bearer ${JAMAI_SERVICE_KEY}` + Authorization: `Bearer ${OWL_SERVICE_KEY}` }; const test = base.extend<{ tablePage: TablePage; fileTablePage: TablePage }>({ @@ -197,7 +197,7 @@ test.describe('Knowledge Table Page', () => { const tableType = 'knowledge'; const tableName = 'test-knowledge-table'; - const deleteTableRes = await fetch(`${JAMAI_URL}/api/v1/gen_tables/${tableType}/${tableName}`, { + const deleteTableRes = await fetch(`${OWL_URL}/api/v2/gen_tables/${tableType}/${tableName}`, { method: 'DELETE', headers: { ...headers, @@ -211,7 +211,7 @@ test.describe('Knowledge Table Page', () => { //* Get embedding model const modelsRes = await fetch( - `${JAMAI_URL}/api/v1/models?${new URLSearchParams({ + `${OWL_URL}/api/v2/models?${new URLSearchParams({ capabilities: 'embed' })}`, { @@ -227,7 +227,7 @@ test.describe('Knowledge Table Page', () => { throw modelsBody; } - const createTableRes = await fetch(`${JAMAI_URL}/api/v1/gen_tables/${tableType}`, { + const createTableRes = await fetch(`${OWL_URL}/api/v2/gen_tables/${tableType}`, { method: 'POST', headers: { ...headers, @@ -236,7 +236,7 @@ test.describe('Knowledge Table Page', () => { }, body: JSON.stringify({ id: tableName, - version: '0.3.0', + version: '0.5.0', cols: [], embedding_model: modelsBody.data[0].id }) diff --git a/services/app/tsconfig.json b/services/app/tsconfig.json old mode 100644 new mode 100755 diff --git a/services/app/vite.config.ts b/services/app/vite.config.ts old mode 100644 new mode 100755 index 9629400..a48c877 --- a/services/app/vite.config.ts +++ b/services/app/vite.config.ts @@ -1,8 +1,10 @@ -import 'dotenv/config'; +import { paraglideVitePlugin } from '@inlang/paraglide-js'; import { sveltekit } from '@sveltejs/kit/vite'; -import type { ProxyOptions, ViteDevServer } from 'vite'; +import 'dotenv/config'; import express from 'express'; import expressOpenIdConnect from 'express-openid-connect'; +import type { ProxyOptions, ViteDevServer } from 'vite'; +import devtoolsJson from 'vite-plugin-devtools-json'; import { defineConfig } from 'vitest/config'; const proxy: Record = { @@ -14,35 +16,33 @@ const proxy: Record = { function expressPlugin() { const app = express(); - if (process.env.PUBLIC_IS_LOCAL === 'false') { - app.use( - expressOpenIdConnect.auth({ - authorizationParams: { - response_type: 'code', - scope: 'openid profile email offline_access' - }, - authRequired: false, - auth0Logout: true, - baseURL: `http://localhost:5173`, - clientID: process.env.AUTH0_CLIENT_ID, - clientSecret: process.env.AUTH0_CLIENT_SECRET, - issuerBaseURL: process.env.AUTH0_ISSUER_BASE_URL, - secret: process.env.AUTH0_SECRET, - attemptSilentLogin: false, - routes: { - login: false - } - }) - ); - app.get('/login', (req, res) => { - res.oidc.login({ - returnTo: (typeof req.query.returnTo === 'string' ? req.query.returnTo : '/') || '/' - }); - }); - app.get('/dev-profile', (req, res) => { - res.json(req.oidc.user ?? {}); + app.use( + expressOpenIdConnect.auth({ + authorizationParams: { + response_type: 'code', + scope: 'openid profile email offline_access' + }, + authRequired: false, + auth0Logout: true, + baseURL: `http://localhost:5173`, + clientID: process.env.AUTH0_CLIENT_ID, + clientSecret: process.env.AUTH0_CLIENT_SECRET, + issuerBaseURL: process.env.AUTH0_ISSUER_BASE_URL, + secret: process.env.AUTH0_SECRET, + attemptSilentLogin: false, + routes: { + login: false + } + }) + ); + app.get('/login', (req, res) => { + res.oidc.login({ + returnTo: (typeof req.query.returnTo === 'string' ? req.query.returnTo : '/') || '/' }); - } + }); + app.get('/dev-profile', (req, res) => { + res.json(req.oidc.user ?? {}); + }); return { name: 'express-plugin', @@ -67,7 +67,16 @@ export default defineConfig({ target: 'esnext' } }, - plugins: [process.env.PUBLIC_IS_LOCAL === 'false' && expressPlugin(), sveltekit()], + plugins: [ + devtoolsJson(), + paraglideVitePlugin({ + project: './project.inlang', + outdir: './src/lib/paraglide', + strategy: ['cookie', 'baseLocale'] + }), + !!process.env.OWL_SERVICE_KEY && !!process.env.AUTH0_CLIENT_SECRET && expressPlugin(), + sveltekit() + ], test: { include: ['src/**/*.{test,spec}.{js,ts}'] } diff --git a/services/docio/.env b/services/docio/.env deleted file mode 100644 index dbd0994..0000000 --- a/services/docio/.env +++ /dev/null @@ -1,2 +0,0 @@ -DOCIO_PORT=6979 -DOCIO_WORKERS=2 \ No newline at end of file diff --git a/services/docio/MANIFEST.in b/services/docio/MANIFEST.in deleted file mode 100644 index 5b0063f..0000000 --- a/services/docio/MANIFEST.in +++ /dev/null @@ -1 +0,0 @@ -exclude tests/**/* \ No newline at end of file diff --git a/services/docio/README.md b/services/docio/README.md deleted file mode 100644 index a027370..0000000 --- a/services/docio/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# Document IO (DocIO) - -This is a package and a service that helps to parse file. Currently it support parsing of - -- txt -- md -- pdf - -## Compile Windows Executable File - -1. Create python virtual environment. -2. `cd services\docio`. -3. Install the python dependencies in the python virtual environment. PowerShell: `.\scripts\SetupWinExeEnv.ps1`. -4. Generate Python executable file. PowerShell: `.\scripts\GenerateWinExe.ps1`. The generate output can be found in `dist`. diff --git a/services/docio/docio.spec b/services/docio/docio.spec deleted file mode 100644 index f028833..0000000 --- a/services/docio/docio.spec +++ /dev/null @@ -1,70 +0,0 @@ -# -*- mode: python ; coding: utf-8 -*- -import sys -import docio -from pathlib import Path - -# Increase the recursion limit -sys.setrecursionlimit(sys.getrecursionlimit() * 5) - -# Print the path of the Python executable -print(sys.executable) - -from PyInstaller.utils.hooks import collect_all - -binaries_list = [] - -datas_list = [] - -hiddenimports_list = ['multipart', 'torch'] - -def add_package(package_name): - datas, binaries, hiddenimports = collect_all(package_name) - datas_list.extend(datas) - binaries_list.extend(binaries) - hiddenimports_list.extend(hiddenimports) - -add_package('pypdfium2') -add_package('pypdfium2_raw') -add_package('docio') - -a = Analysis( - [Path('src/docio/entrypoints/api.py').as_posix()], - pathex=[], - binaries=binaries_list, - datas=datas_list, - hiddenimports=hiddenimports_list, - hookspath=[], - hooksconfig={}, - runtime_hooks=[], - excludes=[], - noarchive=False, - optimize=0, -) -pyz = PYZ(a.pure) - -exe = EXE( - pyz, - a.scripts, - [], - exclude_binaries=True, - name='docio', - debug=False, - bootloader_ignore_signals=False, - strip=False, - upx=True, - console=True, - disable_windowed_traceback=False, - argv_emulation=False, - target_arch=None, - codesign_identity=None, - entitlements_file=None, -) -coll = COLLECT( - exe, - a.binaries, - a.datas, - strip=False, - upx=True, - upx_exclude=[], - name='docio', -) diff --git a/services/docio/pyproject.toml b/services/docio/pyproject.toml deleted file mode 100644 index e905b26..0000000 --- a/services/docio/pyproject.toml +++ /dev/null @@ -1,153 +0,0 @@ -# See https://gitlab.liris.cnrs.fr/pagoda/tools/mkdocs_template/-/blob/master/user_config/pyproject.toml - -# ----------------------------------------------------------------------------- -# Pytest configuration -# https://docs.pytest.org/en/latest/customize.html?highlight=pyproject#pyproject-toml - -[tool.pytest.ini_options] -log_cli = true -asyncio_mode = "auto" -# log_cli_level = "DEBUG" -# addopts = "--cov=docio --doctest-modules" -testpaths = ["tests"] -filterwarnings = [ - "ignore::DeprecationWarning:tensorflow.*", - "ignore::DeprecationWarning:tensorboard.*", - "ignore::DeprecationWarning:matplotlib.*", - "ignore::DeprecationWarning:flatbuffers.*", -] - -# ----------------------------------------------------------------------------- -# Ruff configuration -# https://docs.astral.sh/ruff/ - -[tool.ruff] -line-length = 99 -indent-width = 4 -target-version = "py310" -extend-include = [".pyi?$", ".ipynb"] -extend-exclude = ["archive/*"] -respect-gitignore = true - -[tool.ruff.format] -# Like Black, use double quotes for strings. -quote-style = "double" - -# Like Black, indent with spaces, rather than tabs. -indent-style = "space" - -# Enable auto-formatting of code examples in docstrings. Markdown, -# reStructuredText code/literal blocks and doctests are all supported. -docstring-code-format = true - -[tool.ruff.lint] -# 1. Enable flake8-bugbear (`B`) rules, in addition to the defaults. -select = ["E1", "E4", "E7", "E9", "F", "I", "W1", "W2", "W3", "W6", "B"] - -# 2. Avoid enforcing line-length violations (`E501`) -ignore = ["E501"] - -# 3. Avoid trying to fix flake8-bugbear (`B`) violations. -unfixable = ["B"] - -# 4. Ignore `E402` (import violations) in all `__init__.py` files, and in selected subdirectories. -[tool.ruff.lint.per-file-ignores] -"__init__.py" = ["E402"] -"**/{tests,docs,tools}/*" = ["E402"] - -[tool.ruff.lint.isort] -known-first-party = ["jamaibase", "owl", "docio"] - -[tool.ruff.lint.flake8-bugbear] -# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. -extend-immutable-calls = [ - "fastapi.Depends", - "fastapi.File", - "fastapi.Form", - "fastapi.Path", - "fastapi.Query", -] - -# ----------------------------------------------------------------------------- -# setuptools -# https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html - -[build-system] -# setuptools-scm considers all files tracked by git to be data files -requires = ["setuptools>=62.0", "setuptools-scm"] -build-backend = "setuptools.build_meta" - -[project] -name = "docio" -description = "DocIO service for PDF loading and parsing." -readme = "README.md" -requires-python = "~=3.10" -# keywords = ["one", "two"] -license = { text = "Proprietary" } -classifiers = [ # https://pypi.org/classifiers/ - "Development Status :: 3 - Alpha", - "Programming Language :: Python :: 3 :: Only", - "Intended Audience :: Information Technology", - "Operating System :: Unix", -] -dependencies = [ - "accelerate~=0.28", - "fastapi~=0.110.0", - "gunicorn~=21.2.0", - "jamaibase>=0.0.1", - "langchain-community~=0.0.25", - "langchain~=0.1.10", - "loguru~=0.7.2", - "matplotlib", - "pandas~=2.2.2", - "pdfplumber~=0.10.4", # pdfplumber - "pydantic-settings>=2.2.1", - "pydantic~=2.6.3", - "pypdfium2~=4.27.0", - "python-multipart", - "s3fs", - "timm", - "torch~=2.2.0", - "transformers>=4.38.2", - "unstructured-client @ git+https://github.com/EmbeddedLLM/unstructured-python-client.git@fix-nested-asyncio-conflict-with-uvloop#egg=unstructured-client", - "unstructured~=0.14.9", - "uvicorn[standard]~=0.27.1", -] # Sort your dependencies https://sortmylist.com/ -dynamic = ["version"] - -[project.optional-dependencies] -lint = ["ruff~=0.5.7"] -test = [ - "flaky~=3.7.0", - "mypy~=1.5.1", - "openai~=1.9.0", - "pytest-cov~=4.1.0", - "pytest~=7.4.2", -] -docs = [ - "furo~=2023.9.10", # Sphinx theme (nice looking, with dark mode) - "myst-parser~=2.0.0", - "sphinx-autobuild~=2021.3.14", - "sphinx-copybutton~=0.5.2", - "sphinx~=7.2.6", - "sphinx_rtd_theme~=1.3.0", # Sphinx theme -] -build = [ - "build", - "twine", -] # https://realpython.com/pypi-publish-python-package/#build-your-package -all = [ - "docio[lint,test,docs,build]", # https://hynek.me/articles/python-recursive-optional-dependencies/ -] - -# [project.scripts] -# docio = "docio.scripts.example:main_cli" - -[tool.setuptools.dynamic] -version = { attr = "docio.version.__version__" } - -[tool.setuptools.packages.find] -where = ["src"] - -[tool.setuptools.package-data] -docio = ["**/*.json"] diff --git a/services/docio/scripts/GenerateWinExe.ps1 b/services/docio/scripts/GenerateWinExe.ps1 deleted file mode 100644 index f055907..0000000 --- a/services/docio/scripts/GenerateWinExe.ps1 +++ /dev/null @@ -1,7 +0,0 @@ -if (Test-Path -Path ".\dist") { - Remove-Item -Path ".\dist" -Recurse -Force -} - -Get-Command python -pyinstaller .\docio.spec -Copy-Item -Path .\.env -Destination .\dist\docio\ \ No newline at end of file diff --git a/services/docio/scripts/SetupWinExeEnv.ps1 b/services/docio/scripts/SetupWinExeEnv.ps1 deleted file mode 100644 index 81ebd38..0000000 --- a/services/docio/scripts/SetupWinExeEnv.ps1 +++ /dev/null @@ -1,5 +0,0 @@ -cd ..\..\clients\python -pip install --no-cache . -cd ..\..\services\docio -pip install --no-cache . -pip install --no-cache pyinstaller \ No newline at end of file diff --git a/services/docio/scripts/validate_exe.py b/services/docio/scripts/validate_exe.py deleted file mode 100644 index 7cf332e..0000000 --- a/services/docio/scripts/validate_exe.py +++ /dev/null @@ -1,39 +0,0 @@ -from mimetypes import guess_type -from pathlib import Path - -import httpx - - -def get_local_uri(): - return [ - Path("../../clients/python/tests/files/txt/weather.txt").as_posix(), - Path("../../clients/python/tests/files/pdf/ca-warn-report.pdf").as_posix(), - Path("README.md").as_posix(), - ] - - -def test_file_loader_api(file_uri: str): - # Guess the MIME type of the file based on its extension - mime_type, _ = guess_type(file_uri) - if mime_type is None: - mime_type = "application/octet-stream" # Default MIME type - - # Extract the filename from the file path - filename = file_uri.split("/")[-1] - - # Open the file in binary mode - with open(file_uri, "rb") as f: - response = httpx.post( - "http://127.0.0.1:6979/api/docio/v1/load_file", - files={ - "file": (filename, f, mime_type), - }, - timeout=None, - ) - - assert response.status_code == 200 - - -if __name__ == "__main__": - for file_url in get_local_uri(): - test_file_loader_api(file_uri=file_url) diff --git a/services/docio/src/docio/config.py b/services/docio/src/docio/config.py deleted file mode 100644 index ad25008..0000000 --- a/services/docio/src/docio/config.py +++ /dev/null @@ -1,21 +0,0 @@ -LOGS = { - "stderr": { - "level": "INFO", - "serialize": False, - "backtrace": False, - "diagnose": True, - "enqueue": True, - "catch": True, - }, - "logs/docio.log": { - "level": "INFO", - "serialize": False, - "backtrace": False, - "diagnose": True, - "enqueue": True, - "catch": True, - "rotation": "50 MB", - "delay": False, - "watch": False, - }, -} diff --git a/services/docio/src/docio/entrypoints/api.py b/services/docio/src/docio/entrypoints/api.py deleted file mode 100644 index 27c242b..0000000 --- a/services/docio/src/docio/entrypoints/api.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -API server. - -```shell -$ python -m docio.entrypoints.api -$ CUDA_VISIBLE_DEVICES=1 WORKERS=2 python -m docio.entrypoints.api -``` -""" - -from fastapi import FastAPI, Request, Response, status -from fastapi.encoders import jsonable_encoder -from fastapi.responses import ORJSONResponse -from loguru import logger -from pydantic_settings import BaseSettings, SettingsConfigDict - -from docio.routers import loader -from docio.utils.logging import replace_logging_handlers, setup_logger_sinks - - -class Config(BaseSettings): - model_config = SettingsConfigDict( - env_file=".env", env_file_encoding="utf-8", extra="ignore", cli_parse_args=True - ) - docio_port: int = 6979 - docio_host: str = "0.0.0.0" - docio_workers: int = 2 - service: str | None = None - prefix: str = "/api/docio" - - -config = Config() -setup_logger_sinks() -# We purposely don't intercept uvicorn logs since it is typically not useful -# We also don't intercept transformers logs -replace_logging_handlers(["uvicorn.access"], False) - - -app = FastAPI( - logger=logger, - openapi_url=f"{config.prefix}/openapi.json", - docs_url=f"{config.prefix}/docs", - redoc_url=f"{config.prefix}/redoc", -) -services = { - "loader": (loader.router, ["Document loader"]), -} -if config.service: - try: - router, tags = services[config.service] - except KeyError: - logger.error(f"Invalid service '{config.service}', choose from: {list(services.keys())}") - raise - app.include_router(router, prefix=config.prefix, tags=tags) -else: - # Mount everything - for router, tags in services.values(): - app.include_router(router, prefix=config.prefix, tags=tags) - - -@app.on_event("startup") -async def startup(): - # Temporary for backwards compatibility - logger.info(f"Using configuration: {config}") - - -# Order of handler does not matter -@app.exception_handler(FileNotFoundError) -async def file_not_found_exc_handler(request: Request, exc: FileNotFoundError): - return ORJSONResponse( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - content={"detail": [{"type": "file_not_found", "msg": str(exc)}]}, - ) - - -@app.exception_handler(Exception) -async def exception_handler(request: Request, exc: Exception): - return ORJSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=jsonable_encoder( - [ - { - "type": "unexpected_error", - "msg": f"Encountered error: {exc!r}", - } - ] - ), - ) - - -@app.get("/health") -async def health() -> Response: - """Health check.""" - return Response(status_code=200) - - -if __name__ == "__main__": - import os - - import uvicorn - - if os.name == "nt": - import asyncio - from multiprocessing import freeze_support - - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - freeze_support() - logger.info("The system is Windows.") - else: - logger.info("The system is not Windows.") - - uvicorn.run( - "docio.entrypoints.api:app", - reload=False, - host=config.docio_host, - port=config.docio_port, - workers=config.docio_workers, - ) diff --git a/services/docio/src/docio/langchain/jsonloader.py b/services/docio/src/docio/langchain/jsonloader.py deleted file mode 100644 index 72a9bd6..0000000 --- a/services/docio/src/docio/langchain/jsonloader.py +++ /dev/null @@ -1,141 +0,0 @@ -import json -from pathlib import Path -from typing import Any, Callable, Dict, Iterator, Optional, Union - -from langchain_community.document_loaders.base import BaseLoader -from langchain_core.documents import Document - - -class JSONLoader(BaseLoader): - """Load a `JSON` file generically.""" - - def __init__( - self, - file_path: Union[str, Path], - content_key: Optional[str] = None, - metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, - text_content: bool = True, - json_lines: bool = False, - ): - """Initialize the JSONLoader. - - Args: - file_path (Union[str, Path]): The path to the JSON or JSON Lines file. - content_key (str): The key to use to extract the content from - the JSON if the result is a list of objects (dict). - This should be a simple string key. - metadata_func (Callable[Dict, Dict]): A function that takes in the JSON - object and the default metadata and returns a dict of the updated metadata. - text_content (bool): Boolean flag to indicate whether the content is in - string format, default to True. - json_lines (bool): Boolean flag to indicate whether the input is in - JSON Lines format. - """ - self.file_path = Path(file_path).resolve() - self._content_key = content_key - self._metadata_func = metadata_func - self._text_content = text_content - self._json_lines = json_lines - - def lazy_load(self) -> Iterator[Document]: - """Load and return documents from the JSON file.""" - index = 0 - if self._json_lines: - with self.file_path.open(encoding="utf-8") as f: - for line in f: - line = line.strip() - if line: - for doc in self._parse(line, index): - yield doc - index += 1 - else: - for doc in self._parse(self.file_path.read_text(encoding="utf-8"), index): - yield doc - index += 1 - - def _parse(self, content: str, index: int) -> Iterator[Document]: - """Convert given content to documents.""" - data = json.loads(content) - - # Perform some validation - # This is not a perfect validation, but it should catch most cases - # and prevent the user from getting a cryptic error later on. - if self._content_key is not None: - self._validate_content_key(data) - if self._metadata_func is not None: - self._validate_metadata_func(data) - - # If the data is a dictionary, treat it as a single document - if isinstance(data, dict): - text = self._get_text(sample=data) - metadata = self._get_metadata( - sample=data, source=str(self.file_path), seq_num=index + 1 - ) - yield Document(page_content=text, metadata=metadata) - else: - for i, sample in enumerate(data, index + 1): - text = self._get_text(sample=sample) - metadata = self._get_metadata(sample=sample, source=str(self.file_path), seq_num=i) - yield Document(page_content=text, metadata=metadata) - - def _get_text(self, sample: Any) -> str: - """Convert sample to string format""" - if self._content_key is not None: - content = sample[self._content_key] - else: - content = sample - - if self._text_content and not isinstance(content, str): - raise ValueError( - f"Expected page_content is string, got {type(content)} instead. \ - Set `text_content=False` if the desired input for \ - `page_content` is not a string" - ) - - # In case the text is None, set it to an empty string - elif isinstance(content, str): - return content - elif isinstance(content, dict): - return json.dumps(content, ensure_ascii=False) if content else "" - else: - return str(content) if content is not None else "" - - def _get_metadata(self, sample: Dict[str, Any], **additional_fields: Any) -> Dict[str, Any]: - """ - Return a metadata dictionary base on the existence of metadata_func - :param sample: single data payload - :param additional_fields: key-word arguments to be added as metadata values - :return: - """ - if self._metadata_func is not None: - return self._metadata_func(sample, additional_fields) - else: - return additional_fields - - def _validate_content_key(self, data: Any) -> None: - """Check if a content key is valid""" - - sample = data[0] if isinstance(data, list) else data - if not isinstance(sample, dict): - raise ValueError( - f"Expected the JSON to result in a list of objects (dict), \ - so sample must be a dict but got `{type(sample)}`" - ) - - if sample.get(self._content_key) is None: - raise ValueError( - f"Expected the JSON to result in a list of objects (dict) \ - with the key `{self._content_key}`" - ) - - def _validate_metadata_func(self, data: Any) -> None: - """Check if the metadata_func output is valid""" - - sample = data[0] if isinstance(data, list) else data - if self._metadata_func is not None: - sample_metadata = self._metadata_func(sample, {}) - if not isinstance(sample_metadata, dict): - raise ValueError( - f"Expected the metadata_func to return a dict but got \ - `{type(sample_metadata)}`" - ) diff --git a/services/docio/src/docio/langchain/pdfplumber.py b/services/docio/src/docio/langchain/pdfplumber.py deleted file mode 100644 index 5d3acc8..0000000 --- a/services/docio/src/docio/langchain/pdfplumber.py +++ /dev/null @@ -1,515 +0,0 @@ -import warnings -from typing import Any, Iterable, Iterator, Mapping, Sequence - -import matplotlib.pyplot as plt -import numpy as np -import pdfplumber.page -import torch -from langchain.document_loaders.base import BaseBlobParser -from langchain.document_loaders.blob_loaders import Blob -from langchain.document_loaders.pdf import BasePDFLoader -from pydantic_settings import BaseSettings, SettingsConfigDict -from transformers import DetrFeatureExtractor, TableTransformerForObjectDetection - -from docio.protocol import Document - - -class Config(BaseSettings): - model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore") - docio_device: str = "cpu" - - -config = Config() - -_PDF_FILTER_WITH_LOSS = ["DCTDecode", "DCT", "JPXDecode"] -_PDF_FILTER_WITHOUT_LOSS = [ - "LZWDecode", - "LZW", - "FlateDecode", - "Fl", - "ASCII85Decode", - "A85", - "ASCIIHexDecode", - "AHx", - "RunLengthDecode", - "RL", - "CCITTFaxDecode", - "CCF", - "JBIG2Decode", -] - -# Define colors for visualization -COLORS = [ - [0.000, 0.447, 0.741], - [0.850, 0.325, 0.098], - [0.929, 0.694, 0.125], - [0.494, 0.184, 0.556], - [0.466, 0.674, 0.188], - [0.301, 0.745, 0.933], -] - - -def extract_from_images_with_rapidocr( - images: Sequence[Iterable[np.ndarray] | bytes], -) -> str: - try: - from rapidocr_onnxruntime import RapidOCR - except ImportError as e: - raise ImportError( - "`rapidocr-onnxruntime` package not found, please install it with " - "`pip install rapidocr-onnxruntime`" - ) from e - ocr = RapidOCR() - text = "" - for img in images: - result, _ = ocr(img) - if result: - result = [text[1] for text in result] - text += "\n".join(result) - return text - - -class PDFPlumberParser(BaseBlobParser): - """Parse `PDF` with `PDFPlumber`.""" - - def __init__( - self, - text_kwargs: Mapping[str, Any] | None = None, - dedupe: bool = False, - extract_images: bool = False, - table_detection_conf: float = 0.7, - ) -> None: - """Initialize the parser. - - Args: - text_kwargs: Keyword arguments to pass to ``pdfplumber.Page.extract_text()`` - dedupe: Avoiding the error of duplicate characters if `dedupe=True`. - """ - self.text_kwargs = text_kwargs or {} - self.dedupe = dedupe - self.extract_images = extract_images - - self.feature_extractor = DetrFeatureExtractor() - self.model = TableTransformerForObjectDetection.from_pretrained( - "microsoft/table-transformer-detection", device_map=config.docio_device - ) - - self.table_detection_conf = table_detection_conf - - def lazy_parse(self, blob: Blob) -> Iterator[Document]: - """Lazily parse the blob.""" - import pdfplumber - - with blob.as_bytes_io() as file_path: - doc = pdfplumber.open(file_path) # open document - - yield from [ - Document( - page_content=self._process_page_content(page) - + "\n" - + self._extract_images_from_page(page), - metadata=dict( - { - "source": blob.source, - "file_path": blob.source, - "page": page.page_number - 1, - "total_pages": len(doc.pages), - }, - **{ - k: doc.metadata[k] - for k in doc.metadata - if type(doc.metadata[k]) in [str, int] - }, - ), - ) - for page in doc.pages - if page.chars != [] # to skip blank page (or page without any text) - ] - - def _table_bbox_results(self, pil_img, model, scores, labels, boxes, save_file=None): - """ - model.config.id2label: { - 0: "table", - 1: "table column", - 2: "table row", - 3: "table column header", - 4: "table projected row header", - 5: "table spanning cell", - } - """ - - # Create a figure for visualization - plt.figure(figsize=(16, 10)) - # plt.figure(figsize=(160, 100)) - plt.imshow(pil_img) - - # Get the current axis - ax = plt.gca() - - # Repeat the COLORS list multiple times for visualization - colors = COLORS * 100 - - table_bboxes = [] - - # Iterate through scores, labels, boxes, and colors for visualization - for score, label, (xmin, ymin, xmax, ymax), c in zip( - scores.tolist(), - labels.tolist(), - boxes.tolist(), - colors, - strict=True, - ): - # Add a rectangle to the image for the detected object's bounding box - ax.add_patch( - plt.Rectangle( - (xmin, ymin), - xmax - xmin, - ymax - ymin, - fill=False, - color=c, - linewidth=3, - ) - ) - table_bboxes.append((xmin, ymin, xmax, ymax)) - - # Prepare the text for the label and score - # print(f"label: {label}, score: {score:0.2f}") - text = f"{model.config.id2label[label]}: {score:0.2f}" - - # Add the label and score text to the image - ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5)) - - # Turn off the axis - plt.axis("off") - - # Display the visualization - # plt.show() - if save_file: - plt.savefig(save_file, bbox_inches="tight") - - plt.close() # close the plt - - return table_bboxes - - def _process_page_content(self, page: pdfplumber.page.Page) -> str: - image = page.to_image().original - - width, height = image.size - # print(f"width, height: {width, height}") # (596, 808) - - encoding = self.feature_extractor(image, return_tensors="pt").to(config.docio_device) - # Get the keys of the encoding dictionary - # keys = encoding.keys() - # print(f"keys: {keys}") - - # with torch.no_grad(): # to onnx - with torch.inference_mode(): # to onnx - outputs = self.model(**encoding) - - # Post-process the object detection outputs using the feature extractor - results = self.feature_extractor.post_process_object_detection( - outputs, threshold=self.table_detection_conf, target_sizes=[(height, width)] - )[0] - - # save_detection_file = os.path.join( - # save_detection_dir, f"{pdf_file.split('/')[-1][:-4]}_p{i+1:03d}.png" - # ) - - # print(f"model.config.id2label: {model.config.id2label}") - - # table_bboxes = self._table_bbox_results( - # image, - # self.model, - # results["scores"], - # results["labels"], - # results["boxes"], - # # save_detection_file, - # ) - table_bboxes = results["boxes"].tolist() - # PDF Parsing - pages_chars = [] - pages_words = [] - char_widths = [] - char_heights = [] - sizes = [] - full_text = "" - for c in page.chars: - char_widths.append(c["width"]) - char_heights.append(c["height"]) - sizes.append(c["size"]) - - charW_med = np.median(np.array(char_widths)) - # charW_avg = np.sum(np.array(char_widths)) / len(char_widths) - # print(f"char_width_median: {charW_med}") - # print(f"char_width_avg: {charW_avg}") - - charH_med = np.median(np.array(char_heights)) - # charH_avg = np.sum(np.array(char_heights)) / len(char_heights) - # print(f"char_height_median: {charH_med}") - # print(f"char_height_avg: {charH_avg}") - - size_med = np.median(np.array(sizes)) - # size_avg = np.sum(np.array(sizes)) / len(sizes) - # print(f"size_median: {size_med}") - # print(f"size_avg: {size_avg}") - - # for i, page in enumerate(pdf.pages): - try: - page = page.within_bbox(bbox=(0, 0, page.width, page.height)) - # print(f"page.width, page.height: {page.width, page.height}") # (595.276, 807.874) - except Exception: - pass - selected_w_info = [] - words = page.extract_words() - for w in words: - selected_w_info.append( - { - # "page_number": i + 1, - "text": w["text"], - "size": w["bottom"] - w["top"], - "x0": w["x0"], - "x1": w["x1"], - "y0": page.height - w["bottom"], - "y1": (page.height - w["bottom"]) + (w["bottom"]) - w["top"], # y0 + size - "top": w["top"], - "bottom": w["bottom"], - "doctop": w["doctop"], - } - ) - - selected_info = [] - for c in page.chars: - selected_info.append( - { - "page_number": c["page_number"], - "text": c["text"], - "size": c["size"], - "adv": c["adv"], - # "upright": c["upright"], - "height": c["height"], - "width": c["width"], - "x0": c["x0"], - "x1": c["x1"], - "y0": c["y0"], - "y1": c["y1"], - "top": c["top"], - "bottom": c["bottom"], - "doctop": c["doctop"], - } - ) - horizontal_bottom = selected_info[0]["bottom"] - horizontal_top = selected_info[0]["top"] - char_right = selected_info[0]["x1"] - # char_left = selected_info[0]["x0"] - - table_char_idxes_groups = [] - - # tmp_text = "" - # print(f"page_tables: {table_bboxes}") - - # image bbox enlargement - based on intersection of extract_words bbox - for page_table in table_bboxes: - # print(f"\ntable {j}") - - # (xmin, ymin) == top left (from image bbox) - xmin, ymin, xmax, ymax = page_table - - # convert to pdf bbox - # (xmin, ymin) == bottom left (pdf bbox) - ymin2 = page.height - ymax - ymax2 = 0 + ymin2 + (ymax - ymin) - - xminL, yminL, xmaxL, ymaxL = xmin, ymin2, xmax, ymax2 - # print(f"xmin, ymin2, xmax, ymax2: {xmin, ymin2, xmax, ymax2}") - for w_ in selected_w_info: - """ - (x0, y1) (x1, y1) - 1 ___________ 2 - | | - | | - 0 |___________| 3 - (x0, y0) (x1, y0) - """ - # check if either word textbbox coor inside the table bbox - # if yes, enlarge the table bbox - if ( - ( - (w_["x0"] >= xmin and w_["x0"] <= xmax) - and (w_["y0"] >= ymin2 and w_["y0"] <= ymax2) - ) # bottom left - x0, y0 - or ( - (w_["x0"] >= xmin and w_["x0"] <= xmax) - and (w_["y1"] >= ymin2 and w_["y1"] <= ymax2) - ) # top left - x0, y1 - or ( - (w_["x1"] >= xmin and w_["x1"] <= xmax) - and (w_["y1"] >= ymin2 and w_["y1"] <= ymax2) - ) # top right - x1, y1 - or ( - (w_["x1"] >= xmin and w_["x1"] <= xmax) - and (w_["y0"] >= ymin2 and w_["y0"] <= ymax2) - ) # bottom right - x1, y0 - ): - xminL = min(w_["x0"], xminL) - yminL = min(w_["y0"], yminL) - xmaxL = max(w_["x1"], xmaxL) - ymaxL = max(w_["y1"], ymaxL) - - # print(f"xminL, yminL, xmaxL, ymaxL: {xminL, yminL, xmaxL, ymaxL}") - - table_char_idxes = [] - xmin, ymin, xmax, ymax = xminL, yminL, xmaxL, ymaxL - for k, c_ in enumerate(selected_info): - # check if either char bbox coor inside the enlarged table bbox - if ( - ( - (c_["x0"] >= xmin and c_["x0"] <= xmax) - and (c_["y0"] >= ymin2 and c_["y0"] <= ymax2) - ) # bottom left - x0, y0 - or ( - (c_["x0"] >= xmin and c_["x0"] <= xmax) - and (c_["y1"] >= ymin2 and c_["y1"] <= ymax2) - ) # top left - x0, y1 - or ( - (c_["x1"] >= xmin and c_["x1"] <= xmax) - and (c_["y1"] >= ymin2 and c_["y1"] <= ymax2) - ) # top right - x1, y1 - or ( - (c_["x1"] >= xmin and c_["x1"] <= xmax) - and (c_["y0"] >= ymin2 and c_["y0"] <= ymax2) - ) # bottom right - x1, y0 - ): - table_char_idxes.append(k) - # tmp_text += c_["text"] - # print(f"tmp_text: {tmp_text}") - table_char_idxes_groups.append(table_char_idxes) - table_start_idxes = [] - table_end_idxes = [] - for table_char_idxes_group in table_char_idxes_groups: - if len(table_char_idxes_group) > 0: - table_start_idxes.append(table_char_idxes_group[0]) - table_end_idxes.append(table_char_idxes_group[-1]) - - # print(f"table_start_idxes: {table_start_idxes}") - # print(f"table_end_idxes: {table_end_idxes}") - - for k, c_ in enumerate(selected_info): - if k in table_start_idxes: - full_text += "\n" - - if c_["top"] > (horizontal_bottom): - if (c_["bottom"] - horizontal_bottom) > page.height * 0.3: - # ex: CONTENTS - full_text += "\n" + c_["text"] - # elif (c_["x0"] - char_right) > charW_med * c_["adv"] * 1.9: - # # next word - # full_text += ("" if c_["text"] == " " else " ") + c_["text"] - elif (c_["x0"] - char_right) > charW_med * c_["adv"] * 1.9: - # next word - full_text += ("" if c_["text"] == " " else " ") + c_["text"] - else: - # next paragraph - full_text += ( - "\n\n" if (c_["top"] - horizontal_bottom) > charH_med * 0.8 else "\n" - ) + c_["text"] - elif c_["bottom"] < horizontal_top: - if c_["x0"] < char_right: - # ex: ANNUAL REPORT, Other Listed Company Directorship(s) - full_text += "\n\n" + c_["text"] - else: - # next column - full_text += "\n\n" + c_["text"] - elif c_["size"] > size_med * 1.7: # bigger text - # full_text += "<>" - if (c_["x0"] - char_right) > (c_["size"] / c_["adv"]): - # next word - full_text += ("" if c_["text"] == " " else " ") + c_["text"] - else: - full_text += c_["text"] # normal next char - elif c_["size"] < size_med * 1.3: # smaller text - # full_text += "<>" - if (c_["x0"] - char_right) > charH_med * 0.2: - # next word - full_text += ("" if c_["text"] == " " else " ") + c_["text"] - else: - full_text += c_["text"] # normal next char - # elif (c_["x0"] - char_right) > charW_med * 0.2: - elif (c_["x0"] - char_right) > charW_med * c_["adv"] * 1.9: - # next word - full_text += ("" if c_["text"] == " " else " ") + c_["text"] - else: - full_text += c_["text"] # normal next char - - if k in table_end_idxes: - full_text += "\n
    " - - horizontal_bottom = c_["bottom"] - horizontal_top = c_["top"] - char_right = c_["x1"] - # char_left = c_["x0"] - - pages_chars += selected_info - pages_words += selected_w_info - - # df = pd.DataFrame.from_records(pages_chars) - # df.to_csv("page_plumber.csv", float_format="%.2f") - # df = pd.DataFrame.from_records(pages_words) - # df.to_csv("page_plumber_text_words.csv", float_format="%.2f") - - return full_text - - def _extract_images_from_page(self, page: pdfplumber.page.Page) -> str: - """Extract images from page and get the text with RapidOCR.""" - if not self.extract_images: - return "" - - images = [] - for img in page.images: - if img["stream"]["Filter"].name in _PDF_FILTER_WITHOUT_LOSS: - images.append( - np.frombuffer(img["stream"].get_data(), dtype=np.uint8).reshape( - img["stream"]["Height"], img["stream"]["Width"], -1 - ) - ) - elif img["stream"]["Filter"].name in _PDF_FILTER_WITH_LOSS: - images.append(img["stream"].get_data()) - else: - warnings.warn("Unknown PDF Filter!", stacklevel=2) - - return extract_from_images_with_rapidocr(images) - - -class PDFPlumberLoader(BasePDFLoader): - """Load `PDF` files using `pdfplumber`.""" - - def __init__( - self, - file_path: str, - text_kwargs: Mapping[str, Any] | None = None, - dedupe: bool = False, - headers: dict | None = None, - extract_images: bool = False, - ) -> None: - """Initialize with a file path.""" - try: - import pdfplumber # noqa:F401 - except ImportError as e: - raise ImportError( - "pdfplumber package not found, please install it with " "`pip install pdfplumber`" - ) from e - - super().__init__(file_path, headers=headers) - self.text_kwargs = text_kwargs or {} - self.dedupe = dedupe - self.extract_images = extract_images - - def load(self) -> list[Document]: - """Load file.""" - - parser = PDFPlumberParser( - text_kwargs=self.text_kwargs, - dedupe=self.dedupe, - extract_images=self.extract_images, - ) - blob = Blob.from_path(self.file_path) - return parser.parse(blob) diff --git a/services/docio/src/docio/langchain/tsvloader.py b/services/docio/src/docio/langchain/tsvloader.py deleted file mode 100644 index 83d5dfa..0000000 --- a/services/docio/src/docio/langchain/tsvloader.py +++ /dev/null @@ -1,136 +0,0 @@ -import csv -from io import TextIOWrapper -from os.path import join -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Dict, Iterator, Optional, Sequence, Union - -import pandas as pd -from langchain_community.document_loaders.base import BaseLoader -from langchain_community.document_loaders.helpers import detect_file_encodings -from langchain_core.documents import Document -from loguru import logger - - -class TSVLoader(BaseLoader): - """ - Load a TSV file into a list of Documents. - - Each document represents one row of the TSV file. Every row is converted into a - key/value pair and outputted to a new line in the document's page_content. - - The source for each document loaded from the TSV file is set to the value of the - `file_path` argument for all documents by default. You can override this by setting - the `source_column` argument to the name of a column in the TSV file. The source of - each document will then be set to the value of the column with the name specified in - `source_column`. - - Output Example: - .. code-block:: txt - - column1: value1 - column2: value2 - column3: value3 - """ - - def __init__( - self, - file_path: Union[str, Path], - source_column: Optional[str] = None, - metadata_columns: Sequence[str] = (), - csv_args: Optional[Dict] = None, - encoding: Optional[str] = None, - autodetect_encoding: bool = False, - ): - """ - Initialize the TSVLoader. - - Args: - file_path: The path to the TSV file. - source_column: The name of the column in the TSV file to use as the source. - Optional. Defaults to None. - metadata_columns: A sequence of column names to use as metadata. Optional. - csv_args: A dictionary of arguments to pass to the csv.DictReader. - Optional. Defaults to None. - encoding: The encoding of the TSV file. Optional. Defaults to None. - autodetect_encoding: Whether to try to autodetect the file encoding. - """ - self.file_path = file_path - self.source_column = source_column - self.metadata_columns = metadata_columns - self.encoding = encoding - self.csv_args = csv_args or {} - self.autodetect_encoding = autodetect_encoding - - def lazy_load(self) -> Iterator[Document]: - """ - Lazily load documents from the TSV file. - - Yields: - Document: A document representing a row in the TSV file. - """ - try: - with open(self.file_path, newline="", encoding=self.encoding) as csvfile: - yield from self.__read_file(csvfile) - except UnicodeDecodeError as e: - if self.autodetect_encoding: - detected_encodings = detect_file_encodings(self.file_path) - for encoding in detected_encodings: - try: - with open( - self.file_path, newline="", encoding=encoding.encoding - ) as csvfile: - yield from self.__read_file(csvfile) - break - except UnicodeDecodeError: - continue - else: - raise RuntimeError(f"Error loading {self.file_path}") from e - except Exception as e: - raise RuntimeError(f"Error loading {self.file_path}") from e - - def __read_file(self, tsvfile: TextIOWrapper) -> Iterator[Document]: - """ - Read the TSV file and convert each row into a Document. - - Args: - tsvfile: A file object representing the TSV file. - - Yields: - Document: A document representing a row in the TSV file. - """ - with TemporaryDirectory() as tmp_dir_path: - tmp_csv_path = join(tmp_dir_path, "tmpfile.csv") - content = pd.read_csv(tsvfile, sep="\t") - content.to_csv(tmp_csv_path, index=False) - - logger.debug(f"Loading from temporary file: {tmp_csv_path}") - - with open(tmp_csv_path, "r") as tmp_csv: - csv_reader = csv.DictReader(tmp_csv, **self.csv_args) - - for i, row in enumerate(csv_reader): - try: - source = ( - row[self.source_column] - if self.source_column is not None - else str(self.file_path) - ) - except KeyError as e: - raise ValueError( - f"Source column '{self.source_column}' not found in TSV file." - ) from e - content = "\n".join( - f"{k.strip()}: {v.strip() if v is not None else v}" - for k, v in row.items() - if k not in self.metadata_columns - ) - metadata = {"source": source, "row": i} - for col in self.metadata_columns: - try: - metadata[col] = row[col] - except KeyError as e: - raise ValueError( - f"Metadata column '{col}' not found in TSV file." - ) from e - yield Document(page_content=content, metadata=metadata) diff --git a/services/docio/src/docio/protocol.py b/services/docio/src/docio/protocol.py deleted file mode 100644 index 93e41ef..0000000 --- a/services/docio/src/docio/protocol.py +++ /dev/null @@ -1,8 +0,0 @@ -from pydantic import BaseModel - - -class Document(BaseModel): - """Document class for compatibility with LangChain.""" - - page_content: str - metadata: dict = {} diff --git a/services/docio/src/docio/routers/loader.py b/services/docio/src/docio/routers/loader.py deleted file mode 100644 index faaa49f..0000000 --- a/services/docio/src/docio/routers/loader.py +++ /dev/null @@ -1,118 +0,0 @@ -import sys -from os.path import join, splitext -from tempfile import TemporaryDirectory - -from fastapi import APIRouter, File, UploadFile -from fastapi.exceptions import RequestValidationError -from langchain_community import document_loaders as loaders -from loguru import logger -from pydantic import SecretStr -from pydantic_settings import BaseSettings, SettingsConfigDict - -from docio.langchain.jsonloader import JSONLoader -from docio.langchain.pdfplumber import PDFPlumberLoader -from docio.langchain.tsvloader import TSVLoader -from docio.protocol import Document - - -class Config(BaseSettings): - model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore") - s3_key: str = "minioadmin" # MinIO key - s3_secret: SecretStr = "fasts3xystoragelabel" - s3_url: str = "http://10.103.68.103:9000" - unstructuredio_url: str = "http://unstructuredio:6989/general/v0/general" - unstructuredio_api_key: SecretStr = "ellm" - - @property - def s3_secret_plain(self): - return self.s3_secret.get_secret_value() - - @property - def unstructuredio_api_key_plain(self): - return self.unstructuredio_api_key.get_secret_value() - - -config = Config() -router = APIRouter() - - -# build a table mapping all non-printable characters to None -NOPRINT_TRANS_TABLE = { - i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable() and chr(i) != "\n" -} - - -def make_printable(s: str) -> str: - """ - Replace non-printable characters in a string using - `translate()` that removes characters that map to None. - - # https://stackoverflow.com/a/54451873 - """ - return s.translate(NOPRINT_TRANS_TABLE) - - -def load_file(file_path: str) -> list[Document]: - ext = splitext(file_path)[1].lower() - if ext in (".txt", ".md"): - loader = loaders.TextLoader(file_path) - elif ext == ".pdf": - loader = PDFPlumberLoader(file_path) - elif ext == ".csv": - loader = loaders.CSVLoader(file_path) - elif ext == ".tsv": - loader = TSVLoader(file_path) - elif ext == ".json": - loader = JSONLoader(file_path, text_content=False) - elif ext == ".jsonl": - loader = JSONLoader(file_path, text_content=False, json_lines=True) - else: - raise ValueError(f'Unsupported file type: "{ext}"') - - documents = loader.load() - logger.info(f"docio {str(documents)}") - documents = [ - Document( - # TODO: Probably can use regex for this - # Replace vertical tabs, form feed, Unicode replacement character - # page_content=d.page_content.replace("\x0c", " ") - # .replace("\x0b", " ") - # .replace("\uFFFD", ""), - # For now we use a more aggressive strategy - page_content=make_printable(d.page_content), - metadata={"page": d.metadata.get("page", 0), **d.metadata}, - ) - for d in documents - ] - return documents - - -@router.post("/v1/load_file") -async def load_file_api( - file: UploadFile = File( - description="File to be uploaded in the form of `multipart/form-data`." - ), -) -> list[Document]: - logger.info( - "Upload type: {content_type} {filename}", - content_type=file.content_type, - filename=file.filename, - ) - try: - ext = splitext(file.filename)[1] - with TemporaryDirectory() as tmp_dir_path: - tmp_path = join(tmp_dir_path, f"tmpfile{ext}") - with open(tmp_path, "wb") as tmp: - tmp.write(await file.read()) - tmp.flush() - logger.trace("Loading from temporary file: {name}", name=tmp_path) - documents = load_file(tmp_path) - for d in documents: - d.metadata["source"] = file.filename - d.metadata["document_id"] = file.filename - return documents - except RequestValidationError: - raise - except Exception: - logger.exception("Failed to load file.") - raise diff --git a/services/docio/src/docio/utils/logging.py b/services/docio/src/docio/utils/logging.py deleted file mode 100644 index 0254382..0000000 --- a/services/docio/src/docio/utils/logging.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Configure handlers and formats for application loggers. - -https://gist.github.com/nkhitrov/a3e31cfcc1b19cba8e1b626276148c49 -""" - -import inspect -import logging - -from loguru import logger - - -class InterceptHandler(logging.Handler): - def emit(self, record: logging.LogRecord) -> None: - # https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging - # Get corresponding Loguru level if it exists. - level: str | int - try: - level = logger.level(record.levelname).name - except ValueError: - level = record.levelno - - # Find caller from where originated the logged message. - frame, depth = inspect.currentframe(), 0 - while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__): - frame = frame.f_back - depth += 1 - - logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) - - -def replace_logging_handlers(names: list[str], include_submodules: bool = True): - """ - Replaces logging handlers with `InterceptHandler` for use with `loguru`. - """ - if not isinstance(names, (list, tuple)): - raise TypeError("`names` should be a list of str.") - logger_names = [] - for name in names: - if include_submodules: - logger_names += [n for n in logging.root.manager.loggerDict if n.startswith(name)] - else: - logger_names += [n for n in logging.root.manager.loggerDict if n == name] - logger.info(f"Replacing logger handlers: {logger_names}") - loggers = (logging.getLogger(n) for n in logger_names) - for lgg in loggers: - lgg.handlers = [InterceptHandler()] - # logging.getLogger(name).handlers = [InterceptHandler()] - - -def setup_logger_sinks(): - import sys - from copy import deepcopy - - from docio.config import LOGS - - logger.remove() - log_cfg = deepcopy(LOGS) - stderr_cfg = log_cfg.pop("stderr", None) - if stderr_cfg is not None: - logger.add(sys.stderr, **stderr_cfg) - for path, cfg in log_cfg.items(): - logger.add(sink=path, **cfg) - logger.info(f"Writing logs to: {path}") diff --git a/services/docio/src/docio/version.py b/services/docio/src/docio/version.py deleted file mode 100644 index f102a9c..0000000 --- a/services/docio/src/docio/version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.0.1"