master-thesis/package/src/ExpressionProcessing.jl

185 lines
5.7 KiB
Julia
Raw Normal View History

module ExpressionProcessing
export expr_to_postfix, is_binary_operator
2024-07-12 16:35:30 +02:00
export PostfixType
export Operator, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, ABS, LOG, EXP, SQRT
export ElementType, EMPTY, FLOAT32, OPERATOR, INDEX
export ExpressionElement
@enum Operator ADD=1 SUBTRACT=2 MULTIPLY=3 DIVIDE=4 POWER=5 ABS=6 LOG=7 EXP=8 SQRT=9
@enum ElementType EMPTY=0 FLOAT32=1 OPERATOR=2 INDEX=3
struct ExpressionElement
Type::ElementType
Value::Int32 # Reinterpret the stored value to type "ElementType" when using it
end
const PostfixType = Vector{ExpressionElement}
"
Converts a julia expression to its postfix notation.
NOTE: All 64-Bit values will be converted to 32-Bit. Be aware of the lost precision
"
2024-07-12 16:35:30 +02:00
function expr_to_postfix(expr::Expr)::PostfixType
postfix = PostfixType()
operator = get_operator(expr.args[1])
2024-07-12 16:35:30 +02:00
for j in 2:length(expr.args)
arg = expr.args[j]
2024-09-08 11:52:10 +02:00
if typeof(arg) === Expr
append!(postfix, expr_to_postfix(arg))
elseif typeof(arg) === Symbol # variables/parameters
exprElement = convert_to_ExpressionElement(convert_var_to_int(arg))
push!(postfix, exprElement)
else
exprElement = convert_to_ExpressionElement(convert(Float32, arg))
push!(postfix, exprElement)
end
2024-09-08 11:52:10 +02:00
# only add operator if at least 2 values are added. For the case where another expression is added first, we check if we are at the first iteration or not ( j != 2)
if length(postfix) >= 2 && j != 2
exprElement = convert_to_ExpressionElement(operator)
push!(postfix, exprElement)
end
end
# For the case this expression has an operator that only takes in a single value like "abs(x)"
if length(postfix) == 1
push!(postfix, convert_to_ExpressionElement(operator))
end
return postfix
end
function get_operator(op::Symbol)::Operator
if op == :+
return ADD
elseif op == :-
return SUBTRACT
elseif op == :*
return MULTIPLY
elseif op == :/
return DIVIDE
elseif op == :^
return POWER
elseif op == :abs
return ABS
elseif op == :log
return LOG
elseif op == :exp
return EXP
elseif op == :sqrt
return SQRT
end
end
"Extracts the number from a variable/parameter and returns it. If the symbol is a parameter ```pn```, the resulting value will be negativ.
```x0 and p0``` are not allowed."
function convert_var_to_int(var::Symbol)::Int32
varStr = String(var)
number = parse(Int32, SubString(varStr, 2))
if varStr[1] == 'p'
number = -number
end
return number
end
function convert_to_ExpressionElement(element::Int32)::ExpressionElement
value = reinterpret(Int32, element)
return ExpressionElement(INDEX, value)
end
function convert_to_ExpressionElement(element::Int64)::ExpressionElement
value = reinterpret(Int32, convert(Int32, element))
return ExpressionElement(INDEX, value)
end
function convert_to_ExpressionElement(element::Float32)::ExpressionElement
value = reinterpret(Int32, element)
return ExpressionElement(FLOAT32, value)
end
function convert_to_ExpressionElement(element::Float64)::ExpressionElement
value = reinterpret(Int32, convert(Float32, element))
return ExpressionElement(FLOAT32, value)
end
function convert_to_ExpressionElement(element::Operator)::ExpressionElement
value = reinterpret(Int32, element)
return ExpressionElement(OPERATOR, value)
end
function is_binary_operator(operator::Operator)::Bool
if operator == ADD
return true
elseif operator == SUBTRACT
return true
elseif operator == MULTIPLY
return true
elseif operator == DIVIDE
return true
elseif operator == POWER
return true
elseif operator == ABS
return false
elseif operator == LOG
return false
elseif operator == EXP
return false
elseif operator == SQRT
return false
else
throw(ArgumentError("Unknown operator '$operator'. Cannot determine if it is binary or not."))
end
end
#
# Everything below is currently not needed. Left here for potential future use
#
const SymbolTable32 = Dict{Tuple{Expr, Symbol},Float32}
"Replaces all the variables and parameters of the given expression with their corresponding Value stored in the symtable
# Arguments
- `symtable::SymbolTable32`: Contains the values of all variables for each expression
- `originalExpr::Expr`: Contains a deep copy of the original expression. It is used to link the expression and variables to their according Value stored in the symtable
"
function replace_variables!(ex::Expr, symtable::SymbolTable32, originalExpr::Expr)
for i in 1:length(ex.args)
arg = ex.args[i]
if typeof(arg) === Expr
replace_variables!(arg, symtable, originalExpr)
elseif haskey(symtable, (originalExpr,arg)) # We found a variable/parameter and can replace it with the actual Value
ex.args[i] = symtable[(originalExpr,arg)]
end
end
end
# TODO: Completly rewrite this function because I misunderstood it. Not every column is linked to an expression. therefore all other functions need to be reworked as well. Probably can't replace the variables in julia anymore, look into this. (see ExpressionExecutorCuda.jl for more info)
# Before rewriting, proceed with just creating a postfix notation and sending the variable matrix as well as the parameter "matrix" to the GPU to perform first calculations
function construct_symtable(expressions::Vector{Expr}, mat::Matrix{Float32}, params::Vector{Vector{Float32}})::SymbolTable32
symtable = SymbolTable32()
for i in eachindex(expressions)
expr = expressions[i]
values = mat[i,:]
parameters = params[i]
fill_symtable!(expr, symtable, values, "x")
fill_symtable!(expr, symtable, parameters, "p")
end
return symtable
end
function fill_symtable!(expr::Expr, symtable::SymbolTable32, values::Vector{Float32}, symbolPrefix::String)
varIndex = 1
for j in eachindex(values)
val = values[j]
sym = Symbol(symbolPrefix, varIndex)
symtable[expr,sym] = val
varIndex += 1
end
end
end