Conv1d test with padding. (#356)

This commit is contained in:
Laurent Mazare
2023-08-09 06:45:38 +02:00
committed by GitHub
parent cf965ecaa8
commit dbc6f281c9
2 changed files with 31 additions and 1 deletions

View File

@ -39,6 +39,37 @@ fn conv1d() -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
);
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
assert_eq!(res.dims(), [1, 2, 5]);
/* Note that the default for padding is different from PyTorch at the moment: instead of
padding with zeros, the edge value from the input tensor is used, i.e. this is similiar to:
t = torch.nn.functional.pad(t, (1, 1), mode='replicate')
res = torch.nn.functional.conv1d(t, w, padding=0)
*/
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.5209, 2.6357, -1.3336, 4.1393, 0.4951, 3.6855, -1.1784, 3.5675, 0.5069, 4.9562]
);
Ok(())
}
#[test]
fn conv1d_small() -> Result<()> {
let dev = &Device::Cpu;
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
let res = t.conv1d(&w, 0, 1)?;
assert_eq!(res.dims(), [1, 1, 2]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.4056, -0.8689]
);
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
assert_eq!(res.dims(), [1, 1, 4]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.4056, 0.4056, -0.8689, -0.0773],
);
Ok(())
}