rewrote function for generating code for operators. now the entire operation will be returned and not just the operator
Some checks failed
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Has been cancelled
Some checks failed
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Has been cancelled
This commit is contained in:
parent
67ef9a5139
commit
8d129dbfcc
|
@ -1,6 +1,6 @@
|
||||||
module ExpressionProcessing
|
module ExpressionProcessing
|
||||||
|
|
||||||
export expr_to_postfix
|
export expr_to_postfix, is_binary_operator
|
||||||
export PostfixType
|
export PostfixType
|
||||||
export Operator, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, ABS, LOG, EXP, SQRT
|
export Operator, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, ABS, LOG, EXP, SQRT
|
||||||
export ElementType, EMPTY, FLOAT32, OPERATOR, INDEX
|
export ElementType, EMPTY, FLOAT32, OPERATOR, INDEX
|
||||||
|
@ -108,6 +108,30 @@ function convert_to_ExpressionElement(element::Operator)::ExpressionElement
|
||||||
return ExpressionElement(OPERATOR, value)
|
return ExpressionElement(OPERATOR, value)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function is_binary_operator(operator::Operator)::Bool
|
||||||
|
if operator == ADD
|
||||||
|
return true
|
||||||
|
elseif operator == SUBTRACT
|
||||||
|
return true
|
||||||
|
elseif operator == MULTIPLY
|
||||||
|
return true
|
||||||
|
elseif operator == DIVIDE
|
||||||
|
return true
|
||||||
|
elseif operator == POWER
|
||||||
|
return true
|
||||||
|
elseif operator == ABS
|
||||||
|
return false
|
||||||
|
elseif operator == LOG
|
||||||
|
return false
|
||||||
|
elseif operator == EXP
|
||||||
|
return false
|
||||||
|
elseif operator == SQRT
|
||||||
|
return false
|
||||||
|
else
|
||||||
|
throw(ArgumentError("Unknown operator '$operator'. Cannot determine if it is binary or not."))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
#
|
#
|
||||||
# Everything below is currently not needed. Left here for potential future use
|
# Everything below is currently not needed. Left here for potential future use
|
||||||
#
|
#
|
||||||
|
|
|
@ -23,6 +23,8 @@ using ..ExpressionProcessing
|
||||||
# Note: Maybe make an additional function that transpiles and executed the code. This would then be the function the user calls
|
# Note: Maybe make an additional function that transpiles and executed the code. This would then be the function the user calls
|
||||||
#
|
#
|
||||||
|
|
||||||
|
const Operand = Union{Float32, String} # Operand is either fixed value or register
|
||||||
|
|
||||||
# To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string
|
# To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string
|
||||||
function transpile(expression::ExpressionProcessing.PostfixType)::String
|
function transpile(expression::ExpressionProcessing.PostfixType)::String
|
||||||
exitJumpLocationMarker = "\$L__BB0_2"
|
exitJumpLocationMarker = "\$L__BB0_2"
|
||||||
|
@ -112,7 +114,7 @@ end
|
||||||
# Current assumption: Expression only made out of constant values
|
# Current assumption: Expression only made out of constant values
|
||||||
function generate_calculation_code(expression::ExpressionProcessing.PostfixType)::String
|
function generate_calculation_code(expression::ExpressionProcessing.PostfixType)::String
|
||||||
codeBuffer = IOBuffer()
|
codeBuffer = IOBuffer()
|
||||||
operands = Vector{Union{Float32, String}}()
|
operands = Vector{Operand}()
|
||||||
|
|
||||||
println(expression)
|
println(expression)
|
||||||
for i in eachindex(expression)
|
for i in eachindex(expression)
|
||||||
|
@ -121,17 +123,19 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType)
|
||||||
if token.Type == FLOAT32
|
if token.Type == FLOAT32
|
||||||
push!(operands, reinterpret(Float32, token.Value))
|
push!(operands, reinterpret(Float32, token.Value))
|
||||||
elseif token.Type == OPERATOR
|
elseif token.Type == OPERATOR
|
||||||
# get_ptx_operator will be reworked completly to return the code for the operation
|
operator = reinterpret(Operator, token.Value)
|
||||||
operator = get_ptx_operator(reinterpret(Operator, token.Value))
|
|
||||||
register = get_next_free_register("f")
|
|
||||||
print(codeBuffer, " $operator $register, ")
|
|
||||||
|
|
||||||
ops = last(operands, 2)
|
right = nothing
|
||||||
pop!(operands);pop!(operands)
|
if is_binary_operator(operator)
|
||||||
print(codeBuffer, join(ops, ", "))
|
right = pop!(operands)
|
||||||
println(codeBuffer, ";")
|
left = pop!(operands)
|
||||||
|
else
|
||||||
push!(operands, register)
|
left = pop!(operands)
|
||||||
|
end
|
||||||
|
operation, resultRegister = get_operation(operator, left, right)
|
||||||
|
|
||||||
|
println(codeBuffer, operation)
|
||||||
|
push!(operands, resultRegister)
|
||||||
elseif token.Type == INDEX
|
elseif token.Type == INDEX
|
||||||
# TODO
|
# TODO
|
||||||
end
|
end
|
||||||
|
@ -150,29 +154,37 @@ function type_to_ptx_type(type::DataType)::String
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
# TODO: Change this, to return the entire calculation not just the operator. Because for POWER and EXP we need multiple instructions to calculate them (seperation of concerns).
|
function get_operation(operator::Operator, left::Operand, right::Union{Operand, Nothing} = nothing)::Tuple{String, String}
|
||||||
function get_ptx_operator(operator::Operator)::String
|
resultRegister = get_next_free_register("f")
|
||||||
if operator == ADD
|
resultCode = ""
|
||||||
return "add.f32"
|
|
||||||
elseif operator == SUBTRACT
|
if is_binary_operator(operator) && isnothing(right)
|
||||||
return "sub.f32"
|
throw(ArgumentError("Given operator '$operator' is a binary operator. However only one operator has been given."))
|
||||||
elseif operator == MULTIPLY
|
|
||||||
return "mul.f32"
|
|
||||||
elseif operator == DIVIDE
|
|
||||||
return "div.approx.f32"
|
|
||||||
elseif operator == POWER
|
|
||||||
return ""
|
|
||||||
elseif operator == ABS
|
|
||||||
return "abs.f32"
|
|
||||||
elseif operator == LOG
|
|
||||||
return "lg2.approx.f32"
|
|
||||||
elseif operator == EXP
|
|
||||||
return ""
|
|
||||||
elseif operator == SQRT
|
|
||||||
return "sqrt.approx.f32"
|
|
||||||
else
|
|
||||||
throw(ArgumentError("Operator conversion to ptx not implemented for $operator"))
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
if operator == ADD
|
||||||
|
resultCode = "add.f32 $resultRegister, $left, $right;"
|
||||||
|
elseif operator == SUBTRACT
|
||||||
|
resultCode = "sub.f32 $resultRegister, $left, $right;"
|
||||||
|
elseif operator == MULTIPLY
|
||||||
|
resultCode = "mul.f32 $resultRegister, $left, $right;"
|
||||||
|
elseif operator == DIVIDE
|
||||||
|
resultCode = "div.approx.f32 $resultRegister, $left, $right;"
|
||||||
|
elseif operator == POWER
|
||||||
|
resultCode = " $resultRegister, $left;" # TODO
|
||||||
|
elseif operator == ABS
|
||||||
|
resultCode = "abs.f32 $resultRegister, $left;"
|
||||||
|
elseif operator == LOG
|
||||||
|
resultCode = "lg2.approx.f32 $resultRegister, $left;"
|
||||||
|
elseif operator == EXP
|
||||||
|
resultCode = " $resultRegister, $left;" # TODO
|
||||||
|
elseif operator == SQRT
|
||||||
|
resultCode = "sqrt.approx.f32 $resultRegister, $left;"
|
||||||
|
else
|
||||||
|
throw(ArgumentError("Operator conversion to ptx not implemented for '$operator'"))
|
||||||
|
end
|
||||||
|
|
||||||
|
return (resultCode, resultRegister)
|
||||||
end
|
end
|
||||||
|
|
||||||
let registers = Dict() # stores the count of the register already used.
|
let registers = Dict() # stores the count of the register already used.
|
||||||
|
|
|
@ -6,8 +6,8 @@ expressions = Vector{Expr}(undef, 2)
|
||||||
variables = Matrix{Float32}(undef, 2,2)
|
variables = Matrix{Float32}(undef, 2,2)
|
||||||
parameters = Vector{Vector{Float32}}(undef, 2)
|
parameters = Vector{Vector{Float32}}(undef, 2)
|
||||||
|
|
||||||
# Resulting value should be 10 for the first expression
|
# Resulting value should be 1.14... for the first expression
|
||||||
expressions[1] = :(1 + 3 * 5 / 7 - 1)
|
expressions[1] = :(1 + 3 * 5 / 7 - sqrt(4))
|
||||||
expressions[2] = :(5 + x1 + 1 * x2 + p1 + p2)
|
expressions[2] = :(5 + x1 + 1 * x2 + p1 + p2)
|
||||||
variables[1,1] = 2.0
|
variables[1,1] = 2.0
|
||||||
variables[2,1] = 3.0
|
variables[2,1] = 3.0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user