Im2col cuda optimization. (#2885)

This commit is contained in:
Laurent Mazare
2025-04-13 10:07:53 +02:00
committed by GitHub
parent 15ed0b11ce
commit d9198deb37
2 changed files with 21 additions and 21 deletions

View File

@ -157,15 +157,15 @@ impl Map1 for Im2Col1D {
let shape = layout.shape();
let dims = shape.dims();
let l_out = self.l_out(dims[2]);
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let threads = dims[0] * l_out * dims[1];
let cfg = LaunchConfig::for_num_elems(threads as u32);
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), &kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el)? };
let dst = unsafe { dev.alloc::<T>(threads * self.l_k)? };
let mut builder = func.builder();
barg!(builder, dst_el);
barg!(builder, threads);
barg!(builder, l_out);
barg!(builder, self.l_k);
barg!(builder, self.stride);