mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Handle the case where the kernel is not contiguous in the cuda backend. (#809)
This commit is contained in:
@ -1746,10 +1746,20 @@ impl BackendStorage for CudaStorage {
|
||||
let k = params.k_h * params.k_w * params.c_in;
|
||||
let m = h_out * w_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
let res = col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?;
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||
.transpose(1, 2)?
|
||||
.transpose(1, 3)?;
|
||||
|
Reference in New Issue
Block a user