From f7926c343814133f8651ec884549fe735f040501 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sun, 5 Jan 2025 11:19:03 +0100 Subject: [PATCH] finished implementing operators --- package/src/Interpreter.jl | 5 +---- package/src/Transpiler.jl | 17 +++++++++++++---- package/test/TranspilerTests.jl | 4 +++- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/package/src/Interpreter.jl b/package/src/Interpreter.jl index 877076b..4ffeea7 100644 --- a/package/src/Interpreter.jl +++ b/package/src/Interpreter.jl @@ -49,10 +49,7 @@ function interpret_expression(expressions::CuDeviceArray{ExpressionElement}, var operationStack = MVector{MAX_STACK_SIZE, Float32}(undef) # Try to get this to function with variable size too, to allow better memory usage operationStackTop = 0 # stores index of the last defined/valid value - @cuprintln("before loop") - for varSetIndex in index:stride:length(expressions) - @cuprintln("in loop with index '$varSetIndex'") - + for varSetIndex in index:stride firstVariableIndex = ((varSetIndex - 1) * stepsize[3]) # Exclusive for i in firstExprIndex:lastExprIndex diff --git a/package/src/Transpiler.jl b/package/src/Transpiler.jl index 3150a27..eae62fa 100644 --- a/package/src/Transpiler.jl +++ b/package/src/Transpiler.jl @@ -171,19 +171,28 @@ function get_operation(operator::Operator, left::Operand, right::Union{Operand, elseif operator == DIVIDE resultCode = "div.approx.f32 $resultRegister, $left, $right;" elseif operator == POWER - resultCode = " $resultRegister, $left;" # TODO + # x^y == 2^(y*log2(x)) as generated by nvcc for "pow(x, y)" + resultCode = " + lg2.approx.f32 $resultRegister, $left; + mul.f32 $resultRegister, $right, $resultRegister; + ex2.approx.f32 $resultRegister, $resultRegister;" elseif operator == ABS resultCode = "abs.f32 $resultRegister, $left;" elseif operator == LOG - resultCode = "lg2.approx.f32 $resultRegister, $left;" + # log(x) == log2(x) * ln(2) as generated by nvcc for "log(x)" + resultCode = " + lg2.approx.f32 $resultRegister, $left; + mul.f32 $resultRegister, $resultRegister, 0.693147182;" elseif operator == EXP - resultCode = " $resultRegister, $left;" # TODO + # e^x == 2^(x/ln(2)) as generated by nvcc for "exp(x)" + resultCode = " + mul.f32 $resultRegister, $left, 1.44269502; + ex2.approx.f32 $resultRegister, $resultRegister;" elseif operator == SQRT resultCode = "sqrt.approx.f32 $resultRegister, $left;" else throw(ArgumentError("Operator conversion to ptx not implemented for '$operator'")) end - return (resultCode, resultRegister) end diff --git a/package/test/TranspilerTests.jl b/package/test/TranspilerTests.jl index 865c3dd..2e5cb0d 100644 --- a/package/test/TranspilerTests.jl +++ b/package/test/TranspilerTests.jl @@ -24,8 +24,10 @@ parameters[2][2] = 0.0 postfixExpr = expr_to_postfix(expressions[1]) postfixExprs = Vector([postfixExpr]) push!(postfixExprs, expr_to_postfix(expressions[2])) + push!(postfixExprs, expr_to_postfix(:(5^3))) - generatedCode = Transpiler.transpile(postfixExpr) + # generatedCode = Transpiler.transpile(postfixExpr) + generatedCode = Transpiler.transpile(postfixExprs[3]) # TEMP # CUDA.@sync interpret(postfixExprs, variables, parameters) # This is just here for testing. This will be called inside the execute method in the Transpiler module