Some checks are pending
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Waiting to run
312 lines
13 KiB
Julia
312 lines
13 KiB
Julia
module Transpiler
|
|
using CUDA
|
|
using ..ExpressionProcessing
|
|
using ..Utils
|
|
|
|
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications
|
|
|
|
const BYTES = sizeof(Float32)
|
|
const Operand = Union{Float32, String} # Operand is either fixed value or register
|
|
|
|
"
|
|
- kwparam ```frontendCache```: The cache that stores the (partial) results of the frontend, to speedup the pre-processing
|
|
- kwparam ```frontendCache```: The cache that stores the result of the transpilation. Useful for parameter optimisation, as the same expression gets executed multiple times
|
|
"
|
|
function evaluate(expressions::Vector{ExpressionProcessing.PostfixType}, cudaVars::CuArray{Float32}, variableColumns::Integer, variableRows::Integer, parameters::Vector{Vector{Float32}})::Matrix{Float32}
|
|
|
|
cudaParams = Utils.create_cuda_array(parameters, NaN32) # maybe make constant (see PerformanceTests.jl for more info)
|
|
|
|
# each expression has nr. of variable sets (nr. of columns of the variables) results and there are n expressions
|
|
cudaResults = CuArray{Float32}(undef, variableColumns, length(expressions))
|
|
|
|
threads = min(variableColumns, 256)
|
|
blocks = cld(variableColumns, threads)
|
|
|
|
kernelName = "evaluate_gpu"
|
|
@inbounds Threads.@threads for i in eachindex(expressions)
|
|
kernel = transpile(expressions[i], variableRows, Utils.get_max_inner_length(parameters), variableColumns, i-1, kernelName) # i-1 because julia is 1-based but PTX needs 0-based indexing
|
|
compiledKernel = CompileKernel(kernel, kernelName)
|
|
cudacall(compiledKernel, (CuPtr{Float32},CuPtr{Float32},CuPtr{Float32}), cudaVars, cudaParams, cudaResults; threads=threads, blocks=blocks)
|
|
end
|
|
|
|
return cudaResults
|
|
end
|
|
|
|
"
|
|
A simplified version of the evaluate function. It takes a list of already compiled kernels to be executed. This should yield better performance, where the same expressions should be evaluated multiple times i.e. for parameter optimisation.
|
|
"
|
|
function evaluate(kernels::Vector{String}, cudaVars::CuArray{Float32}, nrOfVariableSets::Integer, parameters::Vector{Vector{Float32}}, kernelName::String)::Matrix{Float32}
|
|
|
|
cudaParams = Utils.create_cuda_array(parameters, NaN32) # maybe make constant (see PerformanceTests.jl for more info)
|
|
|
|
# each expression has nr. of variable sets (nr. of columns of the variables) results and there are n expressions
|
|
cudaResults = CuArray{Float32}(undef, nrOfVariableSets, length(expressions))
|
|
|
|
threads = min(nrOfVariableSets, 256)
|
|
blocks = cld(nrOfVariableSets, threads)
|
|
|
|
@inbounds Threads.@threads for i in eachindex(kernels)
|
|
compiledKernel = CompileKernel(kernel[i], kernelName)
|
|
cudacall(compiledKernel, (CuPtr{Float32},CuPtr{Float32},CuPtr{Float32}), cudaVars, cudaParams, cudaResults; threads=threads, blocks=blocks)
|
|
end
|
|
|
|
return cudaResults
|
|
end
|
|
|
|
function CompileKernel(ptxKernel::String, kernelName::String)::CuFunction
|
|
linker = CuLink()
|
|
add_data!(linker, kernelName, ptxKernel)
|
|
|
|
image = complete(linker)
|
|
mod = CuModule(image)
|
|
return CuFunction(mod, kernelName)
|
|
end
|
|
|
|
# To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string
|
|
# seekstart(buf1); write(buf2, buf1)
|
|
"
|
|
- param ```varSetSize```: The size of a variable set. Equal to number of rows of variable matrix (in a column major matrix)
|
|
- param ```paramSetSize```: The size of the longest parameter set. As it has to be stored in a column major matrix, the nr of rows is dependent oon the longest parameter set
|
|
- param ```expressionIndex```: The 0-based index of the expression
|
|
"
|
|
function transpile(expression::ExpressionProcessing.PostfixType, varSetSize::Integer, paramSetSize::Integer,
|
|
nrOfVariableSets::Integer, expressionIndex::Integer, kernelName::String)::String
|
|
exitJumpLocationMarker = "L__BB0_2"
|
|
ptxBuffer = IOBuffer()
|
|
regManager = Utils.RegisterManager(Dict(), Dict())
|
|
|
|
# TODO: Suboptimal solution. get_kernel_signature should also return the name of the registers used for the parameters, so further below, we do not have to hard-code them
|
|
signature, paramLoading = get_kernel_signature(kernelName, [Float32, Float32, Float32], regManager) # Vars, Params, Results
|
|
guardClause, threadId64Reg = get_guard_clause(exitJumpLocationMarker, nrOfVariableSets, regManager)
|
|
|
|
println(ptxBuffer, get_cuda_header())
|
|
println(ptxBuffer, signature)
|
|
println(ptxBuffer, "{")
|
|
|
|
|
|
calc_code = generate_calculation_code(expression, "%parameter0", varSetSize, "%parameter1", paramSetSize, "%parameter2",
|
|
threadId64Reg, expressionIndex, nrOfVariableSets, regManager)
|
|
println(ptxBuffer, Utils.get_register_definitions(regManager))
|
|
println(ptxBuffer, paramLoading)
|
|
println(ptxBuffer, guardClause)
|
|
println(ptxBuffer, calc_code)
|
|
|
|
# exit jump location
|
|
print(ptxBuffer, exitJumpLocationMarker); println(ptxBuffer, ": ret;")
|
|
println(ptxBuffer, "}")
|
|
|
|
generatedCode = String(take!(ptxBuffer))
|
|
return generatedCode
|
|
end
|
|
|
|
# TODO: Make version, target and address_size configurable
|
|
function get_cuda_header()::String
|
|
return "
|
|
.version 8.5
|
|
.target sm_61
|
|
.address_size 64
|
|
"
|
|
end
|
|
|
|
"
|
|
param ```parameters```: [1] = nr of var sets; [2] = variables; [3] = parameters; [4] = result
|
|
"
|
|
function get_kernel_signature(kernelName::String, parameters::Vector{DataType}, regManager::Utils.RegisterManager)::Tuple{String, String}
|
|
|
|
signatureBuffer = IOBuffer()
|
|
paramLoadingBuffer = IOBuffer()
|
|
print(signatureBuffer, ".visible .entry ")
|
|
print(signatureBuffer, kernelName)
|
|
println(signatureBuffer, "(")
|
|
|
|
for i in eachindex(parameters)
|
|
print(signatureBuffer, " .param .u64 param_", i)
|
|
|
|
parametersLocation = Utils.get_next_free_register(regManager, "rd")
|
|
println(paramLoadingBuffer, "ld.param.u64 $parametersLocation, [param_$i];")
|
|
println(paramLoadingBuffer, "cvta.to.global.u64 $(Utils.get_next_free_register(regManager, "parameter")), $parametersLocation;")
|
|
if i != lastindex(parameters)
|
|
println(signatureBuffer, ",")
|
|
end
|
|
end
|
|
|
|
print(signatureBuffer, ")")
|
|
return (String(take!(signatureBuffer)), String(take!(paramLoadingBuffer)))
|
|
end
|
|
|
|
"
|
|
Constructs the PTX code used for handling the case where too many threads are started.
|
|
|
|
- param ```nrOfVarSetsRegister```: The register which holds the total amount of variable sets for the kernel
|
|
"
|
|
function get_guard_clause(exitJumpLocation::String, nrOfVarSets::Integer, regManager::Utils.RegisterManager)::Tuple{String, String}
|
|
guardBuffer = IOBuffer()
|
|
|
|
threadIds = Utils.get_next_free_register(regManager, "r")
|
|
threadsPerCTA = Utils.get_next_free_register(regManager, "r")
|
|
currentThreadId = Utils.get_next_free_register(regManager, "r")
|
|
|
|
println(guardBuffer, "mov.u32 $threadIds, %ntid.x;")
|
|
println(guardBuffer, "mov.u32 $threadsPerCTA, %ctaid.x;")
|
|
println(guardBuffer, "mov.u32 $currentThreadId, %tid.x;")
|
|
|
|
globalThreadId = Utils.get_next_free_register(regManager, "r") # basically the index of the thread in the variable set
|
|
breakCondition = Utils.get_next_free_register(regManager, "p")
|
|
println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;")
|
|
println(guardBuffer, "setp.gt.s32 $breakCondition, $globalThreadId, $nrOfVarSets;") # guard clause = index > nrOfVariableSets
|
|
|
|
# branch to end if breakCondition is true
|
|
println(guardBuffer, "@$breakCondition bra $exitJumpLocation;")
|
|
|
|
# Convert threadIdReg to a 64 bit register. Not 64 bit from the start, as this would take up more registers. Performance tests can be performed to determin if it is faster doing this, or making everything 64-bit from the start
|
|
threadId64Reg = Utils.get_next_free_register(regManager, "rd")
|
|
print(guardBuffer, "cvt.u64.u32 $threadId64Reg, $globalThreadId;")
|
|
|
|
return (String(take!(guardBuffer)), threadId64Reg)
|
|
end
|
|
|
|
"
|
|
- param ```parametersSetSize```: Size of the largest parameter set
|
|
"
|
|
function generate_calculation_code(expression::ExpressionProcessing.PostfixType, variablesLocation::String, variablesSetSize::Integer,
|
|
parametersLocation::String, parametersSetSize::Integer, resultsLocation::String,
|
|
threadId64Reg::String, expressionIndex::Integer, nrOfVarSets::Integer, regManager::Utils.RegisterManager)::String
|
|
|
|
codeBuffer = IOBuffer()
|
|
operands = Vector{Operand}()
|
|
|
|
exprId64Reg = Utils.get_next_free_register(regManager, "rd")
|
|
println(codeBuffer, "mov.u64 $exprId64Reg, $expressionIndex;")
|
|
|
|
for token in expression
|
|
|
|
if token.Type == FLOAT32
|
|
value = reinterpret(Float32, token.Value)
|
|
if isfinite(value)
|
|
push!(operands, value)
|
|
else
|
|
push!(operands, "0f" * string(token.Value, base = 16)) # otherwise, values like "Inf" would be written as "Inf" and therefore not understandable to the PTX compiler
|
|
end
|
|
elseif token.Type == OPERATOR
|
|
operator = reinterpret(Operator, token.Value)
|
|
|
|
right = nothing
|
|
if is_binary_operator(operator)
|
|
right = pop!(operands)
|
|
left = pop!(operands)
|
|
else
|
|
left = pop!(operands)
|
|
end
|
|
operation, resultRegister = get_operation(operator, regManager, left, right)
|
|
|
|
println(codeBuffer, operation)
|
|
push!(operands, resultRegister)
|
|
elseif token.Type == VARIABLE
|
|
var, first_access = Utils.get_register_for_name(regManager, "x$(token.Value)")
|
|
if first_access
|
|
println(codeBuffer, load_into_register(var, variablesLocation, token.Value, threadId64Reg, variablesSetSize, regManager))
|
|
end
|
|
push!(operands, var)
|
|
elseif token.Type == PARAMETER
|
|
param, first_access = Utils.get_register_for_name(regManager, "p$(token.Value)")
|
|
if first_access
|
|
println(codeBuffer, load_into_register(param, parametersLocation, token.Value, exprId64Reg, parametersSetSize, regManager))
|
|
end
|
|
push!(operands, param)
|
|
else
|
|
throw("Token unkown. Token was '$(token)'")
|
|
end
|
|
end
|
|
|
|
tempReg = Utils.get_next_free_register(regManager, "rd")
|
|
println(codeBuffer, "
|
|
add.u64 $tempReg, $((expressionIndex)*nrOfVarSets), $threadId64Reg;
|
|
mad.lo.u64 $tempReg, $tempReg, $BYTES, $resultsLocation;
|
|
st.global.f32 [$tempReg], $(pop!(operands));
|
|
")
|
|
|
|
return String(take!(codeBuffer))
|
|
end
|
|
|
|
"
|
|
Loads a value from a location into the given register. It is assumed that the location refers to a column-major matrix
|
|
|
|
- param ```register```: The register where the loaded value will be stored
|
|
- param ```loadLocation```: The location from where to load the value
|
|
- param ```valueIndex```: 1-based index of the value in the variable set/parameter set
|
|
- param ```setIndexReg64```: 0-based index of the set. Needed to calculate the actual index from the ```valueIndex```. Is equal to the global threadId
|
|
- param ```setSize```: The size of one set. Needed to calculate the actual index from the ```valueIndex```. Total number of elements in the set (length(set))
|
|
"
|
|
function load_into_register(register::String, loadLocation::String, valueIndex::Integer, setIndexReg64::String, setSize::Integer, regManager::Utils.RegisterManager)::String
|
|
tempReg = Utils.get_next_free_register(regManager, "rd")
|
|
|
|
# "mad" calculates the offset and "add" applies the offset. Classical pointer arithmetic for accessing values of an array like in C
|
|
return "
|
|
mad.lo.u64 $tempReg, $setIndexReg64, $(setSize*BYTES), $((valueIndex - 1) * BYTES);
|
|
add.u64 $tempReg, $loadLocation, $tempReg;
|
|
ld.global.f32 $register, [$tempReg];"
|
|
#TODO: This is not the most efficient way. The index of the set should be calculated only once if possible and not like here multiple times
|
|
end
|
|
|
|
function type_to_ptx_type(type::DataType)::String
|
|
if type == Int64
|
|
return ".s64"
|
|
elseif type == Int32
|
|
return ".s32"
|
|
elseif type == Float32
|
|
return ".f32"
|
|
else
|
|
return ".b64"
|
|
end
|
|
end
|
|
|
|
function get_operation(operator::Operator, regManager::Utils.RegisterManager, left::Operand, right::Union{Operand, Nothing} = nothing)::Tuple{String, String}
|
|
resultRegister = Utils.get_next_free_register(regManager, "f")
|
|
resultCode = ""
|
|
|
|
if is_binary_operator(operator) && isnothing(right)
|
|
throw(ArgumentError("Given operator '$operator' is a binary operator. However only one operand has been given."))
|
|
end
|
|
|
|
if operator == ADD
|
|
resultCode = "add.f32 $resultRegister, $left, $right;"
|
|
elseif operator == SUBTRACT
|
|
resultCode = "sub.f32 $resultRegister, $left, $right;"
|
|
elseif operator == MULTIPLY
|
|
resultCode = "mul.f32 $resultRegister, $left, $right;"
|
|
elseif operator == DIVIDE
|
|
resultCode = "div.approx.f32 $resultRegister, $left, $right;"
|
|
elseif operator == POWER
|
|
# x^y == 2^(y*log2(x)) as generated by nvcc for "pow(x, y)"
|
|
resultCode = "
|
|
// x^y:
|
|
lg2.approx.f32 $resultRegister, $left;
|
|
mul.f32 $resultRegister, $right, $resultRegister;
|
|
ex2.approx.f32 $resultRegister, $resultRegister;"
|
|
elseif operator == ABS
|
|
resultCode = "abs.f32 $resultRegister, $left;"
|
|
elseif operator == LOG
|
|
# log(x) == log2(x) * ln(2) as generated by nvcc for "log(x)"
|
|
resultCode = "
|
|
// log(x):
|
|
lg2.approx.f32 $resultRegister, $left;
|
|
mul.f32 $resultRegister, $resultRegister, 0.693147182;"
|
|
elseif operator == EXP
|
|
# e^x == 2^(x/ln(2)) as generated by nvcc for "exp(x)"
|
|
resultCode = "
|
|
// e^x:
|
|
mul.f32 $resultRegister, $left, 1.44269502;
|
|
ex2.approx.f32 $resultRegister, $resultRegister;"
|
|
elseif operator == SQRT
|
|
resultCode = "sqrt.approx.f32 $resultRegister, $left;"
|
|
elseif operator == INV
|
|
resultCode = "rcp.approx.f32 $resultRegister, $left;"
|
|
else
|
|
throw(ArgumentError("Operator conversion to ptx not implemented for '$operator'"))
|
|
end
|
|
return (resultCode, resultRegister)
|
|
end
|
|
|
|
end
|
|
|