reworked code to use new 'register manager'
Some checks failed
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Has been cancelled
Some checks failed
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Has been cancelled
This commit is contained in:
parent
203900bb26
commit
67ef9a5139
|
@ -2,82 +2,6 @@ module Transpiler
|
||||||
using CUDA
|
using CUDA
|
||||||
using ..ExpressionProcessing
|
using ..ExpressionProcessing
|
||||||
|
|
||||||
# culoadtest(N, rand(["add.f32", "sub.f32", "mul.f32", "div.approx.f32"]))
|
|
||||||
function culoadtest(N::Int32, op = "add.f32")
|
|
||||||
|
|
||||||
vadd_code = ".version 7.1
|
|
||||||
|
|
||||||
.target sm_52
|
|
||||||
.address_size 64
|
|
||||||
|
|
||||||
// .globl VecAdd_kernel
|
|
||||||
|
|
||||||
.visible .entry VecAdd_kernel(
|
|
||||||
.param .u64 VecAdd_kernel_param_0,
|
|
||||||
.param .u64 VecAdd_kernel_param_1,
|
|
||||||
.param .u64 VecAdd_kernel_param_2,
|
|
||||||
.param .u32 VecAdd_kernel_param_3
|
|
||||||
)
|
|
||||||
|
|
||||||
{
|
|
||||||
.reg .pred %p<2>;
|
|
||||||
.reg .f32 %f<4>;
|
|
||||||
.reg .b32 %r<6>;
|
|
||||||
.reg .b64 %rd<11>;
|
|
||||||
|
|
||||||
ld.param.u64 %rd1, [VecAdd_kernel_param_0];
|
|
||||||
ld.param.u64 %rd2, [VecAdd_kernel_param_1];
|
|
||||||
ld.param.u64 %rd3, [VecAdd_kernel_param_2];
|
|
||||||
ld.param.u32 %r2, [VecAdd_kernel_param_3];
|
|
||||||
|
|
||||||
mov.u32 %r3, %ntid.x;
|
|
||||||
mov.u32 %r4, %ctaid.x;
|
|
||||||
mov.u32 %r5, %tid.x;
|
|
||||||
|
|
||||||
mad.lo.s32 %r1, %r3, %r4, %r5;
|
|
||||||
|
|
||||||
setp.ge.s32 %p1, %r1, %r2;
|
|
||||||
|
|
||||||
@%p1 bra \$L__BB0_2;
|
|
||||||
|
|
||||||
cvta.to.global.u64 %rd4, %rd1;
|
|
||||||
|
|
||||||
mul.wide.s32 %rd5, %r1, 4;
|
|
||||||
add.s64 %rd6, %rd4, %rd5;
|
|
||||||
cvta.to.global.u64 %rd7, %rd2;
|
|
||||||
add.s64 %rd8, %rd7, %rd5;
|
|
||||||
|
|
||||||
ld.global.f32 %f1, [%rd8];
|
|
||||||
ld.global.f32 %f2, [%rd6];" *
|
|
||||||
op *
|
|
||||||
" %f3, %f2, %f1;
|
|
||||||
cvta.to.global.u64 %rd9, %rd3;
|
|
||||||
add.s64 %rd10, %rd9, %rd5;
|
|
||||||
st.global.f32 [%rd10], %f3;
|
|
||||||
|
|
||||||
\$L__BB0_2:
|
|
||||||
ret;
|
|
||||||
}"
|
|
||||||
|
|
||||||
linker = CuLink()
|
|
||||||
add_data!(linker, "VecAdd_kernel", vadd_code)
|
|
||||||
|
|
||||||
image = complete(linker)
|
|
||||||
|
|
||||||
mod = CuModule(image)
|
|
||||||
func = CuFunction(mod, "VecAdd_kernel")
|
|
||||||
|
|
||||||
d_a = CUDA.fill(1.0f0, N)
|
|
||||||
d_b = CUDA.fill(2.0f0, N)
|
|
||||||
d_c = CUDA.fill(0.0f0, N)
|
|
||||||
|
|
||||||
# Grid/Block configuration
|
|
||||||
threadsPerBlock = 256;
|
|
||||||
blocksPerGrid = (N + threadsPerBlock - 1) ÷ threadsPerBlock;
|
|
||||||
|
|
||||||
@time CUDA.@sync cudacall(func, Tuple{CuPtr{Cfloat},CuPtr{Cfloat},CuPtr{Cfloat},Cint}, d_a, d_b, d_c, N; threads=threadsPerBlock, blocks=blocksPerGrid)
|
|
||||||
end
|
|
||||||
|
|
||||||
# Number of threads per block/SM + max number of registers
|
# Number of threads per block/SM + max number of registers
|
||||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications
|
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications
|
||||||
# Need to assume a max of 2048 threads per Streaming Multiprocessor (SM)
|
# Need to assume a max of 2048 threads per Streaming Multiprocessor (SM)
|
||||||
|
@ -85,8 +9,7 @@ end
|
||||||
# One thread can at max use 255 registers
|
# One thread can at max use 255 registers
|
||||||
# Meaning one has access to at most 32 registers in the worst case. Using 64 bit values this number gets halfed (see: https://docs.nvidia.com/cuda/cuda-c-programming-guide/#multiprocessor-level (almost at the end of the linked section))
|
# Meaning one has access to at most 32 registers in the worst case. Using 64 bit values this number gets halfed (see: https://docs.nvidia.com/cuda/cuda-c-programming-guide/#multiprocessor-level (almost at the end of the linked section))
|
||||||
|
|
||||||
# I think I will go with max 16 registers for now and leave a better register allocation technique for future work
|
# Maybe helpful for future performance tuning: https://docs.nvidia.com/cuda/cuda-c-programming-guide/#maximum-number-of-registers-per-thread
|
||||||
# Maybe helpful for future tuning: https://docs.nvidia.com/cuda/cuda-c-programming-guide/#maximum-number-of-registers-per-thread
|
|
||||||
|
|
||||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#multiprocessor-level
|
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#multiprocessor-level
|
||||||
# This states, that using fewer registers allows more threads to reside on a single SM which improves performance.
|
# This states, that using fewer registers allows more threads to reside on a single SM which improves performance.
|
||||||
|
@ -101,27 +24,23 @@ end
|
||||||
#
|
#
|
||||||
|
|
||||||
# To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string
|
# To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string
|
||||||
const exitJumpLocationMarker = "\$L__BB0_2"
|
|
||||||
function transpile(expression::ExpressionProcessing.PostfixType)::String
|
function transpile(expression::ExpressionProcessing.PostfixType)::String
|
||||||
|
exitJumpLocationMarker = "\$L__BB0_2"
|
||||||
ptxBuffer = IOBuffer()
|
ptxBuffer = IOBuffer()
|
||||||
|
|
||||||
println(ptxBuffer, get_cuda_header())
|
println(ptxBuffer, get_cuda_header())
|
||||||
println(ptxBuffer, get_kernel_signature("ExpressionProcessing", [Int32, Float32]))
|
println(ptxBuffer, get_kernel_signature("ExpressionProcessing", [Int32, Float32]))
|
||||||
println(ptxBuffer, "{")
|
println(ptxBuffer, "{")
|
||||||
|
|
||||||
# TODO: Actually calculate the number of needed registers and extend to more register kinds
|
|
||||||
println(ptxBuffer, get_register_definitions(1, 5, 0)) # apparently I can define registers anywhere. This might make things easier
|
|
||||||
# TODO: Parameter loading
|
# TODO: Parameter loading
|
||||||
println(ptxBuffer, get_guard_clause())
|
|
||||||
|
|
||||||
# top down create the code. keep track of the max number of variables/parameters used (needed for later iterations. See section "Plan" in "PTX_understanding.md")
|
# TODO: once parameters are loaded, the second parameter for the guard clause can be set
|
||||||
# return this alongside the generated code
|
temp = get_next_free_register("r")
|
||||||
# Generate registers based off of the above number
|
guardClause = get_guard_clause(exitJumpLocationMarker, temp) # since we need to know how many registers we used, we cannot yet write the guard clause to the ptxBuffer
|
||||||
# Variables have: %var0 to %varn - 1
|
|
||||||
# Parameters have: %param0 to %paramn - 1
|
calc_code = generate_calculation_code(expression)
|
||||||
# Code goes here
|
println(ptxBuffer, get_register_definitions())
|
||||||
(calc_code, fRegisterCount) = generate_calculation_code(expression)
|
println(ptxBuffer, guardClause)
|
||||||
println(ptxBuffer, get_register_definitions(0, 0, fRegisterCount))
|
|
||||||
println(ptxBuffer, calc_code)
|
println(ptxBuffer, calc_code)
|
||||||
|
|
||||||
# exit jump location
|
# exit jump location
|
||||||
|
@ -165,50 +84,36 @@ end
|
||||||
"
|
"
|
||||||
Constructs the PTX code used for handling the case where too many threads are started.
|
Constructs the PTX code used for handling the case where too many threads are started.
|
||||||
|
|
||||||
Assumes the following:
|
- param ```nrOfVarSetsRegister```: The register which holds the total amount of variable sets for the kernel
|
||||||
- There are the unused ```32 bit``` registers ```r0, r1, r2, r3 (index of the variable set)```
|
|
||||||
- There is an unused ```predicate``` register ```p0```
|
|
||||||
- The ```32 bit``` register ```r4``` contains the number of variable sets
|
|
||||||
"
|
"
|
||||||
function get_guard_clause()::String
|
function get_guard_clause(exitJumpLocation::String, nrOfVarSetsRegister::String)::String
|
||||||
guardBuffer = IOBuffer()
|
guardBuffer = IOBuffer()
|
||||||
|
|
||||||
println(guardBuffer, "mov.u32 %r0, %ntid.x;") # nr of thread ids
|
threadIds = get_next_free_register("r")
|
||||||
println(guardBuffer, "mov.u32 %r1, %ctaid.x;") # nr of threads per cta
|
threadsPerCTA = get_next_free_register("r")
|
||||||
println(guardBuffer, "mov.u32 %r2, %tid.x;") # id of the current thread
|
currentThreadId = get_next_free_register("r")
|
||||||
|
|
||||||
println(guardBuffer, "mad.lo.s32 %r3, %r0, %r1, %r2;") # the current index (basically index of variable set)
|
# load data into above defined registers
|
||||||
println(guardBuffer, "setp.ge.s32 %p0, %r3, %r4;") # guard clause (p0 = r3 > r4 -> index > nrOfVariableSets)
|
println(guardBuffer, "mov.u32 $threadIds, %ntid.x;")
|
||||||
|
println(guardBuffer, "mov.u32 $threadsPerCTA, %ctaid.x;")
|
||||||
|
println(guardBuffer, "mov.u32 $currentThreadId, %tid.x;")
|
||||||
|
|
||||||
# branch to end if p0 is true
|
globalThreadId = get_next_free_register("r") # basically the index of the thread in the variable set
|
||||||
print(guardBuffer, "@%p0 bra $exitJumpLocationMarker;")
|
breakCondition = get_next_free_register("p")
|
||||||
|
println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;")
|
||||||
|
println(guardBuffer, "setp.ge.s32 $breakCondition, $globalThreadId, $nrOfVarSetsRegister;") # guard clause = index > nrOfVariableSets
|
||||||
|
|
||||||
|
# branch to end if breakCondition is true
|
||||||
|
print(guardBuffer, "@$breakCondition bra $exitJumpLocation;")
|
||||||
|
|
||||||
return String(take!(guardBuffer))
|
return String(take!(guardBuffer))
|
||||||
end
|
end
|
||||||
|
|
||||||
# TODO: refactor this for better usage. Maybe make this generate only one register definition and pass in the details
|
|
||||||
function get_register_definitions(nrPred::Int, nr32Bit::Int, nrFloat32::Int)::String
|
|
||||||
registersBuffer = IOBuffer()
|
|
||||||
|
|
||||||
if nrPred > 0
|
|
||||||
println(registersBuffer, ".reg .pred %p<$nrPred>;")
|
|
||||||
end
|
|
||||||
if nr32Bit > 0
|
|
||||||
println(registersBuffer, ".reg .b32 %r<$nr32Bit>;")
|
|
||||||
end
|
|
||||||
if nrFloat32 > 0
|
|
||||||
println(registersBuffer, ".reg .f32 %f<$nrFloat32>;")
|
|
||||||
end
|
|
||||||
|
|
||||||
return String(take!(registersBuffer))
|
|
||||||
end
|
|
||||||
|
|
||||||
# Current assumption: Expression only made out of constant values
|
# Current assumption: Expression only made out of constant values
|
||||||
function generate_calculation_code(expression::ExpressionProcessing.PostfixType)::Tuple{String, Int}
|
function generate_calculation_code(expression::ExpressionProcessing.PostfixType)::String
|
||||||
codeBuffer = IOBuffer()
|
codeBuffer = IOBuffer()
|
||||||
operands = Vector{Union{Float32, String}}() # Maybe make it of type ANY. Then I could put the register name "on the stack" instead and build up the code like that. Could also make it easier implementing variables/params
|
operands = Vector{Union{Float32, String}}()
|
||||||
|
|
||||||
registerCounter = 0
|
|
||||||
println(expression)
|
println(expression)
|
||||||
for i in eachindex(expression)
|
for i in eachindex(expression)
|
||||||
token = expression[i]
|
token = expression[i]
|
||||||
|
@ -216,9 +121,9 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType)
|
||||||
if token.Type == FLOAT32
|
if token.Type == FLOAT32
|
||||||
push!(operands, reinterpret(Float32, token.Value))
|
push!(operands, reinterpret(Float32, token.Value))
|
||||||
elseif token.Type == OPERATOR
|
elseif token.Type == OPERATOR
|
||||||
# function call to see if operator is unary -> adapt below calculation; probably able to reuse register
|
# get_ptx_operator will be reworked completly to return the code for the operation
|
||||||
operator = get_ptx_operator(reinterpret(Operator, token.Value))
|
operator = get_ptx_operator(reinterpret(Operator, token.Value))
|
||||||
register = "%f$registerCounter"
|
register = get_next_free_register("f")
|
||||||
print(codeBuffer, " $operator $register, ")
|
print(codeBuffer, " $operator $register, ")
|
||||||
|
|
||||||
ops = last(operands, 2)
|
ops = last(operands, 2)
|
||||||
|
@ -227,13 +132,12 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType)
|
||||||
println(codeBuffer, ";")
|
println(codeBuffer, ";")
|
||||||
|
|
||||||
push!(operands, register)
|
push!(operands, register)
|
||||||
registerCounter += 1
|
|
||||||
elseif token.Type == INDEX
|
elseif token.Type == INDEX
|
||||||
# TODO
|
# TODO
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
return (String(take!(codeBuffer)), registerCounter)
|
return String(take!(codeBuffer))
|
||||||
end
|
end
|
||||||
|
|
||||||
function type_to_ptx_type(type::DataType)::String
|
function type_to_ptx_type(type::DataType)::String
|
||||||
|
@ -246,7 +150,7 @@ function type_to_ptx_type(type::DataType)::String
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
# TODO: Probably change this, to return the entire calculation not just the operator. Because for POWER and EXP we need multiple instructions to calculate them (seperation of concerns).
|
# TODO: Change this, to return the entire calculation not just the operator. Because for POWER and EXP we need multiple instructions to calculate them (seperation of concerns).
|
||||||
function get_ptx_operator(operator::Operator)::String
|
function get_ptx_operator(operator::Operator)::String
|
||||||
if operator == ADD
|
if operator == ADD
|
||||||
return "add.f32"
|
return "add.f32"
|
||||||
|
@ -273,12 +177,12 @@ end
|
||||||
|
|
||||||
let registers = Dict() # stores the count of the register already used.
|
let registers = Dict() # stores the count of the register already used.
|
||||||
global get_next_free_register
|
global get_next_free_register
|
||||||
global get_used_registers
|
global get_register_definitions
|
||||||
|
|
||||||
# By convention these names correspond to the following types:
|
# By convention these names correspond to the following types:
|
||||||
# - p -> pred
|
# - p -> pred
|
||||||
# - f32 -> float32
|
# - f -> float32
|
||||||
# - b32 -> 32 bit
|
# - r -> 32 bit
|
||||||
# - var -> float32
|
# - var -> float32
|
||||||
# - param -> float32 !! although, they might get inserted as fixed number and not be sent to gpu?
|
# - param -> float32 !! although, they might get inserted as fixed number and not be sent to gpu?
|
||||||
function get_next_free_register(name::String)::String
|
function get_next_free_register(name::String)::String
|
||||||
|
@ -288,11 +192,31 @@ let registers = Dict() # stores the count of the register already used.
|
||||||
registers[name] = 1
|
registers[name] = 1
|
||||||
end
|
end
|
||||||
|
|
||||||
return string("%", name, registers[name])
|
return string("%", name, registers[name] - 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
function get_used_registers()
|
function get_register_definitions()::String
|
||||||
return pairs(registers)
|
registersBuffer = IOBuffer()
|
||||||
|
|
||||||
|
for definition in registers
|
||||||
|
regType = ""
|
||||||
|
if definition.first == "p"
|
||||||
|
regType = ".pred"
|
||||||
|
elseif definition.first == "f"
|
||||||
|
regType = ".f32"
|
||||||
|
elseif definition.first == "var"
|
||||||
|
regType = ".f32"
|
||||||
|
elseif definition.first == "param"
|
||||||
|
regType = ".f32"
|
||||||
|
elseif definition.first == "r"
|
||||||
|
regType = ".b32"
|
||||||
|
else
|
||||||
|
throw(ArgumentError("Unknown register name used. Name '$(definition.first)' cannot be mapped to a PTX type."))
|
||||||
|
end
|
||||||
|
println(registersBuffer, ".reg $regType %$(definition.first)<$(definition.second)>;")
|
||||||
|
end
|
||||||
|
|
||||||
|
return String(take!(registersBuffer))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user