Factorize the kernel naming scheme.

This commit is contained in:
laurent
2023-06-29 09:29:59 +01:00
parent d3c7b0d168
commit fff13dbb4e

View File

@ -277,6 +277,11 @@ impl Map1 for Clone {
}
}
fn kernel_name<T: WithDType>(root: &str) -> String {
let dtype = T::DTYPE.as_str();
format!("{root}_{dtype}")
}
struct Affine(f64, f64);
impl Map1 for Affine {
fn f<T: DeviceRepr + WithDType>(
@ -291,8 +296,7 @@ impl Map1 for Affine {
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let src = &src.slice(layout.start_offset()..);
let kernel_name = format!("affine_{}", T::DTYPE.as_str());
let func = dev.get_or_load_func(&kernel_name, kernels::AFFINE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }?;
let params = (
@ -337,8 +341,7 @@ impl<'a> Map1 for Sum<'a> {
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
let src = &src.slice(layout.start_offset()..);
let kernel_name = format!("sum_{}", T::DTYPE.as_str());
let func = dev.get_or_load_func(&kernel_name, kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("sum"), kernels::REDUCE)?;
let out = dev.alloc_zeros::<T>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out);
// SAFETY: ffi.