From 1f427ad9b2cf0987e240b2069e850194f3a2f003 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Thu, 31 Jul 2025 10:08:40 +0300 Subject: [PATCH 1/6] Ignore claude local settings --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 5a2f3d5..13aa88a 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ attic/ /full/Project.toml .CondaPkg/ LocalPreferences.toml +/.claude/settings.local.json From 4d98d8d67060b2287bd4b57482fe52ae5ff04e4a Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Tue, 19 Aug 2025 15:58:32 +0300 Subject: [PATCH 2/6] Add show(...) for CatLoop --- src/Sim/loop.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/Sim/loop.jl b/src/Sim/loop.jl index 629968c..5fb6870 100644 --- a/src/Sim/loop.jl +++ b/src/Sim/loop.jl @@ -27,6 +27,15 @@ struct CatLoop{CatEngineT} <: CatConfigBase new_response_callback end +function show(io::IO, ::MIME"text/plain", rules::CatLoop) + print(io, "Next item rule: ") + show(io, MIME("text/plain"), rules.next_item) + print(io, "Termination condition: ") + show(io, MIME("text/plain"), rules.termination_condition) + print(io, "Ability estimator: ") + show(io, MIME("text/plain"), rules.ability_estimator) +end + function CatLoop(; rules, get_response, From 832b86362292518d34d7f31970ccc85eef29ee5f Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Tue, 19 Aug 2025 15:59:19 +0300 Subject: [PATCH 3/6] Add RecordedCatLoop --- src/Aggregators/Aggregators.jl | 5 + src/Aggregators/tracked.jl | 4 + src/Sim/Sim.jl | 7 +- src/Sim/recorded_loop.jl | 193 +++++++++++++++++++++++++++++++++ src/Sim/recorder.jl | 10 ++ 5 files changed, 217 insertions(+), 2 deletions(-) create mode 100644 src/Sim/recorded_loop.jl diff --git a/src/Aggregators/Aggregators.jl b/src/Aggregators/Aggregators.jl index 4e5b967..46162d4 100644 --- a/src/Aggregators/Aggregators.jl +++ b/src/Aggregators/Aggregators.jl @@ -52,6 +52,7 @@ export FunctionOptimizer, FunctionIntegrator export DistributionAbilityEstimator export variance, variance_given_mean, mean_1d export RiemannEnumerationIntegrator +export get_integrator # export EnumerationOptimizer # Basic types @@ -200,6 +201,10 @@ struct FunctionIntegrator{IntegratorT <: Integrator} <: AbilityIntegrator integrator::IntegratorT end +function get_integrator(integrator::FunctionIntegrator) + return integrator.integrator +end + function (integrator::FunctionIntegrator{IntegratorT})(f::F, ncomp, lh_function::LHF) where {F, LHF, IntegratorT} diff --git a/src/Aggregators/tracked.jl b/src/Aggregators/tracked.jl index 8bea270..3301462 100644 --- a/src/Aggregators/tracked.jl +++ b/src/Aggregators/tracked.jl @@ -20,6 +20,10 @@ struct TrackedLikelihoodIntegrator{IntegratorT <: Integrator} <: AbilityIntegrat tracker::GriddedAbilityTracker end +function get_integrator(integrator::TrackedLikelihoodIntegrator) + return integrator.integrator +end + function (integrator::TrackedLikelihoodIntegrator{IntegratorT})(f::F, ncomp) where {F, IntegratorT} integrator.integrator(FunctionArgProduct(f), integrator.tracker.cur_ability, ncomp) diff --git a/src/Sim/Sim.jl b/src/Sim/Sim.jl index 75511c9..df9bf9c 100644 --- a/src/Sim/Sim.jl +++ b/src/Sim/Sim.jl @@ -5,7 +5,7 @@ using ElasticArrays using ElasticArrays: sizehint_lastdim! using DocStringExtensions using StatsBase -using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse +using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse, domdims using PsychometricsBazaarBase.Integrators using PsychometricsBazaarBase.IndentWrappers: indent using ..ConfigBase @@ -13,6 +13,7 @@ using ..Responses using ..Rules: CatRules using ..Aggregators: TrackedResponses, add_response!, + get_integrator, Aggregators, AbilityIntegrator, AbilityEstimator, @@ -22,15 +23,17 @@ using ..Aggregators: TrackedResponses, MeanAbilityEstimator, LikelihoodAbilityEstimator, RiemannEnumerationIntegrator -using ..NextItemRules: compute_criteria, best_item +using ..NextItemRules: AbilityVariance, compute_criteria, best_item import Base: show export CatRecorder, CatRecording export CatLoop, record! +export RecordedCatLoop export run_cat, prompt_response, auto_responder include("./recorder.jl") include("./loop.jl") include("./run.jl") +include("./recorded_loop.jl") end diff --git a/src/Sim/recorded_loop.jl b/src/Sim/recorded_loop.jl new file mode 100644 index 0000000..0fae931 --- /dev/null +++ b/src/Sim/recorded_loop.jl @@ -0,0 +1,193 @@ +struct RecordedCatLoop + cat_loop::CatLoop{<: CatRules} + recorder::CatRecorder + item_bank::Union{AbstractItemBank, Nothing} +end + +function _prepare_get_response!(kwargs) + has_responses = haskey(kwargs, :responses) + has_get_response = haskey(kwargs, :get_response) + if has_responses && has_get_response + error("Cannot provide both `responses` and `get_response`.") + elseif !has_responses && !has_get_response + error("Must provide either `responses` or `get_response`.") + elseif has_get_response + return nothing, pop!(kwargs, :get_response) + else + responses = pop!(kwargs, :responses) + return responses, Sim.auto_responder(responses) + end +end + +function _walk_find_type(obj, typ, out=[]) + if obj isa typ + push!(out, obj) + end + for fieldname in propertynames(obj) + _walk_find_type(getfield(obj, fieldname), typ, out) + end + return out +end + +function _find_mean_ability(rules) + if rules.ability_estimator isa MeanAbilityEstimator + return rules.ability_estimator + end + result = _walk_find_type(rules.next_item, MeanAbilityEstimator) + if !isempty(result) + return result[1] + end + result = _walk_find_type(rules.termination_condition, MeanAbilityEstimator) + if !isempty(result) + return result[1] + end + return nothing +end + +function _find_ability_variance(rules) + result = _walk_find_type(rules.next_item, AbilityVariance) + if !isempty(result) + return result[1] + end + return nothing +end + +function enrich_recorder_requests(old_requests, rules) + requests = Dict() + for (k, v) in pairs(old_requests) + new_v = Dict{Symbol, Any}(pairs(v)) + type = get(new_v, :type, nothing) + if type in (:ability, :ability_distribution, :ability_stddev) + if haskey(new_v, :estimator) && haskey(new_v, :source) + error("Cannot provide both `estimator` and `source` for request `$k`.") + elseif !haskey(new_v, :estimator) + if !haskey(new_v, :source) + error("Must provide either `estimator` or `source` for request `$k`.") + end + source = new_v[:source] + if source != :any + error("Not implemented yet: `source = $source` for request `$k`.") + end + if type == :ability + new_v[:estimator] = rules.ability_estimator + elseif type == :ability_stddev + error("Not implemented yet: `type = :ability_stddev` for request `$k`.") + elseif type == :ability_distribution + estimator = nothing + integrator = nothing + mean_ability = _find_mean_ability(rules) + if mean_ability === nothing + ability_variance = _find_ability_variance(rules) + if ability_variance === nothing + error("Cannot find a `MeanAbilityEstimator` or `AbilityVariance` in the rules for request `$k`.") + end + estimator = ability_variance.dist_est + integrator = ability_variance.integrator + else + estimator = distribution_estimator(mean_ability) + integrator = mean_ability.integrator + end + new_v[:estimator] = estimator + if !haskey(new_v, :integrator) + new_v[:integrator] = integrator + end + if !haskey(new_v, :points) + integrator = get_integrator(new_v[:integrator]) + if !(integrator isa AnyGridIntegrator) + error("Must provide `points` for request `$k` when `integrator` is not an `AnyGridIntegrator`.") + end + new_v[:points] = get_grid(integrator) + end + end + end + end + requests[k] = NamedTuple(new_v) + end + return requests +end + +""" +```julia +RecordedCatLoop(; + rules::CatRules, + item_bank::AbstractItemBank = nothing, + responses::Union{Nothing, Vector{ResponseType}} = nothing, + dims::Union{Nothing, Tuple{Int, Int}} = nothing, + expected_responses::Int = 0, + get_response::Function = nothing, + new_response_callback::Function = nothing, + new_response_callbacks::Vector{Function} = Any[] + requests... +) +``` + +This `RecordedCatLoop` is a simplified construction of a `[CatRules](@ref)`-based `[CatLoop](@ref)` and `[CatRecorder](@ref)`. + +It can be constructed with just some cat `rules`, an `item_bank`, and a response memory `responses`, as well as usually one or more `requests` for the `[CatRecorder](@ref). +In this case `dims` are provided by the `item_bank`, and `expected_responses` is set to the length of `responses` as well as used to provide responses using `get_responses`, otherwise the respective arguments must be provided. +The arguments `get_response`, `new_response_callback`, and `new_response_callbacks` are passed to the underlying `CatLoop`. + +The resulting `RecordedCatLoop` can be run directly with run_cat. +""" +function RecordedCatLoop(; kwargs...) + kwargs = Dict(kwargs) + responses, get_response = _prepare_get_response!(kwargs) + local expected_responses, rules + if responses !== nothing + expected_responses = length(responses) + else + expected_responses = pop!(kwargs, :expected_responses, 0) + end + if haskey(kwargs, :rules) + rules = pop!(kwargs, :rules) + else + error("Must provide `rules`.") + end + new_response_callback = pop!(kwargs, :new_response_callback, nothing) + new_response_callbacks = pop!(kwargs, :new_response_callbacks, Any[]) + local dims + item_bank = nothing + if !haskey(kwargs, :item_bank) && !haskey(kwargs, :dims) + error("Must provide either `item_bank` or `dims`.") + end + if haskey(kwargs, :item_bank) + item_bank = pop!(kwargs, :item_bank) + dims = domdims(item_bank) + end + if haskey(kwargs, :dims) + dims = pop!(kwargs, :dims) + end + requests = enrich_recorder_requests(kwargs, rules) + cat_recorder = CatRecorder(dims, expected_responses; requests...) + RecordedCatLoop( + CatLoop(; + rules, + get_response, + new_response_callback, + new_response_callbacks, + recorder=cat_recorder + ), + cat_recorder, + item_bank + ) +end + +""" +$TYPEDSIGNATURES + +Run a given [RecordedCatLoop](@ref) by delegating the call to the wrapped [CatLoop](@ref). + +In case `item_bank` is not provided, the item bank provided during the construction of `RecordedCatLoop` is used. +""" +function run_cat(loop::RecordedCatLoop, + item_bank::AbstractItemBank; + ib_labels = nothing) + run_cat(loop.cat_loop, item_bank; ib_labels=ib_labels) +end + +function run_cat(loop::RecordedCatLoop; ib_labels = nothing) + if loop.item_bank === nothing + error("Trying to run a RecordedCatLoop without an item bank when no item bank was provided at construction time.") + end + run_cat(loop, loop.item_bank; ib_labels=ib_labels) +end diff --git a/src/Sim/recorder.jl b/src/Sim/recorder.jl index b897a88..66136df 100644 --- a/src/Sim/recorder.jl +++ b/src/Sim/recorder.jl @@ -255,14 +255,24 @@ function name_to_label(name) titlecase(join(split(String(name), "_"), " ")) end +function hasallkeys(haystack, needles...) + return all(n in keys(haystack) for n in needles) +end + function CatRecorder(dims::Int, expected_responses::Int; requests...) out = [] sizehint!(out, length(requests)) for (name, request) in pairs(requests) extra = (;) + if !haskey(request, :type) + error("Must provide `type` for $name.") + end if request.type in (:ability, :ability_stddev) data = empty_capacity(Float64, expected_responses) elseif request.type == :ability_distribution + if !hasallkeys(request, :points, :estimator, :integrator) + error("Must provide `points`, `estimator`, and `integrator` for $name.") + end if dims == 0 data = empty_capacity(Float64, length(request.points), expected_responses) else From cd77e3b7505a7bac62878a4205cbfead3f074341 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Fri, 22 Aug 2025 15:26:49 +0300 Subject: [PATCH 4/6] Fix spelling error Criteron => Criterion --- src/NextItemRules/NextItemRules.jl | 2 +- src/NextItemRules/combinators/scalarizers.jl | 6 +++--- src/NextItemRules/porcelain/porcelain.jl | 4 ++-- test/ability_estimator_2d.jl | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/NextItemRules/NextItemRules.jl b/src/NextItemRules/NextItemRules.jl index def1dfb..0839d37 100644 --- a/src/NextItemRules/NextItemRules.jl +++ b/src/NextItemRules/NextItemRules.jl @@ -56,7 +56,7 @@ export PointResponseExpectation, DistributionResponseExpectation export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer export AbilityCovarianceStateMultiCriterion, StateMultiCriterion, ItemMultiCriterion export InformationMatrixCriteria -export ScalarizedStateCriteron, ScalarizedItemCriteron +export ScalarizedStateCriterion, ScalarizedItemCriterion export DRuleItemCriterion, TRuleItemCriterion # Prelude diff --git a/src/NextItemRules/combinators/scalarizers.jl b/src/NextItemRules/combinators/scalarizers.jl index 9326a6d..d1a60b9 100644 --- a/src/NextItemRules/combinators/scalarizers.jl +++ b/src/NextItemRules/combinators/scalarizers.jl @@ -4,7 +4,7 @@ scalarize(::DeterminantScalarizer, mat) = det(mat) struct TraceScalarizer <: MatrixScalarizer end scalarize(::TraceScalarizer, mat) = tr(mat) -struct ScalarizedItemCriteron{ +struct ScalarizedItemCriterion{ ItemMultiCriterionT <: ItemMultiCriterion, MatrixScalarizerT <: MatrixScalarizer } <: ItemCriterion @@ -12,7 +12,7 @@ struct ScalarizedItemCriteron{ scalarizer::MatrixScalarizerT end -struct ScalarizedStateCriteron{ +struct ScalarizedStateCriterion{ StateMultiCriterionT <: StateMultiCriterion, MatrixScalarizerT <: MatrixScalarizer } <: StateCriterion @@ -20,7 +20,7 @@ struct ScalarizedStateCriteron{ scalarizer::MatrixScalarizerT end -function compute_criterion(ssc::Union{ScalarizedItemCriteron, ScalarizedStateCriteron}, +function compute_criterion(ssc::Union{ScalarizedItemCriterion, ScalarizedStateCriterion}, tracked_responses::TrackedResponses, item_idx...) res = scalarize( ssc.scalarizer, diff --git a/src/NextItemRules/porcelain/porcelain.jl b/src/NextItemRules/porcelain/porcelain.jl index dd43165..3944e42 100644 --- a/src/NextItemRules/porcelain/porcelain.jl +++ b/src/NextItemRules/porcelain/porcelain.jl @@ -1,11 +1,11 @@ function DRuleItemCriterion(ability_estimator) - ScalarizedItemCriteron( + ScalarizedItemCriterion( InformationMatrixCriteria(ability_estimator), DeterminantScalarizer()) end function TRuleItemCriterion(ability_estimator) - ScalarizedItemCriteron( + ScalarizedItemCriterion( InformationMatrixCriteria(ability_estimator), TraceScalarizer()) end diff --git a/test/ability_estimator_2d.jl b/test/ability_estimator_2d.jl index 7e820d7..fce036b 100644 --- a/test/ability_estimator_2d.jl +++ b/test/ability_estimator_2d.jl @@ -61,7 +61,7 @@ mle_mode_2d = ModeAbilityEstimator(lh_est_2d, optimizer_2d) @testset "2 dim information higher closer to current estimate" begin information_matrix_criteria = InformationMatrixCriteria(mle_mean_2d) - information_criterion = ScalarizedItemCriteron( + information_criterion = ScalarizedItemCriterion( information_matrix_criteria, DeterminantScalarizer()) # Item closer to the current estimate (1, 1) @@ -79,7 +79,7 @@ mle_mode_2d = ModeAbilityEstimator(lh_est_2d, optimizer_2d) @testset "2 dim variance smaller closer to current estimate" begin covariance_state_criterion = AbilityCovarianceStateMultiCriterion( lh_est_2d, integrator_2d) - variance_criterion = ScalarizedStateCriteron( + variance_criterion = ScalarizedStateCriterion( covariance_state_criterion, DeterminantScalarizer()) variance_item_criterion = ExpectationBasedItemCriterion( mle_mean_2d, variance_criterion) @@ -99,7 +99,7 @@ mle_mode_2d = ModeAbilityEstimator(lh_est_2d, optimizer_2d) @testset "2 dim variance is whack with trace scalarizer" begin covariance_state_criterion = AbilityCovarianceStateMultiCriterion( lh_est_2d, integrator_2d) - variance_criterion = ScalarizedStateCriteron( + variance_criterion = ScalarizedStateCriterion( covariance_state_criterion, TraceScalarizer()) variance_item_criterion = ExpectationBasedItemCriterion( mle_mean_2d, variance_criterion) From 5663460af908611798b7f0329b40dfd996e60aa3 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Sat, 30 Aug 2025 14:31:58 +0300 Subject: [PATCH 5/6] Add item bank to power_summary of RecordedCatLoop * Clear recorded data on run_cat startup --- Project.toml | 4 +- src/Aggregators/Aggregators.jl | 1 + src/Aggregators/ability_estimator.jl | 2 +- src/Rules.jl | 14 ++++- src/Sim/Sim.jl | 4 +- src/Sim/loop.jl | 58 ++++++++++++--------- src/Sim/recorded_loop.jl | 5 ++ src/Sim/recorder.jl | 77 +++++++++++++++++++++------- src/Sim/run.jl | 7 ++- 9 files changed, 121 insertions(+), 51 deletions(-) diff --git a/Project.toml b/Project.toml index 4e24eff..02f616e 100644 --- a/Project.toml +++ b/Project.toml @@ -50,7 +50,7 @@ DocStringExtensions = " ^0.9" EffectSizes = "^1.0.1" ElasticArrays = "1.2.12" FillArrays = "0.13, 1.5.0" -FittedItemBanks = "^0.7.2" +FittedItemBanks = "^0.7.3" ForwardDiff = "1" HypothesisTests = "^0.10.12, ^0.11.0" Interpolations = "^0.14, ^0.15" @@ -62,7 +62,7 @@ MacroTools = "^0.5.6" Mmap = "^1.11" Optim = "1.7.3" PrecompileTools = "1.2.1" -PsychometricsBazaarBase = "^0.8.4" +PsychometricsBazaarBase = "^0.8.6" QuickHeaps = "0.2.2" Random = "^1.11" Reexport = "1" diff --git a/src/Aggregators/Aggregators.jl b/src/Aggregators/Aggregators.jl index 46162d4..b5d24b5 100644 --- a/src/Aggregators/Aggregators.jl +++ b/src/Aggregators/Aggregators.jl @@ -20,6 +20,7 @@ using FittedItemBanks: AbstractItemBank, ContinuousDomain, using ..Responses using ..Responses: concrete_response_type, function_xs, function_ys, Responses using ..ConfigBase +using PsychometricsBazaarBase: power_summary using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome, find1_instance, find1_type, find1_type_sloppy diff --git a/src/Aggregators/ability_estimator.jl b/src/Aggregators/ability_estimator.jl index ed2d204..2189a77 100644 --- a/src/Aggregators/ability_estimator.jl +++ b/src/Aggregators/ability_estimator.jl @@ -65,7 +65,7 @@ function show(io::IO, ::MIME"text/plain", ability_estimator::PosteriorAbilityEst println(io, "Ability posterior distribution") indent_io = indent(io, 2) print(indent_io, "Prior: ") - show(indent_io, MIME("text/plain"), ability_estimator.prior) + power_summary(indent_io, ability_estimator.prior) println(io) end diff --git a/src/Rules.jl b/src/Rules.jl index 3ec8c02..aece5b6 100644 --- a/src/Rules.jl +++ b/src/Rules.jl @@ -12,6 +12,7 @@ using ..NextItemRules: NextItemRule using ..TerminationConditions: TerminationCondition using ..ConfigBase import Base: show +import PsychometricsBazaarBase: power_summary """ $(TYPEDEF) @@ -82,12 +83,21 @@ function CatRules(bits...) end function show(io::IO, ::MIME"text/plain", rules::CatRules) + power_summary(io, rules; toplevel=true) +end + +function power_summary(io::IO, rules::CatRules; toplevel=false) + # TODO print(io, "Next item rule: ") show(io, MIME("text/plain"), rules.next_item) - println(io) + if toplevel + println(io) + end print(io, "Termination condition: ") show(io, MIME("text/plain"), rules.termination_condition) - println(io) + if toplevel + println(io) + end print(io, "Ability estimator: ") show(io, MIME("text/plain"), rules.ability_estimator) end diff --git a/src/Sim/Sim.jl b/src/Sim/Sim.jl index df9bf9c..2e30ac2 100644 --- a/src/Sim/Sim.jl +++ b/src/Sim/Sim.jl @@ -2,10 +2,11 @@ module Sim using DataFrames: DataFrame using ElasticArrays -using ElasticArrays: sizehint_lastdim! +using ElasticArrays: sizehint_lastdim!, resize_lastdim! using DocStringExtensions using StatsBase using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse, domdims +using PsychometricsBazaarBase: show_into_buf, power_summary_into_buf using PsychometricsBazaarBase.Integrators using PsychometricsBazaarBase.IndentWrappers: indent using ..ConfigBase @@ -25,6 +26,7 @@ using ..Aggregators: TrackedResponses, RiemannEnumerationIntegrator using ..NextItemRules: AbilityVariance, compute_criteria, best_item import Base: show +import PsychometricsBazaarBase: power_summary export CatRecorder, CatRecording export CatLoop, record! diff --git a/src/Sim/loop.jl b/src/Sim/loop.jl index 5fb6870..b0fdb58 100644 --- a/src/Sim/loop.jl +++ b/src/Sim/loop.jl @@ -24,16 +24,28 @@ struct CatLoop{CatEngineT} <: CatConfigBase A callback called each time there is a new responses. If provided, it is passed `(responses::TrackedResponses, terminating)`. """ - new_response_callback + new_response_callback::Any + """ + A callback called each time a CAT is run + If provided, it is passed `(item_bank::AbstractItemBank)`. + """ + init_callback::Any end function show(io::IO, ::MIME"text/plain", rules::CatLoop) - print(io, "Next item rule: ") - show(io, MIME("text/plain"), rules.next_item) - print(io, "Termination condition: ") - show(io, MIME("text/plain"), rules.termination_condition) - print(io, "Ability estimator: ") - show(io, MIME("text/plain"), rules.ability_estimator) + print(io, "Computer-Adaptive Test Loop based on ") + show(io, MIME("text/plain"), rules.rules) +end + +function collate_cat_callbacks(callbacks...) + callbacks = filter(!isnothing, callbacks) + function all_callbacks(args...) + for callback in callbacks + callback(args...) + end + nothing + end + all_callbacks end function CatLoop(; @@ -41,23 +53,19 @@ function CatLoop(; get_response, new_response_callback = nothing, new_response_callbacks = Any[], + init_callback = nothing, + init_callbacks = Any[], recorder = nothing ) - new_response_callbacks = collect(new_response_callbacks) - if new_response_callback !== nothing - push!(new_response_callbacks, new_response_callback) - end - if recorder !== nothing && showable(MIME("text/plain"), rules) - buf = IOBuffer() - show(buf, MIME("text/plain"), rules) - recorder.recording.rules_description = String(take!(buf)) - push!(new_response_callbacks, catrecorder_callback(recorder)) - end - function all_callbacks(responses, terminating) - for callback in new_response_callbacks - callback(responses, terminating) - end - nothing - end - CatLoop{typeof(rules)}(rules, get_response, all_callbacks) -end \ No newline at end of file + new_response_callback = collate_cat_callbacks( + new_response_callbacks..., + new_response_callback, + isnothing(recorder) ? nothing : recorder_response_callback(recorder) + ) + init_callback = collate_cat_callbacks( + init_callbacks..., + init_callback, + isnothing(recorder) ? nothing : recorder_init_callback(recorder) + ) + CatLoop{typeof(rules)}(rules, get_response, new_response_callback, init_callback) +end diff --git a/src/Sim/recorded_loop.jl b/src/Sim/recorded_loop.jl index 0fae931..b2aed0e 100644 --- a/src/Sim/recorded_loop.jl +++ b/src/Sim/recorded_loop.jl @@ -191,3 +191,8 @@ function run_cat(loop::RecordedCatLoop; ib_labels = nothing) end run_cat(loop, loop.item_bank; ib_labels=ib_labels) end + +function show(io::IO, ::MIME"text/plain", loop::RecordedCatLoop) + println(io, "Recorded Computer-Adaptive Test:") + power_summary(io, loop.recorder.recording; skip_first_line=true) +end diff --git a/src/Sim/recorder.jl b/src/Sim/recorder.jl index 66136df..4cfde6b 100644 --- a/src/Sim/recorder.jl +++ b/src/Sim/recorder.jl @@ -22,7 +22,8 @@ Base.@kwdef mutable struct CatRecording{LikelihoodsT <: NamedTuple} data::LikelihoodsT item_index::Vector{Int} item_correctness::Vector{Bool} - rules_description::Union{Nothing, String} = nothing + rules_description::Union{Nothing, IOBuffer} = nothing + item_bank_description::Union{Nothing, IOBuffer} = nothing end function Base.getproperty(obj::CatRecording, sym::Symbol) @@ -67,26 +68,33 @@ function prepare_dataframe(recording::CatRecording) end function show(io::IO, ::MIME"text/plain", recording::CatRecording) - println(io, "Recording of a Computer-Adaptive Test") - if recording.rules_description === nothing + power_summary(io, recording; include_cat_config = :always) +end + +function power_summary(io::IO, recording::CatRecording; include_cat_config = :always, skip_first_line=false, kwargs...) + if !skip_first_line + println(io, "Recording of a Computer-Adaptive Test") + end + if recording.rules_description === nothing && include_cat_config == :always println(io, " Unknown CAT configuration") - else + elseif include_cat_config != :never # :available or :always println(io, " CAT configuration:") - for line in split(strip(recording.rules_description, '\n'), "\n") - println(io, " ", line) - end + write(indent(io, 4), recording.rules_description) + seekstart(recording.rules_description) + println(io) + end + if recording.item_bank_description === nothing + println(io, " Unknown item bank") + else + println(io, " Item bank:") + write(indent(io, 4), recording.item_bank_description) + seekstart(recording.item_bank_description) + println(io) end - println(io) println(io, " Recorded information:") df = prepare_dataframe(recording) - buf = IOBuffer() - show(buf, MIME("text/plain"), df; summary=false, eltypes=false, rowlabel=:Number) - seekstart(buf) - for line in eachline(buf) - println(io, " ", line) - end - #println(io) - #println(io, " Final information:") + buf = show_into_buf(df; summary = false, eltypes = false, rowlabel = :Number) + write(indent(io, 4), buf) end #= @@ -132,6 +140,18 @@ function record!(recording::CatRecording, responses; data...) push!(recording.item_correctness, item_correct) end +function Base.empty!(recording::CatRecording) + empty!(recording.item_index) + empty!(recording.item_correctness) + for (name, value) in pairs(recording.data) + if value.data isa AbstractVector + empty!(value.data) + elseif value.data isa ElasticArray + resize_lastdim!(value.data, 0) + end + end +end + #= """ $(TYPEDSIGNATURES) @@ -402,6 +422,27 @@ function record!(recorder::CatRecorder, tracked_responses) record!(recorder.recording, tracked_responses.responses) end -function catrecorder_callback(recoder::CatRecorder) - return (tracked_responses, _) -> record!(recoder, tracked_responses) +function recorder_response_callback(recorder::CatRecorder) + return (tracked_responses, _) -> record!(recorder, tracked_responses) +end + +function recorder_init_callback(recorder::CatRecorder) + return function (cat_loop, item_bank) + empty!(recorder.recording) + if showable(MIME("text/plain"), cat_loop.rules) + recorder.recording.rules_description = power_summary_into_buf(cat_loop.rules; toplevel=false) + end + if showable(MIME("text/plain"), item_bank) + recorder.recording.item_bank_description = power_summary_into_buf(item_bank) + end + end +end + +function show(io::IO, ::MIME"text/plain", recorder::CatRecorder) + indent_io = indent(io, 4) + println(io, "Computer-Adaptive Test Recorder") + println(io, " Requests:") + show(indent_io, MIME"text/plain", recorder.requests) + println(io, " Recording:") + show(indent_io, MIME"text/plain", recorder.recording) end diff --git a/src/Sim/run.jl b/src/Sim/run.jl index b6327f7..e2450e9 100644 --- a/src/Sim/run.jl +++ b/src/Sim/run.jl @@ -43,7 +43,10 @@ If `ib_labels` is not given, default labels of the form function run_cat(loop::CatLoop{RulesT}, item_bank::AbstractItemBank; ib_labels = nothing) where {RulesT <: CatRules} - (; rules, get_response, new_response_callback) = loop + (; rules, get_response, new_response_callback, init_callback) = loop + if init_callback !== nothing + init_callback(loop, item_bank) + end (; next_item, termination_condition, ability_estimator, ability_tracker) = rules responses = TrackedResponses(BareResponses(ResponseType(item_bank)), item_bank, @@ -81,4 +84,4 @@ function run_cat(loop::CatLoop{RulesT}, end end (responses.responses, ability_estimator(responses)) -end \ No newline at end of file +end From 22823d9aa240a1e79239ce3aeaa88c57e1264f93 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Thu, 4 Sep 2025 08:20:21 +0300 Subject: [PATCH 6/6] Improve printing of RecordedCatLoop * Use power_summary(...) as main method for all printing * Add StdDevEstimator as a std dev measure * Add way to make recording entry before any items are administered * Add all kinds printing for Recorder and Recording * Format numbers to 3dp in Recording --- Project.toml | 8 + src/Aggregators/Aggregators.jl | 6 +- src/Aggregators/ability_estimator.jl | 16 +- src/Aggregators/optimizers.jl | 2 +- src/ConfigBase.jl | 2 + src/NextItemRules/NextItemRules.jl | 1 + src/NextItemRules/combinators/expectation.jl | 10 +- .../criteria/pointwise/information.jl | 12 +- .../criteria/state/ability_variance.jl | 10 +- src/NextItemRules/prelude/next_item_rule.jl | 6 +- src/NextItemRules/strategies/balance.jl | 6 +- src/NextItemRules/strategies/pointwise.jl | 4 +- src/NextItemRules/strategies/randomesque.jl | 4 +- src/NextItemRules/strategies/sequential.jl | 10 +- src/Responses.jl | 4 + src/Rules.jl | 6 +- src/Sim/Sim.jl | 9 +- src/Sim/loop.jl | 6 +- src/Sim/recorded_loop.jl | 25 +- src/Sim/recorder.jl | 214 ++++++++++++++---- src/Sim/run.jl | 6 +- src/TerminationConditions.jl | 3 +- 22 files changed, 267 insertions(+), 103 deletions(-) diff --git a/Project.toml b/Project.toml index 02f616e..4a646bd 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.4.0" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -23,8 +24,11 @@ LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" PsychometricsBazaarBase = "b0d9cada-d963-45e9-a4c6-4746243987f1" QuickHeaps = "30b38841-0f52-47f8-a5f8-18d5d4064379" +RDataGet = "a115732e-4334-4ecb-8ea3-f683e7f66d4d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" @@ -44,6 +48,7 @@ Accessors = "^0.1.12" Aqua = "0.8" AutoHashEquals = "2" ConstructionBase = "^1.2" +DataAPI = "1.16.0" DataFrames = "1.6.1" Distributions = "^0.25.88" DocStringExtensions = " ^0.9" @@ -62,8 +67,11 @@ MacroTools = "^0.5.6" Mmap = "^1.11" Optim = "1.7.3" PrecompileTools = "1.2.1" +PrettyPrinting = "0.4.2" +PrettyTables = "3" PsychometricsBazaarBase = "^0.8.6" QuickHeaps = "0.2.2" +RDataGet = "0.1.0" Random = "^1.11" Reexport = "1" Setfield = "^1" diff --git a/src/Aggregators/Aggregators.jl b/src/Aggregators/Aggregators.jl index b5d24b5..95558ae 100644 --- a/src/Aggregators/Aggregators.jl +++ b/src/Aggregators/Aggregators.jl @@ -20,7 +20,7 @@ using FittedItemBanks: AbstractItemBank, ContinuousDomain, using ..Responses using ..Responses: concrete_response_type, function_xs, function_ys, Responses using ..ConfigBase -using PsychometricsBazaarBase: power_summary +import PsychometricsBazaarBase: power_summary using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome, find1_instance, find1_type, find1_type_sloppy @@ -216,8 +216,8 @@ function (integrator::FunctionIntegrator{IntegratorT})(f::F, integrator.integrator(FunctionProduct(f, lh_function), ncomp) end -function show(io::IO, ::MIME"text/plain", responses::FunctionIntegrator) - show(io, MIME("text/plain"), responses.integrator) +function power_summary(io::IO, responses::FunctionIntegrator) + power_summary(io, responses.integrator) end # Defaults diff --git a/src/Aggregators/ability_estimator.jl b/src/Aggregators/ability_estimator.jl index 2189a77..98dc613 100644 --- a/src/Aggregators/ability_estimator.jl +++ b/src/Aggregators/ability_estimator.jl @@ -26,7 +26,7 @@ function pdf(::LikelihoodAbilityEstimator, AbilityLikelihood(tracked_responses) end -function show(io::IO, ::MIME"text/plain", ability_estimator::LikelihoodAbilityEstimator) +function power_summary(io::IO, ability_estimator::LikelihoodAbilityEstimator) println(io, "Ability likelihood distribution") end @@ -61,7 +61,7 @@ function multiple_response_types_guard(tracked_responses) return false end -function show(io::IO, ::MIME"text/plain", ability_estimator::PosteriorAbilityEstimator) +function power_summary(io::IO, ability_estimator::PosteriorAbilityEstimator) println(io, "Ability posterior distribution") indent_io = indent(io, 2) print(indent_io, "Prior: ") @@ -224,11 +224,11 @@ function ModeAbilityEstimator(bits...) ModeAbilityEstimator(dist_est, optimizer) end -function show(io::IO, ::MIME"text/plain", ability_estimator::ModeAbilityEstimator) +function power_summary(io::IO, ability_estimator::ModeAbilityEstimator) println(io, "Estimate ability using its mode") indent_io = indent(io, 2) - show(indent_io, MIME("text/plain"), ability_estimator.dist_est) - show(indent_io, MIME("text/plain"), ability_estimator.optim) + power_summary(indent_io, ability_estimator.dist_est) + power_summary(indent_io, ability_estimator.optim) end struct MeanAbilityEstimator{ @@ -246,12 +246,12 @@ function MeanAbilityEstimator(bits...) MeanAbilityEstimator(dist_est, integrator) end -function show(io::IO, ::MIME"text/plain", ability_estimator::MeanAbilityEstimator) +function power_summary(io::IO, ability_estimator::MeanAbilityEstimator) println(io, "Estimate ability using its mean") indent_io = indent(io, 2) - show(indent_io, MIME("text/plain"), ability_estimator.dist_est) + power_summary(indent_io, ability_estimator.dist_est) print(indent_io, "Integrator: ") - show(indent_io, MIME("text/plain"), ability_estimator.integrator) + power_summary(indent_io, ability_estimator.integrator) end function distribution_estimator(dist_est::DistributionAbilityEstimator)::DistributionAbilityEstimator diff --git a/src/Aggregators/optimizers.jl b/src/Aggregators/optimizers.jl index 314b6c3..a43d191 100644 --- a/src/Aggregators/optimizers.jl +++ b/src/Aggregators/optimizers.jl @@ -10,7 +10,7 @@ function (optim::FunctionOptimizer)(f::F, optim.optim(comp_f) end -function show(io::IO, ::MIME"text/plain", optim::FunctionOptimizer) +function power_summary(io::IO, optim::FunctionOptimizer) indent_io = indent(io, 2) if optim.optim isa Optimizers.OneDimOptimOptimizer || optim.optim isa Optimizers.MultiDimOptimOptimizer || optim.optim isa Optimizers.NativeOneDimOptimOptimizer inner = optim.optim diff --git a/src/ConfigBase.jl b/src/ConfigBase.jl index 06292a0..230f40b 100644 --- a/src/ConfigBase.jl +++ b/src/ConfigBase.jl @@ -10,6 +10,8 @@ $(TYPEDEF) """ abstract type CatConfigBase end +show(io::IO, ::MIME"text/plain", obj::CatConfigBase) = power_summary(io, obj) + function walk(f, x::CatConfigBase, lens = identity) f(x, lens) for fieldname in fieldnames(typeof(x)) diff --git a/src/NextItemRules/NextItemRules.jl b/src/NextItemRules/NextItemRules.jl index 0839d37..1094cd2 100644 --- a/src/NextItemRules/NextItemRules.jl +++ b/src/NextItemRules/NextItemRules.jl @@ -17,6 +17,7 @@ using Random: AbstractRNG, Xoshiro using ..Responses: BareResponses using ..ConfigBase +import PsychometricsBazaarBase: power_summary using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome, find1_instance, find1_type using PsychometricsBazaarBase.Integrators: Integrator, intval diff --git a/src/NextItemRules/combinators/expectation.jl b/src/NextItemRules/combinators/expectation.jl index 7e51841..16e7923 100644 --- a/src/NextItemRules/combinators/expectation.jl +++ b/src/NextItemRules/combinators/expectation.jl @@ -39,10 +39,10 @@ function Aggregators.response_expectation( item_idx) end -function show(io::IO, ::MIME"text/plain", point_response_expectation::PointResponseExpectation) +function power_summary(io::IO, point_response_expectation::PointResponseExpectation) println(io, "Expected response at point ability estimate") indent_io = indent(io, 2) - show(indent_io, MIME("text/plain"), point_response_expectation.ability_estimator) + power_summary(indent_io, point_response_expectation.ability_estimator) end struct DistributionResponseExpectation{ @@ -131,9 +131,9 @@ function compute_criterion( res end -function show(io::IO, ::MIME"text/plain", item_criterion::ExpectationBasedItemCriterion) +function power_summary(io::IO, item_criterion::ExpectationBasedItemCriterion) println(io, "Optimize an state/item/item-category criterion based on an expected response") indent_io = indent(io, 2) - show(indent_io, MIME"text/plain"(), item_criterion.response_expectation) - show(indent_io, MIME"text/plain"(), item_criterion.criterion) + power_summary(indent_io, item_criterion.response_expectation) + power_summary(indent_io, item_criterion.criterion) end diff --git a/src/NextItemRules/criteria/pointwise/information.jl b/src/NextItemRules/criteria/pointwise/information.jl index 1104bef..566249b 100644 --- a/src/NextItemRules/criteria/pointwise/information.jl +++ b/src/NextItemRules/criteria/pointwise/information.jl @@ -22,7 +22,7 @@ function compute_criterion_vec( -actual end -function show(io::IO, ::MIME"text/plain", ::ObservedInformationPointwiseItemCategoryCriterion) +function power_summary(io::IO, ::ObservedInformationPointwiseItemCategoryCriterion) println(io, "Observed pointwise item-category information") end @@ -51,7 +51,7 @@ function compute_criterion_vec( end -function show(io::IO, ::MIME"text/plain", ::RawEmpiricalInformationPointwiseItemCategoryCriterion) +function power_summary(io::IO, ::RawEmpiricalInformationPointwiseItemCategoryCriterion) println(io, "Raw empirical pointwise item-category information") end @@ -104,7 +104,7 @@ function compute_criterion_vec( -actual end -function show(io::IO, ::MIME"text/plain", ::EmpiricalInformationPointwiseItemCategoryCriterion) +function power_summary(io::IO, ::EmpiricalInformationPointwiseItemCategoryCriterion) println(io, "Empirical pointwise item-category information") end @@ -131,7 +131,7 @@ function compute_criterion( sum(compute_criterion_vec(tii.pcic, ir, ability)) end -function show(io::IO, ::MIME"text/plain", rule::TotalItemInformation) +function power_summary(io::IO, rule::TotalItemInformation) if rule.pcic isa ObservedInformationPointwiseItemCategoryCriterion println(io, "Observed pointwise item information") elseif rule.pcic isa RawEmpiricalInformationPointwiseItemCategoryCriterion @@ -140,6 +140,6 @@ function show(io::IO, ::MIME"text/plain", rule::TotalItemInformation) println(io, "Empirical pointwise item information") else print(io, "Total ") - show(io, MIME("text/plain"), rule.pcic) + power_summary(io, rule.pcic) end -end \ No newline at end of file +end diff --git a/src/NextItemRules/criteria/state/ability_variance.jl b/src/NextItemRules/criteria/state/ability_variance.jl index 47d7232..7af861b 100644 --- a/src/NextItemRules/criteria/state/ability_variance.jl +++ b/src/NextItemRules/criteria/state/ability_variance.jl @@ -68,11 +68,13 @@ function compute_criterion( denom) end -function show(io::IO, ::MIME"text/plain", criterion::AbilityVariance) - println(io, "Minimise variance of ability estimate") +function power_summary(io::IO, criterion::AbilityVariance; skip_first_line=false) + if !skip_first_line + println(io, "Minimise variance of ability estimate") + end indent_io = indent(io, 2) - show(indent_io, MIME("text/plain"), criterion.dist_est) - show(indent_io, MIME("text/plain"), criterion.integrator) + power_summary(indent_io, criterion.dist_est) + power_summary(indent_io, criterion.integrator) end struct AbilityCovarianceStateMultiCriterion{ diff --git a/src/NextItemRules/prelude/next_item_rule.jl b/src/NextItemRules/prelude/next_item_rule.jl index c61836c..9c93d38 100644 --- a/src/NextItemRules/prelude/next_item_rule.jl +++ b/src/NextItemRules/prelude/next_item_rule.jl @@ -53,11 +53,11 @@ function best_item(rule::NextItemRule, tracked_responses::TrackedResponses) best_item(rule, tracked_responses, tracked_responses.item_bank) end -function Base.show(io::IO, ::MIME"text/plain", rule::ItemCriterionRule) +function power_summary(io::IO, rule::ItemCriterionRule) println(io, "Pick optimal item criterion according to strategy") indent_io = indent(io, 2) - show(indent_io, MIME"text/plain"(), rule.strategy) - show(indent_io, MIME"text/plain"(), rule.criterion) + power_summary(indent_io, rule.strategy) + power_summary(indent_io, rule.criterion) end # Default implementation diff --git a/src/NextItemRules/strategies/balance.jl b/src/NextItemRules/strategies/balance.jl index 5adcd20..42a64b7 100644 --- a/src/NextItemRules/strategies/balance.jl +++ b/src/NextItemRules/strategies/balance.jl @@ -39,11 +39,11 @@ function GreedyForcedContentBalancer(targets::AbstractVector, groups, bits...) GreedyForcedContentBalancer(targets, groups, NextItemRule(bits...)) end -function show(io::IO, ::MIME"text/plain", rule::GreedyForcedContentBalancer) +function power_summary(io::IO, rule::GreedyForcedContentBalancer) indent_io = indent(io, 2) println(io, "Greedy + forced content balancing") println(indent_io, "Target ratio: " * join(rule.targets, ", ")) - show(indent_io, MIME("text/plain"), rule.inner_rule) + power_summary(indent_io, rule.inner_rule) end function next_item_bank(targets, groups, responses, items) @@ -86,4 +86,4 @@ function compute_criteria( expanded = fill(Inf, length(items)) expanded[matching_indicator] .= criteria return expanded -end \ No newline at end of file +end diff --git a/src/NextItemRules/strategies/pointwise.jl b/src/NextItemRules/strategies/pointwise.jl index e0a5616..81454d7 100644 --- a/src/NextItemRules/strategies/pointwise.jl +++ b/src/NextItemRules/strategies/pointwise.jl @@ -16,12 +16,12 @@ function best_item(rule::PointwiseNextItemRule, responses::TrackedResponses, ite return idx end -function show(io::IO, ::MIME"text/plain", rule::PointwiseNextItemRule) +function power_summary(io::IO, rule::PointwiseNextItemRule) println(io, "Optimize a pointwise criterion at specified points") indent_io = indent(io, 2) points_desc = join(rule.points, ", ") println(indent_io, "Points: $points_desc") - show(indent_io, MIME("text/plain"), rule.criterion) + power_summary(indent_io, rule.criterion) end diff --git a/src/NextItemRules/strategies/randomesque.jl b/src/NextItemRules/strategies/randomesque.jl index b3e0ac5..cf5b4e1 100644 --- a/src/NextItemRules/strategies/randomesque.jl +++ b/src/NextItemRules/strategies/randomesque.jl @@ -53,6 +53,6 @@ function best_item( randomesque(rule.strategy.rng, rule.criterion, responses, items, rule.strategy.k)[1] end -function show(io::IO, ::MIME"text/plain", rule::RandomesqueStrategy) +function power_summary(io::IO, rule::RandomesqueStrategy) println(io, "Randomesque strategy with k = $(rule.k)") -end \ No newline at end of file +end diff --git a/src/NextItemRules/strategies/sequential.jl b/src/NextItemRules/strategies/sequential.jl index 7f98653..d62d764 100644 --- a/src/NextItemRules/strategies/sequential.jl +++ b/src/NextItemRules/strategies/sequential.jl @@ -29,14 +29,14 @@ function compute_criteria(rule::FixedRuleSequencer, responses::TrackedResponses) return compute_criteria(current_rule(rule, responses), responses) end -function show(io::IO, ::MIME"text/plain", rule::FixedRuleSequencer) +function power_summary(io::IO, rule::FixedRuleSequencer) indent_io = indent(io, 2) println(io, "Fixed rule sequencing:") print(indent_io, "Firstly: ") - show(indent_io, MIME("text/plain"), rule.rules[1]) + power_summary(indent_io, rule.rules[1]) for (responses, rule) in zip(rule.breaks, rule.rules[2:end]) print(indent_io, "After $responses responses: ") - show(indent_io, MIME("text/plain"), rule) + power_summary(indent_io, rule) end end @@ -58,11 +58,11 @@ function best_item(rule::MemoryNextItemRule, responses::TrackedResponses, _items # TODO: Add some basic error checking -- can only panic end -function show(io::IO, ::MIME"text/plain", rule::MemoryNextItemRule) +function power_summary(io::IO, rule::MemoryNextItemRule) item_list = join(rule.item_idxs, ", ") println(io, "Present the items indexed: $item_list") end function FixedFirstItem(item_idx::Int, rule::NextItemRule) FixedRuleSequencer((1,), (MemoryNextItemRule((item_idx,)), rule)) -end \ No newline at end of file +end diff --git a/src/Responses.jl b/src/Responses.jl index 086c513..9d3801b 100644 --- a/src/Responses.jl +++ b/src/Responses.jl @@ -75,6 +75,10 @@ function Base.empty!(responses::BareResponses) Base.empty!(responses.values) end +function Base.length(responses::BareResponses) + return length(responses.indices) +end + function add_response!(responses::BareResponses, response::Response)::BareResponses push!(responses.indices, response.index) push!(responses.values, response.value) diff --git a/src/Rules.jl b/src/Rules.jl index aece5b6..9f346a1 100644 --- a/src/Rules.jl +++ b/src/Rules.jl @@ -89,17 +89,17 @@ end function power_summary(io::IO, rules::CatRules; toplevel=false) # TODO print(io, "Next item rule: ") - show(io, MIME("text/plain"), rules.next_item) + power_summary(io, rules.next_item) if toplevel println(io) end print(io, "Termination condition: ") - show(io, MIME("text/plain"), rules.termination_condition) + power_summary(io, rules.termination_condition) if toplevel println(io) end print(io, "Ability estimator: ") - show(io, MIME("text/plain"), rules.ability_estimator) + power_summary(io, rules.ability_estimator) end function _find_ability_estimator_and_tracker(bits...) diff --git a/src/Sim/Sim.jl b/src/Sim/Sim.jl index 2e30ac2..7c39e1d 100644 --- a/src/Sim/Sim.jl +++ b/src/Sim/Sim.jl @@ -9,6 +9,7 @@ using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse, domdims using PsychometricsBazaarBase: show_into_buf, power_summary_into_buf using PsychometricsBazaarBase.Integrators using PsychometricsBazaarBase.IndentWrappers: indent +using PsychometricsBazaarBase: GridSummary using ..ConfigBase using ..Responses using ..Rules: CatRules @@ -24,9 +25,13 @@ using ..Aggregators: TrackedResponses, MeanAbilityEstimator, LikelihoodAbilityEstimator, RiemannEnumerationIntegrator -using ..NextItemRules: AbilityVariance, compute_criteria, best_item +using ..NextItemRules: AbilityVariance, compute_criteria, compute_criterion, best_item import Base: show import PsychometricsBazaarBase: power_summary +using PrettyPrinting +using PrettyTables: fmt__printf +using DataAPI: nrow +using DataFrames: DataFrames export CatRecorder, CatRecording export CatLoop, record! @@ -38,4 +43,6 @@ include("./loop.jl") include("./run.jl") include("./recorded_loop.jl") +show(io::IO, ::MIME"text/plain", obj::Union{CatRecorder, CatRecording, RecordedCatLoop}) = power_summary(io, obj) + end diff --git a/src/Sim/loop.jl b/src/Sim/loop.jl index b0fdb58..54d9d69 100644 --- a/src/Sim/loop.jl +++ b/src/Sim/loop.jl @@ -27,14 +27,14 @@ struct CatLoop{CatEngineT} <: CatConfigBase new_response_callback::Any """ A callback called each time a CAT is run - If provided, it is passed `(item_bank::AbstractItemBank)`. + If provided, it is passed `(loop::CatLoop, responses::TrackedResponses)`. """ init_callback::Any end -function show(io::IO, ::MIME"text/plain", rules::CatLoop) +function power_summary(io::IO, rules::CatLoop) print(io, "Computer-Adaptive Test Loop based on ") - show(io, MIME("text/plain"), rules.rules) + power_summary(io, rules.rules) end function collate_cat_callbacks(callbacks...) diff --git a/src/Sim/recorded_loop.jl b/src/Sim/recorded_loop.jl index b2aed0e..0ead8cb 100644 --- a/src/Sim/recorded_loop.jl +++ b/src/Sim/recorded_loop.jl @@ -52,6 +52,21 @@ function _find_ability_variance(rules) return nothing end +struct StdDevEstimator + ability_variance::AbilityVariance +end + +function (est::StdDevEstimator)(tracked_responses::TrackedResponses) + sqrt(compute_criterion(est.ability_variance, tracked_responses)) +end + +function power_summary(io::IO, est::StdDevEstimator) + println(io, "Standard deviation based on variance estimate") + power_summary(io, est.ability_variance; skip_first_line=true) +end + +show(io::IO, ::MIME"text/html", est::StdDevEstimator) = power_summary(io, est) + function enrich_recorder_requests(old_requests, rules) requests = Dict() for (k, v) in pairs(old_requests) @@ -71,7 +86,11 @@ function enrich_recorder_requests(old_requests, rules) if type == :ability new_v[:estimator] = rules.ability_estimator elseif type == :ability_stddev - error("Not implemented yet: `type = :ability_stddev` for request `$k`.") + ability_variance = _find_ability_variance(rules) + if ability_variance === nothing + error("Cannot find a `AbilityVariance` in the rules for request `$k`.") + end + new_v[:estimator] = StdDevEstimator(ability_variance) elseif type == :ability_distribution estimator = nothing integrator = nothing @@ -192,7 +211,7 @@ function run_cat(loop::RecordedCatLoop; ib_labels = nothing) run_cat(loop, loop.item_bank; ib_labels=ib_labels) end -function show(io::IO, ::MIME"text/plain", loop::RecordedCatLoop) +function power_summary(io::IO, loop::RecordedCatLoop) println(io, "Recorded Computer-Adaptive Test:") - power_summary(io, loop.recorder.recording; skip_first_line=true) + power_summary(io, loop.recorder.recording; skip_first_line=true, recorder=loop.recorder) end diff --git a/src/Sim/recorder.jl b/src/Sim/recorder.jl index 4cfde6b..2f37f3a 100644 --- a/src/Sim/recorder.jl +++ b/src/Sim/recorder.jl @@ -24,6 +24,8 @@ Base.@kwdef mutable struct CatRecording{LikelihoodsT <: NamedTuple} item_correctness::Vector{Bool} rules_description::Union{Nothing, IOBuffer} = nothing item_bank_description::Union{Nothing, IOBuffer} = nothing + has_initial::Bool = false + include_initial::Bool = true end function Base.getproperty(obj::CatRecording, sym::Symbol) @@ -37,64 +39,157 @@ end Base.@kwdef struct CatRecorder{RequestsT <: NamedTuple, LikelihoodsT <: NamedTuple} recording::CatRecording{LikelihoodsT} requests::RequestsT + include_initial=true #integrator::AbilityIntegrator #raw_estimator::LikelihoodAbilityEstimator #ability_estimator::AbilityEstimator end +""" + consume!(dict, key) do value + ... + end + +Execute the callback with the value at `key` in `dict` if it exists, and remove that key from the dictionary. +""" +function consume!(cb, dict, key) + if haskey(dict, key) + cb(dict[key]) + delete!(dict, key) + end +end + +function power_summary(io::IO, recorder::CatRecorder; include_recording = true, skip_first_line=false, kwargs...) + if !skip_first_line + println(io, "Recorder for Computer-Adaptive Tests:") + end + if include_recording + power_summary(io, recorder.recording; skip_first_line=true, recorder=recorder, kwargs...) + else + for (name, config) in pairs(recorder.requests) + println(io, " \"" * string(name) * "\"") + indent_io = indent(io, 4) + config_dict = Dict{Symbol, Any}(pairs(config)) + for k in (:label, :type, :source) + consume!(config_dict, k) do v + println(indent_io, uppercasefirst(string(k)) * ": ", v) + end + end + consume!(config_dict, :estimator) do v + power_summary(indent_io, v) + end + consume!(config_dict, :points) do v + println(indent_io, "Points:") + power_summary(indent(indent_io, 2), GridSummary(v)) + end + consume!(config_dict, :integrator) do v + power_summary(indent_io, v) + end + for (k, v) in pairs(config_dict) + println(indent_io, "Unknown key $k:") + println(indent(indent_io, 2), pprint(v)) + end + end + end +end + function CatRecording( data, - expected_responses=0 + expected_responses=0, + include_initial=true ) CatRecording(; data=data, item_index=empty_capacity(Int, expected_responses), - item_correctness=empty_capacity(Bool, expected_responses) + item_correctness=empty_capacity(Bool, expected_responses), + include_initial ) end function prepare_dataframe(recording::CatRecording) - cols::Vector{Pair{String, Vector{Any}}} = [ - "Item" => recording.item_index, - "Response" => recording.item_correctness, - ] + item_indices = convert(Vector{Union{Nothing, Int}}, recording.item_index) + responses = convert(Vector{Union{Nothing, Bool}}, recording.item_correctness) + if recording.include_initial && recording.has_initial + pushfirst!(item_indices, nothing) + pushfirst!(responses, nothing) + end + cols = (; + Item = item_indices, + Response = responses, + ) for (name, value) in pairs(recording.data) - #@show name value.type keys(value) size(value.data) if value.data isa AbstractVector - push!(cols, String(name) => value.data) + label = haskey(value, :label) ? Symbol(value.label) : name + cols = (; + cols..., + label => copy(value.data) + ) end end - return DataFrame(cols) -end - -function show(io::IO, ::MIME"text/plain", recording::CatRecording) - power_summary(io, recording; include_cat_config = :always) + return DataFrame(cols, copycols=false) end -function power_summary(io::IO, recording::CatRecording; include_cat_config = :always, skip_first_line=false, kwargs...) +function power_summary(io::IO, recording::CatRecording; include_cat_config = :always, skip_first_line=false, recorder=nothing, toplevel=true, kwargs...) if !skip_first_line println(io, "Recording of a Computer-Adaptive Test") end - if recording.rules_description === nothing && include_cat_config == :always - println(io, " Unknown CAT configuration") - elseif include_cat_config != :never # :available or :always - println(io, " CAT configuration:") - write(indent(io, 4), recording.rules_description) - seekstart(recording.rules_description) - println(io) + is_empty = ( + isnothing(recording.rules_description) && + isnothing(recording.item_bank_description) && + isempty(recording) + ) + indent_io = if toplevel + println() + io + else + indent(io, 2) + end + if !is_empty + if recording.rules_description === nothing && include_cat_config == :always + println(indent_io, "Unknown CAT configuration") + elseif include_cat_config != :never # :available or :always + println(indent_io, "CAT configuration:") + write(indent(indent_io, 2), recording.rules_description) + seekstart(recording.rules_description) + end + if toplevel + println(io) + end + if recording.item_bank_description === nothing + println(indent_io, "Unknown item bank") + else + println(indent_io, "Item bank:") + write(indent(indent_io, 2), recording.item_bank_description) + seekstart(recording.item_bank_description) + if toplevel + println(io) + end + end + end + if recorder !== nothing + println(indent_io, "Requested information:") + power_summary(indent_io, recorder; include_recording=false, skip_first_line=true, kwargs...) + if toplevel + println(io) + end end - if recording.item_bank_description === nothing - println(io, " Unknown item bank") + if is_empty + println(indent_io, "CAT has not yet been run; no recorded information") else - println(io, " Item bank:") - write(indent(io, 4), recording.item_bank_description) - seekstart(recording.item_bank_description) + println(indent_io, "Recorded information:") println(io) + df = prepare_dataframe(recording) + buf = show_into_buf( + df; + summary = false, + eltypes = false, + stubhead_label = "Administration", + row_labels = 0:(nrow(df) - 1), + compact_printing = false, + formatters = [DataFrames._pretty_tables_general_formatter, fmt__printf("%5.3f", [3, 4])] + ) + write(indent(indent_io, 2), buf) end - println(io, " Recorded information:") - df = prepare_dataframe(recording) - buf = show_into_buf(df; summary = false, eltypes = false, rowlabel = :Number) - write(indent(io, 4), buf) end #= @@ -132,15 +227,18 @@ end =# function record!(recording::CatRecording, responses; data...) - #push_ability_est!(recording.ability_ests, recording.col_idx, ability_est) - - item_index = responses.indices[end] - item_correct = responses.values[end] > 0 - push!(recording.item_index, item_index) - push!(recording.item_correctness, item_correct) + if length(responses) == 0 + recording.has_initial = true + else + item_index = responses.indices[end] + item_correct = responses.values[end] > 0 + push!(recording.item_index, item_index) + push!(recording.item_correctness, item_correct) + end end function Base.empty!(recording::CatRecording) + recording.has_initial = false empty!(recording.item_index) empty!(recording.item_correctness) for (name, value) in pairs(recording.data) @@ -152,6 +250,10 @@ function Base.empty!(recording::CatRecording) end end +function Base.isempty(recording::CatRecording) + return length(recording.item_index) == 0 && !recording.has_initial +end + #= """ $(TYPEDSIGNATURES) @@ -282,7 +384,12 @@ end function CatRecorder(dims::Int, expected_responses::Int; requests...) out = [] sizehint!(out, length(requests)) - for (name, request) in pairs(requests) + include_initial = true + requests_dict = Dict{Symbol, Any}(pairs(requests)) + consume!(requests_dict, :include_initial) do v + include_initial = v + end + for (name, request) in pairs(requests_dict) extra = (;) if !haskey(request, :type) error("Must provide `type` for $name.") @@ -308,8 +415,9 @@ function CatRecorder(dims::Int, expected_responses::Int; requests...) ))) end return CatRecorder(; - recording=CatRecording(NamedTuple(out), expected_responses), - requests=NamedTuple(requests), + recording=CatRecording(NamedTuple(out), expected_responses, include_initial), + requests=NamedTuple(requests_dict), + include_initial ) #= CatRecording( @@ -415,10 +523,18 @@ end $(TYPEDSIGNATURES) """ function record!(recorder::CatRecorder, tracked_responses) - item_index = tracked_responses.responses.indices[end] - item_correct = tracked_responses.responses.values[end] > 0 - ir = ItemResponse(tracked_responses.item_bank, item_index) - service_requests!(recorder, tracked_responses, ir, item_correct) + local ir, item_correct + if length(tracked_responses.responses) == 0 + ir = nothing + item_correct = nothing + else + item_index = tracked_responses.responses.indices[end] + item_correct = tracked_responses.responses.values[end] > 0 + ir = ItemResponse(tracked_responses.item_bank, item_index) + end + if ir !== nothing || recorder.include_initial + service_requests!(recorder, tracked_responses, ir, item_correct) + end record!(recorder.recording, tracked_responses.responses) end @@ -427,7 +543,8 @@ function recorder_response_callback(recorder::CatRecorder) end function recorder_init_callback(recorder::CatRecorder) - return function (cat_loop, item_bank) + return function (cat_loop, tracked_responses) + item_bank = tracked_responses.item_bank empty!(recorder.recording) if showable(MIME("text/plain"), cat_loop.rules) recorder.recording.rules_description = power_summary_into_buf(cat_loop.rules; toplevel=false) @@ -435,14 +552,17 @@ function recorder_init_callback(recorder::CatRecorder) if showable(MIME("text/plain"), item_bank) recorder.recording.item_bank_description = power_summary_into_buf(item_bank) end + record!(recorder, tracked_responses) end end -function show(io::IO, ::MIME"text/plain", recorder::CatRecorder) +#= +function power_summary(io::IO, recorder::CatRecorder) indent_io = indent(io, 4) println(io, "Computer-Adaptive Test Recorder") println(io, " Requests:") - show(indent_io, MIME"text/plain", recorder.requests) + power_summary(indent_io, recorder.requests) println(io, " Recording:") - show(indent_io, MIME"text/plain", recorder.recording) + power_summary(indent_io, recorder.recording) end +=# diff --git a/src/Sim/run.jl b/src/Sim/run.jl index e2450e9..adfa911 100644 --- a/src/Sim/run.jl +++ b/src/Sim/run.jl @@ -44,13 +44,13 @@ function run_cat(loop::CatLoop{RulesT}, item_bank::AbstractItemBank; ib_labels = nothing) where {RulesT <: CatRules} (; rules, get_response, new_response_callback, init_callback) = loop - if init_callback !== nothing - init_callback(loop, item_bank) - end (; next_item, termination_condition, ability_estimator, ability_tracker) = rules responses = TrackedResponses(BareResponses(ResponseType(item_bank)), item_bank, ability_tracker) + if init_callback !== nothing + init_callback(loop, responses) + end while true local next_index @debug begin diff --git a/src/TerminationConditions.jl b/src/TerminationConditions.jl index 45867c7..080838f 100644 --- a/src/TerminationConditions.jl +++ b/src/TerminationConditions.jl @@ -4,6 +4,7 @@ using DocStringExtensions: TYPEDEF, TYPEDFIELDS using FittedItemBanks: AbstractItemBank using ..Aggregators: TrackedResponses using ..ConfigBase +import PsychometricsBazaarBase: power_summary using PsychometricsBazaarBase.ConfigTools: @returnsome, find1_instance using FittedItemBanks import Base: show @@ -32,7 +33,7 @@ function (condition::FixedLength)(responses::TrackedResponses, length(responses) >= condition.num_items end -function show(io::IO, ::MIME"text/plain", condition::FixedLength) +function power_summary(io::IO, condition::FixedLength) println(io, "Terminate test after administering $(condition.num_items) items") end