Add CPU Interpreter and a test case.
This commit is contained in:
		@ -5,9 +5,13 @@ version = "1.0.0-DEV"
 | 
			
		||||
 | 
			
		||||
[deps]
 | 
			
		||||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
 | 
			
		||||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
 | 
			
		||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 | 
			
		||||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
 | 
			
		||||
 | 
			
		||||
[compat]
 | 
			
		||||
Printf = "1.11.0"
 | 
			
		||||
Random = "1.11.0"
 | 
			
		||||
julia = "1.6.7"
 | 
			
		||||
 | 
			
		||||
[extras]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										207
									
								
								package/src/Code.jl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										207
									
								
								package/src/Code.jl
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,207 @@
 | 
			
		||||
using Printf
 | 
			
		||||
 | 
			
		||||
@enum Opcode::UInt8 begin
 | 
			
		||||
  opc_stop = 1 # must start with 1 here TODO: remove stop
 | 
			
		||||
  opc_add 
 | 
			
		||||
  opc_sub
 | 
			
		||||
  opc_mul 
 | 
			
		||||
  opc_div
 | 
			
		||||
  opc_inv
 | 
			
		||||
  opc_log 
 | 
			
		||||
  opc_log10 
 | 
			
		||||
  opc_exp 
 | 
			
		||||
  opc_pow
 | 
			
		||||
  opc_powconst
 | 
			
		||||
  opc_powabs
 | 
			
		||||
  opc_neg
 | 
			
		||||
  opc_abs 
 | 
			
		||||
  opc_sign
 | 
			
		||||
  opc_sin
 | 
			
		||||
  opc_asin
 | 
			
		||||
  opc_tan
 | 
			
		||||
  opc_tanh
 | 
			
		||||
  opc_cos
 | 
			
		||||
  opc_cosh
 | 
			
		||||
  opc_constant 
 | 
			
		||||
  opc_param 
 | 
			
		||||
  opc_variable
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
const terminal_opcodes = [opc_stop, opc_constant, opc_param, opc_variable]
 | 
			
		||||
const unary_opcodes = [opc_log, opc_log10, opc_exp, opc_abs, opc_sign, opc_sin, opc_cos, opc_cosh, opc_asin, opc_tan, opc_tanh, opc_powconst, opc_neg, opc_inv]
 | 
			
		||||
const binary_opcodes = [opc_add, opc_sub, opc_mul, opc_div, opc_pow, opc_powabs]
 | 
			
		||||
 | 
			
		||||
function opcode(sy::Symbol)::Opcode
 | 
			
		||||
    if sy == :+ return opc_add 
 | 
			
		||||
    elseif sy == :- return opc_sub
 | 
			
		||||
    elseif sy == :* return opc_mul
 | 
			
		||||
    elseif sy == :/ return opc_div
 | 
			
		||||
    elseif sy == :inv return opc_inv
 | 
			
		||||
    elseif sy == :log return opc_log
 | 
			
		||||
    elseif sy == :log10 return opc_log10
 | 
			
		||||
    elseif sy == :exp return opc_exp
 | 
			
		||||
    elseif sy == :^ return opc_powabs # TODO: this is temporary to enforce that all powers are evaluated as pow(abs(...)) for parameter optimization
 | 
			
		||||
    elseif sy == :powabs return opc_powabs # TODO: this is temporary to enforce that all powers are evaluated as pow(abs(...)) for parameter optimization
 | 
			
		||||
    elseif sy == :abs return opc_abs
 | 
			
		||||
    elseif sy == :sign return opc_sign
 | 
			
		||||
    elseif sy == :sin return opc_sin
 | 
			
		||||
    elseif sy == :asin return opc_asin
 | 
			
		||||
    elseif sy == :cos return opc_cos
 | 
			
		||||
    elseif sy == :cosh return opc_cosh
 | 
			
		||||
    elseif sy == :tan return opc_tan
 | 
			
		||||
    elseif sy == :tanh return opc_tanh
 | 
			
		||||
    else error("no opcode for symbol $sy")
 | 
			
		||||
    end
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
function degree(opc::Opcode)::Integer 
 | 
			
		||||
    if opc in terminal_opcodes return 0
 | 
			
		||||
    elseif opc in unary_opcodes return 1
 | 
			
		||||
    elseif opc in binary_opcodes return 2
 | 
			
		||||
    else error("unknown degree of opcode $opc")
 | 
			
		||||
    end
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# code is a Vector{Instruction} which is a linear representation of a directed acyclic graph of expressions.
 | 
			
		||||
# The code can be evaluated from left to right.
 | 
			
		||||
struct Instruction{T}
 | 
			
		||||
    opcode::Opcode
 | 
			
		||||
    arg1idx::UInt32 # index of first argument. 0 for terminals
 | 
			
		||||
    arg2idx::UInt32 # index of second argument. 0 for functions with a single argument
 | 
			
		||||
    idx::UInt32  # for variables and parameters
 | 
			
		||||
    val::T # for constants
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
function Base.show(io::IO, instr::Instruction)
 | 
			
		||||
    Printf.format(io, Printf.format"%15s %3d %3d %3d %f", instr.opcode, instr.arg1idx, instr.arg2idx, instr.idx, instr.val)
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
create_const_instruction(val::T) where {T}                                = Instruction{T}(opc_constant, UInt32(0), UInt32(0), UInt32(0), val)
 | 
			
		||||
create_var_instruction(::Type{T}, varidx) where {T}                       = Instruction{T}(opc_variable, UInt32(0), UInt32(0), UInt32(varidx), zero(T))
 | 
			
		||||
create_param_instruction(::Type{T}, paramidx; val::T = zero(T)) where {T} = Instruction{T}(opc_param, UInt32(0), UInt32(0), UInt32(paramidx), val)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
function convert_expr_to_code(::Type{T}, expr::Expr)::Vector{Instruction{T}} where {T}
 | 
			
		||||
    code = Vector{Instruction{T}}()
 | 
			
		||||
 | 
			
		||||
    Base.remove_linenums!(expr)
 | 
			
		||||
    paramTup = expr.args[1]
 | 
			
		||||
    xSy = paramTup.args[1]
 | 
			
		||||
    pSy = paramTup.args[2]
 | 
			
		||||
    body = expr.args[2]
 | 
			
		||||
 | 
			
		||||
    cache = Dict{Any,Int32}() # for de-duplication of expressions. If an expression is in the cache simply return the index of the existing code
 | 
			
		||||
 | 
			
		||||
    convert_expr_to_code!(code, cache, body, xSy, pSy)
 | 
			
		||||
 | 
			
		||||
    # for debugging
 | 
			
		||||
    # for tup in sort(cache; byvalue=true)
 | 
			
		||||
    #     println(tup)
 | 
			
		||||
    # end
 | 
			
		||||
    return code
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
# uses cache (hashcons) to de-duplicate subexpressions in the tree.
 | 
			
		||||
function convert_expr_to_code!(code::Vector{Instruction{T}}, cache, val::TV, xSy, pSy)::UInt32 where {T,TV}
 | 
			
		||||
    if haskey(cache, val) return cache[val] end
 | 
			
		||||
 | 
			
		||||
    push!(code, create_const_instruction(T(val)))
 | 
			
		||||
    cache[val] = length(code)    
 | 
			
		||||
    return length(code)
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
function convert_expr_to_code!(code::Vector{Instruction{T}}, cache, expr::Expr, xSy, pSy)::UInt32 where {T}
 | 
			
		||||
    # predicate to check if an expression is abs(...)
 | 
			
		||||
    is_abs(a) = a isa Expr && a.head == :call && a.args[1] == :abs
 | 
			
		||||
 | 
			
		||||
    if haskey(cache, expr) return cache[expr] end
 | 
			
		||||
 | 
			
		||||
    sy = expr.head
 | 
			
		||||
    if sy == :call
 | 
			
		||||
        func = expr.args[1]
 | 
			
		||||
        arg1idx::UInt32 = 0
 | 
			
		||||
        arg2idx::UInt32 = 0
 | 
			
		||||
        # unary functions
 | 
			
		||||
        if length(expr.args) == 2
 | 
			
		||||
            arg1idx = convert_expr_to_code!(code, cache, expr.args[2], xSy, pSy)
 | 
			
		||||
            if (func == :-)
 | 
			
		||||
                # - with one argument => negate
 | 
			
		||||
                push!(code, Instruction{T}(opc_neg, arg1idx, UInt32(0), UInt32(0), zero(T)))
 | 
			
		||||
            elseif (func == :sqrt)
 | 
			
		||||
                push!(code, Instruction{T}(opc_powconst, arg1idx, UInt32(0), UInt32(0), T(0.5)))
 | 
			
		||||
            else
 | 
			
		||||
                push!(code, Instruction{T}(opcode(func), arg1idx, UInt32(0), UInt32(0), zero(T)))
 | 
			
		||||
            end
 | 
			
		||||
        elseif length(expr.args) == 3
 | 
			
		||||
            arg1idx = convert_expr_to_code!(code, cache, expr.args[2], xSy, pSy)
 | 
			
		||||
            if func == :^ && expr.args[3] isa Number && round(expr.args[3]) == expr.args[3] # is integer 
 | 
			
		||||
                # special case for constant powers
 | 
			
		||||
                push!(code, Instruction{T}(opc_powconst, arg1idx, UInt32(0), UInt32(0), T(expr.args[3])))
 | 
			
		||||
            elseif func == :^ && is_abs(expr.args[2])
 | 
			
		||||
                # fuse pow(abs(x), y) --> powabs(x,y)
 | 
			
		||||
                absexpr = expr.args[2]
 | 
			
		||||
                x = absexpr.args[2]
 | 
			
		||||
                arg1idx = convert_expr_to_code!(code, cache, x, xSy, pSy) # because of hashconsing this will return the index within the code for abs(x) generated above
 | 
			
		||||
                arg2idx = convert_expr_to_code!(code, cache, expr.args[3], xSy, pSy)
 | 
			
		||||
                push!(code, Instruction{T}(opc_powabs, arg1idx, arg2idx, UInt32(0), zero(T)))
 | 
			
		||||
            else 
 | 
			
		||||
                arg2idx = convert_expr_to_code!(code, cache, expr.args[3], xSy, pSy)
 | 
			
		||||
                push!(code, Instruction{T}(opcode(func), arg1idx, arg2idx, UInt32(0), zero(T)))
 | 
			
		||||
            end
 | 
			
		||||
        else 
 | 
			
		||||
            # dump(expr)
 | 
			
		||||
            errpr("only unary and binary functions are supported ($func is not supported)")
 | 
			
		||||
        end
 | 
			
		||||
    elseif sy == :ref
 | 
			
		||||
        arrSy = expr.args[1]
 | 
			
		||||
        idx = expr.args[2]
 | 
			
		||||
        if arrSy == xSy
 | 
			
		||||
            push!(code, create_var_instruction(T, idx))
 | 
			
		||||
        elseif arrSy == pSy
 | 
			
		||||
            push!(code, create_param_instruction(T, idx))
 | 
			
		||||
        else
 | 
			
		||||
            dump(expr)
 | 
			
		||||
            throw(UndefVarError("unknown symbol"))
 | 
			
		||||
        end
 | 
			
		||||
    else 
 | 
			
		||||
        error("Unsupported symbol $sy")
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
    cache[expr] = length(code)
 | 
			
		||||
    return length(code)
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
function Base.show(io::IO, code::AbstractArray{Instruction{T}}) where {T}
 | 
			
		||||
    sym = Dict(
 | 
			
		||||
        opc_stop => ".",
 | 
			
		||||
        opc_add => "+",
 | 
			
		||||
        opc_sub => "-",
 | 
			
		||||
        opc_neg => "neg",
 | 
			
		||||
        opc_mul => "*",
 | 
			
		||||
        opc_div => "/",
 | 
			
		||||
        opc_inv => "inv",
 | 
			
		||||
        opc_pow => "^",
 | 
			
		||||
        opc_powabs => "abs^",
 | 
			
		||||
        opc_powconst => "^c",
 | 
			
		||||
        opc_log => "log",
 | 
			
		||||
        opc_log10 => "l10",
 | 
			
		||||
        opc_exp => "exp",
 | 
			
		||||
        opc_abs => "abs",
 | 
			
		||||
        opc_sign => "sgn",
 | 
			
		||||
        opc_sin => "sin",
 | 
			
		||||
        opc_cos => "cos",
 | 
			
		||||
        opc_variable => "var",
 | 
			
		||||
        opc_constant => "con",
 | 
			
		||||
        opc_param => "par",
 | 
			
		||||
    )
 | 
			
		||||
    
 | 
			
		||||
    for i in eachindex(code)
 | 
			
		||||
        instr = code[i]
 | 
			
		||||
        Printf.format(io, Printf.format"%4d %4s %3d %3d %3d %f", i, sym[instr.opcode], instr.arg1idx, instr.arg2idx, instr.idx, instr.val)
 | 
			
		||||
        println(io)
 | 
			
		||||
        # printfmtln(io, "{1:>4d} {2:>4s} {3:>3d} {4:>3d} {5:>3d} {6:>}", i, sym[instr.opcode], instr.arg1idx, instr.arg2idx, instr.idx, instr.val)
 | 
			
		||||
    end
 | 
			
		||||
end
 | 
			
		||||
							
								
								
									
										172
									
								
								package/src/CpuInterpreter.jl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										172
									
								
								package/src/CpuInterpreter.jl
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,172 @@
 | 
			
		||||
using Random 
 | 
			
		||||
 | 
			
		||||
struct InterpreterBuffers{T}
 | 
			
		||||
    resultcache::Matrix{T} # for forward eval
 | 
			
		||||
    diffcache::Matrix{T} # for reverse AD
 | 
			
		||||
    jaccache::Matrix{T} # for Jacobian
 | 
			
		||||
    tmp::Vector{T} # a temporary space for each of the vector operations
 | 
			
		||||
 | 
			
		||||
    function InterpreterBuffers{T}(codelen, num_param, batchsize) where {T<:AbstractFloat}
 | 
			
		||||
        buf = Matrix{T}(undef, batchsize, codelen) 
 | 
			
		||||
        rev_buf = Matrix{T}(undef, batchsize, codelen)
 | 
			
		||||
        jac_buf = Matrix{T}(undef, batchsize, num_param)
 | 
			
		||||
        tmp = Vector{T}(undef, batchsize)
 | 
			
		||||
 | 
			
		||||
        new(buf, rev_buf, jac_buf, tmp)
 | 
			
		||||
    end
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
mutable struct Interpreter{T}
 | 
			
		||||
    const code::Vector{Instruction{T}}
 | 
			
		||||
    const buffers::InterpreterBuffers{T}
 | 
			
		||||
    const batchsize::UInt32
 | 
			
		||||
    pc::Int32
 | 
			
		||||
 | 
			
		||||
    function Interpreter{T}(expr::Expr, num_param; batchsize = 1024) where {T<:AbstractFloat}
 | 
			
		||||
        code = convert_expr_to_code(T, expr)
 | 
			
		||||
        # println(code)
 | 
			
		||||
        buffers = InterpreterBuffers{T}(length(code), num_param, batchsize)
 | 
			
		||||
        new(code, buffers, batchsize, 1)
 | 
			
		||||
    end
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
peek_instruction(interpreter) = interpreter.code[interpreter.pc]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# batch size 1024 was fast in benchmark
 | 
			
		||||
interpret!(result::AbstractVector{T}, expr::Expr, x::AbstractMatrix{T}, p; batchsize=1024) where {T} = interpret!(result, Interpreter{T}(expr, length(p); batchsize), x, p)
 | 
			
		||||
 | 
			
		||||
# for Float evaluation use the preallocated buffer
 | 
			
		||||
function interpret!(result::AbstractVector{T}, interpreter::Interpreter{T}, x::AbstractMatrix{T}, p::AbstractArray{T}) where {T} 
 | 
			
		||||
    interpret_withbuf!(result, interpreter, interpreter.buffers.resultcache, interpreter.buffers.tmp, x, p)
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
function interpret_withbuf!(result::AbstractVector{T}, interpreter::Interpreter{T}, batchresult, tmp, x::AbstractMatrix{T}, p::AbstractArray{TD}) where {T,TD}
 | 
			
		||||
    allrows = axes(x, 1)
 | 
			
		||||
    @assert length(result) == length(allrows)
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    # all batches
 | 
			
		||||
    start = first(allrows)
 | 
			
		||||
    while start + interpreter.batchsize < last(allrows)
 | 
			
		||||
        batchrows = start:(start + interpreter.batchsize - 1)
 | 
			
		||||
        interpret_batch!(interpreter, batchresult, tmp, x, p, batchrows)
 | 
			
		||||
        copy!((@view result[batchrows]), (@view batchresult[:, end]))
 | 
			
		||||
        start += interpreter.batchsize
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    # process remaining rows
 | 
			
		||||
    remrows = start:last(allrows)
 | 
			
		||||
    if length(remrows) > 0
 | 
			
		||||
        interpret_batch!(interpreter, batchresult, tmp, x, p, remrows)
 | 
			
		||||
        copy!((@view result[remrows]), (@view batchresult[1:length(remrows), end]))
 | 
			
		||||
        # res += sum(view(batchresult, 1:length(remrows), lastcolidx))
 | 
			
		||||
    end
 | 
			
		||||
    # res
 | 
			
		||||
    result
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
function interpret_batch!(interpreter, 
 | 
			
		||||
                    batchresult, tmp,
 | 
			
		||||
                    x, p, rows)
 | 
			
		||||
    # forward pass
 | 
			
		||||
    interpret_fwd!(interpreter, batchresult, tmp, x, p, rows)
 | 
			
		||||
 | 
			
		||||
    nothing
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
function interpret_fwd!(interpreter, batchresult, tmp, x, p, rows)
 | 
			
		||||
    interpreter.pc = 1 
 | 
			
		||||
    while interpreter.pc <= length(interpreter.code)
 | 
			
		||||
        step!(interpreter, batchresult, tmp, x, p, rows)
 | 
			
		||||
    end
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
function step!(interpreter, batchresult, tmp, x, p, range)
 | 
			
		||||
    instr = interpreter.code[interpreter.pc]
 | 
			
		||||
    opc = instr.opcode
 | 
			
		||||
    res = view(batchresult, :, interpreter.pc)
 | 
			
		||||
 | 
			
		||||
    if degree(opc) == 0
 | 
			
		||||
        if opc == opc_variable 
 | 
			
		||||
            copyto!(res, view(x, range, instr.idx))
 | 
			
		||||
        elseif opc == opc_param
 | 
			
		||||
            fill!(res, p[instr.idx])
 | 
			
		||||
        elseif opc == opc_constant
 | 
			
		||||
            fill!(res, instr.val)
 | 
			
		||||
        end
 | 
			
		||||
    elseif degree(opc) == 1
 | 
			
		||||
        arg = view(batchresult, :, instr.arg1idx)
 | 
			
		||||
        # is converted to a switch automatically by LLVM
 | 
			
		||||
        if     opc == opc_log      vec_log!(res, arg, tmp)
 | 
			
		||||
        elseif opc == opc_log10    vec_log10!(res, arg, tmp)
 | 
			
		||||
        elseif opc == opc_exp      vec_exp!(res, arg, tmp)
 | 
			
		||||
        elseif opc == opc_abs      vec_abs!(res, arg, tmp)
 | 
			
		||||
        elseif opc == opc_neg      vec_neg!(res, arg, tmp)
 | 
			
		||||
        elseif opc == opc_inv      vec_inv!(res, arg, tmp)
 | 
			
		||||
        elseif opc == opc_sign     vec_sign!(res, arg, tmp)
 | 
			
		||||
        elseif opc == opc_powconst vec_powconst!(res, arg, instr.val, tmp);
 | 
			
		||||
        elseif opc == opc_sin      vec_sin!(res, arg, tmp)
 | 
			
		||||
        elseif opc == opc_cos      vec_cos!(res, arg, tmp)
 | 
			
		||||
        elseif opc == opc_cosh     vec_cosh!(res, arg, tmp)
 | 
			
		||||
        elseif opc == opc_asin     vec_asin!(res, arg, tmp)
 | 
			
		||||
        elseif opc == opc_tan      vec_tan!(res, arg, tmp)
 | 
			
		||||
        elseif opc == opc_tanh     vec_tanh!(res, arg, tmp)
 | 
			
		||||
 | 
			
		||||
        else throw(DomainError("Unsupported opcode $opc"))
 | 
			
		||||
        end
 | 
			
		||||
    elseif degree(opc) == 2
 | 
			
		||||
        left = view(batchresult, :, instr.arg1idx)
 | 
			
		||||
        right = view(batchresult, :, instr.arg2idx)
 | 
			
		||||
 | 
			
		||||
        if     opc == opc_add    vec_add!(res, left, right, tmp)
 | 
			
		||||
        elseif opc == opc_sub    vec_sub!(res, left, right, tmp)
 | 
			
		||||
        elseif opc == opc_mul    vec_mul!(res, left, right, tmp)
 | 
			
		||||
        elseif opc == opc_div    vec_div!(res, left, right, tmp)
 | 
			
		||||
        elseif opc == opc_pow    vec_pow!(res, left, right, tmp)
 | 
			
		||||
        elseif opc == opc_powabs vec_powabs!(res, left, right, tmp)
 | 
			
		||||
        else throw(DomainError("Unsupported opcode $opc"))
 | 
			
		||||
        end
 | 
			
		||||
        # if any(isnan, res) 
 | 
			
		||||
        #     throw(DomainError("got NaN for $opc $(interpreter.pc) $left $right"))
 | 
			
		||||
        # end
 | 
			
		||||
    end
 | 
			
		||||
 | 
			
		||||
    interpreter.pc += 1
 | 
			
		||||
 | 
			
		||||
    return nothing
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
for unaryfunc in (:exp, :abs, :sin, :cos, :cosh, :asin, :tan, :tanh, :sinh)
 | 
			
		||||
    funcsy = Symbol("vec_$(unaryfunc)!")
 | 
			
		||||
    @eval function $funcsy(res::AbstractVector{T}, arg::AbstractVector{T}, ::AbstractVector{T}) where T<:Real
 | 
			
		||||
        @simd for i in eachindex(res)
 | 
			
		||||
            @inbounds res[i] = Base.$unaryfunc(arg[i])
 | 
			
		||||
        end
 | 
			
		||||
    end
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
function vec_add!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] + right[i] end end
 | 
			
		||||
function vec_sub!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] - right[i] end end
 | 
			
		||||
function vec_mul!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] * right[i] end end
 | 
			
		||||
function vec_div!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] / right[i] end end
 | 
			
		||||
function vec_pow!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = left[i] ^ right[i] end end
 | 
			
		||||
 | 
			
		||||
# TODO: special case scalar power
 | 
			
		||||
function vec_powconst!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::TC, ::AbstractVector{TE}) where {TE<:Real,TC<:Real} @simd for i in eachindex(res) @inbounds res[i] = left[i] ^ right end end
 | 
			
		||||
function vec_powabs!(res::AbstractVector{TE}, left::AbstractVector{TE}, right::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real   @simd for i in eachindex(res) @inbounds res[i] = abs(left[i]) ^ right[i] end end
 | 
			
		||||
 | 
			
		||||
function vec_neg!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE})   where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = -arg[i] end end
 | 
			
		||||
function vec_inv!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE})   where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = inv(arg[i]) end end
 | 
			
		||||
function vec_sign!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE})  where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = sign(arg[i]) end end
 | 
			
		||||
 | 
			
		||||
# handle log and exp specially to use NaN instead of DomainError
 | 
			
		||||
function vec_log!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE})   where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = arg[i] < zero(TE) ? TE(NaN) : log(arg[i]) end end
 | 
			
		||||
function vec_log10!(res::AbstractVector{TE}, arg::AbstractVector{TE}, ::AbstractVector{TE}) where TE<:Real @simd for i in eachindex(res) @inbounds res[i] = arg[i] < zero(TE) ? TE(NaN) : log10(arg[i]) end end
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,12 @@ module ExpressionExecutorCuda
 | 
			
		||||
include("ExpressionProcessing.jl")
 | 
			
		||||
include("Interpreter.jl")
 | 
			
		||||
 | 
			
		||||
export interpret_gpu
 | 
			
		||||
module CpuInterpreter
 | 
			
		||||
include("Code.jl")
 | 
			
		||||
include("CpuInterpreter.jl")
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
export interpret_gpu,interpret_cpu
 | 
			
		||||
export evaluate_gpu
 | 
			
		||||
export test
 | 
			
		||||
 | 
			
		||||
@ -22,8 +27,21 @@ function evaluate_gpu(exprs::Vector{Expr}, X::Matrix{Float32}, p::Vector{Vector{
 | 
			
		||||
	# Look into this to maybe speed up PTX generation: https://cuda.juliagpu.org/stable/tutorials/introduction/#Parallelization-on-the-CPU
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
# Evaluate Expressions on the CPU
 | 
			
		||||
function interpret_cpu(exprs::Vector{Expr}, X::Matrix{Float32}, p::Vector{Vector{Float32}})::Matrix{Float32}
 | 
			
		||||
	@assert axes(exprs) == axes(p)
 | 
			
		||||
	nrows = size(X, 1)
 | 
			
		||||
	
 | 
			
		||||
	# each column of the matrix has the result for an expr
 | 
			
		||||
	res = Matrix{Float32}(undef, nrows, length(exprs))
 | 
			
		||||
 | 
			
		||||
	for i in eachindex(exprs) 
 | 
			
		||||
		CpuInterpreter.interpret!((@view res[:,i]), exprs[i], X, p[i])
 | 
			
		||||
	end
 | 
			
		||||
 | 
			
		||||
	res
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Flow
 | 
			
		||||
@ -35,3 +53,5 @@ end
 | 
			
		||||
# The following can be done on the CPU
 | 
			
		||||
#     convert expression to postfix notation (mandatory)
 | 
			
		||||
#     optional: replace every parameter with the correct value (should only improve performance if data transfer is the bottleneck)
 | 
			
		||||
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										31
									
								
								package/test/CpuInterpreterTests.jl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								package/test/CpuInterpreterTests.jl
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,31 @@
 | 
			
		||||
 | 
			
		||||
function test_cpu_interpreter(nrows; parallel = false)
 | 
			
		||||
    exprs = [
 | 
			
		||||
        # CPU interpreter requires an anonymous function and array ref s
 | 
			
		||||
        :(p[1] * x[1] + p[2]), # 5 op
 | 
			
		||||
        :((((x[1] + x[2]) + x[3]) + x[4]) + x[5]), # 9 op
 | 
			
		||||
        :(log(abs(x[1]))), # 3 op
 | 
			
		||||
        :(powabs(p[2] - powabs(p[1] + x[1], 1/x[1]),p[3])) # 13 op
 | 
			
		||||
    ] # 30 op
 | 
			
		||||
    exprs = map(e -> Expr(:->, :(x,p), e), exprs)
 | 
			
		||||
    X = randn(Float32, nrows, 10)
 | 
			
		||||
    p = [randn(Float32, 10) for _ in 1:length(exprs)] # generate 10 random parameter values for each expr
 | 
			
		||||
    
 | 
			
		||||
    # warmup
 | 
			
		||||
    interpret_cpu(exprs, X, p)
 | 
			
		||||
    if parallel 
 | 
			
		||||
        t_sec = @elapsed Threads.@threads :static for i in 1:100 interpret_cpu(exprs, X, p) end
 | 
			
		||||
        println("~ $(round(30 * 100 * nrows  / 1e9 / t_sec, digits=2)) GFLOPS (single-core) ($(round(peakflops(1000, eltype=Float32, ntrials=1, parallel=false) / 1e9, digits=2)) GFLOPS (peak, single-core))")
 | 
			
		||||
    else
 | 
			
		||||
        t_sec = @elapsed for i in 1:100 interpret_cpu(exprs, X, p) end
 | 
			
		||||
        println("~ $(round(30 * 100 * nrows  / 1e9 / t_sec, digits=2)) GFLOPS ($(Threads.nthreads()) threads) ($(round(peakflops(1000, eltype=Float32, ntrials=1, parallel=false) / 1e9, digits=2)) GFLOPS (peak, single-core))")
 | 
			
		||||
    end
 | 
			
		||||
    true
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
LinearAlgebra.BLAS.set_num_threads(1) # only use a single thread for peakflops
 | 
			
		||||
 | 
			
		||||
@test test_cpu_interpreter(1000)
 | 
			
		||||
@test test_cpu_interpreter(1000, parallel=true) # start julia -t 6 for six threads
 | 
			
		||||
@test test_cpu_interpreter(10000)
 | 
			
		||||
@test test_cpu_interpreter(10000, parallel=true)
 | 
			
		||||
@ -11,3 +11,8 @@ include(joinpath(baseFolder, "src", "Transpiler.jl"))
 | 
			
		||||
	include("InterpreterTests.jl")
 | 
			
		||||
	include("TranspilerTests.jl")
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@testset "CPU Interpreter" begin
 | 
			
		||||
	include("CpuInterpreterTests.jl")
 | 
			
		||||
end
 | 
			
		||||
		Reference in New Issue
	
	Block a user