|
157 | 157 | end |
158 | 158 | end |
159 | 159 |
|
160 | | -@generated function ifelse( |
161 | | - m::AbstractMask, |
| 160 | +@generated function _ifelse( |
| 161 | + m::Union{AbstractMask,VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}}, |
162 | 162 | x::ForwardDiff.Dual{TAG,V,P}, |
163 | 163 | y::ForwardDiff.Dual{TAG,V,P} |
164 | 164 | ) where {TAG,V,P} |
|
171 | 171 | ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p)) |
172 | 172 | end |
173 | 173 | end |
174 | | -@generated function ifelse( |
175 | | - m::AbstractMask, |
| 174 | +@generated function _ifelse( |
| 175 | + m::Union{AbstractMask,VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}}, |
176 | 176 | x::Number, |
177 | 177 | y::ForwardDiff.Dual{TAG,V,P} |
178 | 178 | ) where {TAG,V,P} |
|
184 | 184 | ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p)) |
185 | 185 | end |
186 | 186 | end |
187 | | -@generated function ifelse( |
188 | | - m::AbstractMask, |
| 187 | +@generated function _ifelse( |
| 188 | + m::Union{AbstractMask,VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}}, |
189 | 189 | x::ForwardDiff.Dual{TAG,V,P}, |
190 | 190 | y::Number |
191 | 191 | ) where {TAG,V,P} |
|
197 | 197 | ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p)) |
198 | 198 | end |
199 | 199 | end |
| 200 | +@inline ifelse(m::AbstractMask, x::ForwardDiff.Dual, y::Number) = _ifelse(m, x, y) |
| 201 | +@inline ifelse(m::AbstractMask, x::ForwardDiff.Dual, y::ForwardDiff.Dual) = _ifelse(m, x, y) |
| 202 | +@inline ifelse(m::AbstractMask, y::Number, x::ForwardDiff.Dual) = _ifelse(m, y, x) |
| 203 | + |
| 204 | +@inline ifelse(m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}, x::ForwardDiff.Dual, y::Number) = _ifelse(m, x, y) |
| 205 | +@inline ifelse(m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}, x::ForwardDiff.Dual, y::ForwardDiff.Dual) = _ifelse(m, x, y) |
| 206 | +@inline ifelse(m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}, y::Number, x::ForwardDiff.Dual) = _ifelse(m, y, x) |
| 207 | + |
200 | 208 | @inline function SLEEFPirates.softplus(x::ForwardDiff.Dual{TAG}) where {TAG} |
201 | 209 | val = ForwardDiff.value(x) |
202 | 210 | expx = exp(val) |
|
0 commit comments