added loading of variables from memory into registers. Note: Needed to leave, so code currently not compiling
Some checks are pending
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Waiting to run
Some checks are pending
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Waiting to run
This commit is contained in:
parent
85464083c3
commit
b2774322a1
|
@ -33,14 +33,14 @@ function transpile(expression::ExpressionProcessing.PostfixType)::String
|
||||||
|
|
||||||
# TODO: Suboptimal solution
|
# TODO: Suboptimal solution
|
||||||
signature, paramLoading = get_kernel_signature("ExpressionProcessing", [Int32, Int32, Float32]) # nrOfVarSets, nrOfVarsPerSet, Vars
|
signature, paramLoading = get_kernel_signature("ExpressionProcessing", [Int32, Int32, Float32]) # nrOfVarSets, nrOfVarsPerSet, Vars
|
||||||
guardClause = get_guard_clause(exitJumpLocationMarker, "%parameter0") # r0 because first entry holds the number of variables and that is always stored in %r0
|
guardClause = get_guard_clause(exitJumpLocationMarker, "%parameter0") # parameter0 because first entry holds the number of variables and that is always stored in %parameter0
|
||||||
|
|
||||||
println(ptxBuffer, get_cuda_header())
|
println(ptxBuffer, get_cuda_header())
|
||||||
println(ptxBuffer, signature)
|
println(ptxBuffer, signature)
|
||||||
println(ptxBuffer, "{")
|
println(ptxBuffer, "{")
|
||||||
|
|
||||||
|
|
||||||
calc_code = generate_calculation_code(expression, "%parameter2")
|
calc_code = generate_calculation_code(expression, "%parameter1", "%parameter2")
|
||||||
println(ptxBuffer, get_register_definitions())
|
println(ptxBuffer, get_register_definitions())
|
||||||
println(ptxBuffer, paramLoading)
|
println(ptxBuffer, paramLoading)
|
||||||
println(ptxBuffer, guardClause)
|
println(ptxBuffer, guardClause)
|
||||||
|
@ -74,9 +74,9 @@ function get_kernel_signature(kernelName::String, parameters::Vector{DataType}):
|
||||||
for i in eachindex(parameters)
|
for i in eachindex(parameters)
|
||||||
print(signatureBuffer, " .param .u32", " ", "param_", i)
|
print(signatureBuffer, " .param .u32", " ", "param_", i)
|
||||||
|
|
||||||
parameterRegister = get_next_free_register!("r")
|
parameterRegister = get_next_free_register("r")
|
||||||
println(paramLoadingBuffer, "ld.param.u32 $parameterRegister, [param_$i];")
|
println(paramLoadingBuffer, "ld.param.u32 $parameterRegister, [param_$i];")
|
||||||
println(paramLoadingBuffer, "cvta.to.global.u32 $(get_next_free_register!("parameter")), $parameterRegister;")
|
println(paramLoadingBuffer, "cvta.to.global.u32 $(get_next_free_register("parameter")), $parameterRegister;")
|
||||||
if i != lastindex(parameters)
|
if i != lastindex(parameters)
|
||||||
println(signatureBuffer, ",")
|
println(signatureBuffer, ",")
|
||||||
end
|
end
|
||||||
|
@ -94,18 +94,18 @@ Constructs the PTX code used for handling the case where too many threads are st
|
||||||
function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String)::String
|
function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String)::String
|
||||||
guardBuffer = IOBuffer()
|
guardBuffer = IOBuffer()
|
||||||
|
|
||||||
threadIds = get_next_free_register!("r")
|
threadIds = get_next_free_register("r")
|
||||||
threadsPerCTA = get_next_free_register!("r")
|
threadsPerCTA = get_next_free_register("r")
|
||||||
currentThreadId = get_next_free_register!("r")
|
currentThreadId = get_next_free_register("r")
|
||||||
|
|
||||||
# load data into above defined registers
|
# load data into above defined registers
|
||||||
println(guardBuffer, "mov.u32 $threadIds, %ntid.x;")
|
println(guardBuffer, "mov.u32 $threadIds, %ntid.x;")
|
||||||
println(guardBuffer, "mov.u32 $threadsPerCTA, %ctaid.x;")
|
println(guardBuffer, "mov.u32 $threadsPerCTA, %ctaid.x;")
|
||||||
println(guardBuffer, "mov.u32 $currentThreadId, %tid.x;")
|
println(guardBuffer, "mov.u32 $currentThreadId, %tid.x;")
|
||||||
|
|
||||||
globalThreadId = get_next_free_register!("r") # basically the index of the thread in the variable set
|
globalThreadId = get_next_free_register("r") # basically the index of the thread in the variable set
|
||||||
breakCondition = get_next_free_register!("p")
|
breakCondition = get_next_free_register("p")
|
||||||
nrOfVarSets = get_next_free_register!("i")
|
nrOfVarSets = get_next_free_register("i")
|
||||||
println(guardBuffer, "ld.global.u32 $nrOfVarSets, $nrOfVarSetsRegister;")
|
println(guardBuffer, "ld.global.u32 $nrOfVarSets, $nrOfVarSetsRegister;")
|
||||||
println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;")
|
println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;")
|
||||||
println(guardBuffer, "setp.ge.s32 $breakCondition, $globalThreadId, $nrOfVarSets;") # guard clause = index > nrOfVariableSets
|
println(guardBuffer, "setp.ge.s32 $breakCondition, $globalThreadId, $nrOfVarSets;") # guard clause = index > nrOfVariableSets
|
||||||
|
@ -116,7 +116,7 @@ function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String)
|
||||||
return String(take!(guardBuffer))
|
return String(take!(guardBuffer))
|
||||||
end
|
end
|
||||||
|
|
||||||
function generate_calculation_code(expression::ExpressionProcessing.PostfixType, variablesRegister::String)::String
|
function generate_calculation_code(expression::ExpressionProcessing.PostfixType, variablesRegister::String, parameterRegister::String)::String
|
||||||
codeBuffer = IOBuffer()
|
codeBuffer = IOBuffer()
|
||||||
operands = Vector{Operand}()
|
operands = Vector{Operand}()
|
||||||
|
|
||||||
|
@ -140,17 +140,19 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType,
|
||||||
println(codeBuffer, operation)
|
println(codeBuffer, operation)
|
||||||
push!(operands, resultRegister)
|
push!(operands, resultRegister)
|
||||||
elseif token.Type == INDEX
|
elseif token.Type == INDEX
|
||||||
# TODO
|
|
||||||
# %parameter1 + startIndex + Index * bytes
|
|
||||||
# startIndex: should be calculateable by global threadId and size of variables
|
|
||||||
# startIndex: threadId (==var-set/param) * size of var/params
|
|
||||||
if token.Value > 0 # varaibles
|
if token.Value > 0 # varaibles
|
||||||
var, first_access = get_register_for_name!("x$(token.Value)")
|
var, first_access = get_register_for_name("x$(token.Value)")
|
||||||
#TODO: if first_access is true -> generate code for loading from global to local memory
|
if first_access
|
||||||
push!(operands, "[$variablesRegister+$(token.Value*sizeof(token.Value))]") # missing: startIndex
|
println(codeBuffer, load_into_register(var, variablesRegister, token.Value, , ))
|
||||||
|
end
|
||||||
|
push!(operands, var)
|
||||||
else
|
else
|
||||||
param, first_access = get_register_for_name!("x$(token.Value)")
|
absVal = abs(token.Value)
|
||||||
#TODO: if first_access is true -> generate code for loading from global to local memory
|
param, first_access = get_register_for_name("p$absVal")
|
||||||
|
if first_access
|
||||||
|
println(codeBuffer, load_into_register(param, parameterRegister, absVal, , ))
|
||||||
|
end
|
||||||
|
push!(operands, param)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -158,6 +160,23 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType,
|
||||||
return String(take!(codeBuffer))
|
return String(take!(codeBuffer))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"
|
||||||
|
- param ```register```: The register where the loaded value will be stored
|
||||||
|
- param ```load_location```: The location from where to load the value
|
||||||
|
- param ```value_index```: 0-based index of the value in the variable set/parameter set
|
||||||
|
- param ```set_index```: 0-based index of the set. Needed to calculate the actual index from the ```value_index```. Is equal to the global threadId
|
||||||
|
- param ```set_size```: The size of one set. Needed to calculate the actual index from the ```value_index```
|
||||||
|
"
|
||||||
|
function load_into_register(register::String, load_location::String, value_index::Integer, set_index::Integer, set_size::Integer)::String
|
||||||
|
# load_location + startIndex + value_index * bytes (4 in our case)
|
||||||
|
# startIndex: set_index * set_size
|
||||||
|
if value_index == 0 && set_index == 0 # accessing the very first value doesn't need any further calculations
|
||||||
|
return "ld.global.f32 $register, [$load_location]"
|
||||||
|
else
|
||||||
|
return "ld.global.f32 $register, [$load_location+$(set_size*set_index + value_index*sizeof(value_index))]"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function type_to_ptx_type(type::DataType)::String
|
function type_to_ptx_type(type::DataType)::String
|
||||||
if type == Int64
|
if type == Int64
|
||||||
return ".s64"
|
return ".s64"
|
||||||
|
@ -213,7 +232,7 @@ function get_operation(operator::Operator, left::Operand, right::Union{Operand,
|
||||||
end
|
end
|
||||||
|
|
||||||
let registers = Dict() # stores the count of the register already used.
|
let registers = Dict() # stores the count of the register already used.
|
||||||
global get_next_free_register!
|
global get_next_free_register
|
||||||
global get_register_definitions
|
global get_register_definitions
|
||||||
|
|
||||||
# By convention these names correspond to the following types:
|
# By convention these names correspond to the following types:
|
||||||
|
@ -221,7 +240,7 @@ let registers = Dict() # stores the count of the register already used.
|
||||||
# - f -> float32
|
# - f -> float32
|
||||||
# - r -> 32 bit
|
# - r -> 32 bit
|
||||||
# - var -> float32 (used for variables and params)
|
# - var -> float32 (used for variables and params)
|
||||||
function get_next_free_register!(name::String)::String
|
function get_next_free_register(name::String)::String
|
||||||
if haskey(registers, name)
|
if haskey(registers, name)
|
||||||
registers[name] += 1
|
registers[name] += 1
|
||||||
else
|
else
|
||||||
|
@ -261,10 +280,10 @@ let registers = Dict() # stores the count of the register already used.
|
||||||
end
|
end
|
||||||
|
|
||||||
let symtable = Dict()
|
let symtable = Dict()
|
||||||
global get_register_for_name!
|
global get_register_for_name
|
||||||
|
|
||||||
"Returns the register for this variable/parameter and true if it is used for the first time and false otherwise."
|
"Returns the register for this variable/parameter and true if it is used for the first time and false otherwise."
|
||||||
function get_register_for_name!(varName::String)
|
function get_register_for_name(varName::String)
|
||||||
if haskey(symtable, varName)
|
if haskey(symtable, varName)
|
||||||
return (symtable[varName], false)
|
return (symtable[varName], false)
|
||||||
else
|
else
|
||||||
|
|
Loading…
Reference in New Issue
Block a user