mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add some conv1d test + bugfix using padding. (#349)
This commit is contained in:
@ -8,25 +8,11 @@ use anyhow::Result;
|
|||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
|
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||||
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
|
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||||
let c = a.matmul(&b)?;
|
let start = std::time::Instant::now();
|
||||||
println!("{a} {b} {c}");
|
let res = inp.conv2d(&w, 0, 1);
|
||||||
|
println!("{:?}", start.elapsed());
|
||||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]];
|
println!("{res:?}");
|
||||||
let t1 = Tensor::new(data, &Device::Cpu)?;
|
|
||||||
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]];
|
|
||||||
let t2 = Tensor::new(data2, &Device::Cpu)?;
|
|
||||||
assert_eq!(
|
|
||||||
Tensor::cat(&[&t1.t()?, &t2.t()?], 1)?
|
|
||||||
.t()?
|
|
||||||
.to_vec2::<f32>()?,
|
|
||||||
[
|
|
||||||
[3.0, 1.0, 4.0, 1.0, 5.0],
|
|
||||||
[2.0, 7.0, 1.0, 8.0, 2.0],
|
|
||||||
[5.0, 5.0, 5.0, 5.0, 5.0],
|
|
||||||
[2.0, 7.0, 1.0, 8.0, 2.0]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1011,7 +1011,7 @@ impl<'a> Map2 for Conv1D<'a> {
|
|||||||
let dst_idx = dst_idx + dst_l;
|
let dst_idx = dst_idx + dst_l;
|
||||||
let mut d = T::zero();
|
let mut d = T::zero();
|
||||||
for offset in 0..p.k_size {
|
for offset in 0..p.k_size {
|
||||||
let src_l_plus = p.stride * dst_l + offset;
|
let src_l_plus = p.stride * dst_l + offset + k_over_2 - p.padding;
|
||||||
// inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
|
// inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
|
||||||
if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in {
|
if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in {
|
||||||
let src_l = src_l_plus - k_over_2;
|
let src_l = src_l_plus - k_over_2;
|
||||||
|
@ -6,6 +6,46 @@ use candle_core::{Device, Tensor};
|
|||||||
import torch
|
import torch
|
||||||
torch.manual_seed(4242)
|
torch.manual_seed(4242)
|
||||||
|
|
||||||
|
t = torch.randn((1, 4, 5))
|
||||||
|
w = torch.randn((2, 4, 3))
|
||||||
|
print(t.flatten())
|
||||||
|
print(w.flatten())
|
||||||
|
res = torch.nn.functional.conv1d(t, w)
|
||||||
|
print(res.flatten())
|
||||||
|
*/
|
||||||
|
#[test]
|
||||||
|
fn conv1d() -> Result<()> {
|
||||||
|
let dev = &Device::Cpu;
|
||||||
|
let t = Tensor::new(
|
||||||
|
&[
|
||||||
|
0.4056f32, -0.8689, -0.0773, -1.5630, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145,
|
||||||
|
1.8025, -0.1536, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278, -1.0124, 0.5599,
|
||||||
|
],
|
||||||
|
dev,
|
||||||
|
)?
|
||||||
|
.reshape((1, 4, 5))?;
|
||||||
|
let w = Tensor::new(
|
||||||
|
&[
|
||||||
|
-0.8404f32, -0.3490, 0.0130, 1.3123, 0.1763, -1.9249, 1.4270, 0.9421, 0.8670, -0.7181,
|
||||||
|
-1.1111, 0.8869, -1.2429, 1.8357, 1.6052, -1.3844, 0.3951, -1.2036, 0.6686, 1.6261,
|
||||||
|
-0.6451, -0.0840, -1.4247, 0.5512,
|
||||||
|
],
|
||||||
|
dev,
|
||||||
|
)?
|
||||||
|
.reshape((2, 4, 3))?;
|
||||||
|
let res = t.conv1d(&w, 0, 1)?;
|
||||||
|
assert_eq!(res.dims(), [1, 2, 3]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
|
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/* This test is based on the following script.
|
||||||
|
import torch
|
||||||
|
torch.manual_seed(4242)
|
||||||
|
|
||||||
t = torch.randn((1, 4, 5, 5))
|
t = torch.randn((1, 4, 5, 5))
|
||||||
w = torch.randn((2, 4, 3, 3))
|
w = torch.randn((2, 4, 3, 3))
|
||||||
print(t.flatten())
|
print(t.flatten())
|
||||||
|
Reference in New Issue
Block a user