benchmarking: further tests done. Seems like transpiler takes ages, need to investigate further
Some checks failed
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.10) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, 1.6) (push) Has been cancelled
CI / Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} (x64, ubuntu-latest, pre) (push) Has been cancelled

This commit is contained in:
Daniel
2025-05-11 16:54:19 +02:00
parent 5b31fbb270
commit 3d80ae95e4
6 changed files with 39 additions and 38 deletions

View File

@ -157,7 +157,7 @@ function get_kernel_signature(kernelName::String, parameters::Vector{DataType},
println(signatureBuffer, "(")
for i in eachindex(parameters)
print(signatureBuffer, " .param .u64", " ", "param_", i)
print(signatureBuffer, " .param .u64 param_", i)
parametersLocation = Utils.get_next_free_register(regManager, "rd")
println(paramLoadingBuffer, "ld.param.u64 $parametersLocation, [param_$i];")
@ -183,21 +183,21 @@ function get_guard_clause(exitJumpLocation::String, nrOfVarSets::Integer, regMan
threadsPerCTA = Utils.get_next_free_register(regManager, "r")
currentThreadId = Utils.get_next_free_register(regManager, "r")
println(guardBuffer, "mov.u32 $threadIds, %ntid.x;")
println(guardBuffer, "mov.u32 $threadsPerCTA, %ctaid.x;")
println(guardBuffer, "mov.u32 $currentThreadId, %tid.x;")
println(guardBuffer, "mov.u32 $threadIds, %ntid.x;")
println(guardBuffer, "mov.u32 $threadsPerCTA, %ctaid.x;")
println(guardBuffer, "mov.u32 $currentThreadId, %tid.x;")
globalThreadId = Utils.get_next_free_register(regManager, "r") # basically the index of the thread in the variable set
breakCondition = Utils.get_next_free_register(regManager, "p")
println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;")
println(guardBuffer, "setp.gt.s32 $breakCondition, $globalThreadId, $nrOfVarSets;") # guard clause = index > nrOfVariableSets
println(guardBuffer, "mad.lo.s32 $globalThreadId, $threadIds, $threadsPerCTA, $currentThreadId;")
println(guardBuffer, "setp.gt.s32 $breakCondition, $globalThreadId, $nrOfVarSets;") # guard clause = index > nrOfVariableSets
# branch to end if breakCondition is true
println(guardBuffer, "@$breakCondition bra $exitJumpLocation;")
println(guardBuffer, "@$breakCondition bra $exitJumpLocation;")
# Convert threadIdReg to a 64 bit register. Not 64 bit from the start, as this would take up more registers. Performance tests can be performed to determin if it is faster doing this, or making everything 64-bit from the start
threadId64Reg = Utils.get_next_free_register(regManager, "rd")
print(guardBuffer, "cvt.u64.u32 $threadId64Reg, $globalThreadId;")
print(guardBuffer, "cvt.u64.u32 $threadId64Reg, $globalThreadId;")
return (String(take!(guardBuffer)), threadId64Reg)
end
@ -306,38 +306,38 @@ function get_operation(operator::Operator, regManager::Utils.RegisterManager, le
end
if operator == ADD
resultCode = "add.f32 $resultRegister, $left, $right;"
resultCode = "add.f32 $resultRegister, $left, $right;"
elseif operator == SUBTRACT
resultCode = "sub.f32 $resultRegister, $left, $right;"
resultCode = "sub.f32 $resultRegister, $left, $right;"
elseif operator == MULTIPLY
resultCode = "mul.f32 $resultRegister, $left, $right;"
resultCode = "mul.f32 $resultRegister, $left, $right;"
elseif operator == DIVIDE
resultCode = "div.approx.f32 $resultRegister, $left, $right;"
resultCode = "div.approx.f32 $resultRegister, $left, $right;"
elseif operator == POWER
# x^y == 2^(y*log2(x)) as generated by nvcc for "pow(x, y)"
resultCode = "
// x^y:
lg2.approx.f32 $resultRegister, $left;
mul.f32 $resultRegister, $right, $resultRegister;
ex2.approx.f32 $resultRegister, $resultRegister;"
lg2.approx.f32 $resultRegister, $left;
mul.f32 $resultRegister, $right, $resultRegister;
ex2.approx.f32 $resultRegister, $resultRegister;"
elseif operator == ABS
resultCode = "abs.f32 $resultRegister, $left;"
resultCode = "abs.f32 $resultRegister, $left;"
elseif operator == LOG
# log(x) == log2(x) * ln(2) as generated by nvcc for "log(x)"
resultCode = "
// log(x):
lg2.approx.f32 $resultRegister, $left;
mul.f32 $resultRegister, $resultRegister, 0.693147182;"
lg2.approx.f32 $resultRegister, $left;
mul.f32 $resultRegister, $resultRegister, 0.693147182;"
elseif operator == EXP
# e^x == 2^(x/ln(2)) as generated by nvcc for "exp(x)"
resultCode = "
// e^x:
mul.f32 $resultRegister, $left, 1.44269502;
ex2.approx.f32 $resultRegister, $resultRegister;"
mul.f32 $resultRegister, $left, 1.44269502;
ex2.approx.f32 $resultRegister, $resultRegister;"
elseif operator == SQRT
resultCode = "sqrt.approx.f32 $resultRegister, $left;"
resultCode = "sqrt.approx.f32 $resultRegister, $left;"
elseif operator == INV
resultCode = "rcp.approx.f32 $resultRegister, $left;"
resultCode = "rcp.approx.f32 $resultRegister, $left;"
else
throw(ArgumentError("Operator conversion to ptx not implemented for '$operator'"))
end