reworked code to use new 'register manager'
Some checks failed
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Has been cancelled

This commit is contained in:
Daniel 2024-12-08 10:50:09 +01:00
parent 203900bb26
commit 67ef9a5139

View File

@ -2,82 +2,6 @@ module Transpiler
using CUDA
using ..ExpressionProcessing
# culoadtest(N, rand(["add.f32", "sub.f32", "mul.f32", "div.approx.f32"]))
function culoadtest(N::Int32, op = "add.f32")
vadd_code = ".version 7.1
.target sm_52
.address_size 64
// .globl VecAdd_kernel
.visible .entry VecAdd_kernel(
.param .u64 VecAdd_kernel_param_0,
.param .u64 VecAdd_kernel_param_1,
.param .u64 VecAdd_kernel_param_2,
.param .u32 VecAdd_kernel_param_3
)
{
.reg .pred %p<2>;
.reg .f32 %f<4>;
.reg .b32 %r<6>;
.reg .b64 %rd<11>;
ld.param.u64 %rd1, [VecAdd_kernel_param_0];
ld.param.u64 %rd2, [VecAdd_kernel_param_1];
ld.param.u64 %rd3, [VecAdd_kernel_param_2];
ld.param.u32 %r2, [VecAdd_kernel_param_3];
mov.u32 %r3, %ntid.x;
mov.u32 %r4, %ctaid.x;
mov.u32 %r5, %tid.x;
mad.lo.s32 %r1, %r3, %r4, %r5;
setp.ge.s32 %p1, %r1, %r2;
@%p1 bra \$L__BB0_2;
cvta.to.global.u64 %rd4, %rd1;
mul.wide.s32 %rd5, %r1, 4;
add.s64 %rd6, %rd4, %rd5;
cvta.to.global.u64 %rd7, %rd2;
add.s64 %rd8, %rd7, %rd5;
ld.global.f32 %f1, [%rd8];
ld.global.f32 %f2, [%rd6];" *
op *
" %f3, %f2, %f1;
cvta.to.global.u64 %rd9, %rd3;
add.s64 %rd10, %rd9, %rd5;
st.global.f32 [%rd10], %f3;
\$L__BB0_2:
ret;
}"
linker = CuLink()
add_data!(linker, "VecAdd_kernel", vadd_code)
image = complete(linker)
mod = CuModule(image)
func = CuFunction(mod, "VecAdd_kernel")
d_a = CUDA.fill(1.0f0, N)
d_b = CUDA.fill(2.0f0, N)
d_c = CUDA.fill(0.0f0, N)
# Grid/Block configuration
threadsPerBlock = 256;
blocksPerGrid = (N + threadsPerBlock - 1) ÷ threadsPerBlock;
@time CUDA.@sync cudacall(func, Tuple{CuPtr{Cfloat},CuPtr{Cfloat},CuPtr{Cfloat},Cint}, d_a, d_b, d_c, N; threads=threadsPerBlock, blocks=blocksPerGrid)
end
# Number of threads per block/SM + max number of registers
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications
# Need to assume a max of 2048 threads per Streaming Multiprocessor (SM)
@ -85,8 +9,7 @@ end
# One thread can at max use 255 registers
# Meaning one has access to at most 32 registers in the worst case. Using 64 bit values this number gets halfed (see: https://docs.nvidia.com/cuda/cuda-c-programming-guide/#multiprocessor-level (almost at the end of the linked section))
# I think I will go with max 16 registers for now and leave a better register allocation technique for future work
# Maybe helpful for future tuning: https://docs.nvidia.com/cuda/cuda-c-programming-guide/#maximum-number-of-registers-per-thread
# Maybe helpful for future performance tuning: https://docs.nvidia.com/cuda/cuda-c-programming-guide/#maximum-number-of-registers-per-thread
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#multiprocessor-level
# This states, that using fewer registers allows more threads to reside on a single SM which improves performance.
@ -101,27 +24,23 @@ end
#
# To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string
const exitJumpLocationMarker = "\$L__BB0_2"
function transpile(expression::ExpressionProcessing.PostfixType)::String
exitJumpLocationMarker = "\$L__BB0_2"
ptxBuffer = IOBuffer()
println(ptxBuffer, get_cuda_header())
println(ptxBuffer, get_kernel_signature("ExpressionProcessing", [Int32, Float32]))
println(ptxBuffer, "{")
# TODO: Actually calculate the number of needed registers and extend to more register kinds
println(ptxBuffer, get_register_definitions(1, 5, 0)) # apparently I can define registers anywhere. This might make things easier
# TODO: Parameter loading
println(ptxBuffer, get_guard_clause())
# top down create the code. keep track of the max number of variables/parameters used (needed for later iterations. See section "Plan" in "PTX_understanding.md")
# return this alongside the generated code
# Generate registers based off of the above number
# Variables have: %var0 to %varn - 1
# Parameters have: %param0 to %paramn - 1
# Code goes here
(calc_code, fRegisterCount) = generate_calculation_code(expression)
println(ptxBuffer, get_register_definitions(0, 0, fRegisterCount))
# TODO: once parameters are loaded, the second parameter for the guard clause can be set
temp = get_next_free_register("r")
guardClause = get_guard_clause(exitJumpLocationMarker, temp) # since we need to know how many registers we used, we cannot yet write the guard clause to the ptxBuffer
calc_code = generate_calculation_code(expression)
println(ptxBuffer, get_register_definitions())
println(ptxBuffer, guardClause)
println(ptxBuffer, calc_code)
# exit jump location
@ -165,50 +84,36 @@ end
"
Constructs the PTX code used for handling the case where too many threads are started.
Assumes the following:
- There are the unused ```32 bit``` registers ```r0, r1, r2, r3 (index of the variable set)```
- There is an unused ```predicate``` register ```p0```
- The ```32 bit``` register ```r4``` contains the number of variable sets
- param ```nrOfVarSetsRegister```: The register which holds the total amount of variable sets for the kernel
"
function get_guard_clause()::String
function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String)::String
guardBuffer = IOBuffer()
println(guardBuffer, "mov.u32 %r0, %ntid.x;") # nr of thread ids
println(guardBuffer, "mov.u32 %r1, %ctaid.x;") # nr of threads per cta
println(guardBuffer, "mov.u32 %r2, %tid.x;") # id of the current thread
threadIds = get_next_free_register("r")
threadsPerCTA = get_next_free_register("r")
currentThreadId = get_next_free_register("r")
println(guardBuffer, "mad.lo.s32 %r3, %r0, %r1, %r2;") # the current index (basically index of variable set)
println(guardBuffer, "setp.ge.s32 %p0, %r3, %r4;") # guard clause (p0 = r3 > r4 -> index > nrOfVariableSets)
# load data into above defined registers
println(guardBuffer, "mov.u32 $threadIds, %ntid.x;")
println(guardBuffer, "mov.u32 $threadsPerCTA, %ctaid.x;")
println(guardBuffer, "mov.u32 $currentThreadId, %tid.x;")
# branch to end if p0 is true
print(guardBuffer, "@%p0 bra $exitJumpLocationMarker;")
globalThreadId = get_next_free_register("r") # basically the index of the thread in the variable set
breakCondition = get_next_free_register("p")
println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;")
println(guardBuffer, "setp.ge.s32 $breakCondition, $globalThreadId, $nrOfVarSetsRegister;") # guard clause = index > nrOfVariableSets
# branch to end if breakCondition is true
print(guardBuffer, "@$breakCondition bra $exitJumpLocation;")
return String(take!(guardBuffer))
end
# TODO: refactor this for better usage. Maybe make this generate only one register definition and pass in the details
function get_register_definitions(nrPred::Int, nr32Bit::Int, nrFloat32::Int)::String
registersBuffer = IOBuffer()
if nrPred > 0
println(registersBuffer, ".reg .pred %p<$nrPred>;")
end
if nr32Bit > 0
println(registersBuffer, ".reg .b32 %r<$nr32Bit>;")
end
if nrFloat32 > 0
println(registersBuffer, ".reg .f32 %f<$nrFloat32>;")
end
return String(take!(registersBuffer))
end
# Current assumption: Expression only made out of constant values
function generate_calculation_code(expression::ExpressionProcessing.PostfixType)::Tuple{String, Int}
function generate_calculation_code(expression::ExpressionProcessing.PostfixType)::String
codeBuffer = IOBuffer()
operands = Vector{Union{Float32, String}}() # Maybe make it of type ANY. Then I could put the register name "on the stack" instead and build up the code like that. Could also make it easier implementing variables/params
operands = Vector{Union{Float32, String}}()
registerCounter = 0
println(expression)
for i in eachindex(expression)
token = expression[i]
@ -216,9 +121,9 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType)
if token.Type == FLOAT32
push!(operands, reinterpret(Float32, token.Value))
elseif token.Type == OPERATOR
# function call to see if operator is unary -> adapt below calculation; probably able to reuse register
# get_ptx_operator will be reworked completly to return the code for the operation
operator = get_ptx_operator(reinterpret(Operator, token.Value))
register = "%f$registerCounter"
register = get_next_free_register("f")
print(codeBuffer, " $operator $register, ")
ops = last(operands, 2)
@ -227,13 +132,12 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType)
println(codeBuffer, ";")
push!(operands, register)
registerCounter += 1
elseif token.Type == INDEX
# TODO
end
end
return (String(take!(codeBuffer)), registerCounter)
return String(take!(codeBuffer))
end
function type_to_ptx_type(type::DataType)::String
@ -246,7 +150,7 @@ function type_to_ptx_type(type::DataType)::String
end
end
# TODO: Probably change this, to return the entire calculation not just the operator. Because for POWER and EXP we need multiple instructions to calculate them (seperation of concerns).
# TODO: Change this, to return the entire calculation not just the operator. Because for POWER and EXP we need multiple instructions to calculate them (seperation of concerns).
function get_ptx_operator(operator::Operator)::String
if operator == ADD
return "add.f32"
@ -273,12 +177,12 @@ end
let registers = Dict() # stores the count of the register already used.
global get_next_free_register
global get_used_registers
global get_register_definitions
# By convention these names correspond to the following types:
# - p -> pred
# - f32 -> float32
# - b32 -> 32 bit
# - f -> float32
# - r -> 32 bit
# - var -> float32
# - param -> float32 !! although, they might get inserted as fixed number and not be sent to gpu?
function get_next_free_register(name::String)::String
@ -288,11 +192,31 @@ let registers = Dict() # stores the count of the register already used.
registers[name] = 1
end
return string("%", name, registers[name])
return string("%", name, registers[name] - 1)
end
function get_used_registers()
return pairs(registers)
function get_register_definitions()::String
registersBuffer = IOBuffer()
for definition in registers
regType = ""
if definition.first == "p"
regType = ".pred"
elseif definition.first == "f"
regType = ".f32"
elseif definition.first == "var"
regType = ".f32"
elseif definition.first == "param"
regType = ".f32"
elseif definition.first == "r"
regType = ".b32"
else
throw(ArgumentError("Unknown register name used. Name '$(definition.first)' cannot be mapped to a PTX type."))
end
println(registersBuffer, ".reg $regType %$(definition.first)<$(definition.second)>;")
end
return String(take!(registersBuffer))
end
end