207 lines
7.4 KiB
Julia
207 lines
7.4 KiB
Julia
using Printf
|
|
|
|
@enum Opcode::UInt8 begin
|
|
opc_stop = 1 # must start with 1 here TODO: remove stop
|
|
opc_add
|
|
opc_sub
|
|
opc_mul
|
|
opc_div
|
|
opc_inv
|
|
opc_log
|
|
opc_log10
|
|
opc_exp
|
|
opc_pow
|
|
opc_powconst
|
|
opc_powabs
|
|
opc_neg
|
|
opc_abs
|
|
opc_sign
|
|
opc_sin
|
|
opc_asin
|
|
opc_tan
|
|
opc_tanh
|
|
opc_cos
|
|
opc_cosh
|
|
opc_constant
|
|
opc_param
|
|
opc_variable
|
|
end
|
|
|
|
const terminal_opcodes = [opc_stop, opc_constant, opc_param, opc_variable]
|
|
const unary_opcodes = [opc_log, opc_log10, opc_exp, opc_abs, opc_sign, opc_sin, opc_cos, opc_cosh, opc_asin, opc_tan, opc_tanh, opc_powconst, opc_neg, opc_inv]
|
|
const binary_opcodes = [opc_add, opc_sub, opc_mul, opc_div, opc_pow, opc_powabs]
|
|
|
|
function opcode(sy::Symbol)::Opcode
|
|
if sy == :+ return opc_add
|
|
elseif sy == :- return opc_sub
|
|
elseif sy == :* return opc_mul
|
|
elseif sy == :/ return opc_div
|
|
elseif sy == :inv return opc_inv
|
|
elseif sy == :log return opc_log
|
|
elseif sy == :log10 return opc_log10
|
|
elseif sy == :exp return opc_exp
|
|
elseif sy == :^ return opc_powabs # TODO: this is temporary to enforce that all powers are evaluated as pow(abs(...)) for parameter optimization
|
|
elseif sy == :powabs return opc_powabs # TODO: this is temporary to enforce that all powers are evaluated as pow(abs(...)) for parameter optimization
|
|
elseif sy == :abs return opc_abs
|
|
elseif sy == :sign return opc_sign
|
|
elseif sy == :sin return opc_sin
|
|
elseif sy == :asin return opc_asin
|
|
elseif sy == :cos return opc_cos
|
|
elseif sy == :cosh return opc_cosh
|
|
elseif sy == :tan return opc_tan
|
|
elseif sy == :tanh return opc_tanh
|
|
else error("no opcode for symbol $sy")
|
|
end
|
|
end
|
|
|
|
function degree(opc::Opcode)::Integer
|
|
if opc in terminal_opcodes return 0
|
|
elseif opc in unary_opcodes return 1
|
|
elseif opc in binary_opcodes return 2
|
|
else error("unknown degree of opcode $opc")
|
|
end
|
|
end
|
|
|
|
|
|
# code is a Vector{Instruction} which is a linear representation of a directed acyclic graph of expressions.
|
|
# The code can be evaluated from left to right.
|
|
struct Instruction{T}
|
|
opcode::Opcode
|
|
arg1idx::UInt32 # index of first argument. 0 for terminals
|
|
arg2idx::UInt32 # index of second argument. 0 for functions with a single argument
|
|
idx::UInt32 # for variables and parameters
|
|
val::T # for constants
|
|
end
|
|
|
|
|
|
function Base.show(io::IO, instr::Instruction)
|
|
Printf.format(io, Printf.format"%15s %3d %3d %3d %f", instr.opcode, instr.arg1idx, instr.arg2idx, instr.idx, instr.val)
|
|
end
|
|
|
|
create_const_instruction(val::T) where {T} = Instruction{T}(opc_constant, UInt32(0), UInt32(0), UInt32(0), val)
|
|
create_var_instruction(::Type{T}, varidx) where {T} = Instruction{T}(opc_variable, UInt32(0), UInt32(0), UInt32(varidx), zero(T))
|
|
create_param_instruction(::Type{T}, paramidx; val::T = zero(T)) where {T} = Instruction{T}(opc_param, UInt32(0), UInt32(0), UInt32(paramidx), val)
|
|
|
|
|
|
function convert_expr_to_code(::Type{T}, expr::Expr)::Vector{Instruction{T}} where {T}
|
|
code = Vector{Instruction{T}}()
|
|
|
|
Base.remove_linenums!(expr)
|
|
paramTup = expr.args[1]
|
|
xSy = paramTup.args[1]
|
|
pSy = paramTup.args[2]
|
|
body = expr.args[2]
|
|
|
|
cache = Dict{Any,Int32}() # for de-duplication of expressions. If an expression is in the cache simply return the index of the existing code
|
|
|
|
convert_expr_to_code!(code, cache, body, xSy, pSy)
|
|
|
|
# for debugging
|
|
# for tup in sort(cache; byvalue=true)
|
|
# println(tup)
|
|
# end
|
|
return code
|
|
end
|
|
|
|
# uses cache (hashcons) to de-duplicate subexpressions in the tree.
|
|
function convert_expr_to_code!(code::Vector{Instruction{T}}, cache, val::TV, xSy, pSy)::UInt32 where {T,TV}
|
|
if haskey(cache, val) return cache[val] end
|
|
|
|
push!(code, create_const_instruction(T(val)))
|
|
cache[val] = length(code)
|
|
return length(code)
|
|
end
|
|
|
|
function convert_expr_to_code!(code::Vector{Instruction{T}}, cache, expr::Expr, xSy, pSy)::UInt32 where {T}
|
|
# predicate to check if an expression is abs(...)
|
|
is_abs(a) = a isa Expr && a.head == :call && a.args[1] == :abs
|
|
|
|
if haskey(cache, expr) return cache[expr] end
|
|
|
|
sy = expr.head
|
|
if sy == :call
|
|
func = expr.args[1]
|
|
arg1idx::UInt32 = 0
|
|
arg2idx::UInt32 = 0
|
|
# unary functions
|
|
if length(expr.args) == 2
|
|
arg1idx = convert_expr_to_code!(code, cache, expr.args[2], xSy, pSy)
|
|
if (func == :-)
|
|
# - with one argument => negate
|
|
push!(code, Instruction{T}(opc_neg, arg1idx, UInt32(0), UInt32(0), zero(T)))
|
|
elseif (func == :sqrt)
|
|
push!(code, Instruction{T}(opc_powconst, arg1idx, UInt32(0), UInt32(0), T(0.5)))
|
|
else
|
|
push!(code, Instruction{T}(opcode(func), arg1idx, UInt32(0), UInt32(0), zero(T)))
|
|
end
|
|
elseif length(expr.args) == 3
|
|
arg1idx = convert_expr_to_code!(code, cache, expr.args[2], xSy, pSy)
|
|
if func == :^ && expr.args[3] isa Number && round(expr.args[3]) == expr.args[3] # is integer
|
|
# special case for constant powers
|
|
push!(code, Instruction{T}(opc_powconst, arg1idx, UInt32(0), UInt32(0), T(expr.args[3])))
|
|
elseif func == :^ && is_abs(expr.args[2])
|
|
# fuse pow(abs(x), y) --> powabs(x,y)
|
|
absexpr = expr.args[2]
|
|
x = absexpr.args[2]
|
|
arg1idx = convert_expr_to_code!(code, cache, x, xSy, pSy) # because of hashconsing this will return the index within the code for abs(x) generated above
|
|
arg2idx = convert_expr_to_code!(code, cache, expr.args[3], xSy, pSy)
|
|
push!(code, Instruction{T}(opc_powabs, arg1idx, arg2idx, UInt32(0), zero(T)))
|
|
else
|
|
arg2idx = convert_expr_to_code!(code, cache, expr.args[3], xSy, pSy)
|
|
push!(code, Instruction{T}(opcode(func), arg1idx, arg2idx, UInt32(0), zero(T)))
|
|
end
|
|
else
|
|
# dump(expr)
|
|
errpr("only unary and binary functions are supported ($func is not supported)")
|
|
end
|
|
elseif sy == :ref
|
|
arrSy = expr.args[1]
|
|
idx = expr.args[2]
|
|
if arrSy == xSy
|
|
push!(code, create_var_instruction(T, idx))
|
|
elseif arrSy == pSy
|
|
push!(code, create_param_instruction(T, idx))
|
|
else
|
|
dump(expr)
|
|
throw(UndefVarError("unknown symbol"))
|
|
end
|
|
else
|
|
error("Unsupported symbol $sy")
|
|
end
|
|
|
|
cache[expr] = length(code)
|
|
return length(code)
|
|
end
|
|
|
|
|
|
function Base.show(io::IO, code::AbstractArray{Instruction{T}}) where {T}
|
|
sym = Dict(
|
|
opc_stop => ".",
|
|
opc_add => "+",
|
|
opc_sub => "-",
|
|
opc_neg => "neg",
|
|
opc_mul => "*",
|
|
opc_div => "/",
|
|
opc_inv => "inv",
|
|
opc_pow => "^",
|
|
opc_powabs => "abs^",
|
|
opc_powconst => "^c",
|
|
opc_log => "log",
|
|
opc_log10 => "l10",
|
|
opc_exp => "exp",
|
|
opc_abs => "abs",
|
|
opc_sign => "sgn",
|
|
opc_sin => "sin",
|
|
opc_cos => "cos",
|
|
opc_variable => "var",
|
|
opc_constant => "con",
|
|
opc_param => "par",
|
|
)
|
|
|
|
for i in eachindex(code)
|
|
instr = code[i]
|
|
Printf.format(io, Printf.format"%4d %4s %3d %3d %3d %f", i, sym[instr.opcode], instr.arg1idx, instr.arg2idx, instr.idx, instr.val)
|
|
println(io)
|
|
# printfmtln(io, "{1:>4d} {2:>4s} {3:>3d} {4:>3d} {5:>3d} {6:>}", i, sym[instr.opcode], instr.arg1idx, instr.arg2idx, instr.idx, instr.val)
|
|
end
|
|
end |