mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Use Map1 for sum.
This commit is contained in:
@ -1,7 +1,9 @@
|
|||||||
use crate::{CpuStorage, DType, Layout, Shape, WithDType};
|
use crate::{CpuStorage, DType, Layout, Shape, WithDType};
|
||||||
use candle_kernels as kernels;
|
use candle_kernels as kernels;
|
||||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||||
use cudarc::driver::{CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig};
|
use cudarc::driver::{
|
||||||
|
CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||||
|
};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@ -244,7 +246,7 @@ enum CudaStorageSlice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
trait Map1 {
|
trait Map1 {
|
||||||
fn f<T: DeviceRepr + WithDType>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
&self,
|
&self,
|
||||||
src: &CudaSlice<T>,
|
src: &CudaSlice<T>,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
@ -276,7 +278,6 @@ impl Map1 for Clone {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct Affine(f64, f64);
|
struct Affine(f64, f64);
|
||||||
|
|
||||||
impl Map1 for Affine {
|
impl Map1 for Affine {
|
||||||
fn f<T: DeviceRepr + WithDType>(
|
fn f<T: DeviceRepr + WithDType>(
|
||||||
&self,
|
&self,
|
||||||
@ -309,6 +310,43 @@ impl Map1 for Affine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Sum<'a>(&'a [usize]);
|
||||||
|
impl<'a> Map1 for Sum<'a> {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let shape = layout.shape();
|
||||||
|
let src_dims = shape.dims();
|
||||||
|
let el = shape.elem_count();
|
||||||
|
let mut dst_el = el;
|
||||||
|
for &sum_dim in self.0.iter() {
|
||||||
|
dst_el /= src_dims[sum_dim];
|
||||||
|
}
|
||||||
|
let mut sum_dims = self.0.to_vec();
|
||||||
|
// Sort the sum_dims as they have to be processed from left to right when converting the
|
||||||
|
// indexes.
|
||||||
|
sum_dims.sort();
|
||||||
|
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
|
||||||
|
let sum_dims_s: Vec<usize> = sum_dims
|
||||||
|
.iter()
|
||||||
|
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
|
||||||
|
.collect();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
|
let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
|
||||||
|
let src = &src.slice(layout.start_offset()..);
|
||||||
|
let kernel_name = format!("sum_{}", T::DTYPE.as_str());
|
||||||
|
let func = dev.get_or_load_func(&kernel_name, kernels::REDUCE)?;
|
||||||
|
let out = dev.alloc_zeros::<T>(dst_el)?;
|
||||||
|
let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn slice_src_and_dst<'a, T>(
|
fn slice_src_and_dst<'a, T>(
|
||||||
src: &'a CudaSlice<T>,
|
src: &'a CudaSlice<T>,
|
||||||
src_l: &Layout,
|
src_l: &Layout,
|
||||||
@ -486,73 +524,8 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||||
let shape = layout.shape();
|
let device = self.device().clone();
|
||||||
let src_dims = shape.dims();
|
let slice = Sum(sum_dims).map(&self.slice, &device, layout)?;
|
||||||
let el = shape.elem_count();
|
|
||||||
let mut dst_el = el;
|
|
||||||
for &sum_dim in sum_dims.iter() {
|
|
||||||
dst_el /= src_dims[sum_dim];
|
|
||||||
}
|
|
||||||
let mut sum_dims = sum_dims.to_vec();
|
|
||||||
// Sort the sum_dims as they have to be processed from left to right when converting the
|
|
||||||
// indexes.
|
|
||||||
sum_dims.sort();
|
|
||||||
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
|
|
||||||
let sum_dims_s: Vec<usize> = sum_dims
|
|
||||||
.iter()
|
|
||||||
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
|
|
||||||
.collect();
|
|
||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
||||||
let dev = self.device();
|
|
||||||
let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
|
|
||||||
let slice = match &self.slice {
|
|
||||||
CudaStorageSlice::U32(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?;
|
|
||||||
let out = dev.alloc_zeros::<u32>(dst_el)?;
|
|
||||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::U32(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::BF16(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?;
|
|
||||||
let out = dev.alloc_zeros::<bf16>(dst_el)?;
|
|
||||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::BF16(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F16(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?;
|
|
||||||
let out = dev.alloc_zeros::<f16>(dst_el)?;
|
|
||||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F16(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F32(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?;
|
|
||||||
let out = dev.alloc_zeros::<f32>(dst_el)?;
|
|
||||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F32(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F64(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?;
|
|
||||||
let out = dev.alloc_zeros::<f64>(dst_el)?;
|
|
||||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F64(out)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let device = dev.clone();
|
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user