|
| 1 | +using IPUToolkit.IPUCompiler, IPUToolkit.Poplar |
| 2 | +using Enzyme |
| 3 | + |
| 4 | +IPUCompiler.KEEP_LLVM_FILES[] = true |
| 5 | +ENV["POPLAR_RUNTIME_OPTIONS"] = """{"target.hostSyncTimeout":"30"}""" |
| 6 | + |
| 7 | +device = Poplar.get_ipu_device() |
| 8 | +target = Poplar.DeviceGetTarget(device) |
| 9 | +graph = Poplar.Graph(target) |
| 10 | + |
| 11 | +num_tiles = Int(Poplar.TargetGetNumTiles(target)) |
| 12 | + |
| 13 | +∂!(f, x, f′) = autodiff_deferred(Reverse, f, Duplicated(x, f′)) |
| 14 | + |
| 15 | +neg_log_density(q::AbstractVector{T}) where {T} = (q[1]^2 - q[2])^2 + (q[1]- one(T))^2 / T(100) |
| 16 | + |
| 17 | +# Note: both input and output must have exactly same type (including *all* parameters). |
| 18 | +function grad_neg_log_density!(f′::V, x::V) where {T,V<:AbstractVector{T}} |
| 19 | + # The derivative is added to duplicated arguments, so we need to zero f′ |
| 20 | + # before going on. |
| 21 | + f′ .= zero(T) |
| 22 | + ∂!(neg_log_density, x, f′) |
| 23 | + return f′ |
| 24 | +end |
| 25 | + |
| 26 | +function leapfrog!(q::AbstractVector{T}, p::AbstractVector{T}, f′::AbstractVector{T}, dt::T) where {T} |
| 27 | + grad_neg_log_density!(f′, q) |
| 28 | + p .-= (dt ./ 2) .* f′ |
| 29 | + q .+= dt .* p |
| 30 | + grad_neg_log_density!(f′, q) |
| 31 | + p .-= (dt / 2) .* f′ |
| 32 | +end |
| 33 | + |
| 34 | +function sample_transition!(q_proposed::AbstractVector{T}, p::AbstractVector{T}, f′::AbstractVector{T}, q::AbstractVector{T}, dt::T, n_step) where {T} |
| 35 | + randn2!(p) |
| 36 | + h_init = sum(abs2, p) / 2 + neg_log_density(q) |
| 37 | + q_proposed .= q |
| 38 | + for step in UInt32(1):n_step |
| 39 | + leapfrog!(q_proposed, p, f′, dt) |
| 40 | + end |
| 41 | + h_diff = h_init - (sum(abs2, p) / 2 + neg_log_density(q_proposed)) |
| 42 | + accept_prob = isnan(h_diff) ? zero(T) : exp(min(0, h_diff)) |
| 43 | + if rand(T) >= accept_prob |
| 44 | + q_proposed .= q |
| 45 | + end |
| 46 | + return accept_prob |
| 47 | +end |
| 48 | + |
| 49 | +function sample_chain!(q_chain::AbstractVector{T}, buffer_q::AbstractVector{T}, p::AbstractVector{T}, f′::AbstractVector{T}, orig_q::AbstractVector{T}, n_sample, n_step, dt::T) where {T} |
| 50 | + sum_accept_prob = zero(T) |
| 51 | + buffer_q .= orig_q |
| 52 | + for sample in UInt32(1):n_sample |
| 53 | + accept_prob = sample_transition!(buffer_q, p, f′, orig_q, dt, n_step) |
| 54 | + for idx in eachindex(buffer_q) |
| 55 | + @inbounds q_chain[length(buffer_q) * (sample - 1) + idx] = buffer_q[idx] |
| 56 | + end |
| 57 | + sum_accept_prob += accept_prob |
| 58 | + end |
| 59 | + return sum_accept_prob / n_sample |
| 60 | +end |
| 61 | + |
| 62 | +n_sample = UInt32(10) |
| 63 | +n_step = UInt32(10) |
| 64 | +dt = Float32(0.1) |
| 65 | + |
| 66 | +@eval @codelet graph function HamiltonianMonteCarlo( |
| 67 | + q_chain::VertexVector{Float32, InOut}, |
| 68 | + buffer_q::VertexVector{Float32, InOut}, |
| 69 | + p::VertexVector{Float32, InOut}, |
| 70 | + gradient::VertexVector{Float32, InOut}, |
| 71 | + orig_q::VertexVector{Float32, InOut}, |
| 72 | + ) |
| 73 | + sample_chain!(q_chain, buffer_q, p, gradient, orig_q, $(n_sample), $(n_step), $(dt)) |
| 74 | +end |
| 75 | + |
| 76 | +orig_q = randn(Float32, 2 * num_tiles) |
| 77 | + |
| 78 | +orig_q_ipu = Poplar.GraphAddVariable(graph, Poplar.FLOAT(), UInt64[length(orig_q)], "orig_q") |
| 79 | +copyto!(graph, orig_q_ipu, orig_q) |
| 80 | +buffer_q_ipu = similar(graph, orig_q, "buffer_q") |
| 81 | +p_ipu = similar(graph, orig_q, "p") |
| 82 | +gradient_ipu = similar(graph, orig_q, "gradient") |
| 83 | +q_chain_ipu = Poplar.GraphAddVariable(graph, Poplar.FLOAT(), UInt64[length(orig_q) * n_sample], "q_chain") |
| 84 | +q_chain = Matrix{Float32}(undef, length(orig_q), n_sample) |
| 85 | + |
| 86 | +prog = Poplar.ProgramSequence() |
| 87 | + |
| 88 | +add_vertex(graph, prog, 0:(num_tiles - 1), HamiltonianMonteCarlo, |
| 89 | + q_chain_ipu, buffer_q_ipu, p_ipu, gradient_ipu, orig_q_ipu) |
| 90 | + |
| 91 | +Poplar.GraphCreateHostRead(graph, "q-chain-read", q_chain_ipu) |
| 92 | + |
| 93 | +flags = Poplar.OptionFlags() |
| 94 | +Poplar.OptionFlagsSet(flags, "debug.instrument", "false") |
| 95 | + |
| 96 | +engine = Poplar.Engine(graph, prog, flags) |
| 97 | +Poplar.EngineLoadAndRun(engine, device) |
| 98 | +Poplar.EngineReadTensor(engine, "q-chain-read", q_chain) |
| 99 | + |
| 100 | +Poplar.detach_devices() |
| 101 | + |
| 102 | +#= |
| 103 | +
|
| 104 | +using Plots |
| 105 | +
|
| 106 | +sample = 10 |
| 107 | +scatter(q_chain[1:2:end, sample], q_chain[2:2:end, sample]; xlims=(-3, 3), ylims=(-3, 6)) |
| 108 | +
|
| 109 | +=# |
0 commit comments