|
1 | 1 | import asyncio |
| 2 | +import json |
| 3 | +import logging |
| 4 | +import os |
2 | 5 |
|
3 | 6 | from dotenv import load_dotenv |
4 | 7 | from sqlalchemy import select |
|
10 | 13 | from fastapi_app.postgres_engine import create_postgres_engine_from_env |
11 | 14 | from fastapi_app.postgres_models import Item |
12 | 15 |
|
| 16 | +logger = logging.getLogger("ragapp") |
13 | 17 |
|
14 | | -async def update_embeddings(): |
| 18 | + |
| 19 | +async def update_embeddings(in_seed_data=False): |
15 | 20 | azure_credential = await get_azure_credentials() |
16 | 21 | engine = await create_postgres_engine_from_env(azure_credential) |
17 | 22 | openai_embed_client = await create_openai_embed_client(azure_credential) |
18 | 23 | common_params = await common_parameters() |
19 | 24 |
|
20 | | - async with async_sessionmaker(engine, expire_on_commit=False)() as session: |
21 | | - async with session.begin(): |
22 | | - items = (await session.scalars(select(Item))).all() |
23 | | - |
24 | | - for item in items: |
25 | | - item.embedding = await compute_text_embedding( |
| 25 | + embedding_column = "" |
| 26 | + OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST") |
| 27 | + if OPENAI_EMBED_HOST == "azure": |
| 28 | + embedding_column = os.getenv("AZURE_OPENAI_EMBEDDING_COLUMN", "embedding_ada002") |
| 29 | + elif OPENAI_EMBED_HOST == "ollama": |
| 30 | + embedding_column = os.getenv("OLLAMA_EMBEDDING_COLUMN", "embedding_nomic") |
| 31 | + else: |
| 32 | + embedding_column = os.getenv("OPENAICOM_EMBEDDING_COLUMN", "embedding_ada002") |
| 33 | + logger.info(f"Updating embeddings in column: {embedding_column}") |
| 34 | + if in_seed_data: |
| 35 | + current_dir = os.path.dirname(os.path.realpath(__file__)) |
| 36 | + items = [] |
| 37 | + with open(os.path.join(current_dir, "seed_data.json")) as f: |
| 38 | + catalog_items = json.load(f) |
| 39 | + for catalog_item in catalog_items: |
| 40 | + item = Item( |
| 41 | + id=catalog_item["id"], |
| 42 | + type=catalog_item["type"], |
| 43 | + brand=catalog_item["brand"], |
| 44 | + name=catalog_item["name"], |
| 45 | + description=catalog_item["description"], |
| 46 | + price=catalog_item["price"], |
| 47 | + embedding_ada002=catalog_item["embedding_ada002"], |
| 48 | + embedding_nomic=catalog_item.get("embedding_nomic"), |
| 49 | + ) |
| 50 | + embedding = await compute_text_embedding( |
26 | 51 | item.to_str_for_embedding(), |
27 | 52 | openai_client=openai_embed_client, |
28 | 53 | embed_model=common_params.openai_embed_model, |
| 54 | + embed_deployment=common_params.openai_embed_deployment, |
29 | 55 | embedding_dimensions=common_params.openai_embed_dimensions, |
30 | 56 | ) |
| 57 | + setattr(item, embedding_column, embedding) |
| 58 | + items.append(item) |
| 59 | + # write to the file |
| 60 | + with open(os.path.join(current_dir, "seed_data.json"), "w") as f: |
| 61 | + json.dump([item.to_dict(include_embedding=True) for item in items], f, indent=4) |
| 62 | + return |
31 | 63 |
|
| 64 | + async with async_sessionmaker(engine, expire_on_commit=False)() as session: |
| 65 | + async with session.begin(): |
| 66 | + items = (await session.scalars(select(Item))).all() |
| 67 | + |
| 68 | + for item in items: |
| 69 | + setattr( |
| 70 | + item, |
| 71 | + embedding_column, |
| 72 | + await compute_text_embedding( |
| 73 | + item.to_str_for_embedding(), |
| 74 | + openai_client=openai_embed_client, |
| 75 | + embed_model=common_params.openai_embed_model, |
| 76 | + embed_deployment=common_params.openai_embed_deployment, |
| 77 | + embedding_dimensions=common_params.openai_embed_dimensions, |
| 78 | + ), |
| 79 | + ) |
32 | 80 | await session.commit() |
33 | 81 |
|
34 | 82 |
|
35 | 83 | if __name__ == "__main__": |
| 84 | + logging.basicConfig(level=logging.WARNING) |
| 85 | + logger.setLevel(logging.INFO) |
36 | 86 | load_dotenv(override=True) |
37 | 87 | asyncio.run(update_embeddings()) |
0 commit comments