Finish reduce kernels.

This commit is contained in:
Nicolas Patry
2023-12-17 19:07:00 +01:00
parent 6bc92e63cb
commit 972903021c
6 changed files with 258 additions and 39 deletions

View File

@ -543,6 +543,7 @@ fn argmax(device: &Device) -> Result<()> {
let t1 = tensor.reshape((190, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
println!("{}", tensor.argmax_keepdim(0)?.argmax_keepdim(2)?);
assert_eq!(
tensor
.argmax_keepdim(0)?