using Random struct InterpreterBuffers{T} resultcache::Matrix{T} # for forward eval diffcache::Matrix{T} # for reverse AD jaccache::Matrix{T} # for Jacobian tmp::Vector{T} # a temporary space for each of the vector operations function InterpreterBuffers{T}(codelen, num_param, batchsize) where {T<:AbstractFloat} buf = Matrix{T}(undef, batchsize, codelen) rev_buf = Matrix{T}(undef, batchsize, codelen) jac_buf = Matrix{T}(undef, batchsize, num_param) tmp = Vector{T}(undef, batchsize) new(buf, rev_buf, jac_buf, tmp) end end mutable struct Interpreter{T} const code::Vector{Instruction{T}} const buffers::InterpreterBuffers{T} const batchsize::UInt32 pc::Int32 function Interpreter{T}(expr::Expr, num_param; batchsize = 1024) where {T<:AbstractFloat} code = convert_expr_to_code(T, expr) # println(code) buffers = InterpreterBuffers{T}(length(code), num_param, batchsize) new(code, buffers, batchsize, 1) end end peek_instruction(interpreter) = interpreter.code[interpreter.pc] # batch size 1024 was fast in benchmark interpret!(result::AbstractVector{T}, expr::Expr, x::AbstractMatrix{T}, p; batchsize=1024) where {T} = interpret!(result, Interpreter{T}(expr, length(p); batchsize), x, p) # for Float evaluation use the preallocated buffer function interpret!(result::AbstractVector{T}, interpreter::Interpreter{T}, x::AbstractMatrix{T}, p::AbstractArray{T}) where {T} interpret_withbuf!(result, interpreter, interpreter.buffers.resultcache, interpreter.buffers.tmp, x, p) end function interpret_withbuf!(result::AbstractVector{T}, interpreter::Interpreter{T}, batchresult, tmp, x::AbstractMatrix{T}, p::AbstractArray{TD}) where {T,TD} allrows = axes(x, 1) @assert length(result) == length(allrows) # all batches start = first(allrows) while start + interpreter.batchsize < last(allrows) batchrows = start:(start + interpreter.batchsize - 1) interpret_batch!(interpreter, batchresult, tmp, x, p, batchrows) copy!((@view result[batchrows]), (@view batchresult[:, end])) start += interpreter.batchsize end # process remaining rows remrows = start:last(allrows) if length(remrows) > 0 interpret_batch!(interpreter, batchresult, tmp, x, p, remrows) copy!((@view result[remrows]), (@view batchresult[1:length(remrows), end])) # res += sum(view(batchresult, 1:length(remrows), lastcolidx)) end # res result end function interpret_batch!(interpreter, batchresult, tmp, x, p, rows) # forward pass interpret_fwd!(interpreter, batchresult, tmp, x, p, rows) nothing end function interpret_fwd!(interpreter, batchresult, tmp, x, p, rows) interpreter.pc = 1 while interpreter.pc <= length(interpreter.code) step!(interpreter, batchresult, tmp, x, p, rows) end end function step!(interpreter, batchresult, tmp, x, p, range) instr = interpreter.code[interpreter.pc] opc = instr.opcode res = view(batchresult, :, interpreter.pc) if degree(opc) == 0 if opc == opc_variable copyto!(res, view(x, range, instr.idx)) elseif opc == opc_param fill!(res, p[instr.idx]) elseif opc == opc_constant fill!(res, instr.val) end elseif degree(opc) == 1 arg = view(batchresult, :, instr.arg1idx) # is converted to a switch automatically by LLVM if opc == opc_log vec_log!(res, arg, tmp) elseif opc == opc_log10 vec_log10!(res, arg, tmp) elseif opc == opc_exp vec_exp!(res, arg, tmp) elseif opc == opc_abs vec_abs!(res, arg, tmp) elseif opc == opc_neg vec_neg!(res, arg, tmp) elseif opc == opc_inv vec_inv!(res, arg, tmp) elseif opc == opc_sign vec_sign!(res, arg, tmp) elseif opc == opc_powconst vec_powconst!(res, arg, instr.val, tmp); elseif opc == opc_sin vec_sin!(res, arg, tmp) elseif opc == opc_cos vec_cos!(res, arg, tmp) elseif opc == opc_cosh vec_cosh!(res, arg, tmp) elseif opc == opc_asin vec_asin!(res, arg, tmp) elseif opc == opc_tan vec_tan!(res, arg, tmp) elseif opc == opc_tanh vec_tanh!(res, arg, tmp) else throw(DomainError("Unsupported opcode $opc")) end elseif degree(opc) == 2 left = view(batchresult, :, instr.arg1idx) right = view(batchresult, :, instr.arg2idx) if opc == opc_add vec_add!(res, left, right, tmp) elseif opc == opc_sub vec_sub!(res, left, right, tmp) elseif opc == opc_mul vec_mul!(res, left, right, tmp) elseif opc == opc_div vec_div!(res, left, right, tmp) elseif opc == opc_pow vec_pow!(res, left, right, tmp) elseif opc == opc_powabs vec_powabs!(res, left, right, tmp) else throw(DomainError("Unsupported opcode $opc")) end # if any(isnan, res) # throw(DomainError("got NaN for $opc $(interpreter.pc) $left $right")) # end end interpreter.pc += 1 return nothing end for unaryfunc in (:exp, :abs, :sin, :cos, :cosh, :asin, :tan, :tanh, :sinh) funcsy = Symbol("vec_$(unaryfunc)!") @eval function $funcsy(res::AbstractVector{T}, arg::AbstractVector{T}, ::AbstractVector{T}) where T<:Real @simd for i in eachindex(res) @inbounds res[i] = Base.$unaryfunc(arg[i]) end end end function vec_add!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] + right[i] end end function vec_sub!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] - right[i] end end function vec_mul!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] * right[i] end end function vec_div!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] / right[i] end end function vec_pow!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] ^ right[i] end end # TODO: special case scalar power function vec_powconst!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::TC, ::AbstractVector{TE}) where {TE<:Real,TC<:Real} @simd for i in eachindex(res) @inbounds res[i] = left[i] ^ right end end function vec_powabs!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = abs(left[i]) ^ right[i] end end function vec_neg!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = -arg[i] end end function vec_inv!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = inv(arg[i]) end end function vec_sign!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = sign(arg[i]) end end # handle log and exp specially to use NaN instead of DomainError function vec_log!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = arg[i] < zero(TE) ? TE(NaN) : log(arg[i]) end end function vec_log10!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = arg[i] < zero(TE) ? TE(NaN) : log10(arg[i]) end end