Cuda support for the mnist training. (#277)

* Cuda support for the mnist training.

* min/max fix + testing.

* Add the argmin/argmax tests.

* More cuda support for argmin/argmax.

* Cuda kernels for argmin and argmax.
This commit is contained in:
Laurent Mazare
2023-07-29 19:48:04 +01:00
committed by GitHub
parent 16c33383eb
commit c950a5c6b1
6 changed files with 453 additions and 28 deletions

View File

@ -244,7 +244,7 @@ impl ReduceIndex {
val = s
}
}
dst[unstr_index] = g(val, acc)
dst_to_set[unstr_index] = g(val, acc)
}
}
}