mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Conv1d test with padding. (#356)
This commit is contained in:
@ -1060,7 +1060,6 @@ impl<'a> Map2 for Conv2D<'a> {
|
|||||||
let dst_idx = dst_idx + dst_w;
|
let dst_idx = dst_idx + dst_w;
|
||||||
let mut d = T::zero();
|
let mut d = T::zero();
|
||||||
for offset_h in 0..p.k_h {
|
for offset_h in 0..p.k_h {
|
||||||
// TODO: Handle the case where padding is larger than p.k_h / 2.
|
|
||||||
let src_h = (p.stride * dst_h + offset_h)
|
let src_h = (p.stride * dst_h + offset_h)
|
||||||
.saturating_sub(p.padding)
|
.saturating_sub(p.padding)
|
||||||
.min(p.i_h - 1);
|
.min(p.i_h - 1);
|
||||||
|
@ -39,6 +39,37 @@ fn conv1d() -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
|
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
|
||||||
);
|
);
|
||||||
|
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
|
||||||
|
assert_eq!(res.dims(), [1, 2, 5]);
|
||||||
|
/* Note that the default for padding is different from PyTorch at the moment: instead of
|
||||||
|
padding with zeros, the edge value from the input tensor is used, i.e. this is similiar to:
|
||||||
|
t = torch.nn.functional.pad(t, (1, 1), mode='replicate')
|
||||||
|
res = torch.nn.functional.conv1d(t, w, padding=0)
|
||||||
|
*/
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
|
[2.5209, 2.6357, -1.3336, 4.1393, 0.4951, 3.6855, -1.1784, 3.5675, 0.5069, 4.9562]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn conv1d_small() -> Result<()> {
|
||||||
|
let dev = &Device::Cpu;
|
||||||
|
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
|
||||||
|
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
|
||||||
|
let res = t.conv1d(&w, 0, 1)?;
|
||||||
|
assert_eq!(res.dims(), [1, 1, 2]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
|
[0.4056, -0.8689]
|
||||||
|
);
|
||||||
|
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
|
||||||
|
assert_eq!(res.dims(), [1, 1, 4]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
|
[0.4056, 0.4056, -0.8689, -0.0773],
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user