Conv1d test with padding. (#356)

This commit is contained in:
Laurent Mazare
2023-08-09 06:45:38 +02:00
committed by GitHub
parent cf965ecaa8
commit dbc6f281c9
2 changed files with 31 additions and 1 deletions

View File

@ -1060,7 +1060,6 @@ impl<'a> Map2 for Conv2D<'a> {
let dst_idx = dst_idx + dst_w;
let mut d = T::zero();
for offset_h in 0..p.k_h {
// TODO: Handle the case where padding is larger than p.k_h / 2.
let src_h = (p.stride * dst_h + offset_h)
.saturating_sub(p.padding)
.min(p.i_h - 1);

View File

@ -39,6 +39,37 @@ fn conv1d() -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
);
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
assert_eq!(res.dims(), [1, 2, 5]);
/* Note that the default for padding is different from PyTorch at the moment: instead of
padding with zeros, the edge value from the input tensor is used, i.e. this is similiar to:
t = torch.nn.functional.pad(t, (1, 1), mode='replicate')
res = torch.nn.functional.conv1d(t, w, padding=0)
*/
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.5209, 2.6357, -1.3336, 4.1393, 0.4951, 3.6855, -1.1784, 3.5675, 0.5069, 4.9562]
);
Ok(())
}
#[test]
fn conv1d_small() -> Result<()> {
let dev = &Device::Cpu;
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
let res = t.conv1d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 1, 2]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.4056, -0.8689]
);
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
assert_eq!(res.dims(), [1, 1, 4]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.4056, 0.4056, -0.8689, -0.0773],
);
Ok(())
}