diff --git a/package/src/Interpreter.jl b/package/src/Interpreter.jl index ec6e7d3..f3c392d 100644 --- a/package/src/Interpreter.jl +++ b/package/src/Interpreter.jl @@ -25,7 +25,7 @@ function interpret(expressions::Vector{Expr}, variables::Matrix{Float32}, parame cudaParams = Utils.create_cuda_array(parameters, NaN32) # column corresponds to data for one expression cudaExprs = Utils.create_cuda_array(exprs, ExpressionElement(EMPTY, 0)) # column corresponds to data for one expression; TODO: replace this 0 with 'undef' if possible # put into seperate cuArray, as this is static and would be inefficient to send seperatly to every kernel - cudaStepsize = CuArray([Utils.get_max_inner_length(parameters), size(variables, 1)]) # max num of values per expression; max nam of parameters per expression; number of variables per expression + cudaStepsize = CuArray([Utils.get_max_inner_length(exprs), Utils.get_max_inner_length(parameters), size(variables, 1)]) # max num of values per expression; max nam of parameters per expression; number of variables per expression # each expression has nr. of variable sets (nr. of columns of the variables) results and there are n expressions cudaResults = CuArray{Float32}(undef, variableCols, length(exprs)) @@ -47,21 +47,24 @@ end const MAX_STACK_SIZE = 25 # The depth of the stack to store the values and intermediate results function interpret_expression(expressions::CuDeviceArray{ExpressionElement}, variables::CuDeviceArray{Float32}, parameters::CuDeviceArray{Float32}, results::CuDeviceArray{Float32}, stepsize::CuDeviceArray{Int}, exprIndex::Int) varSetIndex = (blockIdx().x - 1) * blockDim().x + threadIdx().x # ctaid.x * ntid.x + tid.x (1-based) - @inbounds variableCols = length(variables) / stepsize[2] # number of variable sets + @inbounds variableCols = length(variables) / stepsize[3] # number of variable sets if varSetIndex > variableCols return end - @inbounds firstParamIndex = ((exprIndex - 1) * stepsize[1]) # Exclusive + @inbounds firstExprIndex = ((exprIndex - 1) * stepsize[1]) + 1 # Inclusive + @inbounds lastExprIndex = firstExprIndex + stepsize[1] - 1 # Inclusive + @inbounds firstParamIndex = ((exprIndex - 1) * stepsize[2]) # Exclusive # TODO: Use @cuDynamicSharedMem/@cuStaticSharedMem for variables and or parameters operationStack = MVector{MAX_STACK_SIZE, Float32}(undef) # Try to get this to function with variable size too, to allow better memory usage operationStackTop = 0 # stores index of the last defined/valid value - @inbounds firstVariableIndex = ((varSetIndex-1) * stepsize[2]) # Exclusive + @inbounds firstVariableIndex = ((varSetIndex-1) * stepsize[3]) # Exclusive - @inbounds for expr in expressions + @inbounds for i in firstExprIndex:lastExprIndex + expr = expressions[i] if expr.Type == EMPTY break elseif expr.Type == INDEX diff --git a/package/test/ExpressionProcessingTests.jl b/package/test/ExpressionProcessingTests.jl index a5545a2..4a976eb 100644 --- a/package/test/ExpressionProcessingTests.jl +++ b/package/test/ExpressionProcessingTests.jl @@ -26,7 +26,8 @@ end append!(reference, [ExpressionProcessing.convert_to_ExpressionElement(1), ExpressionProcessing.convert_to_ExpressionElement(1.0), ExpressionProcessing.convert_to_ExpressionElement(2), ExpressionProcessing.convert_to_ExpressionElement(MULTIPLY), ExpressionProcessing.convert_to_ExpressionElement(ADD), ExpressionProcessing.convert_to_ExpressionElement(-1), ExpressionProcessing.convert_to_ExpressionElement(ADD)]) - postfix = expr_to_postfix(expressions[1]) + cache = Dict{Expr, PostfixType}() + postfix = expr_to_postfix(expressions[1], cache) @test isequal(reference, postfix) diff --git a/package/test/runtests.jl b/package/test/runtests.jl index 468e456..46afba3 100644 --- a/package/test/runtests.jl +++ b/package/test/runtests.jl @@ -10,8 +10,8 @@ include(joinpath(baseFolder, "src", "Interpreter.jl")) include(joinpath(baseFolder, "src", "Transpiler.jl")) @testset "Functionality tests" begin - # include("ExpressionProcessingTests.jl") - # include("InterpreterTests.jl") + include("ExpressionProcessingTests.jl") + include("InterpreterTests.jl") include("TranspilerTests.jl") end