Files
master-thesis/package/src/CpuInterpreter.jl
2025-02-19 16:38:11 +01:00

173 lines
7.9 KiB
Julia

using Random
struct InterpreterBuffers{T}
resultcache::Matrix{T} # for forward eval
diffcache::Matrix{T} # for reverse AD
jaccache::Matrix{T} # for Jacobian
tmp::Vector{T} # a temporary space for each of the vector operations
function InterpreterBuffers{T}(codelen, num_param, batchsize) where {T<:AbstractFloat}
buf = Matrix{T}(undef, batchsize, codelen)
rev_buf = Matrix{T}(undef, batchsize, codelen)
jac_buf = Matrix{T}(undef, batchsize, num_param)
tmp = Vector{T}(undef, batchsize)
new(buf, rev_buf, jac_buf, tmp)
end
end
mutable struct Interpreter{T}
const code::Vector{Instruction{T}}
const buffers::InterpreterBuffers{T}
const batchsize::UInt32
pc::Int32
function Interpreter{T}(expr::Expr, num_param; batchsize = 1024) where {T<:AbstractFloat}
code = convert_expr_to_code(T, expr)
# println(code)
buffers = InterpreterBuffers{T}(length(code), num_param, batchsize)
new(code, buffers, batchsize, 1)
end
end
peek_instruction(interpreter) = interpreter.code[interpreter.pc]
# batch size 1024 was fast in benchmark
interpret!(result::AbstractVector{T}, expr::Expr, x::AbstractMatrix{T}, p; batchsize=1024) where {T} = interpret!(result, Interpreter{T}(expr, length(p); batchsize), x, p)
# for Float evaluation use the preallocated buffer
function interpret!(result::AbstractVector{T}, interpreter::Interpreter{T}, x::AbstractMatrix{T}, p::AbstractArray{T}) where {T}
interpret_withbuf!(result, interpreter, interpreter.buffers.resultcache, interpreter.buffers.tmp, x, p)
end
function interpret_withbuf!(result::AbstractVector{T}, interpreter::Interpreter{T}, batchresult, tmp, x::AbstractMatrix{T}, p::AbstractArray{TD}) where {T,TD}
allrows = axes(x, 1)
@assert length(result) == length(allrows)
# all batches
start = first(allrows)
while start + interpreter.batchsize < last(allrows)
batchrows = start:(start + interpreter.batchsize - 1)
interpret_batch!(interpreter, batchresult, tmp, x, p, batchrows)
copy!((@view result[batchrows]), (@view batchresult[:, end]))
start += interpreter.batchsize
end
# process remaining rows
remrows = start:last(allrows)
if length(remrows) > 0
interpret_batch!(interpreter, batchresult, tmp, x, p, remrows)
copy!((@view result[remrows]), (@view batchresult[1:length(remrows), end]))
# res += sum(view(batchresult, 1:length(remrows), lastcolidx))
end
# res
result
end
function interpret_batch!(interpreter,
batchresult, tmp,
x, p, rows)
# forward pass
interpret_fwd!(interpreter, batchresult, tmp, x, p, rows)
nothing
end
function interpret_fwd!(interpreter, batchresult, tmp, x, p, rows)
interpreter.pc = 1
while interpreter.pc <= length(interpreter.code)
step!(interpreter, batchresult, tmp, x, p, rows)
end
end
function step!(interpreter, batchresult, tmp, x, p, range)
instr = interpreter.code[interpreter.pc]
opc = instr.opcode
res = view(batchresult, :, interpreter.pc)
if degree(opc) == 0
if opc == opc_variable
copyto!(res, view(x, range, instr.idx))
elseif opc == opc_param
fill!(res, p[instr.idx])
elseif opc == opc_constant
fill!(res, instr.val)
end
elseif degree(opc) == 1
arg = view(batchresult, :, instr.arg1idx)
# is converted to a switch automatically by LLVM
if opc == opc_log vec_log!(res, arg, tmp)
elseif opc == opc_log10 vec_log10!(res, arg, tmp)
elseif opc == opc_exp vec_exp!(res, arg, tmp)
elseif opc == opc_abs vec_abs!(res, arg, tmp)
elseif opc == opc_neg vec_neg!(res, arg, tmp)
elseif opc == opc_inv vec_inv!(res, arg, tmp)
elseif opc == opc_sign vec_sign!(res, arg, tmp)
elseif opc == opc_powconst vec_powconst!(res, arg, instr.val, tmp);
elseif opc == opc_sin vec_sin!(res, arg, tmp)
elseif opc == opc_cos vec_cos!(res, arg, tmp)
elseif opc == opc_cosh vec_cosh!(res, arg, tmp)
elseif opc == opc_asin vec_asin!(res, arg, tmp)
elseif opc == opc_tan vec_tan!(res, arg, tmp)
elseif opc == opc_tanh vec_tanh!(res, arg, tmp)
else throw(DomainError("Unsupported opcode $opc"))
end
elseif degree(opc) == 2
left = view(batchresult, :, instr.arg1idx)
right = view(batchresult, :, instr.arg2idx)
if opc == opc_add vec_add!(res, left, right, tmp)
elseif opc == opc_sub vec_sub!(res, left, right, tmp)
elseif opc == opc_mul vec_mul!(res, left, right, tmp)
elseif opc == opc_div vec_div!(res, left, right, tmp)
elseif opc == opc_pow vec_pow!(res, left, right, tmp)
elseif opc == opc_powabs vec_powabs!(res, left, right, tmp)
else throw(DomainError("Unsupported opcode $opc"))
end
# if any(isnan, res)
# throw(DomainError("got NaN for $opc $(interpreter.pc) $left $right"))
# end
end
interpreter.pc += 1
return nothing
end
for unaryfunc in (:exp, :abs, :sin, :cos, :cosh, :asin, :tan, :tanh, :sinh)
funcsy = Symbol("vec_$(unaryfunc)!")
@eval function $funcsy(res::AbstractVector{T}, arg::AbstractVector{T}, ::AbstractVector{T}) where T<:Real
@simd for i in eachindex(res)
@inbounds res[i] = Base.$unaryfunc(arg[i])
end
end
end
function vec_add!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] + right[i] end end
function vec_sub!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] - right[i] end end
function vec_mul!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] * right[i] end end
function vec_div!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] / right[i] end end
function vec_pow!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] ^ right[i] end end
# TODO: special case scalar power
function vec_powconst!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::TC, ::AbstractVector{TE}) where {TE<:Real,TC<:Real} @simd for i in eachindex(res) @inbounds res[i] = left[i] ^ right end end
function vec_powabs!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = abs(left[i]) ^ right[i] end end
function vec_neg!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = -arg[i] end end
function vec_inv!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = inv(arg[i]) end end
function vec_sign!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = sign(arg[i]) end end
# handle log and exp specially to use NaN instead of DomainError
function vec_log!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = arg[i] < zero(TE) ? TE(NaN) : log(arg[i]) end end
function vec_log10!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = arg[i] < zero(TE) ? TE(NaN) : log10(arg[i]) end end