rewrote function for generating code for operators. now the entire operation will be returned and not just the operator
Some checks failed
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Has been cancelled

This commit is contained in:
Daniel 2024-12-10 22:58:18 +01:00
parent 67ef9a5139
commit 8d129dbfcc
3 changed files with 72 additions and 36 deletions

View File

@ -1,6 +1,6 @@
module ExpressionProcessing module ExpressionProcessing
export expr_to_postfix export expr_to_postfix, is_binary_operator
export PostfixType export PostfixType
export Operator, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, ABS, LOG, EXP, SQRT export Operator, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, ABS, LOG, EXP, SQRT
export ElementType, EMPTY, FLOAT32, OPERATOR, INDEX export ElementType, EMPTY, FLOAT32, OPERATOR, INDEX
@ -108,6 +108,30 @@ function convert_to_ExpressionElement(element::Operator)::ExpressionElement
return ExpressionElement(OPERATOR, value) return ExpressionElement(OPERATOR, value)
end end
function is_binary_operator(operator::Operator)::Bool
if operator == ADD
return true
elseif operator == SUBTRACT
return true
elseif operator == MULTIPLY
return true
elseif operator == DIVIDE
return true
elseif operator == POWER
return true
elseif operator == ABS
return false
elseif operator == LOG
return false
elseif operator == EXP
return false
elseif operator == SQRT
return false
else
throw(ArgumentError("Unknown operator '$operator'. Cannot determine if it is binary or not."))
end
end
# #
# Everything below is currently not needed. Left here for potential future use # Everything below is currently not needed. Left here for potential future use
# #

View File

@ -23,6 +23,8 @@ using ..ExpressionProcessing
# Note: Maybe make an additional function that transpiles and executed the code. This would then be the function the user calls # Note: Maybe make an additional function that transpiles and executed the code. This would then be the function the user calls
# #
const Operand = Union{Float32, String} # Operand is either fixed value or register
# To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string # To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string
function transpile(expression::ExpressionProcessing.PostfixType)::String function transpile(expression::ExpressionProcessing.PostfixType)::String
exitJumpLocationMarker = "\$L__BB0_2" exitJumpLocationMarker = "\$L__BB0_2"
@ -112,7 +114,7 @@ end
# Current assumption: Expression only made out of constant values # Current assumption: Expression only made out of constant values
function generate_calculation_code(expression::ExpressionProcessing.PostfixType)::String function generate_calculation_code(expression::ExpressionProcessing.PostfixType)::String
codeBuffer = IOBuffer() codeBuffer = IOBuffer()
operands = Vector{Union{Float32, String}}() operands = Vector{Operand}()
println(expression) println(expression)
for i in eachindex(expression) for i in eachindex(expression)
@ -121,17 +123,19 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType)
if token.Type == FLOAT32 if token.Type == FLOAT32
push!(operands, reinterpret(Float32, token.Value)) push!(operands, reinterpret(Float32, token.Value))
elseif token.Type == OPERATOR elseif token.Type == OPERATOR
# get_ptx_operator will be reworked completly to return the code for the operation operator = reinterpret(Operator, token.Value)
operator = get_ptx_operator(reinterpret(Operator, token.Value))
register = get_next_free_register("f")
print(codeBuffer, " $operator $register, ")
ops = last(operands, 2) right = nothing
pop!(operands);pop!(operands) if is_binary_operator(operator)
print(codeBuffer, join(ops, ", ")) right = pop!(operands)
println(codeBuffer, ";") left = pop!(operands)
else
left = pop!(operands)
end
operation, resultRegister = get_operation(operator, left, right)
push!(operands, register) println(codeBuffer, operation)
push!(operands, resultRegister)
elseif token.Type == INDEX elseif token.Type == INDEX
# TODO # TODO
end end
@ -150,29 +154,37 @@ function type_to_ptx_type(type::DataType)::String
end end
end end
# TODO: Change this, to return the entire calculation not just the operator. Because for POWER and EXP we need multiple instructions to calculate them (seperation of concerns). function get_operation(operator::Operator, left::Operand, right::Union{Operand, Nothing} = nothing)::Tuple{String, String}
function get_ptx_operator(operator::Operator)::String resultRegister = get_next_free_register("f")
if operator == ADD resultCode = ""
return "add.f32"
elseif operator == SUBTRACT if is_binary_operator(operator) && isnothing(right)
return "sub.f32" throw(ArgumentError("Given operator '$operator' is a binary operator. However only one operator has been given."))
elseif operator == MULTIPLY
return "mul.f32"
elseif operator == DIVIDE
return "div.approx.f32"
elseif operator == POWER
return ""
elseif operator == ABS
return "abs.f32"
elseif operator == LOG
return "lg2.approx.f32"
elseif operator == EXP
return ""
elseif operator == SQRT
return "sqrt.approx.f32"
else
throw(ArgumentError("Operator conversion to ptx not implemented for $operator"))
end end
if operator == ADD
resultCode = "add.f32 $resultRegister, $left, $right;"
elseif operator == SUBTRACT
resultCode = "sub.f32 $resultRegister, $left, $right;"
elseif operator == MULTIPLY
resultCode = "mul.f32 $resultRegister, $left, $right;"
elseif operator == DIVIDE
resultCode = "div.approx.f32 $resultRegister, $left, $right;"
elseif operator == POWER
resultCode = " $resultRegister, $left;" # TODO
elseif operator == ABS
resultCode = "abs.f32 $resultRegister, $left;"
elseif operator == LOG
resultCode = "lg2.approx.f32 $resultRegister, $left;"
elseif operator == EXP
resultCode = " $resultRegister, $left;" # TODO
elseif operator == SQRT
resultCode = "sqrt.approx.f32 $resultRegister, $left;"
else
throw(ArgumentError("Operator conversion to ptx not implemented for '$operator'"))
end
return (resultCode, resultRegister)
end end
let registers = Dict() # stores the count of the register already used. let registers = Dict() # stores the count of the register already used.

View File

@ -6,8 +6,8 @@ expressions = Vector{Expr}(undef, 2)
variables = Matrix{Float32}(undef, 2,2) variables = Matrix{Float32}(undef, 2,2)
parameters = Vector{Vector{Float32}}(undef, 2) parameters = Vector{Vector{Float32}}(undef, 2)
# Resulting value should be 10 for the first expression # Resulting value should be 1.14... for the first expression
expressions[1] = :(1 + 3 * 5 / 7 - 1) expressions[1] = :(1 + 3 * 5 / 7 - sqrt(4))
expressions[2] = :(5 + x1 + 1 * x2 + p1 + p2) expressions[2] = :(5 + x1 + 1 * x2 + p1 + p2)
variables[1,1] = 2.0 variables[1,1] = 2.0
variables[2,1] = 3.0 variables[2,1] = 3.0