Avoid using batched-matmul in nn::Linear. (#2883)

* Avoid using batched-matmul in nn::Linear.

* Also avoid batched matmul in conv1d.

* Also tweak the conv2d.

* Batched tests.

* Also cover conv2d.
This commit is contained in:
Laurent Mazare
2025-04-12 19:53:58 +02:00
committed by GitHub
parent d7b7ce16e4
commit 34505fdf3a
3 changed files with 73 additions and 24 deletions

View File

@ -1199,7 +1199,6 @@ fn gemm_config<T>(
mnk: (m, n, k),
})?,
};
Ok(StridedBatchedConfig {
batch_size: b as i32,
gemm,
@ -1464,12 +1463,11 @@ impl BackendStorage for CudaStorage {
let n = params.c_out;
let k = params.k_size * params.c_in;
let m = l_out;
let col_l = Layout::contiguous((b, m, k));
let col_l = Layout::contiguous((b * m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
let kernel_l =
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe {
@ -1477,10 +1475,9 @@ impl BackendStorage for CudaStorage {
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
let kernel_l =
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
@ -1578,12 +1575,11 @@ impl BackendStorage for CudaStorage {
let n = params.c_out;
let k = params.k_h * params.k_w * params.c_in;
let m = h_out * w_out;
let col_l = Layout::contiguous((b, m, k));
let col_l = Layout::contiguous((b * m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
let kernel_l =
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe {
@ -1591,10 +1587,9 @@ impl BackendStorage for CudaStorage {
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
let kernel_l =
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)?

View File

@ -53,6 +53,20 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
let res = {
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?
};
assert_eq!(res.dims(), [3, 2, 5]);
// Same as pytorch default padding: use zeros.
assert_eq!(
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
);
assert_eq!(
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
let w = w.transpose(0, 1)?;
// The CPU kernels applied in the contiguous and non contiguous cases are different.
@ -163,6 +177,22 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = {
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
t.conv2d(&w, 0, 1, 1, 1)?
};
assert_eq!(res.dims(), [3, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
);
assert_eq!(
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
[
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;

View File

@ -41,12 +41,36 @@ impl Linear {
impl super::Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let w = match *x.dims() {
[b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
// When possible, we avoid using a broadcasted matmul as it is much slower
// than the standard matmul for the cuda and cpu backends.
let x = match *x.dims() {
[b1, b2, m, k] => {
if x.is_contiguous() {
let w = self.weight.t()?;
x.reshape((b1 * b2 * m, k))?
.matmul(&w)?
.reshape((b1, b2, m, ()))?
} else {
let w = self.weight.broadcast_left((b1, b2))?.t()?;
x.matmul(&w)?
}
}
[bsize, m, k] => {
if x.is_contiguous() {
let w = self.weight.t()?;
x.reshape((bsize * m, k))?
.matmul(&w)?
.reshape((bsize, m, ()))?
} else {
let w = self.weight.broadcast_left(bsize)?.t()?;
x.matmul(&w)?
}
}
_ => {
let w = self.weight.t()?;
x.matmul(&w)?
}
};
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),