updated all to 32-bit to save registers and boost performance
Some checks failed
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Has been cancelled
Some checks failed
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Has been cancelled
This commit is contained in:
@ -106,11 +106,11 @@ function transpile(expression::ExpressionProcessing.PostfixType)::String
|
||||
ptxBuffer = IOBuffer()
|
||||
|
||||
println(ptxBuffer, get_cuda_header())
|
||||
println(ptxBuffer, get_kernel_signature("ExpressionProcessing", [Int64, Float64]))
|
||||
println(ptxBuffer, get_kernel_signature("ExpressionProcessing", [Int32, Float32]))
|
||||
println(ptxBuffer, "{")
|
||||
|
||||
# TODO: Actually calculate the number of needed registers and extend to more register kinds
|
||||
println(ptxBuffer, get_register_definitions(1, 5, 1)) # apparently I can define registers anywhere. This might make things easier
|
||||
println(ptxBuffer, get_register_definitions(1, 5, 0)) # apparently I can define registers anywhere. This might make things easier
|
||||
# TODO: Parameter loading
|
||||
println(ptxBuffer, get_guard_clause())
|
||||
|
||||
@ -120,7 +120,9 @@ function transpile(expression::ExpressionProcessing.PostfixType)::String
|
||||
# Variables have: %var0 to %varn - 1
|
||||
# Parameters have: %param0 to %paramn - 1
|
||||
# Code goes here
|
||||
println(ptxBuffer, generate_calculation_code(expression))
|
||||
(calc_code, fRegisterCount) = generate_calculation_code(expression)
|
||||
println(ptxBuffer, get_register_definitions(0, 0, fRegisterCount))
|
||||
println(ptxBuffer, calc_code)
|
||||
|
||||
# exit jump location
|
||||
print(ptxBuffer, exitJumpLocationMarker); println(ptxBuffer, ": ret;")
|
||||
@ -184,7 +186,8 @@ function get_guard_clause()::String
|
||||
return String(take!(guardBuffer))
|
||||
end
|
||||
|
||||
function get_register_definitions(nrPred::Int, nr32Bit::Int, nrFloat64::Int)::String
|
||||
# TODO: refactor this for better usage. Maybe make this generate only one register definition and pass in the details
|
||||
function get_register_definitions(nrPred::Int, nr32Bit::Int, nrFloat32::Int)::String
|
||||
registersBuffer = IOBuffer()
|
||||
|
||||
if nrPred > 0
|
||||
@ -193,8 +196,8 @@ function get_register_definitions(nrPred::Int, nr32Bit::Int, nrFloat64::Int)::St
|
||||
if nr32Bit > 0
|
||||
println(registersBuffer, ".reg .b32 %r<$nr32Bit>;")
|
||||
end
|
||||
if nrFloat64 > 0
|
||||
println(registersBuffer, ".reg .f64 %f<$nrFloat64>;")
|
||||
if nrFloat32 > 0
|
||||
println(registersBuffer, ".reg .f32 %f<$nrFloat32>;")
|
||||
end
|
||||
|
||||
return String(take!(registersBuffer))
|
||||
@ -205,38 +208,41 @@ end
|
||||
# Probably do this: Get Expr -> traverse tree -> if child node is Expr: basically replace that node with the register containing the result of that Expr
|
||||
|
||||
# Current assumption: Expression only made out of constant values
|
||||
function generate_calculation_code(expression::ExpressionProcessing.PostfixType)::String
|
||||
function generate_calculation_code(expression::ExpressionProcessing.PostfixType)::Tuple{String, Int}
|
||||
codeBuffer = IOBuffer()
|
||||
operands = Vector{Float64}()
|
||||
operands = Vector{Union{Float32, String}}() # Maybe make it of type ANY. Then I could put the register name "on the stack" instead and build up the code like that. Could also make it easier implementing variables/params
|
||||
|
||||
registerCounter = 0
|
||||
println(expression)
|
||||
for i in eachindex(expression)
|
||||
token = expression[i]
|
||||
|
||||
if token.Type == FLOAT64
|
||||
push!(operands, reinterpret(Float64, token.Value))
|
||||
if token.Type == FLOAT32
|
||||
push!(operands, reinterpret(Float32, token.Value))
|
||||
elseif token.Type == OPERATOR
|
||||
operator = get_ptx_operator(reinterpret(Operator, token.Value))
|
||||
print(codeBuffer, " $operator %f$registerCounter ")
|
||||
register = "%f$registerCounter"
|
||||
print(codeBuffer, " $operator $register, ")
|
||||
|
||||
# Ugly temporary proof of concept which is ignoring unary operators
|
||||
if length(operands) == 0
|
||||
print(codeBuffer, "%f")
|
||||
print(codeBuffer, registerCounter - 2) # add result before previous result
|
||||
end
|
||||
print(codeBuffer, " ")
|
||||
if length(operands) <= 1
|
||||
print(codeBuffer, "%f")
|
||||
print(codeBuffer, registerCounter - 1) # add previous result
|
||||
end
|
||||
print(codeBuffer, " ")
|
||||
# if length(operands) == 0
|
||||
# print(codeBuffer, "%f")
|
||||
# print(codeBuffer, registerCounter - 2) # add result before previous result
|
||||
# end
|
||||
# print(codeBuffer, " ")
|
||||
# if length(operands) <= 1
|
||||
# print(codeBuffer, "%f")
|
||||
# print(codeBuffer, registerCounter - 1) # add previous result
|
||||
# end
|
||||
# print(codeBuffer, " ")
|
||||
|
||||
ops = last(operands, 2)
|
||||
pop!(operands);pop!(operands)
|
||||
print(codeBuffer, join(ops, ", ")) # if operands has too few values it means the previous calculation is needed. So we need to use registerCounter - 1 or registerCounter - 2 previous registers
|
||||
println(codeBuffer, ";")
|
||||
|
||||
# empty!(operands)
|
||||
push!(operands, register)
|
||||
registerCounter += 1
|
||||
end
|
||||
|
||||
@ -247,14 +253,14 @@ function generate_calculation_code(expression::ExpressionProcessing.PostfixType)
|
||||
# on all other operations either 1 or 2 (one if unary and two if binary operator)
|
||||
end
|
||||
|
||||
return String(take!(codeBuffer))
|
||||
return (String(take!(codeBuffer)), registerCounter)
|
||||
end
|
||||
|
||||
function type_to_ptx_type(type::DataType)::String
|
||||
if type == Int64
|
||||
return ".s64"
|
||||
elseif type == Float64
|
||||
return ".f64"
|
||||
elseif type == Float32
|
||||
return ".f32"
|
||||
else
|
||||
return ".b64"
|
||||
end
|
||||
@ -264,23 +270,23 @@ end
|
||||
# Left out for now since I don't have register management yet
|
||||
function get_ptx_operator(operator::Operator)::String
|
||||
if operator == ADD
|
||||
return "add.f64"
|
||||
return "add.f32"
|
||||
elseif operator == SUBTRACT
|
||||
return "sub.f64"
|
||||
return "sub.f32"
|
||||
elseif operator == MULTIPLY
|
||||
return "mul.f64"
|
||||
return "mul.f32"
|
||||
elseif operator == DIVIDE
|
||||
return "div.approx.f64"
|
||||
return "div.approx.f32"
|
||||
elseif operator == POWER
|
||||
return ""
|
||||
elseif operator == ABS
|
||||
return "abs.f64"
|
||||
return "abs.f32"
|
||||
elseif operator == LOG
|
||||
return "lg2.approx.f64"
|
||||
return "lg2.approx.f32"
|
||||
elseif operator == EXP
|
||||
return ""
|
||||
elseif operator == SQRT
|
||||
return "sqrt.approx.f64"
|
||||
return "sqrt.approx.f32"
|
||||
else
|
||||
throw(ArgumentError("Operator conversion to ptx not implemented for $operator"))
|
||||
end
|
||||
|
Reference in New Issue
Block a user