Avoid using batched-matmul in nn::Linear. (#2883)

* Avoid using batched-matmul in nn::Linear.

* Also avoid batched matmul in conv1d.

* Also tweak the conv2d.

* Batched tests.

* Also cover conv2d.
This commit is contained in:
Laurent Mazare
2025-04-12 19:53:58 +02:00
committed by GitHub
parent d7b7ce16e4
commit 34505fdf3a
3 changed files with 73 additions and 24 deletions

View File

@ -1199,7 +1199,6 @@ fn gemm_config<T>(
mnk: (m, n, k), mnk: (m, n, k),
})?, })?,
}; };
Ok(StridedBatchedConfig { Ok(StridedBatchedConfig {
batch_size: b as i32, batch_size: b as i32,
gemm, gemm,
@ -1464,12 +1463,11 @@ impl BackendStorage for CudaStorage {
let n = params.c_out; let n = params.c_out;
let k = params.k_size * params.c_in; let k = params.k_size * params.c_in;
let m = l_out; let m = l_out;
let col_l = Layout::contiguous((b, m, k)); let col_l = Layout::contiguous((b * m, k));
let res = if kernel_l.is_contiguous() { let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) let kernel_l =
.transpose(1, 2)? Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
.broadcast_as((b, k, n))?; col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else { } else {
// Make the kernel contiguous if not already the case. // Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe { let mut kernel_c = unsafe {
@ -1477,10 +1475,9 @@ impl BackendStorage for CudaStorage {
.alloc_uninit(kernel_l.shape(), kernel.dtype())? .alloc_uninit(kernel_l.shape(), kernel.dtype())?
}; };
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) let kernel_l =
.transpose(1, 2)? Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
.broadcast_as((b, k, n))?; col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
}; };
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
@ -1578,12 +1575,11 @@ impl BackendStorage for CudaStorage {
let n = params.c_out; let n = params.c_out;
let k = params.k_h * params.k_w * params.c_in; let k = params.k_h * params.k_w * params.c_in;
let m = h_out * w_out; let m = h_out * w_out;
let col_l = Layout::contiguous((b, m, k)); let col_l = Layout::contiguous((b * m, k));
let res = if kernel_l.is_contiguous() { let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) let kernel_l =
.transpose(1, 2)? Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
.broadcast_as((b, k, n))?; col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else { } else {
// Make the kernel contiguous if not already the case. // Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe { let mut kernel_c = unsafe {
@ -1591,10 +1587,9 @@ impl BackendStorage for CudaStorage {
.alloc_uninit(kernel_l.shape(), kernel.dtype())? .alloc_uninit(kernel_l.shape(), kernel.dtype())?
}; };
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) let kernel_l =
.transpose(1, 2)? Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
.broadcast_as((b, k, n))?; col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
}; };
let res_l = Layout::contiguous((b, h_out, w_out, n)) let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)? .transpose(1, 2)?

View File

@ -53,6 +53,20 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?, test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
); );
let res = {
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?
};
assert_eq!(res.dims(), [3, 2, 5]);
// Same as pytorch default padding: use zeros.
assert_eq!(
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
);
assert_eq!(
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
let w = w.transpose(0, 1)?; let w = w.transpose(0, 1)?;
// The CPU kernels applied in the contiguous and non contiguous cases are different. // The CPU kernels applied in the contiguous and non contiguous cases are different.
@ -163,6 +177,22 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075 10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
] ]
); );
let res = {
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
t.conv2d(&w, 0, 1, 1, 1)?
};
assert_eq!(res.dims(), [3, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
);
assert_eq!(
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
[
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;

View File

@ -41,12 +41,36 @@ impl Linear {
impl super::Module for Linear { impl super::Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let w = match *x.dims() { // When possible, we avoid using a broadcasted matmul as it is much slower
[b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?, // than the standard matmul for the cuda and cpu backends.
[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, let x = match *x.dims() {
_ => self.weight.t()?, [b1, b2, m, k] => {
if x.is_contiguous() {
let w = self.weight.t()?;
x.reshape((b1 * b2 * m, k))?
.matmul(&w)?
.reshape((b1, b2, m, ()))?
} else {
let w = self.weight.broadcast_left((b1, b2))?.t()?;
x.matmul(&w)?
}
}
[bsize, m, k] => {
if x.is_contiguous() {
let w = self.weight.t()?;
x.reshape((bsize * m, k))?
.matmul(&w)?
.reshape((bsize, m, ()))?
} else {
let w = self.weight.broadcast_left(bsize)?.t()?;
x.matmul(&w)?
}
}
_ => {
let w = self.weight.t()?;
x.matmul(&w)?
}
}; };
let x = x.matmul(&w)?;
match &self.bias { match &self.bias {
None => Ok(x), None => Ok(x),
Some(bias) => x.broadcast_add(bias), Some(bias) => x.broadcast_add(bias),