mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 04:22:50 +00:00
Cuda support for the mnist training. (#277)
* Cuda support for the mnist training. * min/max fix + testing. * Add the argmin/argmax tests. * More cuda support for argmin/argmax. * Cuda kernels for argmin and argmax.
This commit is contained in:
@ -244,7 +244,7 @@ impl ReduceIndex {
|
||||
val = s
|
||||
}
|
||||
}
|
||||
dst[unstr_index] = g(val, acc)
|
||||
dst_to_set[unstr_index] = g(val, acc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -438,6 +438,28 @@ trait Map2InPlace {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map1Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
wrap: W,
|
||||
) -> Result<S>;
|
||||
|
||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||
let out = match s {
|
||||
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||
S::F64(s) => self.f(s, d, l, S::F64)?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
@ -574,13 +596,14 @@ impl<'a> Map1 for Sum<'a> {
|
||||
}
|
||||
|
||||
struct FastReduce<'a>(&'a [usize], ReduceOp);
|
||||
impl<'a> Map1 for FastReduce<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
impl<'a> Map1Any for FastReduce<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
wrap: W,
|
||||
) -> Result<S> {
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
let src_el: usize = src_dims.iter().product();
|
||||
@ -615,20 +638,32 @@ impl<'a> Map1 for FastReduce<'a> {
|
||||
.htod_copy([dims.as_slice(), stride.as_slice()].concat())
|
||||
.w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let name = match self.1 {
|
||||
ReduceOp::Sum => "fast_sum",
|
||||
ReduceOp::Min => "fast_min",
|
||||
ReduceOp::Max => "fast_max",
|
||||
ReduceOp::ArgMin => "fast_argmin",
|
||||
ReduceOp::ArgMax => "fast_argmax",
|
||||
let (name, check_empty, return_index) = match self.1 {
|
||||
ReduceOp::Sum => ("fast_sum", false, false),
|
||||
ReduceOp::Min => ("fast_min", true, false),
|
||||
ReduceOp::Max => ("fast_max", true, false),
|
||||
ReduceOp::ArgMin => ("fast_argmin", true, true),
|
||||
ReduceOp::ArgMax => ("fast_argmax", true, true),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
|
||||
// SAFETY: filled in by the follow up kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
if return_index {
|
||||
// SAFETY: filled in by the follow up kernel.
|
||||
let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
|
||||
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(S::U32(out))
|
||||
} else {
|
||||
// SAFETY: filled in by the follow up kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(wrap(out))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user