|
| 1 | +function LinearAlgebra.mul!(x_temp::AbstractArray{T}, A::DerivativeOperator{T,N}, M::AbstractArray{T}) where {T<:Real,N} |
| 2 | + |
| 3 | + # Check that x_temp has correct dimensions |
| 4 | + v = zeros(ndims(x_temp)) |
| 5 | + v[N] = 2 |
| 6 | + @assert [size(x_temp)...]+v == [size(M)...] |
| 7 | + |
| 8 | + # Check that axis of differentiation is in the dimensions of M and x_temp |
| 9 | + ndimsM = ndims(M) |
| 10 | + @assert N <= ndimsM |
| 11 | + |
| 12 | + dimsM = [axes(M)...] |
| 13 | + alldims = [1:ndims(M);] |
| 14 | + otherdims = setdiff(alldims, N) |
| 15 | + |
| 16 | + idx = Any[first(ind) for ind in axes(M)] |
| 17 | + itershape = tuple(dimsM[otherdims]...) |
| 18 | + nidx = length(otherdims) |
| 19 | + indices = Iterators.drop(CartesianIndices(itershape), 0) |
| 20 | + |
| 21 | + setindex!(idx, :, N) |
| 22 | + for I in indices |
| 23 | + Base.replace_tuples!(nidx, idx, idx, otherdims, I) |
| 24 | + mul!(view(x_temp, idx...), A, view(M, idx...)) |
| 25 | + end |
| 26 | +end |
| 27 | + |
| 28 | +for MT in [2,3] |
| 29 | + @eval begin |
| 30 | + function LinearAlgebra.mul!(x_temp::AbstractArray{T,$MT}, A::DerivativeOperator{T,N,Wind,T2,S1}, M::AbstractArray{T,$MT}) where {T<:Real,N,Wind,T2,SL,S1<:SArray{Tuple{SL},T,1,SL}} |
| 31 | + |
| 32 | + # Check that x_temp has correct dimensions |
| 33 | + v = zeros(ndims(x_temp)) |
| 34 | + v[N] = 2 |
| 35 | + @assert [size(x_temp)...]+v == [size(M)...] |
| 36 | + |
| 37 | + # Check that axis of differentiation is in the dimensions of M and x_temp |
| 38 | + ndimsM = ndims(M) |
| 39 | + @assert N <= ndimsM |
| 40 | + |
| 41 | + # Respahe x_temp for NNlib.conv! |
| 42 | + new_size = Any[size(x_temp)...] |
| 43 | + bpc = A.boundary_point_count |
| 44 | + setindex!(new_size, new_size[N]- 2*bpc, N) |
| 45 | + new_shape = [] |
| 46 | + for i in 1:ndimsM |
| 47 | + if i != N |
| 48 | + push!(new_shape,:) |
| 49 | + else |
| 50 | + push!(new_shape,bpc+1:new_size[N]+bpc) |
| 51 | + end |
| 52 | + end |
| 53 | + _x_temp = reshape(view(x_temp, new_shape...), (new_size...,1,1)) |
| 54 | + |
| 55 | + # Reshape M for NNlib.conv! |
| 56 | + _M = reshape(M, (size(M)...,1,1)) |
| 57 | + s = A.stencil_coefs |
| 58 | + sl = A.stencil_length |
| 59 | + |
| 60 | + # Setup W, the kernel for NNlib.conv! |
| 61 | + Wdims = ones(Int64, ndims(_x_temp)) |
| 62 | + Wdims[N] = sl |
| 63 | + W = zeros(Wdims...) |
| 64 | + Widx = Any[Wdims...] |
| 65 | + setindex!(Widx,:,N) |
| 66 | + W[Widx...] = s ./ A.dx^A.derivative_order # this will change later |
| 67 | + cv = DenseConvDims(_M, W) |
| 68 | + |
| 69 | + conv!(_x_temp, _M, W, cv) |
| 70 | + |
| 71 | + # Now deal with boundaries |
| 72 | + dimsM = [axes(M)...] |
| 73 | + alldims = [1:ndims(M);] |
| 74 | + otherdims = setdiff(alldims, N) |
| 75 | + |
| 76 | + idx = Any[first(ind) for ind in axes(M)] |
| 77 | + itershape = tuple(dimsM[otherdims]...) |
| 78 | + nidx = length(otherdims) |
| 79 | + indices = Iterators.drop(CartesianIndices(itershape), 0) |
| 80 | + |
| 81 | + setindex!(idx, :, N) |
| 82 | + for I in indices |
| 83 | + Base.replace_tuples!(nidx, idx, idx, otherdims, I) |
| 84 | + convolve_BC_left!(view(x_temp, idx...), view(M, idx...), A) |
| 85 | + convolve_BC_right!(view(x_temp, idx...), view(M, idx...), A) |
| 86 | + end |
| 87 | + end |
| 88 | + end |
| 89 | +end |
| 90 | + |
| 91 | +function *(A::DerivativeOperator{T,N},M::AbstractArray{T}) where {T<:Real,N} |
| 92 | + size_x_temp = [size(M)...] |
| 93 | + size_x_temp[N] -= 2 |
| 94 | + x_temp = zeros(promote_type(eltype(A),eltype(M)), size_x_temp...) |
| 95 | + LinearAlgebra.mul!(x_temp, A, M) |
| 96 | + return x_temp |
| 97 | +end |
0 commit comments