finished generating parameter loading code
Some checks are pending
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Waiting to run

This commit is contained in:
Daniel 2025-01-25 11:15:54 +01:00
parent b2774322a1
commit 7598c51df8
2 changed files with 32 additions and 29 deletions

View File

@ -27,20 +27,20 @@ const Operand = Union{Float32, String} # Operand is either fixed value or regist
# To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string
# seekstart(buf1); write(buf2, buf1)
function transpile(expression::ExpressionProcessing.PostfixType)::String
function transpile(expression::ExpressionProcessing.PostfixType, varSetSize::Integer, paramSetSize::Integer)::String
exitJumpLocationMarker = "\$L__BB0_2"
ptxBuffer = IOBuffer()
# TODO: Suboptimal solution
signature, paramLoading = get_kernel_signature("ExpressionProcessing", [Int32, Int32, Float32]) # nrOfVarSets, nrOfVarsPerSet, Vars
guardClause = get_guard_clause(exitJumpLocationMarker, "%parameter0") # parameter0 because first entry holds the number of variables and that is always stored in %parameter0
signature, paramLoading = get_kernel_signature("ExpressionProcessing", [Int32, Float32, Float32]) # nrOfVarSets, Vars, Params
guardClause, threadIdReg = get_guard_clause(exitJumpLocationMarker, "%parameter0") # parameter0 because first entry holds the number of variable sets and that is always stored in %parameter0
println(ptxBuffer, get_cuda_header())
println(ptxBuffer, signature)
println(ptxBuffer, "{")
calc_code = generate_calculation_code(expression, "%parameter1", "%parameter2")
calc_code = generate_calculation_code(expression, "%parameter1", varSetSize, "%parameter2", paramSetSize, threadIdReg)
println(ptxBuffer, get_register_definitions())
println(ptxBuffer, paramLoading)
println(ptxBuffer, guardClause)
@ -60,7 +60,7 @@ function get_cuda_header()::String
return "
.version 7.1
.target sm_61
.address_size 64
.address_size 32
"
end
@ -74,9 +74,9 @@ function get_kernel_signature(kernelName::String, parameters::Vector{DataType}):
for i in eachindex(parameters)
print(signatureBuffer, " .param .u32", " ", "param_", i)
parameterRegister = get_next_free_register("r")
println(paramLoadingBuffer, "ld.param.u32 $parameterRegister, [param_$i];")
println(paramLoadingBuffer, "cvta.to.global.u32 $(get_next_free_register("parameter")), $parameterRegister;")
parametersReg = get_next_free_register("r")
println(paramLoadingBuffer, "ld.param.u32 $parametersReg, [param_$i];")
println(paramLoadingBuffer, "cvta.to.global.u32 $(get_next_free_register("parameter")), $parametersReg;")
if i != lastindex(parameters)
println(signatureBuffer, ",")
end
@ -91,7 +91,7 @@ Constructs the PTX code used for handling the case where too many threads are st
- param ```nrOfVarSetsRegister```: The register which holds the total amount of variable sets for the kernel
"
function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String)::String
function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String)::Tuple{String, String}
guardBuffer = IOBuffer()
threadIds = get_next_free_register("r")
@ -106,17 +106,18 @@ function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String)
globalThreadId = get_next_free_register("r") # basically the index of the thread in the variable set
breakCondition = get_next_free_register("p")
nrOfVarSets = get_next_free_register("i")
println(guardBuffer, "ld.global.u32 $nrOfVarSets, $nrOfVarSetsRegister;")
println(guardBuffer, "ld.global.u32 $nrOfVarSets, [$nrOfVarSetsRegister];")
println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;")
println(guardBuffer, "setp.ge.s32 $breakCondition, $globalThreadId, $nrOfVarSets;") # guard clause = index > nrOfVariableSets
# branch to end if breakCondition is true
print(guardBuffer, "@$breakCondition bra $exitJumpLocation;")
return String(take!(guardBuffer))
return (String(take!(guardBuffer)), globalThreadId)
end
function generate_calculation_code(expression::ExpressionProcessing.PostfixType, variablesRegister::String, parameterRegister::String)::String
function generate_calculation_code(expression::ExpressionProcessing.PostfixType, variablesReg::String, variablesSetSize::Integer,
parametersReg::String, parametersSetSize::Integer, threadIdReg::String)::String
codeBuffer = IOBuffer()
operands = Vector{Operand}()
@ -143,14 +144,14 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType,
if token.Value > 0 # varaibles
var, first_access = get_register_for_name("x$(token.Value)")
if first_access
println(codeBuffer, load_into_register(var, variablesRegister, token.Value, , ))
println(codeBuffer, load_into_register(var, variablesReg, token.Value, threadIdReg, variablesSetSize))
end
push!(operands, var)
else
absVal = abs(token.Value)
param, first_access = get_register_for_name("p$absVal")
if first_access
println(codeBuffer, load_into_register(param, parameterRegister, absVal, , ))
println(codeBuffer, load_into_register(param, parametersReg, absVal, threadIdReg, parametersSetSize))
end
push!(operands, param)
end
@ -162,19 +163,21 @@ end
"
- param ```register```: The register where the loaded value will be stored
- param ```load_location```: The location from where to load the value
- param ```value_index```: 0-based index of the value in the variable set/parameter set
- param ```set_index```: 0-based index of the set. Needed to calculate the actual index from the ```value_index```. Is equal to the global threadId
- param ```set_size```: The size of one set. Needed to calculate the actual index from the ```value_index```
- param ```loadLocation```: The location from where to load the value
- param ```valueIndex```: 0-based index of the value in the variable set/parameter set
- param ```setIndexReg```: 0-based index of the set. Needed to calculate the actual index from the ```valueIndex```. Is equal to the global threadId
- param ```setSize```: The size of one set. Needed to calculate the actual index from the ```valueIndex```
"
function load_into_register(register::String, load_location::String, value_index::Integer, set_index::Integer, set_size::Integer)::String
# load_location + startIndex + value_index * bytes (4 in our case)
# startIndex: set_index * set_size
if value_index == 0 && set_index == 0 # accessing the very first value doesn't need any further calculations
return "ld.global.f32 $register, [$load_location]"
else
return "ld.global.f32 $register, [$load_location+$(set_size*set_index + value_index*sizeof(value_index))]"
end
function load_into_register(register::String, loadLocation::String, valueIndex::Integer, setIndexReg::String, setSize::Integer)::String
# loadLocation + startIndex + valueIndex * bytes (4 in our case)
# startIndex: setIndex * setSize
tempReg = get_next_free_register("i")
# we are using "sizeof(valueIndex)" because it has to use the same amount of bytes as the actual stored values, even though it could use more bytes
return "
mul.lo.u32 $tempReg, $setIndexReg, $setSize;
add.u32 $tempReg, $tempReg, $(valueIndex*sizeof(valueIndex));
add.u32 $tempReg, $loadLocation, $tempReg;
ld.global.f32 $register, [$tempReg];"
end
function type_to_ptx_type(type::DataType)::String
@ -190,7 +193,7 @@ function type_to_ptx_type(type::DataType)::String
end
function get_operation(operator::Operator, left::Operand, right::Union{Operand, Nothing} = nothing)::Tuple{String, String}
resultRegister = get_next_free_register!("f")
resultRegister = get_next_free_register("f")
resultCode = ""
if is_binary_operator(operator) && isnothing(right)
@ -287,7 +290,7 @@ let symtable = Dict()
if haskey(symtable, varName)
return (symtable[varName], false)
else
reg = get_next_free_register!("var")
reg = get_next_free_register("var")
symtable[varName] = reg
return (reg, true)
end

View File

@ -27,7 +27,7 @@ parameters[2][2] = 0.0
push!(postfixExprs, expr_to_postfix(:(5^3 + x1)))
# generatedCode = Transpiler.transpile(postfixExpr)
generatedCode = Transpiler.transpile(postfixExprs[3]) # TEMP
generatedCode = Transpiler.transpile(postfixExprs[3], 2, 3) # TEMP
# CUDA.@sync interpret(postfixExprs, variables, parameters)
# This is just here for testing. This will be called inside the execute method in the Transpiler module