Skip to content

Commit 1e1be7b

Browse files
committed
Remove unnecessary consistency checks for VarNamedVector
1 parent 9a2607b commit 1e1be7b

File tree

3 files changed

+146
-67
lines changed

3 files changed

+146
-67
lines changed

benchmarks/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ LogDensityProblems = "2.1.2"
3030
Mooncake = "0.4"
3131
PrettyTables = "3"
3232
ReverseDiff = "1.15.3"
33-
StableRNGs = "1"
33+
StableRNGs = "1"

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
8080
retvals = model(rng)
8181
vns = [VarName{k}() for k in keys(retvals)]
8282
SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals))))
83+
elseif varinfo_choice == :typed_vector
84+
DynamicPPL.typed_vector_varinfo(rng, model)
85+
elseif varinfo_choice == :untyped_vector
86+
DynamicPPL.typed_vector_varinfo(rng, model)
8387
else
8488
error("Unknown varinfo choice: $varinfo_choice")
8589
end

src/varnamedvector.jl

Lines changed: 141 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
const CHECK_CONSISENCEY_DEFAULT = true
2+
13
"""
24
VarNamedVector
35
@@ -40,6 +42,11 @@ contents of the internal storage quickly with `getindex_internal(vnv, :)`. The o
4042
of `VarNamedVector` are mostly used to keep track of which part of the internal storage
4143
belongs to which `VarName`.
4244
45+
All constructors accept a keyword argument `check_consistency::Bool=true` that controls
46+
whether to run checks like the number of values matching the number of variables. Some of
47+
these checks can be expensive, so if you are confident in the input, you may want to turn
48+
`check_consistency` off for performance.
49+
4350
# Fields
4451
4552
$(FIELDS)
@@ -184,68 +191,71 @@ struct VarNamedVector{
184191
vals::TVal,
185192
transforms::TTrans,
186193
is_unconstrained=fill!(BitVector(undef, length(varnames)), 0),
187-
num_inactive=OrderedDict{Int,Int}(),
194+
num_inactive=OrderedDict{Int,Int}();
195+
check_consistency::Bool=CHECK_CONSISENCEY_DEFAULT,
188196
) where {K,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector}
189-
if length(varnames) != length(ranges) ||
190-
length(varnames) != length(transforms) ||
191-
length(varnames) != length(is_unconstrained) ||
192-
length(varnames) != length(varname_to_index)
193-
msg = (
194-
"Inputs to VarNamedVector have inconsistent lengths. Got lengths varnames: " *
195-
"$(length(varnames)), ranges: " *
196-
"$(length(ranges)), " *
197-
"transforms: $(length(transforms)), " *
198-
"is_unconstrained: $(length(is_unconstrained)), " *
199-
"varname_to_index: $(length(varname_to_index))."
200-
)
201-
throw(ArgumentError(msg))
202-
end
197+
if check_consistency
198+
if length(varnames) != length(ranges) ||
199+
length(varnames) != length(transforms) ||
200+
length(varnames) != length(is_unconstrained) ||
201+
length(varnames) != length(varname_to_index)
202+
msg = (
203+
"Inputs to VarNamedVector have inconsistent lengths. " *
204+
"Got lengths varnames: $(length(varnames)), " *
205+
"ranges: $(length(ranges)), " *
206+
"transforms: $(length(transforms)), " *
207+
"is_unconstrained: $(length(is_unconstrained)), " *
208+
"varname_to_index: $(length(varname_to_index))."
209+
)
210+
throw(ArgumentError(msg))
211+
end
203212

204-
num_vals = mapreduce(length, (+), ranges; init=0) + sum(values(num_inactive))
205-
if num_vals != length(vals)
206-
msg = (
207-
"The total number of elements in `vals` ($(length(vals))) does not match " *
208-
"the sum of the lengths of the ranges and the number of inactive entries " *
209-
"($(num_vals))."
210-
)
211-
throw(ArgumentError(msg))
212-
end
213+
num_vals = mapreduce(length, (+), ranges; init=0) + sum(values(num_inactive))
214+
if num_vals != length(vals)
215+
msg = (
216+
"The total number of elements in `vals` ($(length(vals))) does not " *
217+
"match the sum of the lengths of the ranges and the number of " *
218+
"inactive entries ($(num_vals))."
219+
)
220+
throw(ArgumentError(msg))
221+
end
213222

214-
if Set(values(varname_to_index)) != Set(axes(varnames, 1))
215-
msg = (
216-
"The set of values of `varname_to_index` is not the set of valid indices " *
217-
"for `varnames`."
218-
)
219-
throw(ArgumentError(msg))
220-
end
223+
if Set(values(varname_to_index)) != Set(axes(varnames, 1))
224+
msg = (
225+
"The set of values of `varname_to_index` is not the set of valid " *
226+
"indices for `varnames`."
227+
)
228+
throw(ArgumentError(msg))
229+
end
221230

222-
if !issubset(Set(keys(num_inactive)), Set(values(varname_to_index)))
223-
msg = (
224-
"The keys of `num_inactive` are not a subset of the values of " *
225-
"`varname_to_index`."
226-
)
227-
throw(ArgumentError(msg))
228-
end
231+
if !issubset(Set(keys(num_inactive)), Set(values(varname_to_index)))
232+
msg = (
233+
"The keys of `num_inactive` are not a subset of the values of " *
234+
"`varname_to_index`."
235+
)
236+
throw(ArgumentError(msg))
237+
end
229238

230-
# Check that the varnames don't overlap. The time cost is quadratic in number of
231-
# variables. If this ever becomes an issue, we should be able to go down to at least
232-
# N log N by sorting based on subsumes-order.
233-
for vn1 in keys(varname_to_index)
234-
for vn2 in keys(varname_to_index)
235-
vn1 === vn2 && continue
236-
if subsumes(vn1, vn2)
237-
msg = (
238-
"Variables in a VarNamedVector should not subsume each other, " *
239-
"but $vn1 subsumes $vn2, i.e. $vn2 describes a subrange of $vn1."
240-
)
241-
throw(ArgumentError(msg))
239+
# Check that the varnames don't overlap. The time cost is quadratic in number of
240+
# variables. If this ever becomes an issue, we should be able to go down to at
241+
# least N log N by sorting based on subsumes-order.
242+
for vn1 in keys(varname_to_index)
243+
for vn2 in keys(varname_to_index)
244+
vn1 === vn2 && continue
245+
if subsumes(vn1, vn2)
246+
msg = (
247+
"Variables in a VarNamedVector should not subsume each " *
248+
"other, but $vn1 subsumes $vn2."
249+
)
250+
throw(ArgumentError(msg))
251+
end
242252
end
243253
end
244-
end
245254

246-
# We could also have a test to check that the ranges don't overlap, but that sounds
247-
# unlikely to occur, and implementing it in linear time would require a tiny bit of
248-
# thought.
255+
# We could also have a test to check that the ranges don't overlap, but that
256+
# sounds unlikely to occur, and implementing it in linear time would require a
257+
# tiny bit of thought.
258+
end
249259

250260
return new{K,V,TVN,TVal,TTrans}(
251261
varname_to_index,
@@ -260,7 +270,9 @@ struct VarNamedVector{
260270
end
261271

262272
function VarNamedVector{K,V}() where {K,V}
263-
return VarNamedVector(OrderedDict{K,Int}(), K[], UnitRange{Int}[], V[], Any[])
273+
return VarNamedVector(
274+
OrderedDict{K,Int}(), K[], UnitRange{Int}[], V[], Any[]; check_consistency=false
275+
)
264276
end
265277

266278
# TODO(mhauru) I would like for this to be VarNamedVector(Union{}, Union{}). Simlarly the
@@ -269,15 +281,22 @@ end
269281
# making that change here opens some other cans of worms related to how VarInfo uses
270282
# BangBang, that I don't want to deal with right now.
271283
VarNamedVector() = VarNamedVector{VarName,Real}()
272-
VarNamedVector(xs::Pair...) = VarNamedVector(OrderedDict(xs...))
273-
VarNamedVector(x::AbstractDict) = VarNamedVector(keys(x), values(x))
274-
function VarNamedVector(varnames, vals)
275-
return VarNamedVector(collect_maybe(varnames), collect_maybe(vals))
284+
function VarNamedVector(xs::Pair...; check_consistency=CHECK_CONSISENCEY_DEFAULT)
285+
return VarNamedVector(OrderedDict(xs...); check_consistency=check_consistency)
286+
end
287+
function VarNamedVector(x::AbstractDict; check_consistency=CHECK_CONSISENCEY_DEFAULT)
288+
return VarNamedVector(keys(x), values(x); check_consistency=check_consistency)
289+
end
290+
function VarNamedVector(varnames, vals; check_consistency=CHECK_CONSISENCEY_DEFAULT)
291+
return VarNamedVector(
292+
collect_maybe(varnames), collect_maybe(vals); check_consistency=check_consistency
293+
)
276294
end
277295
function VarNamedVector(
278296
varnames::AbstractVector,
279297
orig_vals::AbstractVector,
280-
transforms=fill(identity, length(varnames)),
298+
transforms=fill(identity, length(varnames));
299+
check_consistency=CHECK_CONSISENCEY_DEFAULT,
281300
)
282301
# Convert `vals` into a vector of vectors.
283302
vals_vecs = map(tovec, orig_vals)
@@ -301,7 +320,19 @@ function VarNamedVector(
301320
offset = r[end]
302321
end
303322

304-
return VarNamedVector(varname_to_index, varnames, ranges, vals, transforms)
323+
# Passing on check_consistency here seems wasteful. Wouldn't it be faster to do a
324+
# lightweight check of the arguments of this function, and rely on the correctness
325+
# of what this function does? However, the expensive check is whether any variable
326+
# subsumes another, and that's the same regardless of where it's done, so the
327+
# optimisation would be quite pointless.
328+
return VarNamedVector(
329+
varname_to_index,
330+
varnames,
331+
ranges,
332+
vals,
333+
transforms;
334+
check_consistency=check_consistency,
335+
)
305336
end
306337

307338
function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector)
@@ -832,7 +863,8 @@ function loosen_types!!(
832863
vnv.vals,
833864
Vector{transform_type}(vnv.transforms),
834865
vnv.is_unconstrained,
835-
vnv.num_inactive,
866+
vnv.num_inactive;
867+
check_consistency=false,
836868
)
837869
end
838870
end
@@ -887,7 +919,8 @@ function tighten_types(vnv::VarNamedVector)
887919
map(identity, vnv.vals),
888920
map(identity, vnv.transforms),
889921
copy(vnv.is_unconstrained),
890-
copy(vnv.num_inactive),
922+
copy(vnv.num_inactive);
923+
check_consistency=false,
891924
)
892925
end
893926

@@ -1041,6 +1074,14 @@ julia> unflatten(vnv, vnv[:]) == vnv
10411074
true
10421075
"""
10431076
function unflatten(vnv::VarNamedVector, vals::AbstractVector)
1077+
if length(vals) != vector_length(vnv)
1078+
throw(
1079+
ArgumentError(
1080+
"Length of `vals` ($(length(vals))) does not match the length of " *
1081+
"`vnv` ($(vector_length(vnv))).",
1082+
),
1083+
)
1084+
end
10441085
new_ranges = deepcopy(vnv.ranges)
10451086
recontiguify_ranges!(new_ranges)
10461087
return VarNamedVector(
@@ -1049,7 +1090,8 @@ function unflatten(vnv::VarNamedVector, vals::AbstractVector)
10491090
new_ranges,
10501091
vals,
10511092
vnv.transforms,
1052-
vnv.is_unconstrained,
1093+
vnv.is_unconstrained;
1094+
check_consistency=false,
10531095
)
10541096
end
10551097

@@ -1063,6 +1105,32 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector)
10631105
vns_right = right_vnv.varnames
10641106
vns_both = union(vns_left, vns_right)
10651107

1108+
# Check that varnames do not subsume each other.
1109+
for vn_left in vns_left
1110+
for vn_right in vns_right
1111+
vn_left == vn_right && continue
1112+
# TODO(mhauru) Subsumation doesn't actually need to be a showstopper. For
1113+
# instance, if right has a value for `x` and left has a value for `x[1]`, then
1114+
# right will take precedence anyway, and we could merge. However, that requires
1115+
# some extra logic that hasn't been done yet.
1116+
if subsumes(vn_left, vn_right)
1117+
throw(
1118+
ArgumentError(
1119+
"Cannot merge VarNamedVectors: variable name $vn_left " *
1120+
"subsumes $vn_right.",
1121+
),
1122+
)
1123+
elseif subsumes(vn_right, vn_left)
1124+
throw(
1125+
ArgumentError(
1126+
"Cannot merge VarNamedVectors: variable name $vn_right " *
1127+
"subsumes $vn_left.",
1128+
),
1129+
)
1130+
end
1131+
end
1132+
end
1133+
10661134
# Determine `eltype` of `vals`.
10671135
T_left = eltype(left_vnv.vals)
10681136
T_right = eltype(right_vnv.vals)
@@ -1117,7 +1185,13 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector)
11171185
end
11181186

11191187
return VarNamedVector(
1120-
varname_to_index, vns_both, ranges, vals, transforms, is_unconstrained
1188+
varname_to_index,
1189+
vns_both,
1190+
ranges,
1191+
vals,
1192+
transforms,
1193+
is_unconstrained;
1194+
check_consistency=false,
11211195
)
11221196
end
11231197

@@ -1193,7 +1267,8 @@ function Base.similar(vnv::VarNamedVector)
11931267
similar(vnv.vals, 0),
11941268
similar(vnv.transforms, 0),
11951269
BitVector(),
1196-
empty(vnv.num_inactive),
1270+
empty(vnv.num_inactive);
1271+
check_consistency=false,
11971272
)
11981273
end
11991274

0 commit comments

Comments
 (0)