Skip to content

Commit a2d9efa

Browse files
fix: make reorder_parameters more type-stable
1 parent 9fa324b commit a2d9efa

File tree

1 file changed

+43
-52
lines changed

1 file changed

+43
-52
lines changed

src/systems/index_cache.jl

Lines changed: 43 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -486,54 +486,62 @@ end
486486
function reorder_parameters(
487487
sys::AbstractSystem, ps = parameters(sys; initial_parameters = true); kwargs...)
488488
if has_index_cache(sys) && get_index_cache(sys) !== nothing
489-
reorder_parameters(get_index_cache(sys), ps; kwargs...)
489+
reorder_parameters(get_index_cache(sys)::IndexCache, ps; kwargs...)
490490
elseif ps isa Tuple
491491
ps
492492
else
493493
(ps,)
494494
end
495495
end
496496

497-
function reorder_parameters(ic::IndexCache, ps; drop_missing = false, flatten = true)
497+
const COMMON_DEFAULT_VAR = unwrap(only(@variables __DEF__))
498+
499+
function reorder_parameters(ic::IndexCache, ps::Vector{SymbolicT}; drop_missing = false, flatten = true)
498500
isempty(ps) && return ()
499-
param_buf = if ic.tunable_buffer_size.length == 0
500-
()
501-
else
502-
(BasicSymbolic[unwrap(variable(:DEF))
503-
for _ in 1:(ic.tunable_buffer_size.length)],)
501+
result = Vector{Union{Vector{SymbolicT}, Vector{Vector{SymbolicT}}}}()
502+
param_buf = fill(COMMON_DEFAULT_VAR, ic.tunable_buffer_size.length)
503+
push!(result, param_buf)
504+
initials_buf = fill(COMMON_DEFAULT_VAR, ic.initials_buffer_size.length)
505+
push!(result, initials_buf)
506+
507+
disc_buf = Vector{SymbolicT}[]
508+
for bufszs in ic.discrete_buffer_sizes
509+
push!(disc_buf, fill(COMMON_DEFAULT_VAR, sum(x -> x.length, bufszs)))
510+
end
511+
const_buf = Vector{SymbolicT}[]
512+
for bufsz in ic.constant_buffer_sizes
513+
push!(const_buf, fill(COMMON_DEFAULT_VAR, bufsz.length))
514+
end
515+
nonnumeric_buf = Vector{SymbolicT}[]
516+
for bufsz in ic.nonnumeric_buffer_sizes
517+
push!(nonnumeric_buf, fill(COMMON_DEFAULT_VAR, bufsz.length))
504518
end
505-
initials_buf = if ic.initials_buffer_size.length == 0
506-
()
519+
if flatten
520+
append!(result, disc_buf)
521+
append!(result, const_buf)
522+
append!(result, nonnumeric_buf)
507523
else
508-
(BasicSymbolic[unwrap(variable(:DEF))
509-
for _ in 1:(ic.initials_buffer_size.length)],)
524+
push!(result, disc_buf)
525+
push!(result, const_buf)
526+
push!(result, nonnumeric_buf)
510527
end
511-
512-
disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF))
513-
for _ in 1:(sum(x -> x.length, temp))]
514-
for temp in ic.discrete_buffer_sizes)
515-
const_buf = Tuple(SymbolicT[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
516-
for temp in ic.constant_buffer_sizes)
517-
nonnumeric_buf = Tuple(SymbolicT[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
518-
for temp in ic.nonnumeric_buffer_sizes)
519528
for p in ps
520-
p = unwrap(p)
521529
if haskey(ic.discrete_idx, p)
522530
idx = ic.discrete_idx[p]
523531
disc_buf[idx.buffer_idx][idx.idx_in_buffer] = p
524532
elseif haskey(ic.tunable_idx, p)
525533
i = ic.tunable_idx[p]
526534
if i isa Int
527-
param_buf[1][i] = unwrap(p)
535+
param_buf[i] = p
528536
else
529-
param_buf[1][i] = unwrap.(collect(p))
537+
param_buf[i] = collect(p)
530538
end
531539
elseif haskey(ic.initials_idx, p)
532540
i = ic.initials_idx[p]
533541
if i isa Int
534-
initials_buf[1][i] = unwrap(p)
542+
initials_buf[i] = p
535543
else
536-
initials_buf[1][i] = unwrap.(collect(p))
544+
initials_buf[i] = collect(p)
537545
end
538546
elseif haskey(ic.constant_idx, p)
539547
i, j = ic.constant_idx[p]
@@ -546,37 +554,20 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false, flatten =
546554
end
547555
end
548556

549-
param_buf = broadcast.(unwrap, param_buf)
550-
initials_buf = broadcast.(unwrap, initials_buf)
551-
disc_buf = broadcast.(unwrap, disc_buf)
552-
const_buf = broadcast.(unwrap, const_buf)
553-
nonnumeric_buf = broadcast.(unwrap, nonnumeric_buf)
554-
555557
if drop_missing
556-
filterer = !isequal(unwrap(variable(:DEF)))
557-
param_buf = filter.(filterer, param_buf)
558-
initials_buf = filter.(filterer, initials_buf)
559-
disc_buf = filter.(filterer, disc_buf)
560-
const_buf = filter.(filterer, const_buf)
561-
nonnumeric_buf = filter.(filterer, nonnumeric_buf)
562-
end
563-
564-
if flatten
565-
result = (
566-
param_buf..., initials_buf..., disc_buf..., const_buf..., nonnumeric_buf...)
567-
if all(isempty, result)
568-
return ()
569-
end
570-
return result
571-
else
572-
if isempty(param_buf)
573-
param_buf = ((),)
574-
end
575-
if isempty(initials_buf)
576-
initials_buf = ((),)
558+
filterer = !isequal(COMMON_DEFAULT_VAR)
559+
for inner in result
560+
if inner isa Vector{SymbolicT}
561+
filter!(filterer, inner)
562+
elseif inner isa Vector{Vector{SymbolicT}}
563+
for buf in inner
564+
filter!(filterer, buf)
565+
end
566+
end
577567
end
578-
return (param_buf..., initials_buf..., disc_buf, const_buf, nonnumeric_buf)
579568
end
569+
570+
return result
580571
end
581572

582573
# Given a parameter index, find the index of the buffer it is in when

0 commit comments

Comments
 (0)