added support for variables and parameters as array. also improved conversion of variables and parameters into Expressionelement
Some checks are pending
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Waiting to run

This commit is contained in:
2025-05-09 11:04:10 +02:00
parent aaa3f2c7c0
commit 2c8a9cd2d8
7 changed files with 101 additions and 67 deletions

View File

@ -3,11 +3,11 @@ module ExpressionProcessing
export expr_to_postfix, is_binary_operator
export PostfixType
export Operator, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, ABS, LOG, EXP, SQRT
export ElementType, EMPTY, FLOAT32, OPERATOR, INDEX
export ElementType, EMPTY, FLOAT32, OPERATOR, VARIABLE, PARAMETER
export ExpressionElement
@enum Operator ADD=1 SUBTRACT=2 MULTIPLY=3 DIVIDE=4 POWER=5 ABS=6 LOG=7 EXP=8 SQRT=9
@enum ElementType EMPTY=0 FLOAT32=1 OPERATOR=2 INDEX=3
@enum ElementType EMPTY=0 FLOAT32=1 OPERATOR=2 VARIABLE=3 PARAMETER=4
const binary_operators = [ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER]
const unary_operators = [ABS, LOG, EXP, SQRT]
@ -24,12 +24,28 @@ Converts a julia expression to its postfix notation.
NOTE: All 64-Bit values will be converted to 32-Bit. Be aware of the lost precision.
NOTE: This function is not thread save, especially cache access is not thread save
"
function expr_to_postfix(expr::Expr, cache::Dict{Expr, PostfixType})::PostfixType
function expr_to_postfix(expression::Expr, cache::Dict{Expr, PostfixType})::PostfixType
expr = expression
if expression.head === :->
# if the expression equals (x, p) -> (...) then the below statement extracts the expression to evaluate
expr = expression.args[2].args[2]
end
if haskey(cache, expr)
return cache[expr]
end
postfix = PostfixType()
# Special handling in the case where the expression is an array access
# This can happen if the token is a variable/parameter of the form x[n]/p[n]
if expr.head == :ref
exprElement = convert_to_ExpressionElement(expr.args[1], expr.args[2]) # we assume that an array access never contains an expression, as this would make not much sense in this case
push!(postfix, exprElement)
cache[expr] = postfix
return postfix
end
@inbounds operator = get_operator(expr.args[1])
@inbounds for j in 2:length(expr.args)
@ -37,9 +53,8 @@ function expr_to_postfix(expr::Expr, cache::Dict{Expr, PostfixType})::PostfixTyp
if typeof(arg) === Expr
append!(postfix, expr_to_postfix(arg, cache))
elseif typeof(arg) === Symbol # variables/parameters
# maybe TODO: replace the parameters with their respective values, as this might make the expr evaluation faster
exprElement = convert_to_ExpressionElement(convert_var_to_int(arg))
elseif typeof(arg) === Symbol # variables/parameters of the form xn/pn
exprElement = convert_to_ExpressionElement(arg)
push!(postfix, exprElement)
else
exprElement = convert_to_ExpressionElement(convert(Float32, arg))
@ -47,7 +62,7 @@ function expr_to_postfix(expr::Expr, cache::Dict{Expr, PostfixType})::PostfixTyp
end
# only add operator if at least 2 values are added. Needed because e.g. multiple consecutive additions are one subtree with one operator, but multiple operators need to be added to the postfix notation.
# For the case where another expression has already been added, we check if we are at the first iteration or not ( j != 2)
# For the case where another expression has already been added to the final postfix notation, we check if we are at the first iteration or not ( j != 2)
if length(postfix) >= 2 && j != 2
exprElement = convert_to_ExpressionElement(operator)
push!(postfix, exprElement)
@ -74,6 +89,8 @@ function get_operator(op::Symbol)::Operator
return DIVIDE
elseif op == :^
return POWER
elseif op == :powabs
return POWER # TODO: Fix this
elseif op == :abs
return ABS
elseif op == :log
@ -82,8 +99,6 @@ function get_operator(op::Symbol)::Operator
return EXP
elseif op == :sqrt
return SQRT
elseif op == :powabs
return POWER # TODO: Fix this
else
throw("Operator unknown")
end
@ -103,14 +118,30 @@ function convert_var_to_int(var::Symbol)::Int32
return number
end
function convert_to_ExpressionElement(element::Int32)::ExpressionElement
value = reinterpret(Int32, element)
return ExpressionElement(INDEX, value)
"parses a symbol to be either a variable or a parameter and returns the corresponding Expressionelement"
function convert_to_ExpressionElement(element::Symbol)::ExpressionElement
varStr = String(element)
index = parse(Int32, SubString(varStr, 2))
if varStr[1] == 'x'
return ExpressionElement(VARIABLE, index)
elseif varStr[1] == 'p'
return ExpressionElement(PARAMETER, index)
else
throw("Cannot parse symbol to be either a variable or a parameter. Symbol was '$varStr'")
end
end
function convert_to_ExpressionElement(element::Int64)::ExpressionElement
value = reinterpret(Int32, convert(Int32, element))
return ExpressionElement(INDEX, value)
"parses a symbol to be either a variable or a parameter and returns the corresponding Expressionelement"
function convert_to_ExpressionElement(element::Symbol, index::Integer)::ExpressionElement
if element == :x
return ExpressionElement(VARIABLE, convert(Int32, index))
elseif element == :p
return ExpressionElement(PARAMETER, convert(Int32, index))
else
throw("Cannot parse symbol to be either a variable or a parameter. Symbol was '$varStr'")
end
end
function convert_to_ExpressionElement(element::Float32)::ExpressionElement
value = reinterpret(Int32, element)
return ExpressionElement(FLOAT32, value)

View File

@ -64,24 +64,20 @@ function interpret_expression(expressions::CuDeviceArray{ExpressionElement}, var
@inbounds firstVariableIndex = ((varSetIndex-1) * stepsize[3]) # Exclusive
@inbounds for i in firstExprIndex:lastExprIndex
expr = expressions[i]
if expr.Type == EMPTY
token = expressions[i]
if token.Type == EMPTY
break
elseif expr.Type == INDEX
val = expr.Value
elseif token.Type == VARIABLE
operationStackTop += 1
if val > 0
operationStack[operationStackTop] = variables[firstVariableIndex + val]
else
val = abs(val)
operationStack[operationStackTop] = parameters[firstParamIndex + val]
end
elseif expr.Type == FLOAT32
operationStack[operationStackTop] = variables[firstVariableIndex + token.Value]
elseif token.Type == PARAMETER
operationStackTop += 1
operationStack[operationStackTop] = reinterpret(Float32, expr.Value)
elseif expr.Type == OPERATOR
opcode = reinterpret(Operator, expr.Value)
operationStack[operationStackTop] = parameters[firstParamIndex + token.Value]
elseif token.Type == FLOAT32
operationStackTop += 1
operationStack[operationStackTop] = reinterpret(Float32, token.Value)
elseif token.Type == OPERATOR
opcode = reinterpret(Operator, token.Value)
if opcode == ADD
operationStackTop -= 1
operationStack[operationStackTop] = operationStack[operationStackTop] + operationStack[operationStackTop + 1]

View File

@ -220,21 +220,20 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType,
println(codeBuffer, operation)
push!(operands, resultRegister)
elseif token.Type == INDEX
if token.Value > 0 # varaibles
var, first_access = Utils.get_register_for_name(regManager, "x$(token.Value)")
if first_access
println(codeBuffer, load_into_register(var, variablesLocation, token.Value, threadId64Reg, variablesSetSize, regManager))
end
push!(operands, var)
else
absVal = abs(token.Value)
param, first_access = Utils.get_register_for_name(regManager, "p$absVal")
if first_access
println(codeBuffer, load_into_register(param, parametersLocation, absVal, exprId64Reg, parametersSetSize, regManager))
end
push!(operands, param)
elseif token.Type == VARIABLE
var, first_access = Utils.get_register_for_name(regManager, "x$(token.Value)")
if first_access
println(codeBuffer, load_into_register(var, variablesLocation, token.Value, threadId64Reg, variablesSetSize, regManager))
end
push!(operands, var)
elseif token.Type == PARAMETER
param, first_access = Utils.get_register_for_name(regManager, "p$(token.Value)")
if first_access
println(codeBuffer, load_into_register(param, parametersLocation, token.Value, exprId64Reg, parametersSetSize, regManager))
end
push!(operands, param)
else
throw("Token unkown. Token was '$(token)'")
end
end