@@ -150,8 +150,11 @@ function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix)
150150 function ColVecs_pullback (:: AbstractVector{<:AbstractVector{<:Real}} )
151151 return error (
152152 " Pullback on AbstractVector{<:AbstractVector}.\n " *
153- " This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n " *
154- " To solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`" ,
153+ " This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n " *
154+ " or because some external computation has acted on `ColVecs` to produce a vector of vectors." *
155+ " In the former case, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `ColVecs`." *
156+ " In the latter case, one needs to track down the `rrule` whose pullback returns a `Vector{Vector{T}}`," *
157+ " rather than a `Tangent`, as the cotangent / gradient for `ColVecs` input, and circumvent it."
155158 )
156159 end
157160 return ColVecs (X), ColVecs_pullback
@@ -162,8 +165,9 @@ function ChainRulesCore.rrule(::Type{<:RowVecs}, X::AbstractMatrix)
162165 function RowVecs_pullback (:: AbstractVector{<:AbstractVector{<:Real}} )
163166 return error (
164167 " Pullback on AbstractVector{<:AbstractVector}.\n " *
165- " This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`.\n " *
166- " To solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`" ,
168+ " This might happen if you try to use gradients on the generic `kernelmatrix` or `kernelmatrix_diag`,\n " *
169+ " or because some external computation has acted on `RowVecs` to produce a vector of vectors." *
170+ " If it is the former, to solve this issue overload `kernelmatrix(_diag)` for your kernel for `RowVecs`" ,
167171 )
168172 end
169173 return RowVecs (X), RowVecs_pullback
0 commit comments