173 lines
7.9 KiB
Julia
173 lines
7.9 KiB
Julia
using Random
|
|
|
|
struct InterpreterBuffers{T}
|
|
resultcache::Matrix{T} # for forward eval
|
|
diffcache::Matrix{T} # for reverse AD
|
|
jaccache::Matrix{T} # for Jacobian
|
|
tmp::Vector{T} # a temporary space for each of the vector operations
|
|
|
|
function InterpreterBuffers{T}(codelen, num_param, batchsize) where {T<:AbstractFloat}
|
|
buf = Matrix{T}(undef, batchsize, codelen)
|
|
rev_buf = Matrix{T}(undef, batchsize, codelen)
|
|
jac_buf = Matrix{T}(undef, batchsize, num_param)
|
|
tmp = Vector{T}(undef, batchsize)
|
|
|
|
new(buf, rev_buf, jac_buf, tmp)
|
|
end
|
|
end
|
|
|
|
mutable struct Interpreter{T}
|
|
const code::Vector{Instruction{T}}
|
|
const buffers::InterpreterBuffers{T}
|
|
const batchsize::UInt32
|
|
pc::Int32
|
|
|
|
function Interpreter{T}(expr::Expr, num_param; batchsize = 1024) where {T<:AbstractFloat}
|
|
code = convert_expr_to_code(T, expr)
|
|
# println(code)
|
|
buffers = InterpreterBuffers{T}(length(code), num_param, batchsize)
|
|
new(code, buffers, batchsize, 1)
|
|
end
|
|
end
|
|
|
|
peek_instruction(interpreter) = interpreter.code[interpreter.pc]
|
|
|
|
|
|
|
|
# batch size 1024 was fast in benchmark
|
|
interpret!(result::AbstractVector{T}, expr::Expr, x::AbstractMatrix{T}, p; batchsize=1024) where {T} = interpret!(result, Interpreter{T}(expr, length(p); batchsize), x, p)
|
|
|
|
# for Float evaluation use the preallocated buffer
|
|
function interpret!(result::AbstractVector{T}, interpreter::Interpreter{T}, x::AbstractMatrix{T}, p::AbstractArray{T}) where {T}
|
|
interpret_withbuf!(result, interpreter, interpreter.buffers.resultcache, interpreter.buffers.tmp, x, p)
|
|
end
|
|
|
|
function interpret_withbuf!(result::AbstractVector{T}, interpreter::Interpreter{T}, batchresult, tmp, x::AbstractMatrix{T}, p::AbstractArray{TD}) where {T,TD}
|
|
allrows = axes(x, 1)
|
|
@assert length(result) == length(allrows)
|
|
|
|
|
|
# all batches
|
|
start = first(allrows)
|
|
while start + interpreter.batchsize < last(allrows)
|
|
batchrows = start:(start + interpreter.batchsize - 1)
|
|
interpret_batch!(interpreter, batchresult, tmp, x, p, batchrows)
|
|
copy!((@view result[batchrows]), (@view batchresult[:, end]))
|
|
start += interpreter.batchsize
|
|
end
|
|
|
|
|
|
# process remaining rows
|
|
remrows = start:last(allrows)
|
|
if length(remrows) > 0
|
|
interpret_batch!(interpreter, batchresult, tmp, x, p, remrows)
|
|
copy!((@view result[remrows]), (@view batchresult[1:length(remrows), end]))
|
|
# res += sum(view(batchresult, 1:length(remrows), lastcolidx))
|
|
end
|
|
# res
|
|
result
|
|
end
|
|
|
|
function interpret_batch!(interpreter,
|
|
batchresult, tmp,
|
|
x, p, rows)
|
|
# forward pass
|
|
interpret_fwd!(interpreter, batchresult, tmp, x, p, rows)
|
|
|
|
nothing
|
|
end
|
|
|
|
function interpret_fwd!(interpreter, batchresult, tmp, x, p, rows)
|
|
interpreter.pc = 1
|
|
while interpreter.pc <= length(interpreter.code)
|
|
step!(interpreter, batchresult, tmp, x, p, rows)
|
|
end
|
|
end
|
|
|
|
|
|
function step!(interpreter, batchresult, tmp, x, p, range)
|
|
instr = interpreter.code[interpreter.pc]
|
|
opc = instr.opcode
|
|
res = view(batchresult, :, interpreter.pc)
|
|
|
|
if degree(opc) == 0
|
|
if opc == opc_variable
|
|
copyto!(res, view(x, range, instr.idx))
|
|
elseif opc == opc_param
|
|
fill!(res, p[instr.idx])
|
|
elseif opc == opc_constant
|
|
fill!(res, instr.val)
|
|
end
|
|
elseif degree(opc) == 1
|
|
arg = view(batchresult, :, instr.arg1idx)
|
|
# is converted to a switch automatically by LLVM
|
|
if opc == opc_log vec_log!(res, arg, tmp)
|
|
elseif opc == opc_log10 vec_log10!(res, arg, tmp)
|
|
elseif opc == opc_exp vec_exp!(res, arg, tmp)
|
|
elseif opc == opc_abs vec_abs!(res, arg, tmp)
|
|
elseif opc == opc_neg vec_neg!(res, arg, tmp)
|
|
elseif opc == opc_inv vec_inv!(res, arg, tmp)
|
|
elseif opc == opc_sign vec_sign!(res, arg, tmp)
|
|
elseif opc == opc_powconst vec_powconst!(res, arg, instr.val, tmp);
|
|
elseif opc == opc_sin vec_sin!(res, arg, tmp)
|
|
elseif opc == opc_cos vec_cos!(res, arg, tmp)
|
|
elseif opc == opc_cosh vec_cosh!(res, arg, tmp)
|
|
elseif opc == opc_asin vec_asin!(res, arg, tmp)
|
|
elseif opc == opc_tan vec_tan!(res, arg, tmp)
|
|
elseif opc == opc_tanh vec_tanh!(res, arg, tmp)
|
|
|
|
else throw(DomainError("Unsupported opcode $opc"))
|
|
end
|
|
elseif degree(opc) == 2
|
|
left = view(batchresult, :, instr.arg1idx)
|
|
right = view(batchresult, :, instr.arg2idx)
|
|
|
|
if opc == opc_add vec_add!(res, left, right, tmp)
|
|
elseif opc == opc_sub vec_sub!(res, left, right, tmp)
|
|
elseif opc == opc_mul vec_mul!(res, left, right, tmp)
|
|
elseif opc == opc_div vec_div!(res, left, right, tmp)
|
|
elseif opc == opc_pow vec_pow!(res, left, right, tmp)
|
|
elseif opc == opc_powabs vec_powabs!(res, left, right, tmp)
|
|
else throw(DomainError("Unsupported opcode $opc"))
|
|
end
|
|
# if any(isnan, res)
|
|
# throw(DomainError("got NaN for $opc $(interpreter.pc) $left $right"))
|
|
# end
|
|
end
|
|
|
|
interpreter.pc += 1
|
|
|
|
return nothing
|
|
end
|
|
|
|
|
|
for unaryfunc in (:exp, :abs, :sin, :cos, :cosh, :asin, :tan, :tanh, :sinh)
|
|
funcsy = Symbol("vec_$(unaryfunc)!")
|
|
@eval function $funcsy(res::AbstractVector{T}, arg::AbstractVector{T}, ::AbstractVector{T}) where T<:Real
|
|
@simd for i in eachindex(res)
|
|
@inbounds res[i] = Base.$unaryfunc(arg[i])
|
|
end
|
|
end
|
|
end
|
|
|
|
|
|
function vec_add!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] + right[i] end end
|
|
function vec_sub!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] - right[i] end end
|
|
function vec_mul!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] * right[i] end end
|
|
function vec_div!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] / right[i] end end
|
|
function vec_pow!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] ^ right[i] end end
|
|
|
|
# TODO: special case scalar power
|
|
function vec_powconst!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::TC, ::AbstractVector{TE}) where {TE<:Real,TC<:Real} @simd for i in eachindex(res) @inbounds res[i] = left[i] ^ right end end
|
|
function vec_powabs!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = abs(left[i]) ^ right[i] end end
|
|
|
|
function vec_neg!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = -arg[i] end end
|
|
function vec_inv!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = inv(arg[i]) end end
|
|
function vec_sign!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = sign(arg[i]) end end
|
|
|
|
# handle log and exp specially to use NaN instead of DomainError
|
|
function vec_log!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = arg[i] < zero(TE) ? TE(NaN) : log(arg[i]) end end
|
|
function vec_log10!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = arg[i] < zero(TE) ? TE(NaN) : log10(arg[i]) end end
|
|
|
|
|