From 90a41942839a1033e5badca8a2537d3ac51fe332 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 19 Apr 2025 10:54:41 +0200 Subject: [PATCH] expression processing: fixed error if expression contained nested unary operators such as log(sqrt(4)) --- package/src/ExpressionProcessing.jl | 4 +++- package/test/InterpreterTests.jl | 4 ++-- package/test/TranspilerTests.jl | 10 +++++----- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/package/src/ExpressionProcessing.jl b/package/src/ExpressionProcessing.jl index 44235b5..9c8259a 100644 --- a/package/src/ExpressionProcessing.jl +++ b/package/src/ExpressionProcessing.jl @@ -9,6 +9,8 @@ export ExpressionElement @enum Operator ADD=1 SUBTRACT=2 MULTIPLY=3 DIVIDE=4 POWER=5 ABS=6 LOG=7 EXP=8 SQRT=9 @enum ElementType EMPTY=0 FLOAT32=1 OPERATOR=2 INDEX=3 +const unary_operators = [ABS, LOG, EXP, SQRT] + struct ExpressionElement Type::ElementType Value::Int32 # Reinterpret the stored value to type "ElementType" when using it @@ -46,7 +48,7 @@ function expr_to_postfix(expr::Expr)::PostfixType end # For the case this expression has an operator that only takes in a single value like "abs(x)" - if length(postfix) == 1 + if operator in unary_operators push!(postfix, convert_to_ExpressionElement(operator)) end return postfix diff --git a/package/test/InterpreterTests.jl b/package/test/InterpreterTests.jl index c4e017b..9845265 100644 --- a/package/test/InterpreterTests.jl +++ b/package/test/InterpreterTests.jl @@ -132,8 +132,8 @@ end # var set 1 @test isapprox(result[1,1], 37.32, atol=0.01) # expr1 - @test isapprox(result[1,2], 64.74, atol=0.01) # expr2 + @test isapprox(result[1,2], 64.75, atol=0.01) # expr2 # var set 2 @test isapprox(result[2,1], 37.32, atol=0.01) # expr1 - @test isapprox(result[2,2], -83.65, atol=0.01) # expr2 + @test isapprox(result[2,2], -83.66, atol=0.01) # expr2 end diff --git a/package/test/TranspilerTests.jl b/package/test/TranspilerTests.jl index e90e776..4b591f9 100644 --- a/package/test/TranspilerTests.jl +++ b/package/test/TranspilerTests.jl @@ -6,7 +6,7 @@ expressions = Vector{Expr}(undef, 3) variables = Matrix{Float32}(undef, 5, 4) parameters = Vector{Vector{Float32}}(undef, 3) -expressions[1] = :(1 + 3 * 5 / 7 - sqrt(4)) +expressions[1] = :(1 + 3 * 5 / 7 - sqrt(log(4))) expressions[2] = :(5 + x1 + 1 * x2 + p1 + p2 + x1^x3) expressions[3] = :(log(x1) / x2 * sqrt(p1) + x3^x4 - exp(x5)) @@ -46,10 +46,10 @@ parameters[3][1] = 16.0 # dump(expressions[3]; maxdepth=10) # Expr 1: - @test isapprox(results[1,1], 1.14286) - @test isapprox(results[2,1], 1.14286) - @test isapprox(results[3,1], 1.14286) - @test isapprox(results[4,1], 1.14286) + @test isapprox(results[1,1], 1.96545) + @test isapprox(results[2,1], 1.96545) + @test isapprox(results[3,1], 1.96545) + @test isapprox(results[4,1], 1.96545) #Expr 2: @test isapprox(results[1,2], 16.0) @test isapprox(results[2,2], 25.0)