From df712ecf64daeb77eed4a368e670050617ed5d26 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 11 Sep 2023 09:48:31 +0100 Subject: [PATCH] Handle the case where the kernel is not contiguous in the cuda backend. (#809) --- candle-core/src/cuda_backend.rs | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index dbfaa928..20cab125 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1746,10 +1746,20 @@ impl BackendStorage for CudaStorage { let k = params.k_h * params.k_w * params.c_in; let m = h_out * w_out; let col_l = Layout::contiguous((b, m, k)); - let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) - .transpose(1, 2)? - .broadcast_as((b, k, n))?; - let res = col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?; + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; let res_l = Layout::contiguous((b, h_out, w_out, n)) .transpose(1, 2)? .transpose(1, 3)?;