22from dataclasses import dataclass , field
33from pathlib import Path
44from textwrap import indent
5+ from typing import Protocol
56
67import tree_sitter_lua
78import tree_sitter_markdown
89from tree_sitter import Language , Parser
910
1011
12+ class LuaType (Protocol ):
13+ @property
14+ def value (self ) -> str : ...
15+
16+ def add (self , value : str ) -> None : ...
17+
18+ def name (self ) -> str : ...
19+
20+ def to_user (self ) -> str | None : ...
21+
22+ def to_str (self ) -> str : ...
23+
24+
1125@dataclass (frozen = True )
1226class LuaAlias :
1327 value : str
@@ -21,21 +35,17 @@ def name(self) -> str:
2135 return self .value .split ()[1 ]
2236
2337 def to_user (self ) -> str | None :
24- if not self .config ():
38+ simple = self .name ().split ("." )[- 1 ]
39+ if simple != "Configs" :
2540 return None
2641
27- def user (s : str ) -> str :
28- return s .replace (".Config" , ".UserConfig" )
29-
30- lines : list [str ] = [user (self .value )]
31- for s in self .options :
32- option = user (s )
33- lines .append (option )
42+ lines : list [str ] = []
43+ values = [self .value ] + self .options
44+ for value in values :
45+ value = value .replace (".Config" , ".UserConfig" )
46+ lines .append (value )
3447 return "\n " .join (lines )
3548
36- def config (self ) -> bool :
37- return self .name ().split ("." )[- 1 ] == "Configs"
38-
3949 def to_str (self ) -> str :
4050 return "\n " .join ([self .value ] + self .options )
4151
@@ -55,27 +65,22 @@ def name(self) -> str:
5565 return self .value .split (":" )[0 ].split ()[- 1 ]
5666
5767 def to_user (self ) -> str | None :
58- if not self .exact () or not self .config ():
68+ kind = self .value .split ()[1 ]
69+ simple = self .name ().split ("." )[- 1 ]
70+ if kind != "(exact)" or simple != "Config" :
5971 return None
6072
61- def user (s : str ) -> str :
62- return s .replace (".Config" , ".UserConfig" )
63-
64- lines : list [str ] = [user (self .value )]
65- for s in self .fields :
66- field = user (s )
67- name = field .split ()[1 ]
68- if not name .endswith ("?" ):
69- field = field .replace (f" { name } " , f" { name } ? " )
70- lines .append (field )
73+ lines : list [str ] = []
74+ values = [self .value ] + self .fields
75+ for i , value in enumerate (values ):
76+ value = value .replace (".Config" , ".UserConfig" )
77+ if i > 0 :
78+ name = value .split ()[1 ]
79+ if not name .endswith ("?" ):
80+ value = value .replace (f" { name } " , f" { name } ? " )
81+ lines .append (value )
7182 return "\n " .join (lines )
7283
73- def exact (self ) -> bool :
74- return self .value .split ()[1 ] == "(exact)"
75-
76- def config (self ) -> bool :
77- return self .name ().split ("." )[- 1 ] == "Config"
78-
7984 def to_str (self ) -> str :
8085 return "\n " .join ([self .value ] + self .fields )
8186
@@ -91,14 +96,14 @@ def update_types(root: Path) -> None:
9196 files : list [Path ] = [root .joinpath ("init.lua" )]
9297 files .extend (sorted (root .joinpath ("config" ).iterdir ()))
9398
94- classes : list [str ] = ["---@meta" ]
95- for definition in get_definitions (files ):
96- user = definition .to_user ()
99+ sections : list [str ] = ["---@meta" ]
100+ for lua_type in get_lua_types (files ):
101+ user = lua_type .to_user ()
97102 if user is not None :
98- classes .append (user )
103+ sections .append (user )
99104
100105 types = root .joinpath ("types.lua" )
101- types .write_text ("\n \n " .join (classes ) + "\n " )
106+ types .write_text ("\n \n " .join (sections ) + "\n " )
102107
103108
104109def update_readme (root : Path ) -> None :
@@ -148,7 +153,7 @@ def update_handlers(root: Path) -> None:
148153 root .joinpath ("config/handlers.lua" ),
149154 root .joinpath ("lib/marks.lua" ),
150155 ]
151- name_lua = {lua .name (): lua for lua in get_definitions (files )}
156+ lua_types = {lua_type .name (): lua_type for lua_type in get_lua_types (files )}
152157 names = [
153158 "render.md.Handler" ,
154159 "render.md.handler.Context" ,
@@ -158,17 +163,17 @@ def update_handlers(root: Path) -> None:
158163 "render.md.mark.Text" ,
159164 "render.md.mark.Hl" ,
160165 ]
161- definitions = [name_lua [name ] for name in names ]
166+ sections = [lua_types [name ]. to_str () for name in names ]
162167
163168 handlers = Path ("doc/custom-handlers.md" )
164- old = get_code_block (handlers , definitions [0 ]. value , 1 )
165- new = "\n " .join ([ lua . to_str ( ) + "\n " for lua in definitions ])
169+ old = get_code_block (handlers , names [0 ], 1 )
170+ new = "\n \n " .join (sections ) + "\n "
166171 text = handlers .read_text ().replace (old , new )
167172 handlers .write_text (text )
168173
169174
170- def get_definitions (files : list [Path ]) -> list [LuaAlias | LuaClass ]:
171- result : list [LuaAlias | LuaClass ] = []
175+ def get_lua_types (files : list [Path ]) -> list [LuaType ]:
176+ result : list [LuaType ] = []
172177 for file in files :
173178 for comment in get_comments (file ):
174179 # ---@class md.Init: md.Api -> class
@@ -178,12 +183,12 @@ def get_definitions(files: list[Path]) -> list[LuaAlias | LuaClass]:
178183 # ---@type md.Config -> type
179184 # ---@param opts? md.UserConfig -> param
180185 # -- Inlined with 'image' elements -> --
181- annotation = comment .split ()[0 ].split ("@" )[- 1 ]
182- if annotation == "alias" :
186+ kind = comment .split ()[0 ].split ("@" )[- 1 ]
187+ if kind == "alias" :
183188 result .append (LuaAlias (comment ))
184- elif annotation == "class" :
189+ elif kind == "class" :
185190 result .append (LuaClass (comment ))
186- elif annotation in ["field" , "---|" ]:
191+ elif kind in ["field" , "---|" ]:
187192 result [- 1 ].add (comment )
188193 return result
189194
0 commit comments