diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 94abd37a..6add6eb7 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -277,6 +277,11 @@ impl Map1 for Clone { } } +fn kernel_name(root: &str) -> String { + let dtype = T::DTYPE.as_str(); + format!("{root}_{dtype}") +} + struct Affine(f64, f64); impl Map1 for Affine { fn f( @@ -291,8 +296,7 @@ impl Map1 for Affine { let cfg = LaunchConfig::for_num_elems(el as u32); let ds = dev.htod_copy([dims, layout.stride()].concat())?; let src = &src.slice(layout.start_offset()..); - let kernel_name = format!("affine_{}", T::DTYPE.as_str()); - let func = dev.get_or_load_func(&kernel_name, kernels::AFFINE)?; + let func = dev.get_or_load_func(&kernel_name::("affine"), kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }?; let params = ( @@ -337,8 +341,7 @@ impl<'a> Map1 for Sum<'a> { let cfg = LaunchConfig::for_num_elems(el as u32); let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?; let src = &src.slice(layout.start_offset()..); - let kernel_name = format!("sum_{}", T::DTYPE.as_str()); - let func = dev.get_or_load_func(&kernel_name, kernels::REDUCE)?; + let func = dev.get_or_load_func(&kernel_name::("sum"), kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out); // SAFETY: ffi.