mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Factorize the kernel naming scheme.
This commit is contained in:
@ -277,6 +277,11 @@ impl Map1 for Clone {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn kernel_name<T: WithDType>(root: &str) -> String {
|
||||||
|
let dtype = T::DTYPE.as_str();
|
||||||
|
format!("{root}_{dtype}")
|
||||||
|
}
|
||||||
|
|
||||||
struct Affine(f64, f64);
|
struct Affine(f64, f64);
|
||||||
impl Map1 for Affine {
|
impl Map1 for Affine {
|
||||||
fn f<T: DeviceRepr + WithDType>(
|
fn f<T: DeviceRepr + WithDType>(
|
||||||
@ -291,8 +296,7 @@ impl Map1 for Affine {
|
|||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let kernel_name = format!("affine_{}", T::DTYPE.as_str());
|
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
|
||||||
let func = dev.get_or_load_func(&kernel_name, kernels::AFFINE)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(el) }?;
|
let out = unsafe { dev.alloc::<T>(el) }?;
|
||||||
let params = (
|
let params = (
|
||||||
@ -337,8 +341,7 @@ impl<'a> Map1 for Sum<'a> {
|
|||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
|
let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let kernel_name = format!("sum_{}", T::DTYPE.as_str());
|
let func = dev.get_or_load_func(&kernel_name::<T>("sum"), kernels::REDUCE)?;
|
||||||
let func = dev.get_or_load_func(&kernel_name, kernels::REDUCE)?;
|
|
||||||
let out = dev.alloc_zeros::<T>(dst_el)?;
|
let out = dev.alloc_zeros::<T>(dst_el)?;
|
||||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out);
|
let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
|
Reference in New Issue
Block a user