added support for variables and parameters as array. also improved conversion of variables and parameters into Expressionelement
Some checks are pending
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Waiting to run

This commit is contained in:
Daniel 2025-05-09 11:04:10 +02:00
parent aaa3f2c7c0
commit 2c8a9cd2d8
7 changed files with 101 additions and 67 deletions

View File

@ -3,11 +3,11 @@ module ExpressionProcessing
export expr_to_postfix, is_binary_operator export expr_to_postfix, is_binary_operator
export PostfixType export PostfixType
export Operator, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, ABS, LOG, EXP, SQRT export Operator, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, ABS, LOG, EXP, SQRT
export ElementType, EMPTY, FLOAT32, OPERATOR, INDEX export ElementType, EMPTY, FLOAT32, OPERATOR, VARIABLE, PARAMETER
export ExpressionElement export ExpressionElement
@enum Operator ADD=1 SUBTRACT=2 MULTIPLY=3 DIVIDE=4 POWER=5 ABS=6 LOG=7 EXP=8 SQRT=9 @enum Operator ADD=1 SUBTRACT=2 MULTIPLY=3 DIVIDE=4 POWER=5 ABS=6 LOG=7 EXP=8 SQRT=9
@enum ElementType EMPTY=0 FLOAT32=1 OPERATOR=2 INDEX=3 @enum ElementType EMPTY=0 FLOAT32=1 OPERATOR=2 VARIABLE=3 PARAMETER=4
const binary_operators = [ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER] const binary_operators = [ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER]
const unary_operators = [ABS, LOG, EXP, SQRT] const unary_operators = [ABS, LOG, EXP, SQRT]
@ -24,12 +24,28 @@ Converts a julia expression to its postfix notation.
NOTE: All 64-Bit values will be converted to 32-Bit. Be aware of the lost precision. NOTE: All 64-Bit values will be converted to 32-Bit. Be aware of the lost precision.
NOTE: This function is not thread save, especially cache access is not thread save NOTE: This function is not thread save, especially cache access is not thread save
" "
function expr_to_postfix(expr::Expr, cache::Dict{Expr, PostfixType})::PostfixType function expr_to_postfix(expression::Expr, cache::Dict{Expr, PostfixType})::PostfixType
expr = expression
if expression.head === :->
# if the expression equals (x, p) -> (...) then the below statement extracts the expression to evaluate
expr = expression.args[2].args[2]
end
if haskey(cache, expr) if haskey(cache, expr)
return cache[expr] return cache[expr]
end end
postfix = PostfixType() postfix = PostfixType()
# Special handling in the case where the expression is an array access
# This can happen if the token is a variable/parameter of the form x[n]/p[n]
if expr.head == :ref
exprElement = convert_to_ExpressionElement(expr.args[1], expr.args[2]) # we assume that an array access never contains an expression, as this would make not much sense in this case
push!(postfix, exprElement)
cache[expr] = postfix
return postfix
end
@inbounds operator = get_operator(expr.args[1]) @inbounds operator = get_operator(expr.args[1])
@inbounds for j in 2:length(expr.args) @inbounds for j in 2:length(expr.args)
@ -37,9 +53,8 @@ function expr_to_postfix(expr::Expr, cache::Dict{Expr, PostfixType})::PostfixTyp
if typeof(arg) === Expr if typeof(arg) === Expr
append!(postfix, expr_to_postfix(arg, cache)) append!(postfix, expr_to_postfix(arg, cache))
elseif typeof(arg) === Symbol # variables/parameters elseif typeof(arg) === Symbol # variables/parameters of the form xn/pn
# maybe TODO: replace the parameters with their respective values, as this might make the expr evaluation faster exprElement = convert_to_ExpressionElement(arg)
exprElement = convert_to_ExpressionElement(convert_var_to_int(arg))
push!(postfix, exprElement) push!(postfix, exprElement)
else else
exprElement = convert_to_ExpressionElement(convert(Float32, arg)) exprElement = convert_to_ExpressionElement(convert(Float32, arg))
@ -47,7 +62,7 @@ function expr_to_postfix(expr::Expr, cache::Dict{Expr, PostfixType})::PostfixTyp
end end
# only add operator if at least 2 values are added. Needed because e.g. multiple consecutive additions are one subtree with one operator, but multiple operators need to be added to the postfix notation. # only add operator if at least 2 values are added. Needed because e.g. multiple consecutive additions are one subtree with one operator, but multiple operators need to be added to the postfix notation.
# For the case where another expression has already been added, we check if we are at the first iteration or not ( j != 2) # For the case where another expression has already been added to the final postfix notation, we check if we are at the first iteration or not ( j != 2)
if length(postfix) >= 2 && j != 2 if length(postfix) >= 2 && j != 2
exprElement = convert_to_ExpressionElement(operator) exprElement = convert_to_ExpressionElement(operator)
push!(postfix, exprElement) push!(postfix, exprElement)
@ -74,6 +89,8 @@ function get_operator(op::Symbol)::Operator
return DIVIDE return DIVIDE
elseif op == :^ elseif op == :^
return POWER return POWER
elseif op == :powabs
return POWER # TODO: Fix this
elseif op == :abs elseif op == :abs
return ABS return ABS
elseif op == :log elseif op == :log
@ -82,8 +99,6 @@ function get_operator(op::Symbol)::Operator
return EXP return EXP
elseif op == :sqrt elseif op == :sqrt
return SQRT return SQRT
elseif op == :powabs
return POWER # TODO: Fix this
else else
throw("Operator unknown") throw("Operator unknown")
end end
@ -103,14 +118,30 @@ function convert_var_to_int(var::Symbol)::Int32
return number return number
end end
function convert_to_ExpressionElement(element::Int32)::ExpressionElement "parses a symbol to be either a variable or a parameter and returns the corresponding Expressionelement"
value = reinterpret(Int32, element) function convert_to_ExpressionElement(element::Symbol)::ExpressionElement
return ExpressionElement(INDEX, value) varStr = String(element)
index = parse(Int32, SubString(varStr, 2))
if varStr[1] == 'x'
return ExpressionElement(VARIABLE, index)
elseif varStr[1] == 'p'
return ExpressionElement(PARAMETER, index)
else
throw("Cannot parse symbol to be either a variable or a parameter. Symbol was '$varStr'")
end
end end
function convert_to_ExpressionElement(element::Int64)::ExpressionElement "parses a symbol to be either a variable or a parameter and returns the corresponding Expressionelement"
value = reinterpret(Int32, convert(Int32, element)) function convert_to_ExpressionElement(element::Symbol, index::Integer)::ExpressionElement
return ExpressionElement(INDEX, value) if element == :x
return ExpressionElement(VARIABLE, convert(Int32, index))
elseif element == :p
return ExpressionElement(PARAMETER, convert(Int32, index))
else
throw("Cannot parse symbol to be either a variable or a parameter. Symbol was '$varStr'")
end
end end
function convert_to_ExpressionElement(element::Float32)::ExpressionElement function convert_to_ExpressionElement(element::Float32)::ExpressionElement
value = reinterpret(Int32, element) value = reinterpret(Int32, element)
return ExpressionElement(FLOAT32, value) return ExpressionElement(FLOAT32, value)

View File

@ -64,24 +64,20 @@ function interpret_expression(expressions::CuDeviceArray{ExpressionElement}, var
@inbounds firstVariableIndex = ((varSetIndex-1) * stepsize[3]) # Exclusive @inbounds firstVariableIndex = ((varSetIndex-1) * stepsize[3]) # Exclusive
@inbounds for i in firstExprIndex:lastExprIndex @inbounds for i in firstExprIndex:lastExprIndex
expr = expressions[i] token = expressions[i]
if expr.Type == EMPTY if token.Type == EMPTY
break break
elseif expr.Type == INDEX elseif token.Type == VARIABLE
val = expr.Value
operationStackTop += 1 operationStackTop += 1
operationStack[operationStackTop] = variables[firstVariableIndex + token.Value]
if val > 0 elseif token.Type == PARAMETER
operationStack[operationStackTop] = variables[firstVariableIndex + val]
else
val = abs(val)
operationStack[operationStackTop] = parameters[firstParamIndex + val]
end
elseif expr.Type == FLOAT32
operationStackTop += 1 operationStackTop += 1
operationStack[operationStackTop] = reinterpret(Float32, expr.Value) operationStack[operationStackTop] = parameters[firstParamIndex + token.Value]
elseif expr.Type == OPERATOR elseif token.Type == FLOAT32
opcode = reinterpret(Operator, expr.Value) operationStackTop += 1
operationStack[operationStackTop] = reinterpret(Float32, token.Value)
elseif token.Type == OPERATOR
opcode = reinterpret(Operator, token.Value)
if opcode == ADD if opcode == ADD
operationStackTop -= 1 operationStackTop -= 1
operationStack[operationStackTop] = operationStack[operationStackTop] + operationStack[operationStackTop + 1] operationStack[operationStackTop] = operationStack[operationStackTop] + operationStack[operationStackTop + 1]

View File

@ -220,21 +220,20 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType,
println(codeBuffer, operation) println(codeBuffer, operation)
push!(operands, resultRegister) push!(operands, resultRegister)
elseif token.Type == INDEX elseif token.Type == VARIABLE
if token.Value > 0 # varaibles var, first_access = Utils.get_register_for_name(regManager, "x$(token.Value)")
var, first_access = Utils.get_register_for_name(regManager, "x$(token.Value)") if first_access
if first_access println(codeBuffer, load_into_register(var, variablesLocation, token.Value, threadId64Reg, variablesSetSize, regManager))
println(codeBuffer, load_into_register(var, variablesLocation, token.Value, threadId64Reg, variablesSetSize, regManager))
end
push!(operands, var)
else
absVal = abs(token.Value)
param, first_access = Utils.get_register_for_name(regManager, "p$absVal")
if first_access
println(codeBuffer, load_into_register(param, parametersLocation, absVal, exprId64Reg, parametersSetSize, regManager))
end
push!(operands, param)
end end
push!(operands, var)
elseif token.Type == PARAMETER
param, first_access = Utils.get_register_for_name(regManager, "p$(token.Value)")
if first_access
println(codeBuffer, load_into_register(param, parametersLocation, token.Value, exprId64Reg, parametersSetSize, regManager))
end
push!(operands, param)
else
throw("Token unkown. Token was '$(token)'")
end end
end end

View File

@ -45,10 +45,10 @@ end
# LinearAlgebra.BLAS.set_num_threads(1) # only use a single thread for peakflops # LinearAlgebra.BLAS.set_num_threads(1) # only use a single thread for peakflops
@test test_cpu_interpreter(1000) # @test test_cpu_interpreter(1000)
@test test_cpu_interpreter(1000, parallel=true) # start julia -t 6 for six threads # @test test_cpu_interpreter(1000, parallel=true) # start julia -t 6 for six threads
@test test_cpu_interpreter(10000) # @test test_cpu_interpreter(10000)
@test test_cpu_interpreter(10000, parallel=true) # @test test_cpu_interpreter(10000, parallel=true)
function test_cpu_interpreter_nikuradse() function test_cpu_interpreter_nikuradse()
@ -62,14 +62,25 @@ function test_cpu_interpreter_nikuradse()
# data/esr_nvar2_len10.txt.gz_9.txt.gz has ~250_000 exprs # data/esr_nvar2_len10.txt.gz_9.txt.gz has ~250_000 exprs
# data/esr_nvar2_len10.txt.gz_10.txt.gz has ~800_000 exrps # data/esr_nvar2_len10.txt.gz_10.txt.gz has ~800_000 exrps
GZip.open("data/esr_nvar2_len10.txt.gz_9.txt.gz") do io GZip.open("data/esr_nvar2_len10.txt.gz_9.txt.gz") do io
i = 0
for line in eachline(io) for line in eachline(io)
expr, p = parse_infix(line, varnames, paramnames) expr, p = parse_infix(line, varnames, paramnames)
if i > 10
return
end
println(expr)
push!(exprs, expr) push!(exprs, expr)
push!(parameters, randn(Float32, length(p))) push!(parameters, randn(Float32, length(p)))
i += 1
end end
end end
interpret_cpu(exprs, X, parameters) # TODO: sufficient to do up to 10 repetitions per expression, interpret_cpu(exprs, X, parameters) # TODO: sufficient to do up to 10 repetitions per expression,
end end
@test test_cpu_interpreter_nikuradse()

View File

@ -1,35 +1,30 @@
using .ExpressionProcessing using .ExpressionProcessing
expressions = Vector{Expr}(undef, 1) expressions = Vector{Expr}(undef, 2)
variables = Matrix{Float32}(undef, 1,2)
parameters = Vector{Vector{Float32}}(undef, 1)
# Resulting value should be 10
expressions[1] = :(x1 + 1 * x2 + p1) expressions[1] = :(x1 + 1 * x2 + p1)
variables[1,1] = 2 expressions[2] = :(x[1] + 1 * x[2] + p[1])
variables[1,2] = 3
parameters[1] = Vector{Float32}(undef, 1)
parameters[1][1] = 5
@testset "Test conversion expression element" begin @testset "Test conversion expression element" begin
reference1 = ExpressionElement(FLOAT32, reinterpret(Int32, 1f0)) reference1 = ExpressionElement(FLOAT32, reinterpret(Int32, 1f0))
reference2 = ExpressionElement(INDEX, reinterpret(Int32, Int32(1))) reference2 = ExpressionElement(VARIABLE, Int32(1))
reference3 = ExpressionElement(OPERATOR, reinterpret(Int32, ADD)) reference3 = ExpressionElement(OPERATOR, reinterpret(Int32, ADD))
@test isequal(reference1, ExpressionProcessing.convert_to_ExpressionElement(1.0)) @test isequal(reference1, ExpressionProcessing.convert_to_ExpressionElement(1.0))
@test isequal(reference2, ExpressionProcessing.convert_to_ExpressionElement(1)) @test isequal(reference2, ExpressionProcessing.convert_to_ExpressionElement(:x1))
@test isequal(reference3, ExpressionProcessing.convert_to_ExpressionElement(ADD)) @test isequal(reference3, ExpressionProcessing.convert_to_ExpressionElement(ADD))
end end
@testset "Test conversion to postfix" begin @testset "Test conversion to postfix" begin
reference = PostfixType() reference = PostfixType()
append!(reference, [ExpressionProcessing.convert_to_ExpressionElement(1), ExpressionProcessing.convert_to_ExpressionElement(1.0), ExpressionProcessing.convert_to_ExpressionElement(2), ExpressionProcessing.convert_to_ExpressionElement(MULTIPLY), append!(reference, [ExpressionProcessing.convert_to_ExpressionElement(:x1), ExpressionProcessing.convert_to_ExpressionElement(1.0), ExpressionProcessing.convert_to_ExpressionElement(:x2), ExpressionProcessing.convert_to_ExpressionElement(MULTIPLY),
ExpressionProcessing.convert_to_ExpressionElement(ADD), ExpressionProcessing.convert_to_ExpressionElement(-1), ExpressionProcessing.convert_to_ExpressionElement(ADD)]) ExpressionProcessing.convert_to_ExpressionElement(ADD), ExpressionProcessing.convert_to_ExpressionElement(:p1), ExpressionProcessing.convert_to_ExpressionElement(ADD)])
cache = Dict{Expr, PostfixType}() postfixVarsAsSymbol = expr_to_postfix(expressions[1], Dict{Expr, PostfixType}())
postfix = expr_to_postfix(expressions[1], cache) postfixVarsAsArray = expr_to_postfix(expressions[2], Dict{Expr, PostfixType}())
@test isequal(reference, postfix) @test isequal(reference, postfixVarsAsSymbol)
@test isequal(reference, postfixVarsAsArray)
# TODO: Do more complex expressions because these have led to errors in the past # TODO: Do more complex expressions because these have led to errors in the past
end end

View File

@ -2,6 +2,8 @@
BenchmarkPlots = "ab8c0f59-4072-4e0d-8f91-a91e1495eb26" BenchmarkPlots = "ab8c0f59-4072-4e0d-8f91-a91e1495eb26"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"

View File

@ -16,9 +16,9 @@ include(joinpath(baseFolder, "src", "Transpiler.jl"))
end end
# @testset "CPU Interpreter" begin @testset "CPU Interpreter" begin
# include("CpuInterpreterTests.jl") # include("CpuInterpreterTests.jl")
# end end
@testset "Performance tests" begin @testset "Performance tests" begin
# include("PerformanceTuning.jl") # include("PerformanceTuning.jl")