From 6d3c3164cf030d5fd32deec6d044d5fd3058888d Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 10 May 2025 09:36:02 +0200 Subject: [PATCH] expression processing: added support for inverse/reciprocal --- package/src/ExpressionProcessing.jl | 4 +++- package/src/Interpreter.jl | 2 ++ package/src/Transpiler.jl | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/package/src/ExpressionProcessing.jl b/package/src/ExpressionProcessing.jl index 7bc05d6..0acebc0 100644 --- a/package/src/ExpressionProcessing.jl +++ b/package/src/ExpressionProcessing.jl @@ -6,7 +6,7 @@ export Operator, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, ABS, LOG, EXP, SQRT export ElementType, EMPTY, FLOAT32, OPERATOR, VARIABLE, PARAMETER export ExpressionElement -@enum Operator ADD=1 SUBTRACT=2 MULTIPLY=3 DIVIDE=4 POWER=5 ABS=6 LOG=7 EXP=8 SQRT=9 +@enum Operator ADD=1 SUBTRACT=2 MULTIPLY=3 DIVIDE=4 POWER=5 ABS=6 LOG=7 EXP=8 SQRT=9 INV=10 @enum ElementType EMPTY=0 FLOAT32=1 OPERATOR=2 VARIABLE=3 PARAMETER=4 const binary_operators = [ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER] @@ -99,6 +99,8 @@ function get_operator(op::Symbol)::Operator return EXP elseif op == :sqrt return SQRT + elseif op == :inv + return INV else throw("Operator unknown. Operator was $op") end diff --git a/package/src/Interpreter.jl b/package/src/Interpreter.jl index cd6c81e..450773d 100644 --- a/package/src/Interpreter.jl +++ b/package/src/Interpreter.jl @@ -98,6 +98,8 @@ function interpret_expression(expressions::CuDeviceArray{ExpressionElement}, var operationStack[operationStackTop] = exp(operationStack[operationStackTop]) elseif opcode == SQRT operationStack[operationStackTop] = sqrt(operationStack[operationStackTop]) + elseif opcode == INV + operationStack[operationStackTop] = inv(operationStack[operationStackTop]) end else operationStack[operationStackTop] = NaN32 diff --git a/package/src/Transpiler.jl b/package/src/Transpiler.jl index a502a10..738956f 100644 --- a/package/src/Transpiler.jl +++ b/package/src/Transpiler.jl @@ -316,6 +316,8 @@ function get_operation(operator::Operator, regManager::Utils.RegisterManager, le ex2.approx.f32 $resultRegister, $resultRegister;" elseif operator == SQRT resultCode = "sqrt.approx.f32 $resultRegister, $left;" + elseif operator == INV + resultCode = "rcp.approx.f32 $resultRegister, $left;" else throw(ArgumentError("Operator conversion to ptx not implemented for '$operator'")) end