diff --git a/package/src/ExpressionProcessing.jl b/package/src/ExpressionProcessing.jl index cfbbeb1..a45e8d4 100644 --- a/package/src/ExpressionProcessing.jl +++ b/package/src/ExpressionProcessing.jl @@ -1,6 +1,6 @@ module ExpressionProcessing -export expr_to_postfix +export expr_to_postfix, is_binary_operator export PostfixType export Operator, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, ABS, LOG, EXP, SQRT export ElementType, EMPTY, FLOAT32, OPERATOR, INDEX @@ -108,6 +108,30 @@ function convert_to_ExpressionElement(element::Operator)::ExpressionElement return ExpressionElement(OPERATOR, value) end +function is_binary_operator(operator::Operator)::Bool + if operator == ADD + return true + elseif operator == SUBTRACT + return true + elseif operator == MULTIPLY + return true + elseif operator == DIVIDE + return true + elseif operator == POWER + return true + elseif operator == ABS + return false + elseif operator == LOG + return false + elseif operator == EXP + return false + elseif operator == SQRT + return false + else + throw(ArgumentError("Unknown operator '$operator'. Cannot determine if it is binary or not.")) + end +end + # # Everything below is currently not needed. Left here for potential future use # diff --git a/package/src/Transpiler.jl b/package/src/Transpiler.jl index e23453d..3150a27 100644 --- a/package/src/Transpiler.jl +++ b/package/src/Transpiler.jl @@ -23,6 +23,8 @@ using ..ExpressionProcessing # Note: Maybe make an additional function that transpiles and executed the code. This would then be the function the user calls # +const Operand = Union{Float32, String} # Operand is either fixed value or register + # To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string function transpile(expression::ExpressionProcessing.PostfixType)::String exitJumpLocationMarker = "\$L__BB0_2" @@ -112,7 +114,7 @@ end # Current assumption: Expression only made out of constant values function generate_calculation_code(expression::ExpressionProcessing.PostfixType)::String codeBuffer = IOBuffer() - operands = Vector{Union{Float32, String}}() + operands = Vector{Operand}() println(expression) for i in eachindex(expression) @@ -121,17 +123,19 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType) if token.Type == FLOAT32 push!(operands, reinterpret(Float32, token.Value)) elseif token.Type == OPERATOR - # get_ptx_operator will be reworked completly to return the code for the operation - operator = get_ptx_operator(reinterpret(Operator, token.Value)) - register = get_next_free_register("f") - print(codeBuffer, " $operator $register, ") + operator = reinterpret(Operator, token.Value) - ops = last(operands, 2) - pop!(operands);pop!(operands) - print(codeBuffer, join(ops, ", ")) - println(codeBuffer, ";") - - push!(operands, register) + right = nothing + if is_binary_operator(operator) + right = pop!(operands) + left = pop!(operands) + else + left = pop!(operands) + end + operation, resultRegister = get_operation(operator, left, right) + + println(codeBuffer, operation) + push!(operands, resultRegister) elseif token.Type == INDEX # TODO end @@ -150,29 +154,37 @@ function type_to_ptx_type(type::DataType)::String end end -# TODO: Change this, to return the entire calculation not just the operator. Because for POWER and EXP we need multiple instructions to calculate them (seperation of concerns). -function get_ptx_operator(operator::Operator)::String - if operator == ADD - return "add.f32" - elseif operator == SUBTRACT - return "sub.f32" - elseif operator == MULTIPLY - return "mul.f32" - elseif operator == DIVIDE - return "div.approx.f32" - elseif operator == POWER - return "" - elseif operator == ABS - return "abs.f32" - elseif operator == LOG - return "lg2.approx.f32" - elseif operator == EXP - return "" - elseif operator == SQRT - return "sqrt.approx.f32" - else - throw(ArgumentError("Operator conversion to ptx not implemented for $operator")) +function get_operation(operator::Operator, left::Operand, right::Union{Operand, Nothing} = nothing)::Tuple{String, String} + resultRegister = get_next_free_register("f") + resultCode = "" + + if is_binary_operator(operator) && isnothing(right) + throw(ArgumentError("Given operator '$operator' is a binary operator. However only one operator has been given.")) end + + if operator == ADD + resultCode = "add.f32 $resultRegister, $left, $right;" + elseif operator == SUBTRACT + resultCode = "sub.f32 $resultRegister, $left, $right;" + elseif operator == MULTIPLY + resultCode = "mul.f32 $resultRegister, $left, $right;" + elseif operator == DIVIDE + resultCode = "div.approx.f32 $resultRegister, $left, $right;" + elseif operator == POWER + resultCode = " $resultRegister, $left;" # TODO + elseif operator == ABS + resultCode = "abs.f32 $resultRegister, $left;" + elseif operator == LOG + resultCode = "lg2.approx.f32 $resultRegister, $left;" + elseif operator == EXP + resultCode = " $resultRegister, $left;" # TODO + elseif operator == SQRT + resultCode = "sqrt.approx.f32 $resultRegister, $left;" + else + throw(ArgumentError("Operator conversion to ptx not implemented for '$operator'")) + end + + return (resultCode, resultRegister) end let registers = Dict() # stores the count of the register already used. diff --git a/package/test/TranspilerTests.jl b/package/test/TranspilerTests.jl index d3026cc..865c3dd 100644 --- a/package/test/TranspilerTests.jl +++ b/package/test/TranspilerTests.jl @@ -6,8 +6,8 @@ expressions = Vector{Expr}(undef, 2) variables = Matrix{Float32}(undef, 2,2) parameters = Vector{Vector{Float32}}(undef, 2) -# Resulting value should be 10 for the first expression -expressions[1] = :(1 + 3 * 5 / 7 - 1) +# Resulting value should be 1.14... for the first expression +expressions[1] = :(1 + 3 * 5 / 7 - sqrt(4)) expressions[2] = :(5 + x1 + 1 * x2 + p1 + p2) variables[1,1] = 2.0 variables[2,1] = 3.0