transpiler: generates valid PTX and evaluates expressions correctly
Some checks are pending
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Waiting to run
Some checks are pending
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Waiting to run
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Waiting to run
This commit is contained in:
@ -6,39 +6,83 @@ using CUDA
|
||||
|
||||
All entries that cannot be filled have ```invalidElement``` as their value
|
||||
"
|
||||
function convert_to_matrix(vec::Vector{Vector{T}}, invalidElement::T)::Matrix{T} where T
|
||||
vecCols = get_max_inner_length(vec)
|
||||
vecRows = length(vec)
|
||||
vecMat = fill(invalidElement, vecCols, vecRows)
|
||||
|
||||
for i in eachindex(vec)
|
||||
vecMat[:,i] = copyto!(vecMat[:,i], vec[i])
|
||||
end
|
||||
function convert_to_matrix(vecs::Vector{Vector{T}}, invalidElement::T)::Matrix{T} where T
|
||||
maxLength = get_max_inner_length(vecs)
|
||||
|
||||
# Pad the shorter vectors with the invalidElement
|
||||
paddedVecs = [vcat(vec, fill(invalidElement, maxLength - length(vec))) for vec in vecs]
|
||||
vecMat = hcat(paddedVecs...)
|
||||
|
||||
return vecMat
|
||||
end
|
||||
|
||||
"Retrieves the number of entries for the largest inner vector"
|
||||
function get_max_inner_length(vec::Vector{Vector{T}})::Int where T
|
||||
maxLength = 0
|
||||
@inbounds for i in eachindex(vec)
|
||||
if length(vec[i]) > maxLength
|
||||
maxLength = length(vec[i])
|
||||
end
|
||||
end
|
||||
|
||||
return maxLength
|
||||
function get_max_inner_length(vecs::Vector{Vector{T}})::Int where T
|
||||
return maximum(length.(vecs))
|
||||
end
|
||||
|
||||
"Returns a CuArray filed with the data provided. The inner vectors do not have to have the same length. All missing elements will be the value ```invalidElement```"
|
||||
function create_cuda_array(data::Vector{Vector{T}}, invalidElement::T)::CuArray{T} where T
|
||||
dataCols = Utils.get_max_inner_length(data)
|
||||
dataRows = length(data)
|
||||
dataMat = Utils.convert_to_matrix(data, invalidElement)
|
||||
cudaArr = CuArray{T}(undef, dataCols, dataRows) # length(parameters) == number of expressions
|
||||
copyto!(cudaArr, dataMat)
|
||||
dataMat = convert_to_matrix(data, invalidElement)
|
||||
cudaArr = CuArray(dataMat)
|
||||
|
||||
return cudaArr
|
||||
end
|
||||
|
||||
struct RegisterManager
|
||||
registers::Dict
|
||||
symtable::Dict
|
||||
end
|
||||
|
||||
function get_next_free_register(manager::RegisterManager, name::String)::String
|
||||
if haskey(manager.registers, name)
|
||||
manager.registers[name] += 1
|
||||
else
|
||||
manager.registers[name] = 1
|
||||
end
|
||||
|
||||
return string("%", name, manager.registers[name] - 1)
|
||||
end
|
||||
|
||||
function get_register_definitions(manager::RegisterManager)::String
|
||||
registersBuffer = IOBuffer()
|
||||
|
||||
for definition in manager.registers
|
||||
regType = ""
|
||||
if definition.first == "p"
|
||||
regType = ".pred"
|
||||
elseif definition.first == "f"
|
||||
regType = ".f32"
|
||||
elseif definition.first == "var"
|
||||
regType = ".f32"
|
||||
elseif definition.first == "param"
|
||||
regType = ".f32"
|
||||
elseif definition.first == "r"
|
||||
regType = ".b32"
|
||||
elseif definition.first == "rd"
|
||||
regType = ".b64"
|
||||
elseif definition.first == "parameter"
|
||||
regType = ".b64"
|
||||
elseif definition.first == "i"
|
||||
regType = ".b64"
|
||||
else
|
||||
throw(ArgumentError("Unknown register name used. Name '$(definition.first)' cannot be mapped to a PTX type."))
|
||||
end
|
||||
println(registersBuffer, ".reg $regType %$(definition.first)<$(definition.second)>;")
|
||||
end
|
||||
|
||||
return String(take!(registersBuffer))
|
||||
end
|
||||
|
||||
"Returns the register for this variable/parameter and true if it is used for the first time and false otherwise."
|
||||
function get_register_for_name(manager::RegisterManager, varName::String)
|
||||
if haskey(manager.symtable, varName)
|
||||
return (manager.symtable[varName], false)
|
||||
else
|
||||
reg = get_next_free_register(manager, "var")
|
||||
manager.symtable[varName] = reg
|
||||
return (reg, true)
|
||||
end
|
||||
end
|
||||
|
||||
end
|
Reference in New Issue
Block a user