mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Bugfix for conv-transpose1d (#1734)
* Add a currently broken test. * Bugfix + fix test.
This commit is contained in:
@ -1263,6 +1263,7 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let k = &k[k_l.start_offset()..];
|
||||
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||
let l_out = p.l_out();
|
||||
|
@ -18,6 +18,9 @@ w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
*/
|
||||
fn conv1d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -59,6 +62,17 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 4, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
|
||||
[
|
||||
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
|
||||
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
|
||||
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
|
||||
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user