Cuda support for the mnist training. (#277)

* Cuda support for the mnist training.

* min/max fix + testing.

* Add the argmin/argmax tests.

* More cuda support for argmin/argmax.

* Cuda kernels for argmin and argmax.
This commit is contained in:
Laurent Mazare
2023-07-29 19:48:04 +01:00
committed by GitHub
parent 16c33383eb
commit c950a5c6b1
6 changed files with 453 additions and 28 deletions

View File

@ -244,7 +244,7 @@ impl ReduceIndex {
val = s
}
}
dst[unstr_index] = g(val, acc)
dst_to_set[unstr_index] = g(val, acc)
}
}
}

View File

@ -438,6 +438,28 @@ trait Map2InPlace {
}
}
trait Map1Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
wrap: W,
) -> Result<S>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U8(s) => self.f(s, d, l, S::U8)?,
S::U32(s) => self.f(s, d, l, S::U32)?,
S::BF16(s) => self.f(s, d, l, S::BF16)?,
S::F16(s) => self.f(s, d, l, S::F16)?,
S::F32(s) => self.f(s, d, l, S::F32)?,
S::F64(s) => self.f(s, d, l, S::F64)?,
};
Ok(out)
}
}
trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
@ -574,13 +596,14 @@ impl<'a> Map1 for Sum<'a> {
}
struct FastReduce<'a>(&'a [usize], ReduceOp);
impl<'a> Map1 for FastReduce<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
impl<'a> Map1Any for FastReduce<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
wrap: W,
) -> Result<S> {
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
let src_el: usize = src_dims.iter().product();
@ -615,20 +638,32 @@ impl<'a> Map1 for FastReduce<'a> {
.htod_copy([dims.as_slice(), stride.as_slice()].concat())
.w()?;
let src = &src.slice(layout.start_offset()..);
let name = match self.1 {
ReduceOp::Sum => "fast_sum",
ReduceOp::Min => "fast_min",
ReduceOp::Max => "fast_max",
ReduceOp::ArgMin => "fast_argmin",
ReduceOp::ArgMax => "fast_argmax",
let (name, check_empty, return_index) = match self.1 {
ReduceOp::Sum => ("fast_sum", false, false),
ReduceOp::Min => ("fast_min", true, false),
ReduceOp::Max => ("fast_max", true, false),
ReduceOp::ArgMin => ("fast_argmin", true, true),
ReduceOp::ArgMax => ("fast_argmax", true, true),
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
// SAFETY: filled in by the follow up kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
if return_index {
// SAFETY: filled in by the follow up kernel.
let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(S::U32(out))
} else {
// SAFETY: filled in by the follow up kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(wrap(out))
}
}
}