diff --git a/package/Project.toml b/package/Project.toml index cf1ead6..7944c28 100644 --- a/package/Project.toml +++ b/package/Project.toml @@ -5,9 +5,13 @@ version = "1.0.0-DEV" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] +Printf = "1.11.0" +Random = "1.11.0" julia = "1.6.7" [extras] diff --git a/package/src/Code.jl b/package/src/Code.jl new file mode 100644 index 0000000..faefb35 --- /dev/null +++ b/package/src/Code.jl @@ -0,0 +1,207 @@ +using Printf + +@enum Opcode::UInt8 begin + opc_stop = 1 # must start with 1 here TODO: remove stop + opc_add + opc_sub + opc_mul + opc_div + opc_inv + opc_log + opc_log10 + opc_exp + opc_pow + opc_powconst + opc_powabs + opc_neg + opc_abs + opc_sign + opc_sin + opc_asin + opc_tan + opc_tanh + opc_cos + opc_cosh + opc_constant + opc_param + opc_variable +end + +const terminal_opcodes = [opc_stop, opc_constant, opc_param, opc_variable] +const unary_opcodes = [opc_log, opc_log10, opc_exp, opc_abs, opc_sign, opc_sin, opc_cos, opc_cosh, opc_asin, opc_tan, opc_tanh, opc_powconst, opc_neg, opc_inv] +const binary_opcodes = [opc_add, opc_sub, opc_mul, opc_div, opc_pow, opc_powabs] + +function opcode(sy::Symbol)::Opcode + if sy == :+ return opc_add + elseif sy == :- return opc_sub + elseif sy == :* return opc_mul + elseif sy == :/ return opc_div + elseif sy == :inv return opc_inv + elseif sy == :log return opc_log + elseif sy == :log10 return opc_log10 + elseif sy == :exp return opc_exp + elseif sy == :^ return opc_powabs # TODO: this is temporary to enforce that all powers are evaluated as pow(abs(...)) for parameter optimization + elseif sy == :powabs return opc_powabs # TODO: this is temporary to enforce that all powers are evaluated as pow(abs(...)) for parameter optimization + elseif sy == :abs return opc_abs + elseif sy == :sign return opc_sign + elseif sy == :sin return opc_sin + elseif sy == :asin return opc_asin + elseif sy == :cos return opc_cos + elseif sy == :cosh return opc_cosh + elseif sy == :tan return opc_tan + elseif sy == :tanh return opc_tanh + else error("no opcode for symbol $sy") + end +end + +function degree(opc::Opcode)::Integer + if opc in terminal_opcodes return 0 + elseif opc in unary_opcodes return 1 + elseif opc in binary_opcodes return 2 + else error("unknown degree of opcode $opc") + end +end + + +# code is a Vector{Instruction} which is a linear representation of a directed acyclic graph of expressions. +# The code can be evaluated from left to right. +struct Instruction{T} + opcode::Opcode + arg1idx::UInt32 # index of first argument. 0 for terminals + arg2idx::UInt32 # index of second argument. 0 for functions with a single argument + idx::UInt32 # for variables and parameters + val::T # for constants +end + + +function Base.show(io::IO, instr::Instruction) + Printf.format(io, Printf.format"%15s %3d %3d %3d %f", instr.opcode, instr.arg1idx, instr.arg2idx, instr.idx, instr.val) +end + +create_const_instruction(val::T) where {T} = Instruction{T}(opc_constant, UInt32(0), UInt32(0), UInt32(0), val) +create_var_instruction(::Type{T}, varidx) where {T} = Instruction{T}(opc_variable, UInt32(0), UInt32(0), UInt32(varidx), zero(T)) +create_param_instruction(::Type{T}, paramidx; val::T = zero(T)) where {T} = Instruction{T}(opc_param, UInt32(0), UInt32(0), UInt32(paramidx), val) + + +function convert_expr_to_code(::Type{T}, expr::Expr)::Vector{Instruction{T}} where {T} + code = Vector{Instruction{T}}() + + Base.remove_linenums!(expr) + paramTup = expr.args[1] + xSy = paramTup.args[1] + pSy = paramTup.args[2] + body = expr.args[2] + + cache = Dict{Any,Int32}() # for de-duplication of expressions. If an expression is in the cache simply return the index of the existing code + + convert_expr_to_code!(code, cache, body, xSy, pSy) + + # for debugging + # for tup in sort(cache; byvalue=true) + # println(tup) + # end + return code +end + +# uses cache (hashcons) to de-duplicate subexpressions in the tree. +function convert_expr_to_code!(code::Vector{Instruction{T}}, cache, val::TV, xSy, pSy)::UInt32 where {T,TV} + if haskey(cache, val) return cache[val] end + + push!(code, create_const_instruction(T(val))) + cache[val] = length(code) + return length(code) +end + +function convert_expr_to_code!(code::Vector{Instruction{T}}, cache, expr::Expr, xSy, pSy)::UInt32 where {T} + # predicate to check if an expression is abs(...) + is_abs(a) = a isa Expr && a.head == :call && a.args[1] == :abs + + if haskey(cache, expr) return cache[expr] end + + sy = expr.head + if sy == :call + func = expr.args[1] + arg1idx::UInt32 = 0 + arg2idx::UInt32 = 0 + # unary functions + if length(expr.args) == 2 + arg1idx = convert_expr_to_code!(code, cache, expr.args[2], xSy, pSy) + if (func == :-) + # - with one argument => negate + push!(code, Instruction{T}(opc_neg, arg1idx, UInt32(0), UInt32(0), zero(T))) + elseif (func == :sqrt) + push!(code, Instruction{T}(opc_powconst, arg1idx, UInt32(0), UInt32(0), T(0.5))) + else + push!(code, Instruction{T}(opcode(func), arg1idx, UInt32(0), UInt32(0), zero(T))) + end + elseif length(expr.args) == 3 + arg1idx = convert_expr_to_code!(code, cache, expr.args[2], xSy, pSy) + if func == :^ && expr.args[3] isa Number && round(expr.args[3]) == expr.args[3] # is integer + # special case for constant powers + push!(code, Instruction{T}(opc_powconst, arg1idx, UInt32(0), UInt32(0), T(expr.args[3]))) + elseif func == :^ && is_abs(expr.args[2]) + # fuse pow(abs(x), y) --> powabs(x,y) + absexpr = expr.args[2] + x = absexpr.args[2] + arg1idx = convert_expr_to_code!(code, cache, x, xSy, pSy) # because of hashconsing this will return the index within the code for abs(x) generated above + arg2idx = convert_expr_to_code!(code, cache, expr.args[3], xSy, pSy) + push!(code, Instruction{T}(opc_powabs, arg1idx, arg2idx, UInt32(0), zero(T))) + else + arg2idx = convert_expr_to_code!(code, cache, expr.args[3], xSy, pSy) + push!(code, Instruction{T}(opcode(func), arg1idx, arg2idx, UInt32(0), zero(T))) + end + else + # dump(expr) + errpr("only unary and binary functions are supported ($func is not supported)") + end + elseif sy == :ref + arrSy = expr.args[1] + idx = expr.args[2] + if arrSy == xSy + push!(code, create_var_instruction(T, idx)) + elseif arrSy == pSy + push!(code, create_param_instruction(T, idx)) + else + dump(expr) + throw(UndefVarError("unknown symbol")) + end + else + error("Unsupported symbol $sy") + end + + cache[expr] = length(code) + return length(code) +end + + +function Base.show(io::IO, code::AbstractArray{Instruction{T}}) where {T} + sym = Dict( + opc_stop => ".", + opc_add => "+", + opc_sub => "-", + opc_neg => "neg", + opc_mul => "*", + opc_div => "/", + opc_inv => "inv", + opc_pow => "^", + opc_powabs => "abs^", + opc_powconst => "^c", + opc_log => "log", + opc_log10 => "l10", + opc_exp => "exp", + opc_abs => "abs", + opc_sign => "sgn", + opc_sin => "sin", + opc_cos => "cos", + opc_variable => "var", + opc_constant => "con", + opc_param => "par", + ) + + for i in eachindex(code) + instr = code[i] + Printf.format(io, Printf.format"%4d %4s %3d %3d %3d %f", i, sym[instr.opcode], instr.arg1idx, instr.arg2idx, instr.idx, instr.val) + println(io) + # printfmtln(io, "{1:>4d} {2:>4s} {3:>3d} {4:>3d} {5:>3d} {6:>}", i, sym[instr.opcode], instr.arg1idx, instr.arg2idx, instr.idx, instr.val) + end +end \ No newline at end of file diff --git a/package/src/CpuInterpreter.jl b/package/src/CpuInterpreter.jl new file mode 100644 index 0000000..dc8c287 --- /dev/null +++ b/package/src/CpuInterpreter.jl @@ -0,0 +1,172 @@ +using Random + +struct InterpreterBuffers{T} + resultcache::Matrix{T} # for forward eval + diffcache::Matrix{T} # for reverse AD + jaccache::Matrix{T} # for Jacobian + tmp::Vector{T} # a temporary space for each of the vector operations + + function InterpreterBuffers{T}(codelen, num_param, batchsize) where {T<:AbstractFloat} + buf = Matrix{T}(undef, batchsize, codelen) + rev_buf = Matrix{T}(undef, batchsize, codelen) + jac_buf = Matrix{T}(undef, batchsize, num_param) + tmp = Vector{T}(undef, batchsize) + + new(buf, rev_buf, jac_buf, tmp) + end +end + +mutable struct Interpreter{T} + const code::Vector{Instruction{T}} + const buffers::InterpreterBuffers{T} + const batchsize::UInt32 + pc::Int32 + + function Interpreter{T}(expr::Expr, num_param; batchsize = 1024) where {T<:AbstractFloat} + code = convert_expr_to_code(T, expr) + # println(code) + buffers = InterpreterBuffers{T}(length(code), num_param, batchsize) + new(code, buffers, batchsize, 1) + end +end + +peek_instruction(interpreter) = interpreter.code[interpreter.pc] + + + +# batch size 1024 was fast in benchmark +interpret!(result::AbstractVector{T}, expr::Expr, x::AbstractMatrix{T}, p; batchsize=1024) where {T} = interpret!(result, Interpreter{T}(expr, length(p); batchsize), x, p) + +# for Float evaluation use the preallocated buffer +function interpret!(result::AbstractVector{T}, interpreter::Interpreter{T}, x::AbstractMatrix{T}, p::AbstractArray{T}) where {T} + interpret_withbuf!(result, interpreter, interpreter.buffers.resultcache, interpreter.buffers.tmp, x, p) +end + +function interpret_withbuf!(result::AbstractVector{T}, interpreter::Interpreter{T}, batchresult, tmp, x::AbstractMatrix{T}, p::AbstractArray{TD}) where {T,TD} + allrows = axes(x, 1) + @assert length(result) == length(allrows) + + + # all batches + start = first(allrows) + while start + interpreter.batchsize < last(allrows) + batchrows = start:(start + interpreter.batchsize - 1) + interpret_batch!(interpreter, batchresult, tmp, x, p, batchrows) + copy!((@view result[batchrows]), (@view batchresult[:, end])) + start += interpreter.batchsize + end + + + # process remaining rows + remrows = start:last(allrows) + if length(remrows) > 0 + interpret_batch!(interpreter, batchresult, tmp, x, p, remrows) + copy!((@view result[remrows]), (@view batchresult[1:length(remrows), end])) + # res += sum(view(batchresult, 1:length(remrows), lastcolidx)) + end + # res + result +end + +function interpret_batch!(interpreter, + batchresult, tmp, + x, p, rows) + # forward pass + interpret_fwd!(interpreter, batchresult, tmp, x, p, rows) + + nothing +end + +function interpret_fwd!(interpreter, batchresult, tmp, x, p, rows) + interpreter.pc = 1 + while interpreter.pc <= length(interpreter.code) + step!(interpreter, batchresult, tmp, x, p, rows) + end +end + + +function step!(interpreter, batchresult, tmp, x, p, range) + instr = interpreter.code[interpreter.pc] + opc = instr.opcode + res = view(batchresult, :, interpreter.pc) + + if degree(opc) == 0 + if opc == opc_variable + copyto!(res, view(x, range, instr.idx)) + elseif opc == opc_param + fill!(res, p[instr.idx]) + elseif opc == opc_constant + fill!(res, instr.val) + end + elseif degree(opc) == 1 + arg = view(batchresult, :, instr.arg1idx) + # is converted to a switch automatically by LLVM + if opc == opc_log vec_log!(res, arg, tmp) + elseif opc == opc_log10 vec_log10!(res, arg, tmp) + elseif opc == opc_exp vec_exp!(res, arg, tmp) + elseif opc == opc_abs vec_abs!(res, arg, tmp) + elseif opc == opc_neg vec_neg!(res, arg, tmp) + elseif opc == opc_inv vec_inv!(res, arg, tmp) + elseif opc == opc_sign vec_sign!(res, arg, tmp) + elseif opc == opc_powconst vec_powconst!(res, arg, instr.val, tmp); + elseif opc == opc_sin vec_sin!(res, arg, tmp) + elseif opc == opc_cos vec_cos!(res, arg, tmp) + elseif opc == opc_cosh vec_cosh!(res, arg, tmp) + elseif opc == opc_asin vec_asin!(res, arg, tmp) + elseif opc == opc_tan vec_tan!(res, arg, tmp) + elseif opc == opc_tanh vec_tanh!(res, arg, tmp) + + else throw(DomainError("Unsupported opcode $opc")) + end + elseif degree(opc) == 2 + left = view(batchresult, :, instr.arg1idx) + right = view(batchresult, :, instr.arg2idx) + + if opc == opc_add vec_add!(res, left, right, tmp) + elseif opc == opc_sub vec_sub!(res, left, right, tmp) + elseif opc == opc_mul vec_mul!(res, left, right, tmp) + elseif opc == opc_div vec_div!(res, left, right, tmp) + elseif opc == opc_pow vec_pow!(res, left, right, tmp) + elseif opc == opc_powabs vec_powabs!(res, left, right, tmp) + else throw(DomainError("Unsupported opcode $opc")) + end + # if any(isnan, res) + # throw(DomainError("got NaN for $opc $(interpreter.pc) $left $right")) + # end + end + + interpreter.pc += 1 + + return nothing +end + + +for unaryfunc in (:exp, :abs, :sin, :cos, :cosh, :asin, :tan, :tanh, :sinh) + funcsy = Symbol("vec_$(unaryfunc)!") + @eval function $funcsy(res::AbstractVector{T}, arg::AbstractVector{T}, ::AbstractVector{T}) where T<:Real + @simd for i in eachindex(res) + @inbounds res[i] = Base.$unaryfunc(arg[i]) + end + end +end + + +function vec_add!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] + right[i] end end +function vec_sub!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] - right[i] end end +function vec_mul!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] * right[i] end end +function vec_div!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] / right[i] end end +function vec_pow!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] ^ right[i] end end + +# TODO: special case scalar power +function vec_powconst!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::TC, ::AbstractVector{TE}) where {TE<:Real,TC<:Real} @simd for i in eachindex(res) @inbounds res[i] = left[i] ^ right end end +function vec_powabs!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = abs(left[i]) ^ right[i] end end + +function vec_neg!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = -arg[i] end end +function vec_inv!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = inv(arg[i]) end end +function vec_sign!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = sign(arg[i]) end end + +# handle log and exp specially to use NaN instead of DomainError +function vec_log!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = arg[i] < zero(TE) ? TE(NaN) : log(arg[i]) end end +function vec_log10!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = arg[i] < zero(TE) ? TE(NaN) : log10(arg[i]) end end + + diff --git a/package/src/ExpressionExecutorCuda.jl b/package/src/ExpressionExecutorCuda.jl index 934509d..2192dcf 100644 --- a/package/src/ExpressionExecutorCuda.jl +++ b/package/src/ExpressionExecutorCuda.jl @@ -2,7 +2,12 @@ module ExpressionExecutorCuda include("ExpressionProcessing.jl") include("Interpreter.jl") -export interpret_gpu +module CpuInterpreter +include("Code.jl") +include("CpuInterpreter.jl") +end + +export interpret_gpu,interpret_cpu export evaluate_gpu export test @@ -22,8 +27,28 @@ function evaluate_gpu(exprs::Vector{Expr}, X::Matrix{Float32}, p::Vector{Vector{ # Look into this to maybe speed up PTX generation: https://cuda.juliagpu.org/stable/tutorials/introduction/#Parallelization-on-the-CPU end -end +# Evaluate Expressions on the CPU +function interpret_cpu(exprs::Vector{Expr}, X::Matrix{Float32}, p::Vector{Vector{Float32}}; repetitions=1)::Matrix{Float32} + @assert axes(exprs) == axes(p) + nrows = size(X, 1) + + # each column of the matrix has the result for an expr + res = Matrix{Float32}(undef, nrows, length(exprs)) + + for i in eachindex(exprs) + # The interpreter holds the postfix code and buffers for evaluation. It is costly to create + interpreter = CpuInterpreter.Interpreter{Float32}(exprs[i], length(p[i])) + + # If an expression has to be evaluated multiple times (e.g. for different parameters), + # it is worthwhile to reuse the interpreter to reduce the number of allocations + for rep in 1:repetitions + CpuInterpreter.interpret!((@view res[:,i]), interpreter, X, p[i]) + end + end + + res +end # Flow @@ -35,3 +60,5 @@ end # The following can be done on the CPU # convert expression to postfix notation (mandatory) # optional: replace every parameter with the correct value (should only improve performance if data transfer is the bottleneck) + +end diff --git a/package/test/CpuInterpreterTests.jl b/package/test/CpuInterpreterTests.jl new file mode 100644 index 0000000..356a4a6 --- /dev/null +++ b/package/test/CpuInterpreterTests.jl @@ -0,0 +1,35 @@ +using LinearAlgebra + +function test_cpu_interpreter(nrows; parallel = false) + exprs = [ + # CPU interpreter requires an anonymous function and array ref s + :(p[1] * x[1] + p[2]), # 5 op + :((((x[1] + x[2]) + x[3]) + x[4]) + x[5]), # 9 op + :(log(abs(x[1]))), # 3 op + :(powabs(p[2] - powabs(p[1] + x[1], 1/x[1]),p[3])) # 13 op + ] # 30 op + exprs = map(e -> Expr(:->, :(x,p), e), exprs) + X = randn(Float32, nrows, 10) + p = [randn(Float32, 10) for _ in 1:length(exprs)] # generate 10 random parameter values for each expr + + # warmup + interpret_cpu(exprs, X, p) + expr_reps = 100 # for each expr + reps= 100 + + if parallel + t_sec = @elapsed fetch.([Threads.@spawn interpret_cpu(exprs, X, p; repetitions=expr_reps) for i in 1:reps]) + println("~ $(round(30 * reps * expr_reps * nrows / 1e9 / t_sec, digits=2)) GFLOPS ($(Threads.nthreads()) threads) ($(round(peakflops(1000, eltype=Float32, ntrials=1) / 1e9, digits=2)) GFLOPS (peak, single-core))") + else + t_sec = @elapsed for i in 1:reps interpret_cpu(exprs, X, p; repetitions=expr_reps) end + println("~ $(round(30 * reps * expr_reps * nrows / 1e9 / t_sec, digits=2)) GFLOPS (single-core) ($(round(peakflops(1000, eltype=Float32, ntrials=1) / 1e9, digits=2)) GFLOPS (peak, single-core))") + end + true +end + +LinearAlgebra.BLAS.set_num_threads(1) # only use a single thread for peakflops + +@test test_cpu_interpreter(1000) +@test test_cpu_interpreter(1000, parallel=true) # start julia -t 6 for six threads +@test test_cpu_interpreter(10000) +@test test_cpu_interpreter(10000, parallel=true) diff --git a/package/test/runtests.jl b/package/test/runtests.jl index fd6da3f..ee68520 100644 --- a/package/test/runtests.jl +++ b/package/test/runtests.jl @@ -11,3 +11,8 @@ include(joinpath(baseFolder, "src", "Transpiler.jl")) include("InterpreterTests.jl") include("TranspilerTests.jl") end + + +@testset "CPU Interpreter" begin + include("CpuInterpreterTests.jl") +end \ No newline at end of file