diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index df1aed29..62b0bd15 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1199,7 +1199,6 @@ fn gemm_config( mnk: (m, n, k), })?, }; - Ok(StridedBatchedConfig { batch_size: b as i32, gemm, @@ -1464,12 +1463,11 @@ impl BackendStorage for CudaStorage { let n = params.c_out; let k = params.k_size * params.c_in; let m = l_out; - let col_l = Layout::contiguous((b, m, k)); + let col_l = Layout::contiguous((b * m, k)); let res = if kernel_l.is_contiguous() { - let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) - .transpose(1, 2)? - .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + let kernel_l = + Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?; + col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = unsafe { @@ -1477,10 +1475,9 @@ impl BackendStorage for CudaStorage { .alloc_uninit(kernel_l.shape(), kernel.dtype())? }; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; - let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) - .transpose(1, 2)? - .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + let kernel_l = + Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?; + col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; @@ -1578,12 +1575,11 @@ impl BackendStorage for CudaStorage { let n = params.c_out; let k = params.k_h * params.k_w * params.c_in; let m = h_out * w_out; - let col_l = Layout::contiguous((b, m, k)); + let col_l = Layout::contiguous((b * m, k)); let res = if kernel_l.is_contiguous() { - let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) - .transpose(1, 2)? - .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + let kernel_l = + Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?; + col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = unsafe { @@ -1591,10 +1587,9 @@ impl BackendStorage for CudaStorage { .alloc_uninit(kernel_l.shape(), kernel.dtype())? }; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; - let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) - .transpose(1, 2)? - .broadcast_as((b, k, n))?; - col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + let kernel_l = + Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?; + col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, h_out, w_out, n)) .transpose(1, 2)? diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index d370bdf8..1b815610 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -53,6 +53,20 @@ fn conv1d(dev: &Device) -> Result<()> { test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] ); + let res = { + let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?; + t.conv1d(&w, /*padding*/ 1, 1, 1, 1)? + }; + assert_eq!(res.dims(), [3, 2, 5]); + // Same as pytorch default padding: use zeros. + assert_eq!( + test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?, + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.] + ); + assert_eq!( + test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?, + [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] + ); let w = w.transpose(0, 1)?; // The CPU kernels applied in the contiguous and non contiguous cases are different. @@ -163,6 +177,22 @@ fn conv2d(dev: &Device) -> Result<()> { 10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075 ] ); + let res = { + let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?; + t.conv2d(&w, 0, 1, 1, 1)? + }; + assert_eq!(res.dims(), [3, 2, 3, 3]); + assert_eq!( + test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?, + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.] + ); + assert_eq!( + test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?, + [ + -4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715, + 10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075 + ] + ); let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index 96409042..82c82793 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -41,12 +41,36 @@ impl Linear { impl super::Module for Linear { fn forward(&self, x: &Tensor) -> candle::Result { - let w = match *x.dims() { - [b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?, - [bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, - _ => self.weight.t()?, + // When possible, we avoid using a broadcasted matmul as it is much slower + // than the standard matmul for the cuda and cpu backends. + let x = match *x.dims() { + [b1, b2, m, k] => { + if x.is_contiguous() { + let w = self.weight.t()?; + x.reshape((b1 * b2 * m, k))? + .matmul(&w)? + .reshape((b1, b2, m, ()))? + } else { + let w = self.weight.broadcast_left((b1, b2))?.t()?; + x.matmul(&w)? + } + } + [bsize, m, k] => { + if x.is_contiguous() { + let w = self.weight.t()?; + x.reshape((bsize * m, k))? + .matmul(&w)? + .reshape((bsize, m, ()))? + } else { + let w = self.weight.broadcast_left(bsize)?.t()?; + x.matmul(&w)? + } + } + _ => { + let w = self.weight.t()?; + x.matmul(&w)? + } }; - let x = x.matmul(&w)?; match &self.bias { None => Ok(x), Some(bias) => x.broadcast_add(bias),