Skip to content

Commit 34de616

Browse files
committed
支持根据参数类型选择函数定义
1 parent 6fb7e18 commit 34de616

File tree

6 files changed

+96
-11
lines changed

6 files changed

+96
-11
lines changed

changelog.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
# changelog
22

3-
## 3.8.4
3+
## 3.9.0
44
* `NEW` goto implementation
5+
* `NEW` narrow the function prototype based on the parameter type
6+
```lua
7+
---@overload fun(a: boolean): A
8+
---@overload fun(a: number): B
9+
local function f(...) end
10+
11+
local r1 = f(true) --> r1 is `A`
12+
local r2 = f(10) --> r2 is `B`
13+
```
514

615
## 3.8.3
716
`2024-4-23`

script/vm/compiler.lua

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,11 +550,14 @@ local function matchCall(source)
550550
or call.node ~= source then
551551
return
552552
end
553-
local funcs = vm.getMatchedFunctions(source, call.args)
554553
local myNode = vm.getNode(source)
555554
if not myNode then
556555
return
557556
end
557+
local funcs = vm.getExactMatchedFunctions(source, call.args)
558+
if not funcs then
559+
return
560+
end
558561
local needRemove
559562
for n in myNode:eachObject() do
560563
if n.type == 'function'

script/vm/function.lua

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,9 @@ end
267267
---@return integer def
268268
function vm.countReturnsOfCall(func, args, mark)
269269
local funcs = vm.getMatchedFunctions(func, args, mark)
270+
if not funcs then
271+
return 0, math.huge, 0
272+
end
270273
---@type integer?, number?, integer?
271274
local min, max, def
272275
for _, f in ipairs(funcs) do
@@ -329,10 +332,52 @@ function vm.countList(list, mark)
329332
return min, max, def
330333
end
331334

335+
---@param uri uri
336+
---@param args parser.object[]
337+
---@return boolean
338+
local function isAllParamMatched(uri, args, params)
339+
if not params then
340+
return false
341+
end
342+
for i = 1, #args do
343+
if not params[i] then
344+
break
345+
end
346+
local argNode = vm.compileNode(args[i])
347+
local defNode = vm.compileNode(params[i])
348+
if not vm.canCastType(uri, defNode, argNode) then
349+
return false
350+
end
351+
end
352+
return true
353+
end
354+
332355
---@param func parser.object
333-
---@param args parser.object[]?
356+
---@param args? parser.object[]
357+
---@return parser.object[]?
358+
function vm.getExactMatchedFunctions(func, args)
359+
local funcs = vm.getMatchedFunctions(func, args)
360+
if not args or not funcs then
361+
return funcs
362+
end
363+
local uri = guide.getUri(func)
364+
local result = {}
365+
for _, n in ipairs(funcs) do
366+
if not vm.isVarargFunctionWithOverloads(n)
367+
and isAllParamMatched(uri, args, n.args) then
368+
result[#result+1] = n
369+
end
370+
end
371+
if #result == 0 then
372+
return nil
373+
end
374+
return result
375+
end
376+
377+
---@param func parser.object
378+
---@param args? parser.object[]
334379
---@param mark? table
335-
---@return parser.object[]
380+
---@return parser.object[]?
336381
function vm.getMatchedFunctions(func, args, mark)
337382
local funcs = {}
338383
local node = vm.compileNode(func)
@@ -342,9 +387,6 @@ function vm.getMatchedFunctions(func, args, mark)
342387
funcs[#funcs+1] = n
343388
end
344389
end
345-
if #funcs <= 1 then
346-
return funcs
347-
end
348390

349391
local amin, amax = vm.countList(args, mark)
350392

@@ -357,7 +399,7 @@ function vm.getMatchedFunctions(func, args, mark)
357399
end
358400

359401
if #matched == 0 then
360-
return funcs
402+
return nil
361403
else
362404
return matched
363405
end

script/vm/infer.lua

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,6 @@ local viewNodeSwitch;viewNodeSwitch = util.switch()
242242
return vm.viewKey(source, uri)
243243
end)
244244

245-
---@class vm.node
246-
---@field lastInfer? vm.infer
247-
248245
---@param node? vm.node
249246
---@return vm.infer
250247
local function createInfer(node)

script/vm/node.lua

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ vm.nodeCache = setmetatable({}, util.MODE_K)
1616
---@field [vm.node.object] true
1717
---@field fields? table<vm.node|string, vm.node>
1818
---@field undefinedGlobal boolean?
19+
---@field lastInfer? vm.infer
1920
local mt = {}
2021
mt.__index = mt
2122
mt.id = 0
@@ -31,6 +32,7 @@ function mt:merge(node)
3132
if not node then
3233
return self
3334
end
35+
self.lastInfer = nil
3436
if node.type == 'vm.node' then
3537
if node == self then
3638
return self

test/type_inference/param_match.lua

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,35 @@ local r1
105105
106106
local <?x?> = f(r1())
107107
]]
108+
109+
TEST '1' [[
110+
---@overload fun(a: 'x'): 1
111+
---@overload fun(a: 'y'): 2
112+
local function f(...) end
113+
114+
local <?r?> = f('x')
115+
]]
116+
117+
TEST '2' [[
118+
---@overload fun(a: 'x'): 1
119+
---@overload fun(a: 'y'): 2
120+
local function f(...) end
121+
122+
local <?r?> = f('y')
123+
]]
124+
125+
TEST '1' [[
126+
---@overload fun(a: boolean): 1
127+
---@overload fun(a: number): 2
128+
local function f(...) end
129+
130+
local <?r?> = f(true)
131+
]]
132+
133+
TEST '2' [[
134+
---@overload fun(a: boolean): 1
135+
---@overload fun(a: number): 2
136+
local function f(...) end
137+
138+
local <?r?> = f(10)
139+
]]

0 commit comments

Comments
 (0)