mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Avoid using batched-matmul in nn::Linear. (#2883)
* Avoid using batched-matmul in nn::Linear. * Also avoid batched matmul in conv1d. * Also tweak the conv2d. * Batched tests. * Also cover conv2d.
This commit is contained in:
@ -1199,7 +1199,6 @@ fn gemm_config<T>(
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
};
|
||||
|
||||
Ok(StridedBatchedConfig {
|
||||
batch_size: b as i32,
|
||||
gemm,
|
||||
@ -1464,12 +1463,11 @@ impl BackendStorage for CudaStorage {
|
||||
let n = params.c_out;
|
||||
let k = params.k_size * params.c_in;
|
||||
let m = l_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let col_l = Layout::contiguous((b * m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
let kernel_l =
|
||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = unsafe {
|
||||
@ -1477,10 +1475,9 @@ impl BackendStorage for CudaStorage {
|
||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||
};
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
let kernel_l =
|
||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
||||
@ -1578,12 +1575,11 @@ impl BackendStorage for CudaStorage {
|
||||
let n = params.c_out;
|
||||
let k = params.k_h * params.k_w * params.c_in;
|
||||
let m = h_out * w_out;
|
||||
let col_l = Layout::contiguous((b, m, k));
|
||||
let col_l = Layout::contiguous((b * m, k));
|
||||
let res = if kernel_l.is_contiguous() {
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
let kernel_l =
|
||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = unsafe {
|
||||
@ -1591,10 +1587,9 @@ impl BackendStorage for CudaStorage {
|
||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||
};
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
.broadcast_as((b, k, n))?;
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
let kernel_l =
|
||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||
.transpose(1, 2)?
|
||||
|
Reference in New Issue
Block a user