Skip to content

Commit 504c6d8

Browse files
committed
chore(core.db): add missing type annotations
1 parent b17f50f commit 504c6d8

File tree

11 files changed

+167
-132
lines changed

11 files changed

+167
-132
lines changed

astrbot/core/db/__init__.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(self) -> None:
3535
self.engine, class_=AsyncSession, expire_on_commit=False
3636
)
3737

38-
async def initialize(self):
38+
async def initialize(self) -> None:
3939
"""初始化数据库连接"""
4040
pass
4141

@@ -100,7 +100,7 @@ async def get_conversations(
100100
...
101101

102102
@abc.abstractmethod
103-
async def get_conversation_by_id(self, cid: str) -> ConversationV2:
103+
async def get_conversation_by_id(self, cid: str) -> ConversationV2 | None:
104104
"""Get a specific conversation by its ID."""
105105
...
106106

@@ -118,7 +118,7 @@ async def get_filtered_conversations(
118118
page_size: int = 20,
119119
platform_ids: list[str] | None = None,
120120
search_query: str = "",
121-
**kwargs,
121+
**kwargs: T.Any, # noqa: ANN401
122122
) -> tuple[list[ConversationV2], int]:
123123
"""Get conversations filtered by platform IDs and search query."""
124124
...
@@ -145,7 +145,7 @@ async def update_conversation(
145145
title: str | None = None,
146146
persona_id: str | None = None,
147147
content: list[dict] | None = None,
148-
) -> None:
148+
) -> ConversationV2 | None:
149149
"""Update a conversation's history."""
150150
...
151151

@@ -167,7 +167,7 @@ async def insert_platform_message_history(
167167
content: dict,
168168
sender_id: str | None = None,
169169
sender_name: str | None = None,
170-
) -> None:
170+
) -> PlatformMessageHistory | None:
171171
"""Insert a new platform message history record."""
172172
...
173173

@@ -195,12 +195,12 @@ async def insert_attachment(
195195
path: str,
196196
type: str,
197197
mime_type: str,
198-
):
198+
) -> Attachment:
199199
"""Insert a new attachment record."""
200200
...
201201

202202
@abc.abstractmethod
203-
async def get_attachment_by_id(self, attachment_id: str) -> Attachment:
203+
async def get_attachment_by_id(self, attachment_id: str) -> Attachment | None:
204204
"""Get an attachment by its ID."""
205205
...
206206

@@ -216,7 +216,7 @@ async def insert_persona(
216216
...
217217

218218
@abc.abstractmethod
219-
async def get_persona_by_id(self, persona_id: str) -> Persona:
219+
async def get_persona_by_id(self, persona_id: str) -> Persona | None:
220220
"""Get a persona by its ID."""
221221
...
222222

@@ -249,7 +249,9 @@ async def insert_preference_or_update(
249249
...
250250

251251
@abc.abstractmethod
252-
async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference:
252+
async def get_preference(
253+
self, scope: str, scope_id: str, key: str
254+
) -> Preference | None:
253255
"""Get a preference by scope ID and key."""
254256
...
255257

astrbot/core/db/migration/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ async def do_migration_v4(
3232
db_helper: BaseDatabase,
3333
platform_id_map: dict[str, dict[str, str]],
3434
astrbot_config: AstrBotConfig,
35-
):
35+
) -> None:
3636
"""
3737
执行数据库迁移
3838
迁移旧的 webchat_conversation 表到新的 conversation 表。

astrbot/core/db/migration/migra_3_to_4.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def get_platform_type(
3737

3838
async def migration_conversation_table(
3939
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
40-
):
40+
) -> None:
4141
db_helper_v3 = SQLiteV3DatabaseV3(
4242
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
4343
)
@@ -91,7 +91,7 @@ async def migration_conversation_table(
9191

9292
async def migration_platform_table(
9393
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
94-
):
94+
) -> None:
9595
db_helper_v3 = SQLiteV3DatabaseV3(
9696
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
9797
)
@@ -166,7 +166,7 @@ async def migration_platform_table(
166166

167167
async def migration_webchat_data(
168168
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
169-
):
169+
) -> None:
170170
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
171171
db_helper_v3 = SQLiteV3DatabaseV3(
172172
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
@@ -219,7 +219,7 @@ async def migration_webchat_data(
219219

220220
async def migration_persona_data(
221221
db_helper: BaseDatabase, astrbot_config: AstrBotConfig
222-
):
222+
) -> None:
223223
"""
224224
迁移 Persona 数据到新的表中。
225225
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
@@ -261,7 +261,7 @@ async def migration_persona_data(
261261

262262
async def migration_preferences(
263263
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
264-
):
264+
) -> None:
265265
# 1. global scope migration
266266
keys = [
267267
"inactivated_llm_tools",

astrbot/core/db/migration/migra_45_to_46.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from astrbot.core.umop_config_router import UmopConfigRouter
44

55

6-
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
6+
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter) -> None:
77
abconf_data = acm.abconf_data
88

99
if not isinstance(abconf_data, dict):

astrbot/core/db/migration/shared_preferences_v3.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77

88

99
class SharedPreferences:
10-
def __init__(self, path=None):
10+
def __init__(self, path: str | None = None) -> None:
1111
if path is None:
1212
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
1313
self.path = path
1414
self._data = self._load_preferences()
1515

16-
def _load_preferences(self):
16+
def _load_preferences(self) -> dict:
1717
if os.path.exists(self.path):
1818
try:
1919
with open(self.path) as f:
@@ -22,24 +22,24 @@ def _load_preferences(self):
2222
os.remove(self.path)
2323
return {}
2424

25-
def _save_preferences(self):
25+
def _save_preferences(self) -> None:
2626
with open(self.path, "w") as f:
2727
json.dump(self._data, f, indent=4, ensure_ascii=False)
2828
f.flush()
2929

30-
def get(self, key, default: _VT = None) -> _VT:
30+
def get(self, key: str, default: _VT = None) -> _VT:
3131
return self._data.get(key, default)
3232

33-
def put(self, key, value):
33+
def put(self, key: str, value: object) -> None:
3434
self._data[key] = value
3535
self._save_preferences()
3636

37-
def remove(self, key):
37+
def remove(self, key: str) -> None:
3838
if key in self._data:
3939
del self._data[key]
4040
self._save_preferences()
4141

42-
def clear(self):
42+
def clear(self) -> None:
4343
self._data.clear()
4444
self._save_preferences()
4545

astrbot/core/db/migration/sqlite_v3.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _get_conn(self, db_path: str) -> sqlite3.Connection:
126126
conn.text_factory = str
127127
return conn
128128

129-
def _exec_sql(self, sql: str, params: tuple = None):
129+
def _exec_sql(self, sql: str, params: tuple | None = None) -> None:
130130
conn = self.conn
131131
try:
132132
c = self.conn.cursor()
@@ -143,7 +143,7 @@ def _exec_sql(self, sql: str, params: tuple = None):
143143

144144
conn.commit()
145145

146-
def insert_platform_metrics(self, metrics: dict):
146+
def insert_platform_metrics(self, metrics: dict) -> None:
147147
for k, v in metrics.items():
148148
self._exec_sql(
149149
"""
@@ -152,7 +152,7 @@ def insert_platform_metrics(self, metrics: dict):
152152
(k, v, int(time.time())),
153153
)
154154

155-
def insert_llm_metrics(self, metrics: dict):
155+
def insert_llm_metrics(self, metrics: dict) -> None:
156156
for k, v in metrics.items():
157157
self._exec_sql(
158158
"""
@@ -225,7 +225,9 @@ def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
225225

226226
return Stats(platform, [], [])
227227

228-
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
228+
def get_conversation_by_user_id(
229+
self, user_id: str, cid: str
230+
) -> Conversation | None:
229231
try:
230232
c = self.conn.cursor()
231233
except sqlite3.ProgrammingError:
@@ -246,7 +248,7 @@ def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
246248

247249
return Conversation(*res)
248250

249-
def new_conversation(self, user_id: str, cid: str):
251+
def new_conversation(self, user_id: str, cid: str) -> None:
250252
history = "[]"
251253
updated_at = int(time.time())
252254
created_at = updated_at
@@ -257,7 +259,7 @@ def new_conversation(self, user_id: str, cid: str):
257259
(user_id, cid, history, updated_at, created_at),
258260
)
259261

260-
def get_conversations(self, user_id: str) -> tuple:
262+
def get_conversations(self, user_id: str) -> list[Conversation]:
261263
try:
262264
c = self.conn.cursor()
263265
except sqlite3.ProgrammingError:
@@ -284,7 +286,7 @@ def get_conversations(self, user_id: str) -> tuple:
284286
)
285287
return conversations
286288

287-
def update_conversation(self, user_id: str, cid: str, history: str):
289+
def update_conversation(self, user_id: str, cid: str, history: str) -> None:
288290
"""更新对话,并且同时更新时间"""
289291
updated_at = int(time.time())
290292
self._exec_sql(
@@ -294,23 +296,25 @@ def update_conversation(self, user_id: str, cid: str, history: str):
294296
(history, updated_at, user_id, cid),
295297
)
296298

297-
def update_conversation_title(self, user_id: str, cid: str, title: str):
299+
def update_conversation_title(self, user_id: str, cid: str, title: str) -> None:
298300
self._exec_sql(
299301
"""
300302
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
301303
""",
302304
(title, user_id, cid),
303305
)
304306

305-
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
307+
def update_conversation_persona_id(
308+
self, user_id: str, cid: str, persona_id: str
309+
) -> None:
306310
self._exec_sql(
307311
"""
308312
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
309313
""",
310314
(persona_id, user_id, cid),
311315
)
312316

313-
def delete_conversation(self, user_id: str, cid: str):
317+
def delete_conversation(self, user_id: str, cid: str) -> None:
314318
self._exec_sql(
315319
"""
316320
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
@@ -381,11 +385,11 @@ def get_filtered_conversations(
381385
self,
382386
page: int = 1,
383387
page_size: int = 20,
384-
platforms: list[str] = None,
385-
message_types: list[str] = None,
386-
search_query: str = None,
387-
exclude_ids: list[str] = None,
388-
exclude_platforms: list[str] = None,
388+
platforms: list[str] | None = None,
389+
message_types: list[str] | None = None,
390+
search_query: str | None = None,
391+
exclude_ids: list[str] | None = None,
392+
exclude_platforms: list[str] | None = None,
389393
) -> tuple[list[dict[str, Any]], int]:
390394
"""获取筛选后的对话列表"""
391395
try:

0 commit comments

Comments
 (0)