rewrote function for generating code for operators. now the entire operation will be returned and not just the operator
Some checks failed
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Has been cancelled
Some checks failed
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Has been cancelled
This commit is contained in:
parent
67ef9a5139
commit
8d129dbfcc
|
@ -1,6 +1,6 @@
|
|||
module ExpressionProcessing
|
||||
|
||||
export expr_to_postfix
|
||||
export expr_to_postfix, is_binary_operator
|
||||
export PostfixType
|
||||
export Operator, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, ABS, LOG, EXP, SQRT
|
||||
export ElementType, EMPTY, FLOAT32, OPERATOR, INDEX
|
||||
|
@ -108,6 +108,30 @@ function convert_to_ExpressionElement(element::Operator)::ExpressionElement
|
|||
return ExpressionElement(OPERATOR, value)
|
||||
end
|
||||
|
||||
function is_binary_operator(operator::Operator)::Bool
|
||||
if operator == ADD
|
||||
return true
|
||||
elseif operator == SUBTRACT
|
||||
return true
|
||||
elseif operator == MULTIPLY
|
||||
return true
|
||||
elseif operator == DIVIDE
|
||||
return true
|
||||
elseif operator == POWER
|
||||
return true
|
||||
elseif operator == ABS
|
||||
return false
|
||||
elseif operator == LOG
|
||||
return false
|
||||
elseif operator == EXP
|
||||
return false
|
||||
elseif operator == SQRT
|
||||
return false
|
||||
else
|
||||
throw(ArgumentError("Unknown operator '$operator'. Cannot determine if it is binary or not."))
|
||||
end
|
||||
end
|
||||
|
||||
#
|
||||
# Everything below is currently not needed. Left here for potential future use
|
||||
#
|
||||
|
|
|
@ -23,6 +23,8 @@ using ..ExpressionProcessing
|
|||
# Note: Maybe make an additional function that transpiles and executed the code. This would then be the function the user calls
|
||||
#
|
||||
|
||||
const Operand = Union{Float32, String} # Operand is either fixed value or register
|
||||
|
||||
# To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string
|
||||
function transpile(expression::ExpressionProcessing.PostfixType)::String
|
||||
exitJumpLocationMarker = "\$L__BB0_2"
|
||||
|
@ -112,7 +114,7 @@ end
|
|||
# Current assumption: Expression only made out of constant values
|
||||
function generate_calculation_code(expression::ExpressionProcessing.PostfixType)::String
|
||||
codeBuffer = IOBuffer()
|
||||
operands = Vector{Union{Float32, String}}()
|
||||
operands = Vector{Operand}()
|
||||
|
||||
println(expression)
|
||||
for i in eachindex(expression)
|
||||
|
@ -121,17 +123,19 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType)
|
|||
if token.Type == FLOAT32
|
||||
push!(operands, reinterpret(Float32, token.Value))
|
||||
elseif token.Type == OPERATOR
|
||||
# get_ptx_operator will be reworked completly to return the code for the operation
|
||||
operator = get_ptx_operator(reinterpret(Operator, token.Value))
|
||||
register = get_next_free_register("f")
|
||||
print(codeBuffer, " $operator $register, ")
|
||||
operator = reinterpret(Operator, token.Value)
|
||||
|
||||
ops = last(operands, 2)
|
||||
pop!(operands);pop!(operands)
|
||||
print(codeBuffer, join(ops, ", "))
|
||||
println(codeBuffer, ";")
|
||||
right = nothing
|
||||
if is_binary_operator(operator)
|
||||
right = pop!(operands)
|
||||
left = pop!(operands)
|
||||
else
|
||||
left = pop!(operands)
|
||||
end
|
||||
operation, resultRegister = get_operation(operator, left, right)
|
||||
|
||||
push!(operands, register)
|
||||
println(codeBuffer, operation)
|
||||
push!(operands, resultRegister)
|
||||
elseif token.Type == INDEX
|
||||
# TODO
|
||||
end
|
||||
|
@ -150,29 +154,37 @@ function type_to_ptx_type(type::DataType)::String
|
|||
end
|
||||
end
|
||||
|
||||
# TODO: Change this, to return the entire calculation not just the operator. Because for POWER and EXP we need multiple instructions to calculate them (seperation of concerns).
|
||||
function get_ptx_operator(operator::Operator)::String
|
||||
if operator == ADD
|
||||
return "add.f32"
|
||||
elseif operator == SUBTRACT
|
||||
return "sub.f32"
|
||||
elseif operator == MULTIPLY
|
||||
return "mul.f32"
|
||||
elseif operator == DIVIDE
|
||||
return "div.approx.f32"
|
||||
elseif operator == POWER
|
||||
return ""
|
||||
elseif operator == ABS
|
||||
return "abs.f32"
|
||||
elseif operator == LOG
|
||||
return "lg2.approx.f32"
|
||||
elseif operator == EXP
|
||||
return ""
|
||||
elseif operator == SQRT
|
||||
return "sqrt.approx.f32"
|
||||
else
|
||||
throw(ArgumentError("Operator conversion to ptx not implemented for $operator"))
|
||||
function get_operation(operator::Operator, left::Operand, right::Union{Operand, Nothing} = nothing)::Tuple{String, String}
|
||||
resultRegister = get_next_free_register("f")
|
||||
resultCode = ""
|
||||
|
||||
if is_binary_operator(operator) && isnothing(right)
|
||||
throw(ArgumentError("Given operator '$operator' is a binary operator. However only one operator has been given."))
|
||||
end
|
||||
|
||||
if operator == ADD
|
||||
resultCode = "add.f32 $resultRegister, $left, $right;"
|
||||
elseif operator == SUBTRACT
|
||||
resultCode = "sub.f32 $resultRegister, $left, $right;"
|
||||
elseif operator == MULTIPLY
|
||||
resultCode = "mul.f32 $resultRegister, $left, $right;"
|
||||
elseif operator == DIVIDE
|
||||
resultCode = "div.approx.f32 $resultRegister, $left, $right;"
|
||||
elseif operator == POWER
|
||||
resultCode = " $resultRegister, $left;" # TODO
|
||||
elseif operator == ABS
|
||||
resultCode = "abs.f32 $resultRegister, $left;"
|
||||
elseif operator == LOG
|
||||
resultCode = "lg2.approx.f32 $resultRegister, $left;"
|
||||
elseif operator == EXP
|
||||
resultCode = " $resultRegister, $left;" # TODO
|
||||
elseif operator == SQRT
|
||||
resultCode = "sqrt.approx.f32 $resultRegister, $left;"
|
||||
else
|
||||
throw(ArgumentError("Operator conversion to ptx not implemented for '$operator'"))
|
||||
end
|
||||
|
||||
return (resultCode, resultRegister)
|
||||
end
|
||||
|
||||
let registers = Dict() # stores the count of the register already used.
|
||||
|
|
|
@ -6,8 +6,8 @@ expressions = Vector{Expr}(undef, 2)
|
|||
variables = Matrix{Float32}(undef, 2,2)
|
||||
parameters = Vector{Vector{Float32}}(undef, 2)
|
||||
|
||||
# Resulting value should be 10 for the first expression
|
||||
expressions[1] = :(1 + 3 * 5 / 7 - 1)
|
||||
# Resulting value should be 1.14... for the first expression
|
||||
expressions[1] = :(1 + 3 * 5 / 7 - sqrt(4))
|
||||
expressions[2] = :(5 + x1 + 1 * x2 + p1 + p2)
|
||||
variables[1,1] = 2.0
|
||||
variables[2,1] = 3.0
|
||||
|
|
Loading…
Reference in New Issue
Block a user