Skip to content

Commit 125d943

Browse files
committed
add ---@cast
1 parent 163da75 commit 125d943

File tree

9 files changed

+242
-27
lines changed

9 files changed

+242
-27
lines changed

changelog.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
local x = true
2424
local y = x--[[@as integer]] -- y is `integer` here
2525
```
26+
* `NEW` add `---@cast`
27+
* `---@cast localname type`
28+
* `---@cast localname +type`
29+
* `---@cast localname -type`
30+
* `---@cast localname +?`
31+
* `---@cast localname -?`
2632
* `NEW` generic: resolve `T[]` by `table<integer, type>` or `---@field [integer] type`
2733
* `NEW` resolve `class[1]` by `---@field [integer] type`
2834
* `NEW` diagnostic: `missing-parameter`

script/core/diagnostics/global-in-nil-env.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ return function (uri, callback)
1616
local env = guide.getENV(root)
1717

1818
local nilDefs = {}
19-
if not env.ref then
19+
if not env or not env.ref then
2020
return
2121
end
2222
for _, ref in ipairs(env.ref) do

script/parser/guide.lua

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ local type = type
5858
---@field step parser.object
5959
---@field redundant { max: integer, passed: integer }
6060
---@field filter parser.object
61+
---@field loc string
62+
---@field keyword integer[]
63+
---@field casts parser.object[]
64+
---@field mode? '+' | '-'
6165
---@field hasGoTo? true
6266
---@field hasReturn? true
6367
---@field hasBreak? true
@@ -148,6 +152,8 @@ local childMap = {
148152
['doc.version'] = {'#versions'},
149153
['doc.diagnostic'] = {'#names'},
150154
['doc.as'] = {'as'},
155+
['doc.cast'] = {'loc', '#casts'},
156+
['doc.cast.block'] = {'extends'},
151157
}
152158

153159
---@type table<string, fun(obj: parser.object, list: parser.object[])>
@@ -420,7 +426,7 @@ function m.getUri(obj)
420426
return ''
421427
end
422428

423-
---@return parser.object
429+
---@return parser.object?
424430
function m.getENV(source, start)
425431
if not start then
426432
start = 1
@@ -454,19 +460,17 @@ function m.getFunctionVarArgs(func)
454460
end
455461

456462
--- 获取指定区块中可见的局部变量
457-
---@param block table
458-
---@param name string {comment = '变量名'}
459-
---@param pos integer {comment = '可见位置'}
460-
function m.getLocal(block, name, pos)
461-
block = m.getBlock(block)
462-
for _ = 1, 10000 do
463-
if not block then
464-
return nil
465-
end
466-
local locals = block.locals
467-
local res
463+
---@param source parser.object
464+
---@param name string # 变量名
465+
---@param pos integer # 可见位置
466+
---@return parser.object?
467+
function m.getLocal(source, name, pos)
468+
local root = m.getRoot(source)
469+
local res
470+
m.eachSourceContain(root, pos, function (src)
471+
local locals = src.locals
468472
if not locals then
469-
goto CONTINUE
473+
return
470474
end
471475
for i = 1, #locals do
472476
local loc = locals[i]
@@ -479,13 +483,8 @@ function m.getLocal(block, name, pos)
479483
end
480484
end
481485
end
482-
if res then
483-
return res, res
484-
end
485-
::CONTINUE::
486-
block = m.getParentBlock(block)
487-
end
488-
error('guide.getLocal overstack')
486+
end)
487+
return res
489488
end
490489

491490
--- 获取指定区块中所有的可见局部变量名称
@@ -610,6 +609,9 @@ local function addChilds(list, obj)
610609
end
611610

612611
--- 遍历所有包含position的source
612+
---@param ast parser.object
613+
---@param position integer
614+
---@param callback fun(src: parser.object)
613615
function m.eachSourceContain(ast, position, callback)
614616
local list = { ast }
615617
local mark = {}

script/parser/luadoc.lua

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Symbol <- ({} {
5353
/ '...'
5454
/ '['
5555
/ ']'
56+
/ '-' !'-'
5657
} {})
5758
-> Symbol
5859
]], {
@@ -1205,6 +1206,70 @@ local docSwitch = util.switch()
12051206
result.finish = getFinish()
12061207
return result
12071208
end)
1209+
: case 'cast'
1210+
: call(function ()
1211+
local result = {
1212+
type = 'doc.cast',
1213+
start = getFinish(),
1214+
finish = getFinish(),
1215+
casts = {},
1216+
}
1217+
1218+
local loc = parseName('doc.cast.name', result)
1219+
if not loc then
1220+
pushWarning {
1221+
type = 'LUADOC_MISS_LOCAL_NAME',
1222+
start = getFinish(),
1223+
finish = getFinish(),
1224+
}
1225+
return result
1226+
end
1227+
1228+
result.loc = loc
1229+
result.finish = loc.finish
1230+
1231+
while true do
1232+
local block = {
1233+
type = 'doc.cast.block',
1234+
parent = result,
1235+
start = getFinish(),
1236+
finish = getFinish(),
1237+
}
1238+
result.casts[#result.casts+1] = block
1239+
if checkToken('symbol', '+', 1) then
1240+
block.mode = '+'
1241+
nextToken()
1242+
block.start = getStart()
1243+
block.finish = getFinish()
1244+
elseif checkToken('symbol', '-', 1) then
1245+
block.mode = '-'
1246+
nextToken()
1247+
block.start = getStart()
1248+
block.finish = getFinish()
1249+
end
1250+
1251+
if checkToken('symbol', '?', 1) then
1252+
block.optional = true
1253+
nextToken()
1254+
block.start = block.start or getStart()
1255+
block.finish = block.finish
1256+
else
1257+
block.extends = parseType(block)
1258+
if block.extends then
1259+
block.start = block.start or block.extends.start
1260+
block.finish = block.extends.finish
1261+
end
1262+
end
1263+
1264+
if checkToken('symbol', ',', 1) then
1265+
nextToken()
1266+
else
1267+
break
1268+
end
1269+
end
1270+
1271+
return result
1272+
end)
12081273

12091274
local function convertTokens()
12101275
local tp, text = nextToken()
@@ -1313,6 +1378,9 @@ local function isNextLine(binded, doc)
13131378
return false
13141379
end
13151380
end
1381+
if doc.type == 'doc.cast' then
1382+
return false
1383+
end
13161384
local lastRow = guide.rowColOf(lastDoc.finish)
13171385
local newRow = guide.rowColOf(doc.start)
13181386
return newRow - lastRow == 1

script/parser/newparser.lua

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -691,9 +691,6 @@ local function parseLocalAttrs()
691691
end
692692

693693
local function createLocal(obj, attrs)
694-
if not obj then
695-
return nil
696-
end
697694
obj.type = 'local'
698695
obj.effect = obj.finish
699696

@@ -2893,7 +2890,11 @@ local function parseLocal()
28932890
pushActionIntoCurrentChunk(loc)
28942891
skipSpace()
28952892
parseMultiVars(loc, parseName, true)
2896-
loc.effect = lastRightPosition()
2893+
if loc.value then
2894+
loc.effect = loc.value.finish
2895+
else
2896+
loc.effect = loc.finish
2897+
end
28972898

28982899
return loc
28992900
end

script/vm/global-manager.lua

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,9 @@ end
358358
---@param source parser.object
359359
function m.compileAst(source)
360360
local env = guide.getENV(source)
361+
if not env then
362+
return
363+
end
361364
m.compileObject(env)
362365
guide.eachSpecialOf(source, 'rawset', function (src)
363366
m.compileObject(src.parent)

script/vm/node.lua

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,15 @@ function mt:remove(name)
204204
end
205205
end
206206

207+
---@param node vm.node
208+
function mt:removeNode(node)
209+
for _, c in ipairs(node) do
210+
if c.type == 'global' and c.cate == 'type' then
211+
self:remove(c.name)
212+
end
213+
end
214+
end
215+
207216
---@return fun():vm.object
208217
function mt:eachObject()
209218
local i = 0

script/vm/runner.lua

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@ mt.__index = mt
1313
mt.index = 1
1414

1515
---@class parser.object
16-
---@field _hasSorted boolean
16+
---@field _casts parser.object[]
1717

1818
---@class vm.runner.step
19-
---@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge'
19+
---@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge' | 'cast'
2020
---@field pos integer
2121
---@field order? integer
2222
---@field node? vm.node
2323
---@field object? parser.object
2424
---@field name? string
25+
---@field cast? parser.object
2526
---@field tag? string
2627
---@field copy? boolean
2728
---@field new? boolean
@@ -250,21 +251,58 @@ function mt:_compileBlock(block)
250251
end
251252
end
252253

254+
---@return parser.object[]
255+
function mt:_getCasts()
256+
local root = guide.getRoot(self.loc)
257+
if not root._casts then
258+
root._casts = {}
259+
local docs = root.docs
260+
for _, doc in ipairs(docs) do
261+
if doc.type == 'doc.cast' and doc.loc then
262+
root._casts[#root._casts+1] = doc
263+
end
264+
end
265+
end
266+
return root._casts
267+
end
268+
253269
function mt:_preCompile()
270+
local startPos = self.loc.start
271+
local finishPos = 0
272+
254273
for _, ref in ipairs(self.loc.ref) do
255274
self.steps[#self.steps+1] = {
256275
type = 'object',
257276
object = ref,
258277
pos = ref.range or ref.start,
259278
}
279+
if ref.start > finishPos then
280+
finishPos = ref.start
281+
end
260282
local block = guide.getParentBlock(ref)
261283
self:_compileBlock(block)
262284
end
285+
263286
for i, step in ipairs(self.steps) do
264287
if step.type ~= 'object' then
265288
step.order = i
266289
end
267290
end
291+
292+
local casts = self:_getCasts()
293+
for _, cast in ipairs(casts) do
294+
if cast.loc[1] == self.loc[1]
295+
and cast.start > startPos
296+
and cast.finish < finishPos
297+
and guide.getLocal(self.loc, self.loc[1], cast.start) == self.loc then
298+
self.steps[#self.steps+1] = {
299+
type = 'cast',
300+
cast = cast,
301+
pos = cast.start,
302+
}
303+
end
304+
end
305+
268306
table.sort(self.steps, function (a, b)
269307
if a.pos == b.pos then
270308
return (a.order or 0) < (b.order or 0)
@@ -363,6 +401,30 @@ function mt:launch(callback)
363401
topNode = node
364402
elseif step.type == 'merge' then
365403
node:merge(step.ref2.node)
404+
elseif step.type == 'cast' then
405+
topNode = node:copy()
406+
for _, cast in ipairs(step.cast.casts) do
407+
if cast.mode == '+' then
408+
if cast.optional then
409+
topNode:addOptional()
410+
end
411+
if cast.extends then
412+
topNode:merge(vm.compileNode(cast.extends))
413+
end
414+
elseif cast.mode == '-' then
415+
if cast.optional then
416+
topNode:removeOptional()
417+
end
418+
if cast.extends then
419+
topNode:removeNode(vm.compileNode(cast.extends))
420+
end
421+
else
422+
if cast.extends then
423+
topNode:clear()
424+
topNode:merge(vm.compileNode(cast.extends))
425+
end
426+
end
427+
end
366428
end
367429
end
368430
end

0 commit comments

Comments
 (0)