Add some conv2d tests. (#347)

* Add some conv2d tests.

* Add a simpler conv2d test.

* More conv2d testing + bugfix.

* Add a todo.
This commit is contained in:
Laurent Mazare
2023-08-08 20:02:42 +02:00
committed by GitHub
parent 13ce68ff9b
commit 1e6dbeac01
3 changed files with 121 additions and 4 deletions

View File

@ -1037,7 +1037,7 @@ struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> {
const OP: &'static str = "conv2d";
fn f<T: 'static + num_traits::NumAssign + Copy>(
fn f<T: 'static + num_traits::NumAssign + Copy + std::fmt::Display>(
&self,
inp: &[T],
inp_l: &Layout,
@ -1063,11 +1063,13 @@ impl<'a> Map2 for Conv2D<'a> {
let dst_idx = dst_idx + dst_w;
let mut d = T::zero();
for offset_h in 0..p.k_h {
let src_h_plus = p.stride * dst_h + offset_h;
// TODO: Handle the case where padding is larger than p.k_h / 2.
let src_h_plus = p.stride * dst_h + offset_h + p.k_h / 2 - p.padding;
if p.k_h / 2 <= src_h_plus && src_h_plus < p.k_h / 2 + p.i_h {
let src_h = src_h_plus - p.k_h / 2;
for offset_w in 0..p.k_w {
let src_w_plus = p.stride * dst_w + offset_w;
let src_w_plus =
p.stride * dst_w + offset_w + p.k_w / 2 - p.padding;
// inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
if p.k_w / 2 <= src_w_plus && src_w_plus < p.k_w / 2 + p.i_w {
let src_w = src_w_plus - p.k_w / 2;

View File

@ -53,7 +53,9 @@ impl DType {
}
}
pub trait WithDType: Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + 'static {
pub trait WithDType:
Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + std::fmt::Display + 'static
{
const DTYPE: DType;
fn from_f64(v: f64) -> Self;

View File

@ -0,0 +1,113 @@
mod test_utils;
use anyhow::Result;
use candle_core::{Device, Tensor};
/* This test is based on the following script.
import torch
torch.manual_seed(4242)
t = torch.randn((1, 4, 5, 5))
w = torch.randn((2, 4, 3, 3))
print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv2d(t, w)
print(res.flatten())
*/
#[test]
fn conv2d() -> Result<()> {
let dev = &Device::Cpu;
let t = Tensor::new(
&[
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
],
dev,
)?;
let w = Tensor::new(
&[
-0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
-2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
-0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
-0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
0.5583, 0.4623, 0.6026,
],
dev,
)?;
let t = t.reshape((1, 4, 5, 5))?;
let w = w.reshape((2, 4, 3, 3))?;
let res = t.conv2d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
Ok(())
}
/* This test is based on the following script.
import torch
torch.manual_seed(4242)
t = torch.randn((1, 2, 3, 3))
w = torch.randn((1, 2, 1, 1))
print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv2d(t, w)
print(res.flatten())
*/
#[test]
fn conv2d_small() -> Result<()> {
let dev = &Device::Cpu;
let t = Tensor::new(
&[
0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145,
-0.6266, 0.3529, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278,
],
dev,
)?;
let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
let t = t.reshape((1, 2, 3, 3))?;
let w = w.reshape((1, 2, 1, 1))?;
let res = t.conv2d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539]
);
Ok(())
}
#[test]
fn conv2d_smaller() -> Result<()> {
let dev = &Device::Cpu;
let t = Tensor::new(
&[
0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866,
],
dev,
)?;
let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
let t = t.reshape((1, 1, 3, 3))?;
let w = w.reshape((1, 1, 3, 3))?;
let res = t.conv2d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 1, 1, 1]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[-0.6197]
);
Ok(())
}