mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Factorize the kernel naming scheme.
This commit is contained in:
@ -277,6 +277,11 @@ impl Map1 for Clone {
|
||||
}
|
||||
}
|
||||
|
||||
fn kernel_name<T: WithDType>(root: &str) -> String {
|
||||
let dtype = T::DTYPE.as_str();
|
||||
format!("{root}_{dtype}")
|
||||
}
|
||||
|
||||
struct Affine(f64, f64);
|
||||
impl Map1 for Affine {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
@ -291,8 +296,7 @@ impl Map1 for Affine {
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let kernel_name = format!("affine_{}", T::DTYPE.as_str());
|
||||
let func = dev.get_or_load_func(&kernel_name, kernels::AFFINE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el) }?;
|
||||
let params = (
|
||||
@ -337,8 +341,7 @@ impl<'a> Map1 for Sum<'a> {
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let kernel_name = format!("sum_{}", T::DTYPE.as_str());
|
||||
let func = dev.get_or_load_func(&kernel_name, kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("sum"), kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<T>(dst_el)?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
|
Reference in New Issue
Block a user