Any examples of algorithm.reduction.argmax usage?

I am writing fast where function, and from description argwhere looks like a good candidate but can't figure out how to use it. For example what is OutputChainPtr? @always_inline fn where(self, val: Int8, vec: DTypePointer[DType.int8]): @parameter fn w[nelts : Int](i : Int): var res = (self.values.simd_load[nelts](i) == val) #algorithm.reduction.argmax algorithm.vectorize[nelts, w](ar_size*ar_size)
6 Replies
zverianskii
zverianskiiOP15mo ago
@Jack Clayton hi, any help here? I have spent a week trying to find a way to use SIMD for tasks like argwhere(
Jack Clayton
Jack Clayton14mo ago
I spent some time playing with this, and it didn't work as I expected it to, will have to get back to you on this one @zverianskii
zverianskii
zverianskiiOP14mo ago
hi @Jack Clayton no new so far?
Jack Clayton
Jack Clayton14mo ago
Hi @zverianskii it's very low level at the moment:
from memory.buffer import NDBuffer
from runtime.llcl import Runtime, OwningOutputChainPtr
from algorithm.reduction import argmax

fn main():
alias size = 42

let vector = NDBuffer[1, DimList(size), DType.int32].stack_allocation()
let output = NDBuffer[1, DimList(1), DType.index].stack_allocation()

for i in range(size):
vector[i] = i

with Runtime() as runtime:
let out_chain = OwningOutputChainPtr(runtime)
argmax(
rebind[NDBuffer[1, DimList.create_unknown[1](), DType.int32]](vector),
0,
rebind[NDBuffer[1, DimList.create_unknown[1](), DType.index]](output),
out_chain.borrow(),
)
out_chain.wait()
print("argmax:", output[0])
from memory.buffer import NDBuffer
from runtime.llcl import Runtime, OwningOutputChainPtr
from algorithm.reduction import argmax

fn main():
alias size = 42

let vector = NDBuffer[1, DimList(size), DType.int32].stack_allocation()
let output = NDBuffer[1, DimList(1), DType.index].stack_allocation()

for i in range(size):
vector[i] = i

with Runtime() as runtime:
let out_chain = OwningOutputChainPtr(runtime)
argmax(
rebind[NDBuffer[1, DimList.create_unknown[1](), DType.int32]](vector),
0,
rebind[NDBuffer[1, DimList.create_unknown[1](), DType.index]](output),
out_chain.borrow(),
)
out_chain.wait()
print("argmax:", output[0])
zverianskii
zverianskiiOP14mo ago
Thanks, @Jack Clayton . So, does this function only return a single element? The documentation says it "Finds the indices of the maximum element along the specified axis." I was actually expecting it to return multiple indices if there are repeated elements, which is why I'm interested in this function.
ModularBot
ModularBot14mo ago
Congrats @zverianskii, you just advanced to level 2!
Want results from more Discord servers?
Add your server