module Transpiler using CUDA using ..ExpressionProcessing using ..Utils # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications const BYTES = sizeof(Float32) const Operand = Union{Float32, String} # Operand is either fixed value or register cache = Dict{Expr, CuFunction}() # needed if multiple runs with the same expr but different parameters are performed function evaluate(expressions::Vector{Expr}, variables::Matrix{Float32}, parameters::Vector{Vector{Float32}})::Matrix{Float32} varRows = size(variables, 1) variableCols = size(variables, 2) kernels = Vector{CuFunction}(undef, length(expressions)) # TODO: test this again with multiple threads. The first time I tried, I was using only one thread # Test this parallel version again when doing performance tests. With the simple "functionality" tests this took 0.03 seconds while sequential took "0.00009" seconds # Threads.@threads for i in eachindex(expressions) # cacheLock = ReentrantLock() # cacheHit = false # lock(cacheLock) do # if haskey(cache, expressions[i]) # kernels[i] = cache[expressions[i]] # cacheHit = true # end # end # if cacheHit # continue # end # formattedExpr = ExpressionProcessing.expr_to_postfix(expressions[i]) # kernel = transpile(formattedExpr, varRows, Utils.get_max_inner_length(parameters), variableCols, i-1) # i-1 because julia is 1-based but PTX needs 0-based indexing # linker = CuLink() # add_data!(linker, "ExpressionProcessing", kernel) # image = complete(linker) # mod = CuModule(image) # kernels[i] = CuFunction(mod, "ExpressionProcessing") # @lock cacheLock cache[expressions[i]] = kernels[i] # end @inbounds for i in eachindex(expressions) if haskey(cache, expressions[i]) kernels[i] = cache[expressions[i]] continue end formattedExpr = ExpressionProcessing.expr_to_postfix(expressions[i]) kernel = transpile(formattedExpr, varRows, Utils.get_max_inner_length(parameters), variableCols, i-1) # i-1 because julia is 1-based but PTX needs 0-based indexing linker = CuLink() add_data!(linker, "ExpressionProcessing", kernel) image = complete(linker) mod = CuModule(image) kernels[i] = CuFunction(mod, "ExpressionProcessing") cache[expressions[i]] = kernels[i] end cudaVars = CuArray(variables) # maybe put in shared memory (see PerformanceTests.jl for more info) cudaParams = Utils.create_cuda_array(parameters, NaN32) # maybe make constant (see PerformanceTests.jl for more info) # each expression has nr. of variable sets (nr. of columns of the variables) results and there are n expressions cudaResults = CuArray{Float32}(undef, variableCols, length(expressions)) # execute each kernel (also try doing this with Threads.@threads. Since we can have multiple grids, this might improve performance) for kernel in kernels # config = launch_configuration(kernels[i]) threads = min(variableCols, 256) blocks = cld(variableCols, threads) cudacall(kernel, (CuPtr{Float32},CuPtr{Float32},CuPtr{Float32}), cudaVars, cudaParams, cudaResults; threads=threads, blocks=blocks) end return cudaResults end # To increase performance, it would probably be best for all helper functions to return their IO Buffer and not a string # seekstart(buf1); write(buf2, buf1) " - param ```varSetSize```: The size of a variable set. Equal to number of rows of variable matrix (in a column major matrix) - param ```paramSetSize```: The size of the longest parameter set. As it has to be stored in a column major matrix, the nr of rows is dependent oon the longest parameter set - param ```expressionIndex```: The 0-based index of the expression " function transpile(expression::ExpressionProcessing.PostfixType, varSetSize::Integer, paramSetSize::Integer, nrOfVariableSets::Integer, expressionIndex::Integer)::String exitJumpLocationMarker = "\$L__BB0_2" ptxBuffer = IOBuffer() regManager = Utils.RegisterManager(Dict(), Dict()) # TODO: Suboptimal solution signature, paramLoading = get_kernel_signature("ExpressionProcessing", [Float32, Float32, Float32], regManager) # Vars, Params, Results guardClause, threadId64Reg = get_guard_clause(exitJumpLocationMarker, nrOfVariableSets, regManager) println(ptxBuffer, get_cuda_header()) println(ptxBuffer, signature) println(ptxBuffer, "{") calc_code = generate_calculation_code(expression, "%parameter0", varSetSize, "%parameter1", paramSetSize, "%parameter2", threadId64Reg, expressionIndex, nrOfVariableSets, regManager) println(ptxBuffer, Utils.get_register_definitions(regManager)) println(ptxBuffer, paramLoading) println(ptxBuffer, guardClause) println(ptxBuffer, calc_code) # exit jump location print(ptxBuffer, exitJumpLocationMarker); println(ptxBuffer, ": ret;") println(ptxBuffer, "}") generatedCode = String(take!(ptxBuffer)) return generatedCode end # TODO: Make version, target and address_size configurable; also see what address_size means exactly function get_cuda_header()::String return " .version 8.5 .target sm_61 .address_size 64 " end " param ```parameters```: [1] = nr of var sets; [2] = variables; [3] = parameters; [4] = result " function get_kernel_signature(kernelName::String, parameters::Vector{DataType}, regManager::Utils.RegisterManager)::Tuple{String, String} signatureBuffer = IOBuffer() paramLoadingBuffer = IOBuffer() print(signatureBuffer, ".visible .entry ") print(signatureBuffer, kernelName) println(signatureBuffer, "(") for i in eachindex(parameters) print(signatureBuffer, " .param .u64", " ", "param_", i) parametersLocation = Utils.get_next_free_register(regManager, "rd") println(paramLoadingBuffer, "ld.param.u64 $parametersLocation, [param_$i];") println(paramLoadingBuffer, "cvta.to.global.u64 $(Utils.get_next_free_register(regManager, "parameter")), $parametersLocation;") if i != lastindex(parameters) println(signatureBuffer, ",") end end print(signatureBuffer, ")") return (String(take!(signatureBuffer)), String(take!(paramLoadingBuffer))) end " Constructs the PTX code used for handling the case where too many threads are started. - param ```nrOfVarSetsRegister```: The register which holds the total amount of variable sets for the kernel " function get_guard_clause(exitJumpLocation::String, nrOfVarSets::Integer, regManager::Utils.RegisterManager)::Tuple{String, String} guardBuffer = IOBuffer() threadIds = Utils.get_next_free_register(regManager, "r") threadsPerCTA = Utils.get_next_free_register(regManager, "r") currentThreadId = Utils.get_next_free_register(regManager, "r") println(guardBuffer, "mov.u32 $threadIds, %ntid.x;") println(guardBuffer, "mov.u32 $threadsPerCTA, %ctaid.x;") println(guardBuffer, "mov.u32 $currentThreadId, %tid.x;") globalThreadId = Utils.get_next_free_register(regManager, "r") # basically the index of the thread in the variable set breakCondition = Utils.get_next_free_register(regManager, "p") println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;") println(guardBuffer, "setp.gt.s32 $breakCondition, $globalThreadId, $nrOfVarSets;") # guard clause = index > nrOfVariableSets # branch to end if breakCondition is true println(guardBuffer, "@$breakCondition bra $exitJumpLocation;") # Convert threadIdReg to a 64 bit register. Not 64 bit from the start, as this would take up more registers. Performance tests can be performed to determin if it is faster doing this, or making everything 64-bit from the start threadId64Reg = Utils.get_next_free_register(regManager, "rd") print(guardBuffer, "cvt.u64.u32 $threadId64Reg, $globalThreadId;") return (String(take!(guardBuffer)), threadId64Reg) end " - param ```parametersSetSize```: Size of the largest parameter set " function generate_calculation_code(expression::ExpressionProcessing.PostfixType, variablesLocation::String, variablesSetSize::Integer, parametersLocation::String, parametersSetSize::Integer, resultsLocation::String, threadId64Reg::String, expressionIndex::Integer, nrOfVarSets::Integer, regManager::Utils.RegisterManager)::String codeBuffer = IOBuffer() operands = Vector{Operand}() exprId64Reg = Utils.get_next_free_register(regManager, "rd") println(codeBuffer, "mov.u64 $exprId64Reg, $expressionIndex;") for token in expression if token.Type == FLOAT32 push!(operands, reinterpret(Float32, token.Value)) elseif token.Type == OPERATOR operator = reinterpret(Operator, token.Value) right = nothing if is_binary_operator(operator) right = pop!(operands) left = pop!(operands) else left = pop!(operands) end operation, resultRegister = get_operation(operator, regManager, left, right) println(codeBuffer, operation) push!(operands, resultRegister) elseif token.Type == INDEX if token.Value > 0 # varaibles var, first_access = Utils.get_register_for_name(regManager, "x$(token.Value)") if first_access println(codeBuffer, load_into_register(var, variablesLocation, token.Value, threadId64Reg, variablesSetSize, regManager)) end push!(operands, var) else absVal = abs(token.Value) param, first_access = Utils.get_register_for_name(regManager, "p$absVal") if first_access println(codeBuffer, load_into_register(param, parametersLocation, absVal, exprId64Reg, parametersSetSize, regManager)) end push!(operands, param) end end end tempReg = Utils.get_next_free_register(regManager, "rd") # reg = pop!(operands) # tmp = "abs.f32 $(reg), 16.0;" # push!(operands, reg) println(codeBuffer, " add.u64 $tempReg, $((expressionIndex)*nrOfVarSets), $threadId64Reg; mad.lo.u64 $tempReg, $tempReg, $BYTES, $resultsLocation; st.global.f32 [$tempReg], $(pop!(operands)); ") return String(take!(codeBuffer)) end " Loads a value from a location into the given register. It is assumed that the location refers to a column-major matrix - param ```register```: The register where the loaded value will be stored - param ```loadLocation```: The location from where to load the value - param ```valueIndex```: 1-based index of the value in the variable set/parameter set - param ```setIndexReg64```: 0-based index of the set. Needed to calculate the actual index from the ```valueIndex```. Is equal to the global threadId - param ```setSize```: The size of one set. Needed to calculate the actual index from the ```valueIndex```. Total number of elements in the set (length(set)) " function load_into_register(register::String, loadLocation::String, valueIndex::Integer, setIndexReg64::String, setSize::Integer, regManager::Utils.RegisterManager)::String tempReg = Utils.get_next_free_register(regManager, "rd") # "mad" calculates the offset and "add" applies the offset. Classical pointer arithmetic for accessing values of an array like in C return " mad.lo.u64 $tempReg, $setIndexReg64, $(setSize*BYTES), $((valueIndex - 1) * BYTES); add.u64 $tempReg, $loadLocation, $tempReg; ld.global.f32 $register, [$tempReg];" end function type_to_ptx_type(type::DataType)::String if type == Int64 return ".s64" elseif type == Int32 return ".s32" elseif type == Float32 return ".f32" else return ".b64" end end function get_operation(operator::Operator, regManager::Utils.RegisterManager, left::Operand, right::Union{Operand, Nothing} = nothing)::Tuple{String, String} resultRegister = Utils.get_next_free_register(regManager, "f") resultCode = "" if is_binary_operator(operator) && isnothing(right) throw(ArgumentError("Given operator '$operator' is a binary operator. However only one operator has been given.")) end if operator == ADD resultCode = "add.f32 $resultRegister, $left, $right;" elseif operator == SUBTRACT resultCode = "sub.f32 $resultRegister, $left, $right;" elseif operator == MULTIPLY resultCode = "mul.f32 $resultRegister, $left, $right;" elseif operator == DIVIDE resultCode = "div.approx.f32 $resultRegister, $left, $right;" elseif operator == POWER # x^y == 2^(y*log2(x)) as generated by nvcc for "pow(x, y)" resultCode = " // x^y: lg2.approx.f32 $resultRegister, $left; mul.f32 $resultRegister, $right, $resultRegister; ex2.approx.f32 $resultRegister, $resultRegister;" elseif operator == ABS resultCode = "abs.f32 $resultRegister, $left;" elseif operator == LOG # log(x) == log2(x) * ln(2) as generated by nvcc for "log(x)" resultCode = " // log(x): lg2.approx.f32 $resultRegister, $left; mul.f32 $resultRegister, $resultRegister, 0.693147182;" elseif operator == EXP # e^x == 2^(x/ln(2)) as generated by nvcc for "exp(x)" resultCode = " // e^x: mul.f32 $resultRegister, $left, 1.44269502; ex2.approx.f32 $resultRegister, $resultRegister;" elseif operator == SQRT resultCode = "sqrt.approx.f32 $resultRegister, $left;" else throw(ArgumentError("Operator conversion to ptx not implemented for '$operator'")) end return (resultCode, resultRegister) end end