added support for variables and parameters as array. also improved conversion of variables and parameters into Expressionelement
Some checks are pending
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Waiting to run
Some checks are pending
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Waiting to run
This commit is contained in:
parent
aaa3f2c7c0
commit
2c8a9cd2d8
|
@ -3,11 +3,11 @@ module ExpressionProcessing
|
|||
export expr_to_postfix, is_binary_operator
|
||||
export PostfixType
|
||||
export Operator, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, ABS, LOG, EXP, SQRT
|
||||
export ElementType, EMPTY, FLOAT32, OPERATOR, INDEX
|
||||
export ElementType, EMPTY, FLOAT32, OPERATOR, VARIABLE, PARAMETER
|
||||
export ExpressionElement
|
||||
|
||||
@enum Operator ADD=1 SUBTRACT=2 MULTIPLY=3 DIVIDE=4 POWER=5 ABS=6 LOG=7 EXP=8 SQRT=9
|
||||
@enum ElementType EMPTY=0 FLOAT32=1 OPERATOR=2 INDEX=3
|
||||
@enum ElementType EMPTY=0 FLOAT32=1 OPERATOR=2 VARIABLE=3 PARAMETER=4
|
||||
|
||||
const binary_operators = [ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER]
|
||||
const unary_operators = [ABS, LOG, EXP, SQRT]
|
||||
|
@ -24,12 +24,28 @@ Converts a julia expression to its postfix notation.
|
|||
NOTE: All 64-Bit values will be converted to 32-Bit. Be aware of the lost precision.
|
||||
NOTE: This function is not thread save, especially cache access is not thread save
|
||||
"
|
||||
function expr_to_postfix(expr::Expr, cache::Dict{Expr, PostfixType})::PostfixType
|
||||
function expr_to_postfix(expression::Expr, cache::Dict{Expr, PostfixType})::PostfixType
|
||||
expr = expression
|
||||
if expression.head === :->
|
||||
# if the expression equals (x, p) -> (...) then the below statement extracts the expression to evaluate
|
||||
expr = expression.args[2].args[2]
|
||||
end
|
||||
|
||||
if haskey(cache, expr)
|
||||
return cache[expr]
|
||||
end
|
||||
|
||||
postfix = PostfixType()
|
||||
|
||||
# Special handling in the case where the expression is an array access
|
||||
# This can happen if the token is a variable/parameter of the form x[n]/p[n]
|
||||
if expr.head == :ref
|
||||
exprElement = convert_to_ExpressionElement(expr.args[1], expr.args[2]) # we assume that an array access never contains an expression, as this would make not much sense in this case
|
||||
push!(postfix, exprElement)
|
||||
cache[expr] = postfix
|
||||
return postfix
|
||||
end
|
||||
|
||||
@inbounds operator = get_operator(expr.args[1])
|
||||
|
||||
@inbounds for j in 2:length(expr.args)
|
||||
|
@ -37,9 +53,8 @@ function expr_to_postfix(expr::Expr, cache::Dict{Expr, PostfixType})::PostfixTyp
|
|||
|
||||
if typeof(arg) === Expr
|
||||
append!(postfix, expr_to_postfix(arg, cache))
|
||||
elseif typeof(arg) === Symbol # variables/parameters
|
||||
# maybe TODO: replace the parameters with their respective values, as this might make the expr evaluation faster
|
||||
exprElement = convert_to_ExpressionElement(convert_var_to_int(arg))
|
||||
elseif typeof(arg) === Symbol # variables/parameters of the form xn/pn
|
||||
exprElement = convert_to_ExpressionElement(arg)
|
||||
push!(postfix, exprElement)
|
||||
else
|
||||
exprElement = convert_to_ExpressionElement(convert(Float32, arg))
|
||||
|
@ -47,7 +62,7 @@ function expr_to_postfix(expr::Expr, cache::Dict{Expr, PostfixType})::PostfixTyp
|
|||
end
|
||||
|
||||
# only add operator if at least 2 values are added. Needed because e.g. multiple consecutive additions are one subtree with one operator, but multiple operators need to be added to the postfix notation.
|
||||
# For the case where another expression has already been added, we check if we are at the first iteration or not ( j != 2)
|
||||
# For the case where another expression has already been added to the final postfix notation, we check if we are at the first iteration or not ( j != 2)
|
||||
if length(postfix) >= 2 && j != 2
|
||||
exprElement = convert_to_ExpressionElement(operator)
|
||||
push!(postfix, exprElement)
|
||||
|
@ -74,6 +89,8 @@ function get_operator(op::Symbol)::Operator
|
|||
return DIVIDE
|
||||
elseif op == :^
|
||||
return POWER
|
||||
elseif op == :powabs
|
||||
return POWER # TODO: Fix this
|
||||
elseif op == :abs
|
||||
return ABS
|
||||
elseif op == :log
|
||||
|
@ -82,8 +99,6 @@ function get_operator(op::Symbol)::Operator
|
|||
return EXP
|
||||
elseif op == :sqrt
|
||||
return SQRT
|
||||
elseif op == :powabs
|
||||
return POWER # TODO: Fix this
|
||||
else
|
||||
throw("Operator unknown")
|
||||
end
|
||||
|
@ -103,14 +118,30 @@ function convert_var_to_int(var::Symbol)::Int32
|
|||
return number
|
||||
end
|
||||
|
||||
function convert_to_ExpressionElement(element::Int32)::ExpressionElement
|
||||
value = reinterpret(Int32, element)
|
||||
return ExpressionElement(INDEX, value)
|
||||
"parses a symbol to be either a variable or a parameter and returns the corresponding Expressionelement"
|
||||
function convert_to_ExpressionElement(element::Symbol)::ExpressionElement
|
||||
varStr = String(element)
|
||||
index = parse(Int32, SubString(varStr, 2))
|
||||
|
||||
if varStr[1] == 'x'
|
||||
return ExpressionElement(VARIABLE, index)
|
||||
elseif varStr[1] == 'p'
|
||||
return ExpressionElement(PARAMETER, index)
|
||||
else
|
||||
throw("Cannot parse symbol to be either a variable or a parameter. Symbol was '$varStr'")
|
||||
end
|
||||
function convert_to_ExpressionElement(element::Int64)::ExpressionElement
|
||||
value = reinterpret(Int32, convert(Int32, element))
|
||||
return ExpressionElement(INDEX, value)
|
||||
end
|
||||
"parses a symbol to be either a variable or a parameter and returns the corresponding Expressionelement"
|
||||
function convert_to_ExpressionElement(element::Symbol, index::Integer)::ExpressionElement
|
||||
if element == :x
|
||||
return ExpressionElement(VARIABLE, convert(Int32, index))
|
||||
elseif element == :p
|
||||
return ExpressionElement(PARAMETER, convert(Int32, index))
|
||||
else
|
||||
throw("Cannot parse symbol to be either a variable or a parameter. Symbol was '$varStr'")
|
||||
end
|
||||
end
|
||||
|
||||
function convert_to_ExpressionElement(element::Float32)::ExpressionElement
|
||||
value = reinterpret(Int32, element)
|
||||
return ExpressionElement(FLOAT32, value)
|
||||
|
|
|
@ -64,24 +64,20 @@ function interpret_expression(expressions::CuDeviceArray{ExpressionElement}, var
|
|||
@inbounds firstVariableIndex = ((varSetIndex-1) * stepsize[3]) # Exclusive
|
||||
|
||||
@inbounds for i in firstExprIndex:lastExprIndex
|
||||
expr = expressions[i]
|
||||
if expr.Type == EMPTY
|
||||
token = expressions[i]
|
||||
if token.Type == EMPTY
|
||||
break
|
||||
elseif expr.Type == INDEX
|
||||
val = expr.Value
|
||||
elseif token.Type == VARIABLE
|
||||
operationStackTop += 1
|
||||
|
||||
if val > 0
|
||||
operationStack[operationStackTop] = variables[firstVariableIndex + val]
|
||||
else
|
||||
val = abs(val)
|
||||
operationStack[operationStackTop] = parameters[firstParamIndex + val]
|
||||
end
|
||||
elseif expr.Type == FLOAT32
|
||||
operationStack[operationStackTop] = variables[firstVariableIndex + token.Value]
|
||||
elseif token.Type == PARAMETER
|
||||
operationStackTop += 1
|
||||
operationStack[operationStackTop] = reinterpret(Float32, expr.Value)
|
||||
elseif expr.Type == OPERATOR
|
||||
opcode = reinterpret(Operator, expr.Value)
|
||||
operationStack[operationStackTop] = parameters[firstParamIndex + token.Value]
|
||||
elseif token.Type == FLOAT32
|
||||
operationStackTop += 1
|
||||
operationStack[operationStackTop] = reinterpret(Float32, token.Value)
|
||||
elseif token.Type == OPERATOR
|
||||
opcode = reinterpret(Operator, token.Value)
|
||||
if opcode == ADD
|
||||
operationStackTop -= 1
|
||||
operationStack[operationStackTop] = operationStack[operationStackTop] + operationStack[operationStackTop + 1]
|
||||
|
|
|
@ -220,21 +220,20 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType,
|
|||
|
||||
println(codeBuffer, operation)
|
||||
push!(operands, resultRegister)
|
||||
elseif token.Type == INDEX
|
||||
if token.Value > 0 # varaibles
|
||||
elseif token.Type == VARIABLE
|
||||
var, first_access = Utils.get_register_for_name(regManager, "x$(token.Value)")
|
||||
if first_access
|
||||
println(codeBuffer, load_into_register(var, variablesLocation, token.Value, threadId64Reg, variablesSetSize, regManager))
|
||||
end
|
||||
push!(operands, var)
|
||||
else
|
||||
absVal = abs(token.Value)
|
||||
param, first_access = Utils.get_register_for_name(regManager, "p$absVal")
|
||||
elseif token.Type == PARAMETER
|
||||
param, first_access = Utils.get_register_for_name(regManager, "p$(token.Value)")
|
||||
if first_access
|
||||
println(codeBuffer, load_into_register(param, parametersLocation, absVal, exprId64Reg, parametersSetSize, regManager))
|
||||
println(codeBuffer, load_into_register(param, parametersLocation, token.Value, exprId64Reg, parametersSetSize, regManager))
|
||||
end
|
||||
push!(operands, param)
|
||||
end
|
||||
else
|
||||
throw("Token unkown. Token was '$(token)'")
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -45,10 +45,10 @@ end
|
|||
|
||||
# LinearAlgebra.BLAS.set_num_threads(1) # only use a single thread for peakflops
|
||||
|
||||
@test test_cpu_interpreter(1000)
|
||||
@test test_cpu_interpreter(1000, parallel=true) # start julia -t 6 for six threads
|
||||
@test test_cpu_interpreter(10000)
|
||||
@test test_cpu_interpreter(10000, parallel=true)
|
||||
# @test test_cpu_interpreter(1000)
|
||||
# @test test_cpu_interpreter(1000, parallel=true) # start julia -t 6 for six threads
|
||||
# @test test_cpu_interpreter(10000)
|
||||
# @test test_cpu_interpreter(10000, parallel=true)
|
||||
|
||||
|
||||
function test_cpu_interpreter_nikuradse()
|
||||
|
@ -62,14 +62,25 @@ function test_cpu_interpreter_nikuradse()
|
|||
# data/esr_nvar2_len10.txt.gz_9.txt.gz has ~250_000 exprs
|
||||
# data/esr_nvar2_len10.txt.gz_10.txt.gz has ~800_000 exrps
|
||||
GZip.open("data/esr_nvar2_len10.txt.gz_9.txt.gz") do io
|
||||
i = 0
|
||||
for line in eachline(io)
|
||||
expr, p = parse_infix(line, varnames, paramnames)
|
||||
|
||||
if i > 10
|
||||
return
|
||||
end
|
||||
println(expr)
|
||||
|
||||
push!(exprs, expr)
|
||||
push!(parameters, randn(Float32, length(p)))
|
||||
|
||||
i += 1
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
interpret_cpu(exprs, X, parameters) # TODO: sufficient to do up to 10 repetitions per expression,
|
||||
end
|
||||
|
||||
|
||||
@test test_cpu_interpreter_nikuradse()
|
||||
|
|
|
@ -1,35 +1,30 @@
|
|||
using .ExpressionProcessing
|
||||
|
||||
expressions = Vector{Expr}(undef, 1)
|
||||
variables = Matrix{Float32}(undef, 1,2)
|
||||
parameters = Vector{Vector{Float32}}(undef, 1)
|
||||
expressions = Vector{Expr}(undef, 2)
|
||||
|
||||
# Resulting value should be 10
|
||||
expressions[1] = :(x1 + 1 * x2 + p1)
|
||||
variables[1,1] = 2
|
||||
variables[1,2] = 3
|
||||
parameters[1] = Vector{Float32}(undef, 1)
|
||||
parameters[1][1] = 5
|
||||
expressions[2] = :(x[1] + 1 * x[2] + p[1])
|
||||
|
||||
@testset "Test conversion expression element" begin
|
||||
reference1 = ExpressionElement(FLOAT32, reinterpret(Int32, 1f0))
|
||||
reference2 = ExpressionElement(INDEX, reinterpret(Int32, Int32(1)))
|
||||
reference2 = ExpressionElement(VARIABLE, Int32(1))
|
||||
reference3 = ExpressionElement(OPERATOR, reinterpret(Int32, ADD))
|
||||
|
||||
@test isequal(reference1, ExpressionProcessing.convert_to_ExpressionElement(1.0))
|
||||
@test isequal(reference2, ExpressionProcessing.convert_to_ExpressionElement(1))
|
||||
@test isequal(reference2, ExpressionProcessing.convert_to_ExpressionElement(:x1))
|
||||
@test isequal(reference3, ExpressionProcessing.convert_to_ExpressionElement(ADD))
|
||||
end
|
||||
|
||||
@testset "Test conversion to postfix" begin
|
||||
reference = PostfixType()
|
||||
|
||||
append!(reference, [ExpressionProcessing.convert_to_ExpressionElement(1), ExpressionProcessing.convert_to_ExpressionElement(1.0), ExpressionProcessing.convert_to_ExpressionElement(2), ExpressionProcessing.convert_to_ExpressionElement(MULTIPLY),
|
||||
ExpressionProcessing.convert_to_ExpressionElement(ADD), ExpressionProcessing.convert_to_ExpressionElement(-1), ExpressionProcessing.convert_to_ExpressionElement(ADD)])
|
||||
cache = Dict{Expr, PostfixType}()
|
||||
postfix = expr_to_postfix(expressions[1], cache)
|
||||
append!(reference, [ExpressionProcessing.convert_to_ExpressionElement(:x1), ExpressionProcessing.convert_to_ExpressionElement(1.0), ExpressionProcessing.convert_to_ExpressionElement(:x2), ExpressionProcessing.convert_to_ExpressionElement(MULTIPLY),
|
||||
ExpressionProcessing.convert_to_ExpressionElement(ADD), ExpressionProcessing.convert_to_ExpressionElement(:p1), ExpressionProcessing.convert_to_ExpressionElement(ADD)])
|
||||
postfixVarsAsSymbol = expr_to_postfix(expressions[1], Dict{Expr, PostfixType}())
|
||||
postfixVarsAsArray = expr_to_postfix(expressions[2], Dict{Expr, PostfixType}())
|
||||
|
||||
@test isequal(reference, postfix)
|
||||
@test isequal(reference, postfixVarsAsSymbol)
|
||||
@test isequal(reference, postfixVarsAsArray)
|
||||
|
||||
# TODO: Do more complex expressions because these have led to errors in the past
|
||||
end
|
|
@ -2,6 +2,8 @@
|
|||
BenchmarkPlots = "ab8c0f59-4072-4e0d-8f91-a91e1495eb26"
|
||||
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
|
||||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
|
||||
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
||||
GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63"
|
||||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
|
||||
|
|
|
@ -16,9 +16,9 @@ include(joinpath(baseFolder, "src", "Transpiler.jl"))
|
|||
end
|
||||
|
||||
|
||||
# @testset "CPU Interpreter" begin
|
||||
@testset "CPU Interpreter" begin
|
||||
# include("CpuInterpreterTests.jl")
|
||||
# end
|
||||
end
|
||||
|
||||
@testset "Performance tests" begin
|
||||
# include("PerformanceTuning.jl")
|
||||
|
|
Loading…
Reference in New Issue
Block a user