mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 12:28:06 +00:00
Avoid using batched-matmul in nn::Linear. (#2883)
* Avoid using batched-matmul in nn::Linear. * Also avoid batched matmul in conv1d. * Also tweak the conv2d. * Batched tests. * Also cover conv2d.
This commit is contained in:
@ -53,6 +53,20 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
let res = {
|
||||
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
|
||||
t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?
|
||||
};
|
||||
assert_eq!(res.dims(), [3, 2, 5]);
|
||||
// Same as pytorch default padding: use zeros.
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
|
||||
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
|
||||
let w = w.transpose(0, 1)?;
|
||||
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
||||
@ -163,6 +177,22 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
let res = {
|
||||
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
|
||||
t.conv2d(&w, 0, 1, 1, 1)?
|
||||
};
|
||||
assert_eq!(res.dims(), [3, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
|
||||
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
|
||||
[
|
||||
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
|
||||
|
Reference in New Issue
Block a user