diff --git a/package/src/Transpiler.jl b/package/src/Transpiler.jl index 141519e..8630dc5 100644 --- a/package/src/Transpiler.jl +++ b/package/src/Transpiler.jl @@ -33,14 +33,14 @@ function transpile(expression::ExpressionProcessing.PostfixType)::String # TODO: Suboptimal solution signature, paramLoading = get_kernel_signature("ExpressionProcessing", [Int32, Int32, Float32]) # nrOfVarSets, nrOfVarsPerSet, Vars - guardClause = get_guard_clause(exitJumpLocationMarker, "%parameter0") # r0 because first entry holds the number of variables and that is always stored in %r0 + guardClause = get_guard_clause(exitJumpLocationMarker, "%parameter0") # parameter0 because first entry holds the number of variables and that is always stored in %parameter0 println(ptxBuffer, get_cuda_header()) println(ptxBuffer, signature) println(ptxBuffer, "{") - calc_code = generate_calculation_code(expression, "%parameter2") + calc_code = generate_calculation_code(expression, "%parameter1", "%parameter2") println(ptxBuffer, get_register_definitions()) println(ptxBuffer, paramLoading) println(ptxBuffer, guardClause) @@ -74,9 +74,9 @@ function get_kernel_signature(kernelName::String, parameters::Vector{DataType}): for i in eachindex(parameters) print(signatureBuffer, " .param .u32", " ", "param_", i) - parameterRegister = get_next_free_register!("r") + parameterRegister = get_next_free_register("r") println(paramLoadingBuffer, "ld.param.u32 $parameterRegister, [param_$i];") - println(paramLoadingBuffer, "cvta.to.global.u32 $(get_next_free_register!("parameter")), $parameterRegister;") + println(paramLoadingBuffer, "cvta.to.global.u32 $(get_next_free_register("parameter")), $parameterRegister;") if i != lastindex(parameters) println(signatureBuffer, ",") end @@ -94,18 +94,18 @@ Constructs the PTX code used for handling the case where too many threads are st function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String)::String guardBuffer = IOBuffer() - threadIds = get_next_free_register!("r") - threadsPerCTA = get_next_free_register!("r") - currentThreadId = get_next_free_register!("r") + threadIds = get_next_free_register("r") + threadsPerCTA = get_next_free_register("r") + currentThreadId = get_next_free_register("r") # load data into above defined registers println(guardBuffer, "mov.u32 $threadIds, %ntid.x;") println(guardBuffer, "mov.u32 $threadsPerCTA, %ctaid.x;") println(guardBuffer, "mov.u32 $currentThreadId, %tid.x;") - globalThreadId = get_next_free_register!("r") # basically the index of the thread in the variable set - breakCondition = get_next_free_register!("p") - nrOfVarSets = get_next_free_register!("i") + globalThreadId = get_next_free_register("r") # basically the index of the thread in the variable set + breakCondition = get_next_free_register("p") + nrOfVarSets = get_next_free_register("i") println(guardBuffer, "ld.global.u32 $nrOfVarSets, $nrOfVarSetsRegister;") println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;") println(guardBuffer, "setp.ge.s32 $breakCondition, $globalThreadId, $nrOfVarSets;") # guard clause = index > nrOfVariableSets @@ -116,7 +116,7 @@ function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String) return String(take!(guardBuffer)) end -function generate_calculation_code(expression::ExpressionProcessing.PostfixType, variablesRegister::String)::String +function generate_calculation_code(expression::ExpressionProcessing.PostfixType, variablesRegister::String, parameterRegister::String)::String codeBuffer = IOBuffer() operands = Vector{Operand}() @@ -140,17 +140,19 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType, println(codeBuffer, operation) push!(operands, resultRegister) elseif token.Type == INDEX - # TODO - # %parameter1 + startIndex + Index * bytes - # startIndex: should be calculateable by global threadId and size of variables - # startIndex: threadId (==var-set/param) * size of var/params if token.Value > 0 # varaibles - var, first_access = get_register_for_name!("x$(token.Value)") - #TODO: if first_access is true -> generate code for loading from global to local memory - push!(operands, "[$variablesRegister+$(token.Value*sizeof(token.Value))]") # missing: startIndex + var, first_access = get_register_for_name("x$(token.Value)") + if first_access + println(codeBuffer, load_into_register(var, variablesRegister, token.Value, , )) + end + push!(operands, var) else - param, first_access = get_register_for_name!("x$(token.Value)") - #TODO: if first_access is true -> generate code for loading from global to local memory + absVal = abs(token.Value) + param, first_access = get_register_for_name("p$absVal") + if first_access + println(codeBuffer, load_into_register(param, parameterRegister, absVal, , )) + end + push!(operands, param) end end end @@ -158,6 +160,23 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType, return String(take!(codeBuffer)) end +" +- param ```register```: The register where the loaded value will be stored +- param ```load_location```: The location from where to load the value +- param ```value_index```: 0-based index of the value in the variable set/parameter set +- param ```set_index```: 0-based index of the set. Needed to calculate the actual index from the ```value_index```. Is equal to the global threadId +- param ```set_size```: The size of one set. Needed to calculate the actual index from the ```value_index``` +" +function load_into_register(register::String, load_location::String, value_index::Integer, set_index::Integer, set_size::Integer)::String + # load_location + startIndex + value_index * bytes (4 in our case) + # startIndex: set_index * set_size + if value_index == 0 && set_index == 0 # accessing the very first value doesn't need any further calculations + return "ld.global.f32 $register, [$load_location]" + else + return "ld.global.f32 $register, [$load_location+$(set_size*set_index + value_index*sizeof(value_index))]" + end +end + function type_to_ptx_type(type::DataType)::String if type == Int64 return ".s64" @@ -213,7 +232,7 @@ function get_operation(operator::Operator, left::Operand, right::Union{Operand, end let registers = Dict() # stores the count of the register already used. - global get_next_free_register! + global get_next_free_register global get_register_definitions # By convention these names correspond to the following types: @@ -221,7 +240,7 @@ let registers = Dict() # stores the count of the register already used. # - f -> float32 # - r -> 32 bit # - var -> float32 (used for variables and params) - function get_next_free_register!(name::String)::String + function get_next_free_register(name::String)::String if haskey(registers, name) registers[name] += 1 else @@ -261,10 +280,10 @@ let registers = Dict() # stores the count of the register already used. end let symtable = Dict() - global get_register_for_name! + global get_register_for_name "Returns the register for this variable/parameter and true if it is used for the first time and false otherwise." - function get_register_for_name!(varName::String) + function get_register_for_name(varName::String) if haskey(symtable, varName) return (symtable[varName], false) else