|
6 | 6 | import traceback |
7 | 7 |
|
8 | 8 | from abc import ABC |
| 9 | +from datetime import datetime, timezone |
9 | 10 | from typing import Any |
10 | 11 |
|
11 | 12 | from tqdm import tqdm |
@@ -399,7 +400,7 @@ def get_memory( |
399 | 400 |
|
400 | 401 | if not all(isinstance(info[field], str) for field in required_fields): |
401 | 402 | raise ValueError("user_id and session_id must be strings") |
402 | | - |
| 403 | + scene_data = self._complete_chat_time(scene_data, type) |
403 | 404 | list_scene_data_info = self.get_scene_data_info(scene_data, type) |
404 | 405 |
|
405 | 406 | memory_list = [] |
@@ -508,6 +509,31 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: |
508 | 509 |
|
509 | 510 | return results |
510 | 511 |
|
| 512 | + def _complete_chat_time(self, scene_data: list[list[dict]], type: str): |
| 513 | + if type != "chat": |
| 514 | + return scene_data |
| 515 | + complete_scene_data = [] |
| 516 | + |
| 517 | + for items in scene_data: |
| 518 | + chat_time_value = None |
| 519 | + |
| 520 | + for item in items: |
| 521 | + if "chat_time" in item: |
| 522 | + chat_time_value = item["chat_time"] |
| 523 | + break |
| 524 | + |
| 525 | + if chat_time_value is None: |
| 526 | + session_date = datetime.now(timezone.utc) |
| 527 | + date_format = "%I:%M %p on %d %B, %Y UTC" |
| 528 | + chat_time_value = session_date.strftime(date_format) |
| 529 | + |
| 530 | + for i in range(len(items)): |
| 531 | + if "chat_time" not in items[i]: |
| 532 | + items[i]["chat_time"] = chat_time_value |
| 533 | + |
| 534 | + complete_scene_data.append(items) |
| 535 | + return complete_scene_data |
| 536 | + |
511 | 537 | def _process_doc_data(self, scene_data_info, info, **kwargs): |
512 | 538 | mode = kwargs.get("mode", "fine") |
513 | 539 | if mode == "fast": |
|
0 commit comments