module Transpiler using CUDA using ..ExpressionProcessing using ..Utils # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications const BYTES = sizeof(Float32) const Operand = Union{Float32, String} # Operand is either fixed value or register " - kwparam ```frontendCache```: The cache that stores the (partial) results of the frontend, to speedup the pre-processing - kwparam ```frontendCache```: The cache that stores the result of the transpilation. Useful for parameter optimisation, as the same expression gets executed multiple times " function evaluate(expressions::Vector{ExpressionProcessing.PostfixType}, cudaVars::CuArray{Float32}, variableColumns::Integer, variableRows::Integer, parameters::Vector{Vector{Float32}})::Matrix{Float32} cudaParams = Utils.create_cuda_array(parameters, NaN32) # maybe make constant (see PerformanceTests.jl for more info) # each expression has nr. of variable sets (nr. of columns of the variables) results and there are n expressions cudaResults = CuArray{Float32}(undef, variableColumns, length(expressions)) threads = min(variableColumns, 256) blocks = cld(variableColumns, threads) kernelName = "evaluate_gpu" @inbounds Threads.@threads for i in eachindex(expressions) kernel = transpile(expressions[i], variableRows, Utils.get_max_inner_length(parameters), variableColumns, i-1, kernelName) # i-1 because julia is 1-based but PTX needs 0-based indexing compiledKernel = CompileKernel(kernel, kernelName) cudacall(compiledKernel, (CuPtr{Float32},CuPtr{Float32},CuPtr{Float32}), cudaVars, cudaParams, cudaResults; threads=threads, blocks=blocks) end return cudaResults end " A simplified version of the evaluate function. It takes a list of already compiled kernels to be executed. This should yield better performance, where the same expressions should be evaluated multiple times i.e. for parameter optimisation. " function evaluate(kernels::Vector{String}, cudaVars::CuArray{Float32}, nrOfVariableSets::Integer, parameters::Vector{Vector{Float32}}, kernelName::String)::Matrix{Float32} cudaParams = Utils.create_cuda_array(parameters, NaN32) # maybe make constant (see PerformanceTests.jl for more info) # each expression has nr. of variable sets (nr. of columns of the variables) results and there are n expressions cudaResults = CuArray{Float32}(undef, nrOfVariableSets, length(expressions)) threads = min(nrOfVariableSets, 256) blocks = cld(nrOfVariableSets, threads) @inbounds Threads.@threads for i in eachindex(kernels) compiledKernel = CompileKernel(kernel[i], kernelName) cudacall(compiledKernel, (CuPtr{Float32},CuPtr{Float32},CuPtr{Float32}), cudaVars, cudaParams, cudaResults; threads=threads, blocks=blocks) end return cudaResults end function CompileKernel(ptxKernel::String, kernelName::String)::CuFunction linker = CuLink() add_data!(linker, kernelName, ptxKernel) image = complete(linker) mod = CuModule(image) return CuFunction(mod, kernelName) end # To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string # seekstart(buf1); write(buf2, buf1) " - param ```varSetSize```: The size of a variable set. Equal to number of rows of variable matrix (in a column major matrix) - param ```paramSetSize```: The size of the longest parameter set. As it has to be stored in a column major matrix, the nr of rows is dependent oon the longest parameter set - param ```expressionIndex```: The 0-based index of the expression " function transpile(expression::ExpressionProcessing.PostfixType, varSetSize::Integer, paramSetSize::Integer, nrOfVariableSets::Integer, expressionIndex::Integer, kernelName::String)::String exitJumpLocationMarker = "L__BB0_2" ptxBuffer = IOBuffer() regManager = Utils.RegisterManager(Dict(), Dict()) # TODO: Suboptimal solution. get_kernel_signature should also return the name of the registers used for the parameters, so further below, we do not have to hard-code them signature, paramLoading = get_kernel_signature(kernelName, [Float32, Float32, Float32], regManager) # Vars, Params, Results guardClause, threadId64Reg = get_guard_clause(exitJumpLocationMarker, nrOfVariableSets, regManager) println(ptxBuffer, get_cuda_header()) println(ptxBuffer, signature) println(ptxBuffer, "{") calc_code = generate_calculation_code(expression, "%parameter0", varSetSize, "%parameter1", paramSetSize, "%parameter2", threadId64Reg, expressionIndex, nrOfVariableSets, regManager) println(ptxBuffer, Utils.get_register_definitions(regManager)) println(ptxBuffer, paramLoading) println(ptxBuffer, guardClause) println(ptxBuffer, calc_code) # exit jump location print(ptxBuffer, exitJumpLocationMarker); println(ptxBuffer, ": ret;") println(ptxBuffer, "}") generatedCode = String(take!(ptxBuffer)) return generatedCode end # TODO: Make version, target and address_size configurable function get_cuda_header()::String return " .version 8.5 .target sm_61 .address_size 64 " end " param ```parameters```: [1] = nr of var sets; [2] = variables; [3] = parameters; [4] = result " function get_kernel_signature(kernelName::String, parameters::Vector{DataType}, regManager::Utils.RegisterManager)::Tuple{String, String} signatureBuffer = IOBuffer() paramLoadingBuffer = IOBuffer() print(signatureBuffer, ".visible .entry ") print(signatureBuffer, kernelName) println(signatureBuffer, "(") for i in eachindex(parameters) print(signatureBuffer, " .param .u64 param_", i) parametersLocation = Utils.get_next_free_register(regManager, "rd") println(paramLoadingBuffer, "ld.param.u64 $parametersLocation, [param_$i];") println(paramLoadingBuffer, "cvta.to.global.u64 $(Utils.get_next_free_register(regManager, "parameter")), $parametersLocation;") if i != lastindex(parameters) println(signatureBuffer, ",") end end print(signatureBuffer, ")") return (String(take!(signatureBuffer)), String(take!(paramLoadingBuffer))) end " Constructs the PTX code used for handling the case where too many threads are started. - param ```nrOfVarSetsRegister```: The register which holds the total amount of variable sets for the kernel " function get_guard_clause(exitJumpLocation::String, nrOfVarSets::Integer, regManager::Utils.RegisterManager)::Tuple{String, String} guardBuffer = IOBuffer() threadIds = Utils.get_next_free_register(regManager, "r") threadsPerCTA = Utils.get_next_free_register(regManager, "r") currentThreadId = Utils.get_next_free_register(regManager, "r") println(guardBuffer, "mov.u32 $threadIds, %ntid.x;") println(guardBuffer, "mov.u32 $threadsPerCTA, %ctaid.x;") println(guardBuffer, "mov.u32 $currentThreadId, %tid.x;") globalThreadId = Utils.get_next_free_register(regManager, "r") # basically the index of the thread in the variable set breakCondition = Utils.get_next_free_register(regManager, "p") println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;") println(guardBuffer, "setp.gt.s32 $breakCondition, $globalThreadId, $nrOfVarSets;") # guard clause = index > nrOfVariableSets # branch to end if breakCondition is true println(guardBuffer, "@$breakCondition bra $exitJumpLocation;") # Convert threadIdReg to a 64 bit register. Not 64 bit from the start, as this would take up more registers. Performance tests can be performed to determin if it is faster doing this, or making everything 64-bit from the start threadId64Reg = Utils.get_next_free_register(regManager, "rd") print(guardBuffer, "cvt.u64.u32 $threadId64Reg, $globalThreadId;") return (String(take!(guardBuffer)), threadId64Reg) end " - param ```parametersSetSize```: Size of the largest parameter set " function generate_calculation_code(expression::ExpressionProcessing.PostfixType, variablesLocation::String, variablesSetSize::Integer, parametersLocation::String, parametersSetSize::Integer, resultsLocation::String, threadId64Reg::String, expressionIndex::Integer, nrOfVarSets::Integer, regManager::Utils.RegisterManager)::String codeBuffer = IOBuffer() operands = Vector{Operand}() exprId64Reg = Utils.get_next_free_register(regManager, "rd") println(codeBuffer, "mov.u64 $exprId64Reg, $expressionIndex;") for token in expression if token.Type == FLOAT32 value = reinterpret(Float32, token.Value) if isfinite(value) push!(operands, value) else push!(operands, "0f" * string(token.Value, base = 16)) # otherwise, values like "Inf" would be written as "Inf" and therefore not understandable to the PTX compiler end elseif token.Type == OPERATOR operator = reinterpret(Operator, token.Value) right = nothing if is_binary_operator(operator) right = pop!(operands) left = pop!(operands) else left = pop!(operands) end operation, resultRegister = get_operation(operator, regManager, left, right) println(codeBuffer, operation) push!(operands, resultRegister) elseif token.Type == VARIABLE var, first_access = Utils.get_register_for_name(regManager, "x$(token.Value)") if first_access println(codeBuffer, load_into_register(var, variablesLocation, token.Value, threadId64Reg, variablesSetSize, regManager)) end push!(operands, var) elseif token.Type == PARAMETER param, first_access = Utils.get_register_for_name(regManager, "p$(token.Value)") if first_access println(codeBuffer, load_into_register(param, parametersLocation, token.Value, exprId64Reg, parametersSetSize, regManager)) end push!(operands, param) else throw("Token unkown. Token was '$(token)'") end end tempReg = Utils.get_next_free_register(regManager, "rd") println(codeBuffer, " add.u64 $tempReg, $((expressionIndex)*nrOfVarSets), $threadId64Reg; mad.lo.u64 $tempReg, $tempReg, $BYTES, $resultsLocation; st.global.f32 [$tempReg], $(pop!(operands)); ") return String(take!(codeBuffer)) end " Loads a value from a location into the given register. It is assumed that the location refers to a column-major matrix - param ```register```: The register where the loaded value will be stored - param ```loadLocation```: The location from where to load the value - param ```valueIndex```: 1-based index of the value in the variable set/parameter set - param ```setIndexReg64```: 0-based index of the set. Needed to calculate the actual index from the ```valueIndex```. Is equal to the global threadId - param ```setSize```: The size of one set. Needed to calculate the actual index from the ```valueIndex```. Total number of elements in the set (length(set)) " function load_into_register(register::String, loadLocation::String, valueIndex::Integer, setIndexReg64::String, setSize::Integer, regManager::Utils.RegisterManager)::String tempReg = Utils.get_next_free_register(regManager, "rd") # "mad" calculates the offset and "add" applies the offset. Classical pointer arithmetic for accessing values of an array like in C return " mad.lo.u64 $tempReg, $setIndexReg64, $(setSize*BYTES), $((valueIndex - 1) * BYTES); add.u64 $tempReg, $loadLocation, $tempReg; ld.global.f32 $register, [$tempReg];" #TODO: This is not the most efficient way. The index of the set should be calculated only once if possible and not like here multiple times end function type_to_ptx_type(type::DataType)::String if type == Int64 return ".s64" elseif type == Int32 return ".s32" elseif type == Float32 return ".f32" else return ".b64" end end function get_operation(operator::Operator, regManager::Utils.RegisterManager, left::Operand, right::Union{Operand, Nothing} = nothing)::Tuple{String, String} resultRegister = Utils.get_next_free_register(regManager, "f") resultCode = "" if is_binary_operator(operator) && isnothing(right) throw(ArgumentError("Given operator '$operator' is a binary operator. However only one operand has been given.")) end if operator == ADD resultCode = "add.f32 $resultRegister, $left, $right;" elseif operator == SUBTRACT resultCode = "sub.f32 $resultRegister, $left, $right;" elseif operator == MULTIPLY resultCode = "mul.f32 $resultRegister, $left, $right;" elseif operator == DIVIDE resultCode = "div.approx.f32 $resultRegister, $left, $right;" elseif operator == POWER # x^y == 2^(y*log2(x)) as generated by nvcc for "pow(x, y)" resultCode = " // x^y: lg2.approx.f32 $resultRegister, $left; mul.f32 $resultRegister, $right, $resultRegister; ex2.approx.f32 $resultRegister, $resultRegister;" elseif operator == ABS resultCode = "abs.f32 $resultRegister, $left;" elseif operator == LOG # log(x) == log2(x) * ln(2) as generated by nvcc for "log(x)" resultCode = " // log(x): lg2.approx.f32 $resultRegister, $left; mul.f32 $resultRegister, $resultRegister, 0.693147182;" elseif operator == EXP # e^x == 2^(x/ln(2)) as generated by nvcc for "exp(x)" resultCode = " // e^x: mul.f32 $resultRegister, $left, 1.44269502; ex2.approx.f32 $resultRegister, $resultRegister;" elseif operator == SQRT resultCode = "sqrt.approx.f32 $resultRegister, $left;" elseif operator == INV resultCode = "rcp.approx.f32 $resultRegister, $left;" else throw(ArgumentError("Operator conversion to ptx not implemented for '$operator'")) end return (resultCode, resultRegister) end end