mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
362 lines
15 KiB
Rust
362 lines
15 KiB
Rust
use anyhow::Result;
|
|
use candle_core::{test_device, test_utils, Device, IndexOp, Tensor};
|
|
|
|
/* This test is based on the following script.
|
|
import torch
|
|
torch.manual_seed(4242)
|
|
|
|
t = torch.randn((1, 4, 5))
|
|
w = torch.randn((2, 4, 3))
|
|
print(t.flatten())
|
|
print(w.flatten())
|
|
res = torch.nn.functional.conv1d(t, w)
|
|
print(res.flatten())
|
|
res = torch.nn.functional.conv1d(t, w, padding=1)
|
|
print(res.flatten())
|
|
*/
|
|
fn conv1d(dev: &Device) -> Result<()> {
|
|
let t = Tensor::new(
|
|
&[
|
|
0.4056f32, -0.8689, -0.0773, -1.5630, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145,
|
|
1.8025, -0.1536, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278, -1.0124, 0.5599,
|
|
],
|
|
dev,
|
|
)?
|
|
.reshape((1, 4, 5))?;
|
|
let w = Tensor::new(
|
|
&[
|
|
-0.8404f32, -0.3490, 0.0130, 1.3123, 0.1763, -1.9249, 1.4270, 0.9421, 0.8670, -0.7181,
|
|
-1.1111, 0.8869, -1.2429, 1.8357, 1.6052, -1.3844, 0.3951, -1.2036, 0.6686, 1.6261,
|
|
-0.6451, -0.0840, -1.4247, 0.5512,
|
|
],
|
|
dev,
|
|
)?
|
|
.reshape((2, 4, 3))?;
|
|
let res = t.conv1d(&w, 0, 1, 1)?;
|
|
assert_eq!(res.dims(), [1, 2, 3]);
|
|
assert_eq!(
|
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
|
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
|
|
);
|
|
let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?;
|
|
assert_eq!(res.dims(), [1, 2, 5]);
|
|
// Same as pytorch default padding: use zeros.
|
|
assert_eq!(
|
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
|
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
fn conv1d_small(dev: &Device) -> Result<()> {
|
|
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
|
|
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
|
|
let res = t.conv1d(&w, 0, 1, 1)?;
|
|
assert_eq!(res.dims(), [1, 1, 2]);
|
|
assert_eq!(
|
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
|
[0.4056, -0.8689]
|
|
);
|
|
let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?;
|
|
assert_eq!(res.dims(), [1, 1, 4]);
|
|
assert_eq!(
|
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
|
[0.0, 0.4056, -0.8689, -0.0773],
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
/* This test is based on the following script.
|
|
import torch
|
|
torch.manual_seed(4242)
|
|
|
|
t = torch.randn((1, 4, 5, 5))
|
|
w = torch.randn((2, 4, 3, 3))
|
|
print(t.flatten())
|
|
print(w.flatten())
|
|
res = torch.nn.functional.conv2d(t, w)
|
|
print(res.flatten())
|
|
|
|
w_t = w.transpose(0, 1)
|
|
res = torch.nn.functional.conv_transpose2d(t, w_t)
|
|
print(res.shape)
|
|
print(res)
|
|
*/
|
|
fn conv2d(dev: &Device) -> Result<()> {
|
|
let t = Tensor::new(
|
|
&[
|
|
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
|
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
|
|
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
|
|
0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
|
|
1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
|
|
0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
|
|
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
|
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
|
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
|
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
|
],
|
|
dev,
|
|
)?;
|
|
let w = Tensor::new(
|
|
&[
|
|
-0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
|
|
-2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
|
|
-0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
|
|
0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
|
|
0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
|
|
-0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
|
|
1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
|
|
0.5583, 0.4623, 0.6026,
|
|
],
|
|
dev,
|
|
)?;
|
|
let t = t.reshape((1, 4, 5, 5))?;
|
|
let w = w.reshape((2, 4, 3, 3))?;
|
|
let res = t.conv2d(&w, 0, 1, 1)?;
|
|
assert_eq!(res.dims(), [1, 2, 3, 3]);
|
|
assert_eq!(
|
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
|
[
|
|
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
|
|
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
|
]
|
|
);
|
|
if dev.is_cpu() {
|
|
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
|
|
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
|
assert_eq!(
|
|
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
|
[
|
|
[
|
|
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
|
|
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
|
|
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
|
|
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
|
|
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
|
|
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
|
|
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
|
|
],
|
|
[
|
|
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
|
|
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
|
|
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
|
|
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
|
|
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
|
|
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
|
|
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
|
|
]
|
|
]
|
|
);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
/* This test is based on the following script.
|
|
import torch
|
|
torch.manual_seed(4242)
|
|
|
|
t = torch.randn((1, 2, 3, 3))
|
|
w = torch.randn((1, 2, 1, 1))
|
|
print(t.flatten())
|
|
print(w.flatten())
|
|
res = torch.nn.functional.conv2d(t, w)
|
|
print(res.flatten())
|
|
|
|
w_t = w.transpose(0, 1)
|
|
res = torch.nn.functional.conv_transpose2d(t, w_t)
|
|
print(res.shape)
|
|
print(res.flatten())
|
|
|
|
t_t = w.transpose(0, 1)
|
|
res = torch.nn.functional.conv_transpose2d(t_t, w)
|
|
print(res.shape)
|
|
print(res.flatten())
|
|
*/
|
|
fn conv2d_small(dev: &Device) -> Result<()> {
|
|
let t = Tensor::new(
|
|
&[
|
|
0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145,
|
|
-0.6266, 0.3529, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278,
|
|
],
|
|
dev,
|
|
)?;
|
|
let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
|
|
let t = t.reshape((1, 2, 3, 3))?;
|
|
let w = w.reshape((1, 2, 1, 1))?;
|
|
let res = t.conv2d(&w, 0, 1, 1)?;
|
|
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
|
assert_eq!(
|
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
|
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539]
|
|
);
|
|
let res = t.conv2d(&w, 2, 1, 1)?;
|
|
assert_eq!(res.dims(), [1, 1, 7, 7]);
|
|
assert_eq!(
|
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
|
[
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855,
|
|
-1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
|
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
|
|
]
|
|
);
|
|
// TODO: enable the test for cuda once we have the proper implementation in place.
|
|
if dev.is_cpu() {
|
|
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
|
|
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
|
assert_eq!(
|
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
|
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
|
|
);
|
|
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1)?;
|
|
assert_eq!(res.dims(), [2, 2, 3, 3]);
|
|
assert_eq!(
|
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
|
[
|
|
-0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728,
|
|
0.528, -1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838,
|
|
0.5802, -0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396,
|
|
-0.8156, 0.4594, 2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267
|
|
]
|
|
);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn conv2d_smaller(dev: &Device) -> Result<()> {
|
|
let t = Tensor::new(
|
|
&[
|
|
0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866,
|
|
],
|
|
dev,
|
|
)?;
|
|
let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
|
|
let t = t.reshape((1, 1, 3, 3))?;
|
|
let w = w.reshape((1, 1, 3, 3))?;
|
|
let res = t.conv2d(&w, 0, 1, 1)?;
|
|
assert_eq!(res.dims(), [1, 1, 1, 1]);
|
|
assert_eq!(
|
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
|
[-0.6197]
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
/* This test is based on the following script.
|
|
import torch
|
|
torch.manual_seed(4242)
|
|
|
|
t = torch.randn((1, 2, 4, 2))
|
|
w = torch.randn((1, 2, 1, 1))
|
|
print(t.flatten())
|
|
print(w.flatten())
|
|
res = torch.nn.functional.conv2d(t, w)
|
|
print(res.flatten())
|
|
*/
|
|
fn conv2d_non_square(dev: &Device) -> Result<()> {
|
|
let t = Tensor::new(
|
|
&[
|
|
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
|
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699,
|
|
],
|
|
dev,
|
|
)?;
|
|
let w = Tensor::new(&[-1.1351f32, 1.3841], dev)?;
|
|
let t = t.reshape((1, 2, 4, 2))?;
|
|
let w = w.reshape((1, 2, 1, 1))?;
|
|
let res = t.conv2d(&w, 0, 1, 1)?;
|
|
assert_eq!(res.dims(), [1, 1, 4, 2]);
|
|
assert_eq!(
|
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
|
[0.2312, 5.2238, 2.3772, 1.9076, 2.0256, -0.5776, -1.6028, -1.467]
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn conv2d_grad() -> Result<()> {
|
|
use candle_core::Var;
|
|
let dev = &Device::Cpu;
|
|
let t = Var::from_slice(
|
|
&[
|
|
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
|
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
|
|
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
|
|
0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
|
|
1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
|
|
0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
|
|
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
|
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
|
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
|
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
|
],
|
|
(1, 4, 5, 5),
|
|
dev,
|
|
)?;
|
|
let w = Var::from_slice(
|
|
&[
|
|
-0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
|
|
-2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
|
|
-0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
|
|
0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
|
|
0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
|
|
-0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
|
|
1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
|
|
0.5583, 0.4623, 0.6026,
|
|
],
|
|
(2, 4, 3, 3),
|
|
dev,
|
|
)?;
|
|
let res = t.conv2d(&w, 0, 1, 1)?;
|
|
let loss = res.sqr()?.sum_all()?;
|
|
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32);
|
|
let grads = loss.backward()?;
|
|
let grad_t = grads.get(&t).unwrap();
|
|
let grad_w = grads.get(&w).unwrap();
|
|
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
|
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
|
assert_eq!(
|
|
test_utils::to_vec1_round(&grad_t.flatten_all()?, 4)?,
|
|
[
|
|
9.2868, -2.8352, -5.7117, 3.3817, -7.7094, -19.1549, 7.016, 29.1037, 9.3411, 34.7339,
|
|
-22.8726, 24.3502, -39.88, -14.007, 21.076, 9.9419, 13.6333, -34.6796, 11.2073,
|
|
-6.2617, 7.7209, -6.3224, -16.6373, -1.0837, -20.2215, 21.7302, -0.3744, -4.0573,
|
|
5.8163, -3.6529, -30.7319, 14.5468, 87.699, 31.6035, 4.5304, -89.785, -75.3709,
|
|
-57.4327, -7.5602, 92.9585, 18.791, -4.6311, -159.7521, -42.4656, -47.2644, 52.8768,
|
|
37.3172, 48.9978, 12.8192, 2.014, -8.9826, 20.1759, 16.621, 12.0599, 15.3849, 19.9979,
|
|
2.5725, -15.2197, 72.6244, -10.7496, 2.2541, -31.2003, 3.753, -0.2049, 9.7574, -0.6824,
|
|
5.2107, -40.4361, -22.5891, -61.6085, 17.2837, 20.4149, 37.5454, 5.2262, 6.8126,
|
|
23.5361, 23.6173, -9.9866, -9.1324, 4.8664, -35.0617, -26.1023, 63.4757, 25.8144,
|
|
-39.2069, -70.6834, -46.9565, 2.3252, 41.8093, 82.4205, -28.626, -11.7812, -35.3284,
|
|
-10.2771, -28.5694, -9.1258, 7.213, -9.0459, -9.6222, -11.2544
|
|
]
|
|
);
|
|
assert_eq!(
|
|
test_utils::to_vec1_round(&grad_w.flatten_all()?, 4)?,
|
|
[
|
|
-28.9232, -22.8833, -141.2296, 73.3462, 61.074, 47.8125, -20.0013, -73.7086, -41.8217,
|
|
-13.5919, 21.501, 28.7179, 28.5683, -46.8486, -90.1874, 143.6107, 16.6764, 7.4259,
|
|
18.8794, -90.8122, -20.2865, 54.7909, 82.6287, 22.943, 77.8084, -16.3928, -13.1977,
|
|
9.3442, -40.3869, -26.6153, 5.3344, -60.9081, 9.0869, -59.368, 7.081, 58.6391, 5.5476,
|
|
20.5152, 2.4985, -17.2466, -6.802, 22.2146, 30.1511, -7.5179, -37.4588, 5.6654,
|
|
22.5832, 9.0316, 47.0547, 17.6123, 37.3121, -98.1295, -14.6141, -4.7958, -6.3597,
|
|
44.6949, 23.3418, 8.3728, -13.52, 80.0522, -34.2403, -16.3648, -12.3139, 1.9195,
|
|
-33.6244, -14.102, -49.2305, -7.3853, 11.4995, -9.9826, 9.6588, 29.6042
|
|
]
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
|
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
|
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
|
test_device!(
|
|
conv2d_non_square,
|
|
conv2d_non_square_cpu,
|
|
conv2d_non_square_gpu
|
|
);
|
|
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
|
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|