1- """
2- @add_kwonly function_definition
3-
4- Define keyword-only version of the `function_definition`.
5-
6- @add_kwonly function f(x; y=1)
7- ...
8- end
9-
10- expands to:
11-
12- function f(x; y=1)
13- ...
14- end
15- function f(; x = error("No argument x"), y=1)
16- ...
17- end
18- """
19- macro add_kwonly (ex)
20- esc (add_kwonly (ex))
21- end
22-
23- add_kwonly (ex:: Expr ) = add_kwonly (Val{ex. head}, ex)
24-
25- function add_kwonly (:: Type{<:Val} , ex)
26- error (" add_only does not work with expression $(ex. head) " )
27- end
28-
29- function add_kwonly (:: Union {Type{Val{:function }},
30- Type{Val{:(= )}}}, ex:: Expr )
31- body = ex. args[2 : end ] # function body
32- default_call = ex. args[1 ] # e.g., :(f(a, b=2; c=3))
33- kwonly_call = add_kwonly (default_call)
34- if kwonly_call === nothing
35- return ex
36- end
37-
38- return quote
39- begin
40- $ ex
41- $ (Expr (ex. head, kwonly_call, body... ))
42- end
43- end
44- end
45-
46- function add_kwonly (:: Type{Val{:where}} , ex:: Expr )
47- default_call = ex. args[1 ]
48- rest = ex. args[2 : end ]
49- kwonly_call = add_kwonly (default_call)
50- if kwonly_call === nothing
51- return nothing
52- end
53- return Expr (:where , kwonly_call, rest... )
54- end
55-
56- function add_kwonly (:: Type{Val{:call}} , default_call:: Expr )
57- # default_call is, e.g., :(f(a, b=2; c=3))
58- funcname = default_call. args[1 ] # e.g., :f
59- required = [] # required positional arguments; e.g., [:a]
60- optional = [] # optional positional arguments; e.g., [:(b=2)]
61- default_kwargs = []
62- for arg in default_call. args[2 : end ]
63- if isa (arg, Symbol)
64- push! (required, arg)
65- elseif arg. head == :(:: )
66- push! (required, arg)
67- elseif arg. head == :kw
68- push! (optional, arg)
69- elseif arg. head == :parameters
70- @assert default_kwargs == [] # can I have :parameters twice?
71- default_kwargs = arg. args
72- else
73- error (" Not expecting to see: $arg " )
74- end
75- end
76- if isempty (required) && isempty (optional)
77- # If the function is already keyword-only, do nothing:
78- return nothing
79- end
80- if isempty (required)
81- # It's not clear what should be done. Let's not support it at
82- # the moment:
83- error (" At least one positional mandatory argument is required." )
84- end
85-
86- kwonly_kwargs = Expr (:parameters ,
87- [Expr (:kw , pa, :(error ($ (" No argument $pa " ))))
88- for pa in required]. .. , optional... , default_kwargs... )
89- kwonly_call = Expr (:call , funcname, kwonly_kwargs)
90- # e.g., :(f(; a=error(...), b=error(...), c=1, d=2))
91-
92- return kwonly_call
93- end
94-
95- function num_types_in_tuple (sig)
96- length (sig. parameters)
97- end
98-
99- function num_types_in_tuple (sig:: UnionAll )
100- length (Base. unwrap_unionall (sig). parameters)
101- end
1021
1032@inline UNITLESS_ABS2 (x) = real (abs2 (x))
1043@inline DEFAULT_NORM (u:: Union{AbstractFloat, Complex} ) = @fastmath abs (u)
1054@inline function DEFAULT_NORM (u:: Array{T} ) where {T <: Union{AbstractFloat, Complex} }
1065 sqrt (real (sum (abs2, u)) / length (u))
1076end
108- @inline function DEFAULT_NORM (u:: StaticArray{T} ) where {T <: Union{AbstractFloat, Complex} }
7+ @inline function DEFAULT_NORM (u:: StaticArraysCore.StaticArray{T} ) where {
8+ T <: Union {
9+ AbstractFloat,
10+ Complex}}
10911 sqrt (real (sum (abs2, u)) / length (u))
11012end
11113@inline function DEFAULT_NORM (u:: RecursiveArrayTools.AbstractVectorOfArray )
11416@inline DEFAULT_NORM (u:: AbstractArray ) = sqrt (real (sum (UNITLESS_ABS2, u)) / length (u))
11517@inline DEFAULT_NORM (u) = norm (u)
11618
117- """
118- prevfloat_tdir(x, x0, x1)
119-
120- Move `x` one floating point towards x0.
121- """
122- function prevfloat_tdir (x, x0, x1)
123- x1 > x0 ? prevfloat (x) : nextfloat (x)
124- end
125-
126- function nextfloat_tdir (x, x0, x1)
127- x1 > x0 ? nextfloat (x) : prevfloat (x)
128- end
129-
130- function max_tdir (a, b, x0, x1)
131- x1 > x0 ? max (a, b) : min (a, b)
132- end
133-
13419alg_autodiff (alg:: AbstractNewtonAlgorithm{CS, AD} ) where {CS, AD} = AD
13520alg_autodiff (alg) = false
13621
@@ -146,15 +31,14 @@ function value_derivative(f::F, x::R) where {F, R}
14631end
14732
14833# Todo: improve this dispatch
149- value_derivative (f:: F , x:: SVector ) where {F} = f (x), ForwardDiff. jacobian (f, x)
34+ function value_derivative (f:: F , x:: StaticArraysCore.SVector ) where {F}
35+ f (x), ForwardDiff. jacobian (f, x)
36+ end
15037
15138value (x) = x
15239value (x:: Dual ) = ForwardDiff. value (x)
15340value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
15441
155- _unwrap_val (:: Val{B} ) where {B} = B
156- _unwrap_val (B) = B
157-
15842_vec (v) = vec (v)
15943_vec (v:: Number ) = v
16044_vec (v:: AbstractVector ) = v
0 commit comments