master-thesis/package/src/Transpiler.jl
Wiplinger Daniel - s2310454043 101ccef67b
Some checks are pending
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Waiting to run
benchmarking: tuned blocksizes; slightly improved performance; mostly improved standard deviation
2025-04-12 13:20:50 +02:00

328 lines
13 KiB
Julia

module Transpiler
using CUDA
using ..ExpressionProcessing
using ..Utils
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications
const BYTES = sizeof(Float32)
const Operand = Union{Float32, String} # Operand is either fixed value or register
cache = Dict{Expr, CuFunction}() # needed if multiple runs with the same expr but different parameters are performed
function evaluate(expressions::Vector{Expr}, variables::Matrix{Float32}, parameters::Vector{Vector{Float32}})::Matrix{Float32}
varRows = size(variables, 1)
variableCols = size(variables, 2)
kernels = Vector{CuFunction}(undef, length(expressions))
# TODO: test this again with multiple threads. The first time I tried, I was using only one thread
# Test this parallel version again when doing performance tests. With the simple "functionality" tests this took 0.03 seconds while sequential took "0.00009" seconds
# Threads.@threads for i in eachindex(expressions)
# cacheLock = ReentrantLock()
# cacheHit = false
# lock(cacheLock) do
# if haskey(cache, expressions[i])
# kernels[i] = cache[expressions[i]]
# cacheHit = true
# end
# end
# if cacheHit
# continue
# end
# formattedExpr = ExpressionProcessing.expr_to_postfix(expressions[i])
# kernel = transpile(formattedExpr, varRows, Utils.get_max_inner_length(parameters), variableCols, i-1) # i-1 because julia is 1-based but PTX needs 0-based indexing
# linker = CuLink()
# add_data!(linker, "ExpressionProcessing", kernel)
# image = complete(linker)
# mod = CuModule(image)
# kernels[i] = CuFunction(mod, "ExpressionProcessing")
# @lock cacheLock cache[expressions[i]] = kernels[i]
# end
@inbounds for i in eachindex(expressions)
if haskey(cache, expressions[i])
kernels[i] = cache[expressions[i]]
continue
end
formattedExpr = ExpressionProcessing.expr_to_postfix(expressions[i])
kernel = transpile(formattedExpr, varRows, Utils.get_max_inner_length(parameters), variableCols, i-1) # i-1 because julia is 1-based but PTX needs 0-based indexing
linker = CuLink()
add_data!(linker, "ExpressionProcessing", kernel)
image = complete(linker)
mod = CuModule(image)
kernels[i] = CuFunction(mod, "ExpressionProcessing")
cache[expressions[i]] = kernels[i]
end
cudaVars = CuArray(variables) # maybe put in shared memory (see PerformanceTests.jl for more info)
cudaParams = Utils.create_cuda_array(parameters, NaN32) # maybe make constant (see PerformanceTests.jl for more info)
# each expression has nr. of variable sets (nr. of columns of the variables) results and there are n expressions
cudaResults = CuArray{Float32}(undef, variableCols, length(expressions))
# execute each kernel (also try doing this with Threads.@threads. Since we can have multiple grids, this might improve performance)
for kernel in kernels
# config = launch_configuration(kernels[i])
threads = min(variableCols, 96)
blocks = cld(variableCols, threads)
cudacall(kernel, (CuPtr{Float32},CuPtr{Float32},CuPtr{Float32}), cudaVars, cudaParams, cudaResults; threads=threads, blocks=blocks)
end
return cudaResults
end
# To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string
# seekstart(buf1); write(buf2, buf1)
"
- param ```varSetSize```: The size of a variable set. Equal to number of rows of variable matrix (in a column major matrix)
- param ```paramSetSize```: The size of the longest parameter set. As it has to be stored in a column major matrix, the nr of rows is dependent oon the longest parameter set
- param ```expressionIndex```: The 0-based index of the expression
"
function transpile(expression::ExpressionProcessing.PostfixType, varSetSize::Integer, paramSetSize::Integer,
nrOfVariableSets::Integer, expressionIndex::Integer)::String
exitJumpLocationMarker = "\$L__BB0_2"
ptxBuffer = IOBuffer()
regManager = Utils.RegisterManager(Dict(), Dict())
# TODO: Suboptimal solution
signature, paramLoading = get_kernel_signature("ExpressionProcessing", [Float32, Float32, Float32], regManager) # Vars, Params, Results
guardClause, threadId64Reg = get_guard_clause(exitJumpLocationMarker, nrOfVariableSets, regManager)
println(ptxBuffer, get_cuda_header())
println(ptxBuffer, signature)
println(ptxBuffer, "{")
calc_code = generate_calculation_code(expression, "%parameter0", varSetSize, "%parameter1", paramSetSize, "%parameter2",
threadId64Reg, expressionIndex, nrOfVariableSets, regManager)
println(ptxBuffer, Utils.get_register_definitions(regManager))
println(ptxBuffer, paramLoading)
println(ptxBuffer, guardClause)
println(ptxBuffer, calc_code)
# exit jump location
print(ptxBuffer, exitJumpLocationMarker); println(ptxBuffer, ": ret;")
println(ptxBuffer, "}")
generatedCode = String(take!(ptxBuffer))
return generatedCode
end
# TODO: Make version, target and address_size configurable; also see what address_size means exactly
function get_cuda_header()::String
return "
.version 8.5
.target sm_61
.address_size 64
"
end
"
param ```parameters```: [1] = nr of var sets; [2] = variables; [3] = parameters; [4] = result
"
function get_kernel_signature(kernelName::String, parameters::Vector{DataType}, regManager::Utils.RegisterManager)::Tuple{String, String}
signatureBuffer = IOBuffer()
paramLoadingBuffer = IOBuffer()
print(signatureBuffer, ".visible .entry ")
print(signatureBuffer, kernelName)
println(signatureBuffer, "(")
for i in eachindex(parameters)
print(signatureBuffer, " .param .u64", " ", "param_", i)
parametersLocation = Utils.get_next_free_register(regManager, "rd")
println(paramLoadingBuffer, "ld.param.u64 $parametersLocation, [param_$i];")
println(paramLoadingBuffer, "cvta.to.global.u64 $(Utils.get_next_free_register(regManager, "parameter")), $parametersLocation;")
if i != lastindex(parameters)
println(signatureBuffer, ",")
end
end
print(signatureBuffer, ")")
return (String(take!(signatureBuffer)), String(take!(paramLoadingBuffer)))
end
"
Constructs the PTX code used for handling the case where too many threads are started.
- param ```nrOfVarSetsRegister```: The register which holds the total amount of variable sets for the kernel
"
function get_guard_clause(exitJumpLocation::String, nrOfVarSets::Integer, regManager::Utils.RegisterManager)::Tuple{String, String}
guardBuffer = IOBuffer()
threadIds = Utils.get_next_free_register(regManager, "r")
threadsPerCTA = Utils.get_next_free_register(regManager, "r")
currentThreadId = Utils.get_next_free_register(regManager, "r")
println(guardBuffer, "mov.u32 $threadIds, %ntid.x;")
println(guardBuffer, "mov.u32 $threadsPerCTA, %ctaid.x;")
println(guardBuffer, "mov.u32 $currentThreadId, %tid.x;")
globalThreadId = Utils.get_next_free_register(regManager, "r") # basically the index of the thread in the variable set
breakCondition = Utils.get_next_free_register(regManager, "p")
println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;")
println(guardBuffer, "setp.gt.s32 $breakCondition, $globalThreadId, $nrOfVarSets;") # guard clause = index > nrOfVariableSets
# branch to end if breakCondition is true
println(guardBuffer, "@$breakCondition bra $exitJumpLocation;")
# Convert threadIdReg to a 64 bit register. Not 64 bit from the start, as this would take up more registers. Performance tests can be performed to determin if it is faster doing this, or making everything 64-bit from the start
threadId64Reg = Utils.get_next_free_register(regManager, "rd")
print(guardBuffer, "cvt.u64.u32 $threadId64Reg, $globalThreadId;")
return (String(take!(guardBuffer)), threadId64Reg)
end
"
- param ```parametersSetSize```: Size of the largest parameter set
"
function generate_calculation_code(expression::ExpressionProcessing.PostfixType, variablesLocation::String, variablesSetSize::Integer,
parametersLocation::String, parametersSetSize::Integer, resultsLocation::String,
threadId64Reg::String, expressionIndex::Integer, nrOfVarSets::Integer, regManager::Utils.RegisterManager)::String
codeBuffer = IOBuffer()
operands = Vector{Operand}()
exprId64Reg = Utils.get_next_free_register(regManager, "rd")
println(codeBuffer, "mov.u64 $exprId64Reg, $expressionIndex;")
for token in expression
if token.Type == FLOAT32
push!(operands, reinterpret(Float32, token.Value))
elseif token.Type == OPERATOR
operator = reinterpret(Operator, token.Value)
right = nothing
if is_binary_operator(operator)
right = pop!(operands)
left = pop!(operands)
else
left = pop!(operands)
end
operation, resultRegister = get_operation(operator, regManager, left, right)
println(codeBuffer, operation)
push!(operands, resultRegister)
elseif token.Type == INDEX
if token.Value > 0 # varaibles
var, first_access = Utils.get_register_for_name(regManager, "x$(token.Value)")
if first_access
println(codeBuffer, load_into_register(var, variablesLocation, token.Value, threadId64Reg, variablesSetSize, regManager))
end
push!(operands, var)
else
absVal = abs(token.Value)
param, first_access = Utils.get_register_for_name(regManager, "p$absVal")
if first_access
println(codeBuffer, load_into_register(param, parametersLocation, absVal, exprId64Reg, parametersSetSize, regManager))
end
push!(operands, param)
end
end
end
tempReg = Utils.get_next_free_register(regManager, "rd")
# reg = pop!(operands)
# tmp = "abs.f32 $(reg), 16.0;"
# push!(operands, reg)
println(codeBuffer, "
add.u64 $tempReg, $((expressionIndex)*nrOfVarSets), $threadId64Reg;
mad.lo.u64 $tempReg, $tempReg, $BYTES, $resultsLocation;
st.global.f32 [$tempReg], $(pop!(operands));
")
return String(take!(codeBuffer))
end
"
Loads a value from a location into the given register. It is assumed that the location refers to a column-major matrix
- param ```register```: The register where the loaded value will be stored
- param ```loadLocation```: The location from where to load the value
- param ```valueIndex```: 1-based index of the value in the variable set/parameter set
- param ```setIndexReg64```: 0-based index of the set. Needed to calculate the actual index from the ```valueIndex```. Is equal to the global threadId
- param ```setSize```: The size of one set. Needed to calculate the actual index from the ```valueIndex```. Total number of elements in the set (length(set))
"
function load_into_register(register::String, loadLocation::String, valueIndex::Integer, setIndexReg64::String, setSize::Integer, regManager::Utils.RegisterManager)::String
tempReg = Utils.get_next_free_register(regManager, "rd")
# "mad" calculates the offset and "add" applies the offset. Classical pointer arithmetic for accessing values of an array like in C
return "
mad.lo.u64 $tempReg, $setIndexReg64, $(setSize*BYTES), $((valueIndex - 1) * BYTES);
add.u64 $tempReg, $loadLocation, $tempReg;
ld.global.f32 $register, [$tempReg];"
end
function type_to_ptx_type(type::DataType)::String
if type == Int64
return ".s64"
elseif type == Int32
return ".s32"
elseif type == Float32
return ".f32"
else
return ".b64"
end
end
function get_operation(operator::Operator, regManager::Utils.RegisterManager, left::Operand, right::Union{Operand, Nothing} = nothing)::Tuple{String, String}
resultRegister = Utils.get_next_free_register(regManager, "f")
resultCode = ""
if is_binary_operator(operator) && isnothing(right)
throw(ArgumentError("Given operator '$operator' is a binary operator. However only one operator has been given."))
end
if operator == ADD
resultCode = "add.f32 $resultRegister, $left, $right;"
elseif operator == SUBTRACT
resultCode = "sub.f32 $resultRegister, $left, $right;"
elseif operator == MULTIPLY
resultCode = "mul.f32 $resultRegister, $left, $right;"
elseif operator == DIVIDE
resultCode = "div.approx.f32 $resultRegister, $left, $right;"
elseif operator == POWER
# x^y == 2^(y*log2(x)) as generated by nvcc for "pow(x, y)"
resultCode = "
// x^y:
lg2.approx.f32 $resultRegister, $left;
mul.f32 $resultRegister, $right, $resultRegister;
ex2.approx.f32 $resultRegister, $resultRegister;"
elseif operator == ABS
resultCode = "abs.f32 $resultRegister, $left;"
elseif operator == LOG
# log(x) == log2(x) * ln(2) as generated by nvcc for "log(x)"
resultCode = "
// log(x):
lg2.approx.f32 $resultRegister, $left;
mul.f32 $resultRegister, $resultRegister, 0.693147182;"
elseif operator == EXP
# e^x == 2^(x/ln(2)) as generated by nvcc for "exp(x)"
resultCode = "
// e^x:
mul.f32 $resultRegister, $left, 1.44269502;
ex2.approx.f32 $resultRegister, $resultRegister;"
elseif operator == SQRT
resultCode = "sqrt.approx.f32 $resultRegister, $left;"
else
throw(ArgumentError("Operator conversion to ptx not implemented for '$operator'"))
end
return (resultCode, resultRegister)
end
end