using Printf @enum Opcode::UInt8 begin opc_stop = 1 # must start with 1 here TODO: remove stop opc_add opc_sub opc_mul opc_div opc_inv opc_log opc_log10 opc_exp opc_pow opc_powconst opc_powabs opc_neg opc_abs opc_sign opc_sin opc_asin opc_tan opc_tanh opc_cos opc_cosh opc_constant opc_param opc_variable end const terminal_opcodes = [opc_stop, opc_constant, opc_param, opc_variable] const unary_opcodes = [opc_log, opc_log10, opc_exp, opc_abs, opc_sign, opc_sin, opc_cos, opc_cosh, opc_asin, opc_tan, opc_tanh, opc_powconst, opc_neg, opc_inv] const binary_opcodes = [opc_add, opc_sub, opc_mul, opc_div, opc_pow, opc_powabs] function opcode(sy::Symbol)::Opcode if sy == :+ return opc_add elseif sy == :- return opc_sub elseif sy == :* return opc_mul elseif sy == :/ return opc_div elseif sy == :inv return opc_inv elseif sy == :log return opc_log elseif sy == :log10 return opc_log10 elseif sy == :exp return opc_exp elseif sy == :^ return opc_powabs # TODO: this is temporary to enforce that all powers are evaluated as pow(abs(...)) for parameter optimization elseif sy == :powabs return opc_powabs # TODO: this is temporary to enforce that all powers are evaluated as pow(abs(...)) for parameter optimization elseif sy == :abs return opc_abs elseif sy == :sign return opc_sign elseif sy == :sin return opc_sin elseif sy == :asin return opc_asin elseif sy == :cos return opc_cos elseif sy == :cosh return opc_cosh elseif sy == :tan return opc_tan elseif sy == :tanh return opc_tanh else error("no opcode for symbol $sy") end end function degree(opc::Opcode)::Integer if opc in terminal_opcodes return 0 elseif opc in unary_opcodes return 1 elseif opc in binary_opcodes return 2 else error("unknown degree of opcode $opc") end end # code is a Vector{Instruction} which is a linear representation of a directed acyclic graph of expressions. # The code can be evaluated from left to right. struct Instruction{T} opcode::Opcode arg1idx::UInt32 # index of first argument. 0 for terminals arg2idx::UInt32 # index of second argument. 0 for functions with a single argument idx::UInt32 # for variables and parameters val::T # for constants end function Base.show(io::IO, instr::Instruction) Printf.format(io, Printf.format"%15s %3d %3d %3d %f", instr.opcode, instr.arg1idx, instr.arg2idx, instr.idx, instr.val) end create_const_instruction(val::T) where {T} = Instruction{T}(opc_constant, UInt32(0), UInt32(0), UInt32(0), val) create_var_instruction(::Type{T}, varidx) where {T} = Instruction{T}(opc_variable, UInt32(0), UInt32(0), UInt32(varidx), zero(T)) create_param_instruction(::Type{T}, paramidx; val::T = zero(T)) where {T} = Instruction{T}(opc_param, UInt32(0), UInt32(0), UInt32(paramidx), val) function convert_expr_to_code(::Type{T}, expr::Expr)::Vector{Instruction{T}} where {T} code = Vector{Instruction{T}}() Base.remove_linenums!(expr) paramTup = expr.args[1] xSy = paramTup.args[1] pSy = paramTup.args[2] body = expr.args[2] cache = Dict{Any,Int32}() # for de-duplication of expressions. If an expression is in the cache simply return the index of the existing code convert_expr_to_code!(code, cache, body, xSy, pSy) # for debugging # for tup in sort(cache; byvalue=true) # println(tup) # end return code end # uses cache (hashcons) to de-duplicate subexpressions in the tree. function convert_expr_to_code!(code::Vector{Instruction{T}}, cache, val::TV, xSy, pSy)::UInt32 where {T,TV} if haskey(cache, val) return cache[val] end push!(code, create_const_instruction(T(val))) cache[val] = length(code) return length(code) end function convert_expr_to_code!(code::Vector{Instruction{T}}, cache, expr::Expr, xSy, pSy)::UInt32 where {T} # predicate to check if an expression is abs(...) is_abs(a) = a isa Expr && a.head == :call && a.args[1] == :abs if haskey(cache, expr) return cache[expr] end sy = expr.head if sy == :call func = expr.args[1] arg1idx::UInt32 = 0 arg2idx::UInt32 = 0 # unary functions if length(expr.args) == 2 arg1idx = convert_expr_to_code!(code, cache, expr.args[2], xSy, pSy) if (func == :-) # - with one argument => negate push!(code, Instruction{T}(opc_neg, arg1idx, UInt32(0), UInt32(0), zero(T))) elseif (func == :sqrt) push!(code, Instruction{T}(opc_powconst, arg1idx, UInt32(0), UInt32(0), T(0.5))) else push!(code, Instruction{T}(opcode(func), arg1idx, UInt32(0), UInt32(0), zero(T))) end elseif length(expr.args) == 3 arg1idx = convert_expr_to_code!(code, cache, expr.args[2], xSy, pSy) if func == :^ && expr.args[3] isa Number && round(expr.args[3]) == expr.args[3] # is integer # special case for constant powers push!(code, Instruction{T}(opc_powconst, arg1idx, UInt32(0), UInt32(0), T(expr.args[3]))) elseif func == :^ && is_abs(expr.args[2]) # fuse pow(abs(x), y) --> powabs(x,y) absexpr = expr.args[2] x = absexpr.args[2] arg1idx = convert_expr_to_code!(code, cache, x, xSy, pSy) # because of hashconsing this will return the index within the code for abs(x) generated above arg2idx = convert_expr_to_code!(code, cache, expr.args[3], xSy, pSy) push!(code, Instruction{T}(opc_powabs, arg1idx, arg2idx, UInt32(0), zero(T))) else arg2idx = convert_expr_to_code!(code, cache, expr.args[3], xSy, pSy) push!(code, Instruction{T}(opcode(func), arg1idx, arg2idx, UInt32(0), zero(T))) end else # dump(expr) errpr("only unary and binary functions are supported ($func is not supported)") end elseif sy == :ref arrSy = expr.args[1] idx = expr.args[2] if arrSy == xSy push!(code, create_var_instruction(T, idx)) elseif arrSy == pSy push!(code, create_param_instruction(T, idx)) else dump(expr) throw(UndefVarError("unknown symbol")) end else error("Unsupported symbol $sy") end cache[expr] = length(code) return length(code) end function Base.show(io::IO, code::AbstractArray{Instruction{T}}) where {T} sym = Dict( opc_stop => ".", opc_add => "+", opc_sub => "-", opc_neg => "neg", opc_mul => "*", opc_div => "/", opc_inv => "inv", opc_pow => "^", opc_powabs => "abs^", opc_powconst => "^c", opc_log => "log", opc_log10 => "l10", opc_exp => "exp", opc_abs => "abs", opc_sign => "sgn", opc_sin => "sin", opc_cos => "cos", opc_variable => "var", opc_constant => "con", opc_param => "par", ) for i in eachindex(code) instr = code[i] Printf.format(io, Printf.format"%4d %4s %3d %3d %3d %f", i, sym[instr.opcode], instr.arg1idx, instr.arg2idx, instr.idx, instr.val) println(io) # printfmtln(io, "{1:>4d} {2:>4s} {3:>3d} {4:>3d} {5:>3d} {6:>}", i, sym[instr.opcode], instr.arg1idx, instr.arg2idx, instr.idx, instr.val) end end