mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Avoid using batched-matmul in nn::Linear. (#2883)
* Avoid using batched-matmul in nn::Linear. * Also avoid batched matmul in conv1d. * Also tweak the conv2d. * Batched tests. * Also cover conv2d.
This commit is contained in:
@ -1199,7 +1199,6 @@ fn gemm_config<T>(
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
};
|
||||
|
||||
Ok(StridedBatchedConfig {
|
||||
batch_size: b as i32,
|
||||
gemm,
|
||||
@ -1464,12 +1463,11 @@ impl BackendStorage for CudaStorage {
|
||||
let n = params.c_out;
|
||||
let k = params.k_size * params.c_in;
|
||||
let m = l_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let col_l = Layout::contiguous((b * m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
let kernel_l =
|
||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = unsafe {
|
||||
@ -1477,10 +1475,9 @@ impl BackendStorage for CudaStorage {
|
||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||
};
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
let kernel_l =
|
||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
||||
@ -1578,12 +1575,11 @@ impl BackendStorage for CudaStorage {
|
||||
let n = params.c_out;
|
||||
let k = params.k_h * params.k_w * params.c_in;
|
||||
let m = h_out * w_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let col_l = Layout::contiguous((b * m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
let kernel_l =
|
||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = unsafe {
|
||||
@ -1591,10 +1587,9 @@ impl BackendStorage for CudaStorage {
|
||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||
};
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
let kernel_l =
|
||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||
.transpose(1, 2)?
|
||||
|
@ -53,6 +53,20 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
let res = {
|
||||
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
|
||||
t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?
|
||||
};
|
||||
assert_eq!(res.dims(), [3, 2, 5]);
|
||||
// Same as pytorch default padding: use zeros.
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
|
||||
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
|
||||
let w = w.transpose(0, 1)?;
|
||||
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
||||
@ -163,6 +177,22 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
let res = {
|
||||
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
|
||||
t.conv2d(&w, 0, 1, 1, 1)?
|
||||
};
|
||||
assert_eq!(res.dims(), [3, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
|
||||
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
|
||||
[
|
||||
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
|
||||
|
@ -41,12 +41,36 @@ impl Linear {
|
||||
|
||||
impl super::Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let w = match *x.dims() {
|
||||
[b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
_ => self.weight.t()?,
|
||||
// When possible, we avoid using a broadcasted matmul as it is much slower
|
||||
// than the standard matmul for the cuda and cpu backends.
|
||||
let x = match *x.dims() {
|
||||
[b1, b2, m, k] => {
|
||||
if x.is_contiguous() {
|
||||
let w = self.weight.t()?;
|
||||
x.reshape((b1 * b2 * m, k))?
|
||||
.matmul(&w)?
|
||||
.reshape((b1, b2, m, ()))?
|
||||
} else {
|
||||
let w = self.weight.broadcast_left((b1, b2))?.t()?;
|
||||
x.matmul(&w)?
|
||||
}
|
||||
}
|
||||
[bsize, m, k] => {
|
||||
if x.is_contiguous() {
|
||||
let w = self.weight.t()?;
|
||||
x.reshape((bsize * m, k))?
|
||||
.matmul(&w)?
|
||||
.reshape((bsize, m, ()))?
|
||||
} else {
|
||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||
x.matmul(&w)?
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
let w = self.weight.t()?;
|
||||
x.matmul(&w)?
|
||||
}
|
||||
};
|
||||
let x = x.matmul(&w)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
|
Reference in New Issue
Block a user