aded symtable for loading vars and params to local memory
Some checks failed
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Has been cancelled

This commit is contained in:
2025-01-19 11:00:29 +01:00
parent 219c0bb14e
commit 85464083c3
4 changed files with 40 additions and 18 deletions

View File

@ -30,6 +30,7 @@ function expr_to_postfix(expr::Expr)::PostfixType
if typeof(arg) === Expr
append!(postfix, expr_to_postfix(arg))
elseif typeof(arg) === Symbol # variables/parameters
# maybe TODO: replace the parameters with their respective values, as this might make the expr evaluation faster
exprElement = convert_to_ExpressionElement(convert_var_to_int(arg))
push!(postfix, exprElement)
else

View File

@ -74,9 +74,9 @@ function get_kernel_signature(kernelName::String, parameters::Vector{DataType}):
for i in eachindex(parameters)
print(signatureBuffer, " .param .u32", " ", "param_", i)
parameterRegister = get_next_free_register("r")
parameterRegister = get_next_free_register!("r")
println(paramLoadingBuffer, "ld.param.u32 $parameterRegister, [param_$i];")
println(paramLoadingBuffer, "cvta.to.global.u32 $(get_next_free_register("parameter")), $parameterRegister;")
println(paramLoadingBuffer, "cvta.to.global.u32 $(get_next_free_register!("parameter")), $parameterRegister;")
if i != lastindex(parameters)
println(signatureBuffer, ",")
end
@ -94,18 +94,18 @@ Constructs the PTX code used for handling the case where too many threads are st
function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String)::String
guardBuffer = IOBuffer()
threadIds = get_next_free_register("r")
threadsPerCTA = get_next_free_register("r")
currentThreadId = get_next_free_register("r")
threadIds = get_next_free_register!("r")
threadsPerCTA = get_next_free_register!("r")
currentThreadId = get_next_free_register!("r")
# load data into above defined registers
println(guardBuffer, "mov.u32 $threadIds, %ntid.x;")
println(guardBuffer, "mov.u32 $threadsPerCTA, %ctaid.x;")
println(guardBuffer, "mov.u32 $currentThreadId, %tid.x;")
globalThreadId = get_next_free_register("r") # basically the index of the thread in the variable set
breakCondition = get_next_free_register("p")
nrOfVarSets = get_next_free_register("i")
globalThreadId = get_next_free_register!("r") # basically the index of the thread in the variable set
breakCondition = get_next_free_register!("p")
nrOfVarSets = get_next_free_register!("i")
println(guardBuffer, "ld.global.u32 $nrOfVarSets, $nrOfVarSetsRegister;")
println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;")
println(guardBuffer, "setp.ge.s32 $breakCondition, $globalThreadId, $nrOfVarSets;") # guard clause = index > nrOfVariableSets
@ -143,10 +143,14 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType,
# TODO
# %parameter1 + startIndex + Index * bytes
# startIndex: should be calculateable by global threadId and size of variables
# startIndex: threadId (==var-set/param) * size of var/params
if token.Value > 0 # varaibles
var = get_next_free_register("f")
#TODO: investigate how best to load var from global to local memory, especially when var used multiple times. (probably kind of symtable)
var, first_access = get_register_for_name!("x$(token.Value)")
#TODO: if first_access is true -> generate code for loading from global to local memory
push!(operands, "[$variablesRegister+$(token.Value*sizeof(token.Value))]") # missing: startIndex
else
param, first_access = get_register_for_name!("x$(token.Value)")
#TODO: if first_access is true -> generate code for loading from global to local memory
end
end
end
@ -167,7 +171,7 @@ function type_to_ptx_type(type::DataType)::String
end
function get_operation(operator::Operator, left::Operand, right::Union{Operand, Nothing} = nothing)::Tuple{String, String}
resultRegister = get_next_free_register("f")
resultRegister = get_next_free_register!("f")
resultCode = ""
if is_binary_operator(operator) && isnothing(right)
@ -209,16 +213,15 @@ function get_operation(operator::Operator, left::Operand, right::Union{Operand,
end
let registers = Dict() # stores the count of the register already used.
global get_next_free_register
global get_next_free_register!
global get_register_definitions
# By convention these names correspond to the following types:
# - p -> pred
# - f -> float32
# - r -> 32 bit
# - var -> float32
# - param -> float32 !! although, they might get inserted as fixed number and not be sent to gpu?
function get_next_free_register(name::String)::String
# - var -> float32 (used for variables and params)
function get_next_free_register!(name::String)::String
if haskey(registers, name)
registers[name] += 1
else
@ -257,5 +260,20 @@ let registers = Dict() # stores the count of the register already used.
end
end
let symtable = Dict()
global get_register_for_name!
"Returns the register for this variable/parameter and true if it is used for the first time and false otherwise."
function get_register_for_name!(varName::String)
if haskey(symtable, varName)
return (symtable[varName], false)
else
reg = get_next_free_register!("var")
symtable[varName] = reg
return (reg, true)
end
end
end
end