benchmarking: further tests done. Seems like transpiler takes ages, need to investigate further
Some checks failed
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Has been cancelled
Some checks failed
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Has been cancelled
This commit is contained in:
@ -157,7 +157,7 @@ function get_kernel_signature(kernelName::String, parameters::Vector{DataType},
|
||||
println(signatureBuffer, "(")
|
||||
|
||||
for i in eachindex(parameters)
|
||||
print(signatureBuffer, " .param .u64", " ", "param_", i)
|
||||
print(signatureBuffer, " .param .u64 param_", i)
|
||||
|
||||
parametersLocation = Utils.get_next_free_register(regManager, "rd")
|
||||
println(paramLoadingBuffer, "ld.param.u64 $parametersLocation, [param_$i];")
|
||||
@ -183,21 +183,21 @@ function get_guard_clause(exitJumpLocation::String, nrOfVarSets::Integer, regMan
|
||||
threadsPerCTA = Utils.get_next_free_register(regManager, "r")
|
||||
currentThreadId = Utils.get_next_free_register(regManager, "r")
|
||||
|
||||
println(guardBuffer, "mov.u32 $threadIds, %ntid.x;")
|
||||
println(guardBuffer, "mov.u32 $threadsPerCTA, %ctaid.x;")
|
||||
println(guardBuffer, "mov.u32 $currentThreadId, %tid.x;")
|
||||
println(guardBuffer, "mov.u32 $threadIds, %ntid.x;")
|
||||
println(guardBuffer, "mov.u32 $threadsPerCTA, %ctaid.x;")
|
||||
println(guardBuffer, "mov.u32 $currentThreadId, %tid.x;")
|
||||
|
||||
globalThreadId = Utils.get_next_free_register(regManager, "r") # basically the index of the thread in the variable set
|
||||
breakCondition = Utils.get_next_free_register(regManager, "p")
|
||||
println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;")
|
||||
println(guardBuffer, "setp.gt.s32 $breakCondition, $globalThreadId, $nrOfVarSets;") # guard clause = index > nrOfVariableSets
|
||||
println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;")
|
||||
println(guardBuffer, "setp.gt.s32 $breakCondition, $globalThreadId, $nrOfVarSets;") # guard clause = index > nrOfVariableSets
|
||||
|
||||
# branch to end if breakCondition is true
|
||||
println(guardBuffer, "@$breakCondition bra $exitJumpLocation;")
|
||||
println(guardBuffer, "@$breakCondition bra $exitJumpLocation;")
|
||||
|
||||
# Convert threadIdReg to a 64 bit register. Not 64 bit from the start, as this would take up more registers. Performance tests can be performed to determin if it is faster doing this, or making everything 64-bit from the start
|
||||
threadId64Reg = Utils.get_next_free_register(regManager, "rd")
|
||||
print(guardBuffer, "cvt.u64.u32 $threadId64Reg, $globalThreadId;")
|
||||
print(guardBuffer, "cvt.u64.u32 $threadId64Reg, $globalThreadId;")
|
||||
|
||||
return (String(take!(guardBuffer)), threadId64Reg)
|
||||
end
|
||||
@ -306,38 +306,38 @@ function get_operation(operator::Operator, regManager::Utils.RegisterManager, le
|
||||
end
|
||||
|
||||
if operator == ADD
|
||||
resultCode = "add.f32 $resultRegister, $left, $right;"
|
||||
resultCode = "add.f32 $resultRegister, $left, $right;"
|
||||
elseif operator == SUBTRACT
|
||||
resultCode = "sub.f32 $resultRegister, $left, $right;"
|
||||
resultCode = "sub.f32 $resultRegister, $left, $right;"
|
||||
elseif operator == MULTIPLY
|
||||
resultCode = "mul.f32 $resultRegister, $left, $right;"
|
||||
resultCode = "mul.f32 $resultRegister, $left, $right;"
|
||||
elseif operator == DIVIDE
|
||||
resultCode = "div.approx.f32 $resultRegister, $left, $right;"
|
||||
resultCode = "div.approx.f32 $resultRegister, $left, $right;"
|
||||
elseif operator == POWER
|
||||
# x^y == 2^(y*log2(x)) as generated by nvcc for "pow(x, y)"
|
||||
resultCode = "
|
||||
// x^y:
|
||||
lg2.approx.f32 $resultRegister, $left;
|
||||
mul.f32 $resultRegister, $right, $resultRegister;
|
||||
ex2.approx.f32 $resultRegister, $resultRegister;"
|
||||
lg2.approx.f32 $resultRegister, $left;
|
||||
mul.f32 $resultRegister, $right, $resultRegister;
|
||||
ex2.approx.f32 $resultRegister, $resultRegister;"
|
||||
elseif operator == ABS
|
||||
resultCode = "abs.f32 $resultRegister, $left;"
|
||||
resultCode = "abs.f32 $resultRegister, $left;"
|
||||
elseif operator == LOG
|
||||
# log(x) == log2(x) * ln(2) as generated by nvcc for "log(x)"
|
||||
resultCode = "
|
||||
// log(x):
|
||||
lg2.approx.f32 $resultRegister, $left;
|
||||
mul.f32 $resultRegister, $resultRegister, 0.693147182;"
|
||||
lg2.approx.f32 $resultRegister, $left;
|
||||
mul.f32 $resultRegister, $resultRegister, 0.693147182;"
|
||||
elseif operator == EXP
|
||||
# e^x == 2^(x/ln(2)) as generated by nvcc for "exp(x)"
|
||||
resultCode = "
|
||||
// e^x:
|
||||
mul.f32 $resultRegister, $left, 1.44269502;
|
||||
ex2.approx.f32 $resultRegister, $resultRegister;"
|
||||
mul.f32 $resultRegister, $left, 1.44269502;
|
||||
ex2.approx.f32 $resultRegister, $resultRegister;"
|
||||
elseif operator == SQRT
|
||||
resultCode = "sqrt.approx.f32 $resultRegister, $left;"
|
||||
resultCode = "sqrt.approx.f32 $resultRegister, $left;"
|
||||
elseif operator == INV
|
||||
resultCode = "rcp.approx.f32 $resultRegister, $left;"
|
||||
resultCode = "rcp.approx.f32 $resultRegister, $left;"
|
||||
else
|
||||
throw(ArgumentError("Operator conversion to ptx not implemented for '$operator'"))
|
||||
end
|
||||
|
Reference in New Issue
Block a user