diff --git a/package/src/Transpiler.jl b/package/src/Transpiler.jl index 8630dc5..508a36f 100644 --- a/package/src/Transpiler.jl +++ b/package/src/Transpiler.jl @@ -27,20 +27,20 @@ const Operand = Union{Float32, String} # Operand is either fixed value or regist # To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string # seekstart(buf1); write(buf2, buf1) -function transpile(expression::ExpressionProcessing.PostfixType)::String +function transpile(expression::ExpressionProcessing.PostfixType, varSetSize::Integer, paramSetSize::Integer)::String exitJumpLocationMarker = "\$L__BB0_2" ptxBuffer = IOBuffer() # TODO: Suboptimal solution - signature, paramLoading = get_kernel_signature("ExpressionProcessing", [Int32, Int32, Float32]) # nrOfVarSets, nrOfVarsPerSet, Vars - guardClause = get_guard_clause(exitJumpLocationMarker, "%parameter0") # parameter0 because first entry holds the number of variables and that is always stored in %parameter0 + signature, paramLoading = get_kernel_signature("ExpressionProcessing", [Int32, Float32, Float32]) # nrOfVarSets, Vars, Params + guardClause, threadIdReg = get_guard_clause(exitJumpLocationMarker, "%parameter0") # parameter0 because first entry holds the number of variable sets and that is always stored in %parameter0 println(ptxBuffer, get_cuda_header()) println(ptxBuffer, signature) println(ptxBuffer, "{") - calc_code = generate_calculation_code(expression, "%parameter1", "%parameter2") + calc_code = generate_calculation_code(expression, "%parameter1", varSetSize, "%parameter2", paramSetSize, threadIdReg) println(ptxBuffer, get_register_definitions()) println(ptxBuffer, paramLoading) println(ptxBuffer, guardClause) @@ -60,7 +60,7 @@ function get_cuda_header()::String return " .version 7.1 .target sm_61 -.address_size 64 +.address_size 32 " end @@ -74,9 +74,9 @@ function get_kernel_signature(kernelName::String, parameters::Vector{DataType}): for i in eachindex(parameters) print(signatureBuffer, " .param .u32", " ", "param_", i) - parameterRegister = get_next_free_register("r") - println(paramLoadingBuffer, "ld.param.u32 $parameterRegister, [param_$i];") - println(paramLoadingBuffer, "cvta.to.global.u32 $(get_next_free_register("parameter")), $parameterRegister;") + parametersReg = get_next_free_register("r") + println(paramLoadingBuffer, "ld.param.u32 $parametersReg, [param_$i];") + println(paramLoadingBuffer, "cvta.to.global.u32 $(get_next_free_register("parameter")), $parametersReg;") if i != lastindex(parameters) println(signatureBuffer, ",") end @@ -91,7 +91,7 @@ Constructs the PTX code used for handling the case where too many threads are st - param ```nrOfVarSetsRegister```: The register which holds the total amount of variable sets for the kernel " -function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String)::String +function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String)::Tuple{String, String} guardBuffer = IOBuffer() threadIds = get_next_free_register("r") @@ -106,17 +106,18 @@ function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String) globalThreadId = get_next_free_register("r") # basically the index of the thread in the variable set breakCondition = get_next_free_register("p") nrOfVarSets = get_next_free_register("i") - println(guardBuffer, "ld.global.u32 $nrOfVarSets, $nrOfVarSetsRegister;") + println(guardBuffer, "ld.global.u32 $nrOfVarSets, [$nrOfVarSetsRegister];") println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;") println(guardBuffer, "setp.ge.s32 $breakCondition, $globalThreadId, $nrOfVarSets;") # guard clause = index > nrOfVariableSets # branch to end if breakCondition is true print(guardBuffer, "@$breakCondition bra $exitJumpLocation;") - return String(take!(guardBuffer)) + return (String(take!(guardBuffer)), globalThreadId) end -function generate_calculation_code(expression::ExpressionProcessing.PostfixType, variablesRegister::String, parameterRegister::String)::String +function generate_calculation_code(expression::ExpressionProcessing.PostfixType, variablesReg::String, variablesSetSize::Integer, + parametersReg::String, parametersSetSize::Integer, threadIdReg::String)::String codeBuffer = IOBuffer() operands = Vector{Operand}() @@ -143,14 +144,14 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType, if token.Value > 0 # varaibles var, first_access = get_register_for_name("x$(token.Value)") if first_access - println(codeBuffer, load_into_register(var, variablesRegister, token.Value, , )) + println(codeBuffer, load_into_register(var, variablesReg, token.Value, threadIdReg, variablesSetSize)) end push!(operands, var) else absVal = abs(token.Value) param, first_access = get_register_for_name("p$absVal") if first_access - println(codeBuffer, load_into_register(param, parameterRegister, absVal, , )) + println(codeBuffer, load_into_register(param, parametersReg, absVal, threadIdReg, parametersSetSize)) end push!(operands, param) end @@ -162,19 +163,21 @@ end " - param ```register```: The register where the loaded value will be stored -- param ```load_location```: The location from where to load the value -- param ```value_index```: 0-based index of the value in the variable set/parameter set -- param ```set_index```: 0-based index of the set. Needed to calculate the actual index from the ```value_index```. Is equal to the global threadId -- param ```set_size```: The size of one set. Needed to calculate the actual index from the ```value_index``` +- param ```loadLocation```: The location from where to load the value +- param ```valueIndex```: 0-based index of the value in the variable set/parameter set +- param ```setIndexReg```: 0-based index of the set. Needed to calculate the actual index from the ```valueIndex```. Is equal to the global threadId +- param ```setSize```: The size of one set. Needed to calculate the actual index from the ```valueIndex``` " -function load_into_register(register::String, load_location::String, value_index::Integer, set_index::Integer, set_size::Integer)::String - # load_location + startIndex + value_index * bytes (4 in our case) - # startIndex: set_index * set_size - if value_index == 0 && set_index == 0 # accessing the very first value doesn't need any further calculations - return "ld.global.f32 $register, [$load_location]" - else - return "ld.global.f32 $register, [$load_location+$(set_size*set_index + value_index*sizeof(value_index))]" - end +function load_into_register(register::String, loadLocation::String, valueIndex::Integer, setIndexReg::String, setSize::Integer)::String + # loadLocation + startIndex + valueIndex * bytes (4 in our case) + # startIndex: setIndex * setSize + tempReg = get_next_free_register("i") + # we are using "sizeof(valueIndex)" because it has to use the same amount of bytes as the actual stored values, even though it could use more bytes + return " + mul.lo.u32 $tempReg, $setIndexReg, $setSize; + add.u32 $tempReg, $tempReg, $(valueIndex*sizeof(valueIndex)); + add.u32 $tempReg, $loadLocation, $tempReg; + ld.global.f32 $register, [$tempReg];" end function type_to_ptx_type(type::DataType)::String @@ -190,7 +193,7 @@ function type_to_ptx_type(type::DataType)::String end function get_operation(operator::Operator, left::Operand, right::Union{Operand, Nothing} = nothing)::Tuple{String, String} - resultRegister = get_next_free_register!("f") + resultRegister = get_next_free_register("f") resultCode = "" if is_binary_operator(operator) && isnothing(right) @@ -287,7 +290,7 @@ let symtable = Dict() if haskey(symtable, varName) return (symtable[varName], false) else - reg = get_next_free_register!("var") + reg = get_next_free_register("var") symtable[varName] = reg return (reg, true) end diff --git a/package/test/TranspilerTests.jl b/package/test/TranspilerTests.jl index 6ef8dd8..01f859d 100644 --- a/package/test/TranspilerTests.jl +++ b/package/test/TranspilerTests.jl @@ -27,7 +27,7 @@ parameters[2][2] = 0.0 push!(postfixExprs, expr_to_postfix(:(5^3 + x1))) # generatedCode = Transpiler.transpile(postfixExpr) - generatedCode = Transpiler.transpile(postfixExprs[3]) # TEMP + generatedCode = Transpiler.transpile(postfixExprs[3], 2, 3) # TEMP # CUDA.@sync interpret(postfixExprs, variables, parameters) # This is just here for testing. This will be called inside the execute method in the Transpiler module