diff --git a/ext/PigeonsDynamicPPLExt/state.jl b/ext/PigeonsDynamicPPLExt/state.jl index 96ef2d27d..8e75b6589 100644 --- a/ext/PigeonsDynamicPPLExt/state.jl +++ b/ext/PigeonsDynamicPPLExt/state.jl @@ -52,7 +52,7 @@ end function Pigeons.step!(explorer::AutoMALA, replica, shared, vi::DynamicPPL.TypedVarInfo) log_potential = Pigeons.find_log_potential(replica, shared.tempering, shared) state = DynamicPPL.getall(vi) - Pigeons._extract_commons_and_run_auto_mala!(explorer, replica, shared, log_potential, state) + Pigeons._extract_commons_and_run!(explorer, replica, shared, log_potential, state) DynamicPPL.setall!(replica.state, state) end diff --git a/test/test_turing.jl b/test/test_turing.jl index ff4a47414..5b8e26815 100644 --- a/test/test_turing.jl +++ b/test/test_turing.jl @@ -1,21 +1,35 @@ include("supporting/turing_models.jl") + + + #= The reason this test is excluded is described in ADgradient() in TuringLogPotential.jl =# -# @testset "Turing-gradient" begin -# target = Pigeons.toy_turing_unid_target() +@testset "Turing-gradient" begin + target = Pigeons.toy_turing_unid_target() -# @show Threads.nthreads() + @show Threads.nthreads() -# logz_mala = Pigeons.stepping_stone_pair(pigeons(; target, explorer = AutoMALA(adapt_pre_conditioning = false))) -# logz_slicer = Pigeons.stepping_stone_pair(pigeons(; target, explorer = SliceSampler())) + logz_mala = Pigeons.stepping_stone_pair(pigeons(; target, explorer = AutoMALA(preconditioner = Pigeons.IdentityPreconditioner()))) + logz_slicer = Pigeons.stepping_stone_pair(pigeons(; target, explorer = SliceSampler())) -# @test abs(logz_mala[1] - logz_slicer[1]) < 0.1 -# end + @test abs(logz_mala[1] - logz_slicer[1]) < 0.1 +end @testset "Turing-variable-names" begin pt = pigeons(target = TuringLogPotential(model_with_vectors()), n_rounds = 2); @test length(variable_names(pt)) == 4 +end + +@testset "Non-turing-gradient" begin + target = Pigeons.toy_mvn_target(2) + + @show Threads.nthreads() + + logz_mala = Pigeons.stepping_stone_pair(pigeons(; target, explorer = AutoMALA(preconditioner = Pigeons.IdentityPreconditioner()))) + logz_slicer = Pigeons.stepping_stone_pair(pigeons(; target, explorer = SliceSampler())) + + @test abs(logz_mala[1] - logz_slicer[1]) < 0.1 end \ No newline at end of file