mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Bump the version number to 0.5.1. (#2155)
* Bump the version number to 0.5.1. * Fix clippy lints for 1.78. * More clippy fixes.
This commit is contained in:
@ -250,44 +250,6 @@ impl Map1 for Powf {
|
||||
}
|
||||
}
|
||||
|
||||
struct Sum<'a>(&'a [usize]);
|
||||
impl<'a> Map1 for Sum<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let shape = layout.shape();
|
||||
let src_dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let mut dst_el = el;
|
||||
for &sum_dim in self.0.iter() {
|
||||
dst_el /= src_dims[sum_dim];
|
||||
}
|
||||
let mut sum_dims = self.0.to_vec();
|
||||
// Sort the sum_dims as they have to be processed from left to right when converting the
|
||||
// indexes.
|
||||
sum_dims.sort();
|
||||
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
|
||||
let sum_dims_s: Vec<usize> = sum_dims
|
||||
.iter()
|
||||
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
|
||||
.collect();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev
|
||||
.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())
|
||||
.w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("sum"), kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<T>(dst_el).w()?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct FastReduce<'a>(&'a [usize], ReduceOp);
|
||||
impl<'a> Map1Any for FastReduce<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
|
Reference in New Issue
Block a user