From 219c0bb14ebf3f273fafe2c10383b0a7a43dbff5 Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 6 Jan 2025 14:01:55 +0100 Subject: [PATCH] started implementing parameter loading --- package/src/Transpiler.jl | 54 +++++++++++++++++++++------------ package/test/TranspilerTests.jl | 2 +- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/package/src/Transpiler.jl b/package/src/Transpiler.jl index eae62fa..1b46b92 100644 --- a/package/src/Transpiler.jl +++ b/package/src/Transpiler.jl @@ -26,22 +26,23 @@ using ..ExpressionProcessing const Operand = Union{Float32, String} # Operand is either fixed value or register # To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string +# seekstart(buf1); write(buf2, buf1) function transpile(expression::ExpressionProcessing.PostfixType)::String exitJumpLocationMarker = "\$L__BB0_2" ptxBuffer = IOBuffer() + # TODO: Suboptimal solution + signature, paramLoading = get_kernel_signature("ExpressionProcessing", [Int32, Int32, Float32]) # nrOfVarSets, nrOfVarsPerSet, Vars + guardClause = get_guard_clause(exitJumpLocationMarker, "%parameter0") # r0 because first entry holds the number of variables and that is always stored in %r0 + println(ptxBuffer, get_cuda_header()) - println(ptxBuffer, get_kernel_signature("ExpressionProcessing", [Int32, Float32])) + println(ptxBuffer, signature) println(ptxBuffer, "{") - # TODO: Parameter loading - # TODO: once parameters are loaded, the second parameter for the guard clause can be set - temp = get_next_free_register("r") - guardClause = get_guard_clause(exitJumpLocationMarker, temp) # since we need to know how many registers we used, we cannot yet write the guard clause to the ptxBuffer - - calc_code = generate_calculation_code(expression) + calc_code = generate_calculation_code(expression, "%parameter2") println(ptxBuffer, get_register_definitions()) + println(ptxBuffer, paramLoading) println(ptxBuffer, guardClause) println(ptxBuffer, calc_code) @@ -58,29 +59,31 @@ end function get_cuda_header()::String return " .version 7.1 -.target sm_52 +.target sm_61 .address_size 64 " end -function get_kernel_signature(kernelName::String, parameters::Vector{DataType})::String +function get_kernel_signature(kernelName::String, parameters::Vector{DataType})::Tuple{String, String} signatureBuffer = IOBuffer() + paramLoadingBuffer = IOBuffer() print(signatureBuffer, ".visible .entry ") print(signatureBuffer, kernelName) println(signatureBuffer, "(") - for i in eachindex(parameters) - type = type_to_ptx_type(parameters[i]) - print(signatureBuffer, - " .param ", type, " ", kernelName, "_param_", i) + print(signatureBuffer, " .param .u32", " ", "param_", i) + + parameterRegister = get_next_free_register("r") + println(paramLoadingBuffer, "ld.param.u32 $parameterRegister, [param_$i];") + println(paramLoadingBuffer, "cvta.to.global.u32 $(get_next_free_register("parameter")), $parameterRegister;") if i != lastindex(parameters) println(signatureBuffer, ",") end end print(signatureBuffer, ")") - return String(take!(signatureBuffer)) + return (String(take!(signatureBuffer)), String(take!(paramLoadingBuffer))) end " @@ -102,8 +105,10 @@ function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String) globalThreadId = get_next_free_register("r") # basically the index of the thread in the variable set breakCondition = get_next_free_register("p") - println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;") - println(guardBuffer, "setp.ge.s32 $breakCondition, $globalThreadId, $nrOfVarSetsRegister;") # guard clause = index > nrOfVariableSets + nrOfVarSets = get_next_free_register("i") + println(guardBuffer, "ld.global.u32 $nrOfVarSets, $nrOfVarSetsRegister;") + println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;") + println(guardBuffer, "setp.ge.s32 $breakCondition, $globalThreadId, $nrOfVarSets;") # guard clause = index > nrOfVariableSets # branch to end if breakCondition is true print(guardBuffer, "@$breakCondition bra $exitJumpLocation;") @@ -111,12 +116,10 @@ function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String) return String(take!(guardBuffer)) end -# Current assumption: Expression only made out of constant values -function generate_calculation_code(expression::ExpressionProcessing.PostfixType)::String +function generate_calculation_code(expression::ExpressionProcessing.PostfixType, variablesRegister::String)::String codeBuffer = IOBuffer() operands = Vector{Operand}() - println(expression) for i in eachindex(expression) token = expression[i] @@ -138,6 +141,13 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType) push!(operands, resultRegister) elseif token.Type == INDEX # TODO + # %parameter1 + startIndex + Index * bytes + # startIndex: should be calculateable by global threadId and size of variables + if token.Value > 0 # varaibles + var = get_next_free_register("f") + #TODO: investigate how best to load var from global to local memory, especially when var used multiple times. (probably kind of symtable) + push!(operands, "[$variablesRegister+$(token.Value*sizeof(token.Value))]") # missing: startIndex + end end end @@ -147,6 +157,8 @@ end function type_to_ptx_type(type::DataType)::String if type == Int64 return ".s64" + elseif type == Int32 + return ".s32" elseif type == Float32 return ".f32" else @@ -231,6 +243,10 @@ let registers = Dict() # stores the count of the register already used. regType = ".f32" elseif definition.first == "r" regType = ".b32" + elseif definition.first == "parameter" + regType = ".u32" + elseif definition.first == "i" + regType = ".u32" else throw(ArgumentError("Unknown register name used. Name '$(definition.first)' cannot be mapped to a PTX type.")) end diff --git a/package/test/TranspilerTests.jl b/package/test/TranspilerTests.jl index 2e5cb0d..6ef8dd8 100644 --- a/package/test/TranspilerTests.jl +++ b/package/test/TranspilerTests.jl @@ -24,7 +24,7 @@ parameters[2][2] = 0.0 postfixExpr = expr_to_postfix(expressions[1]) postfixExprs = Vector([postfixExpr]) push!(postfixExprs, expr_to_postfix(expressions[2])) - push!(postfixExprs, expr_to_postfix(:(5^3))) + push!(postfixExprs, expr_to_postfix(:(5^3 + x1))) # generatedCode = Transpiler.transpile(postfixExpr) generatedCode = Transpiler.transpile(postfixExprs[3]) # TEMP