mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Im2col cuda optimization. (#2885)
This commit is contained in:
@ -157,15 +157,15 @@ impl Map1 for Im2Col1D {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let l_out = self.l_out(dims[2]);
|
||||
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let threads = dims[0] * l_out * dims[1];
|
||||
let cfg = LaunchConfig::for_num_elems(threads as u32);
|
||||
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), &kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(dst_el)? };
|
||||
let dst = unsafe { dev.alloc::<T>(threads * self.l_k)? };
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, dst_el);
|
||||
barg!(builder, threads);
|
||||
barg!(builder, l_out);
|
||||
barg!(builder, self.l_k);
|
||||
barg!(builder, self.stride);
|
||||
|
Reference in New Issue
Block a user