mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Bugfix for conv-transpose1d (#1734)
* Add a currently broken test. * Bugfix + fix test.
This commit is contained in:
@ -1263,6 +1263,7 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
|
|||||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||||
let p = self.0;
|
let p = self.0;
|
||||||
let inp = &inp[inp_l.start_offset()..];
|
let inp = &inp[inp_l.start_offset()..];
|
||||||
|
let k = &k[k_l.start_offset()..];
|
||||||
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||||
let l_out = p.l_out();
|
let l_out = p.l_out();
|
||||||
|
@ -18,6 +18,9 @@ w_t = w.transpose(0, 1)
|
|||||||
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
||||||
print(res.shape)
|
print(res.shape)
|
||||||
print(res)
|
print(res)
|
||||||
|
res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2)
|
||||||
|
print(res.shape)
|
||||||
|
print(res)
|
||||||
*/
|
*/
|
||||||
fn conv1d(dev: &Device) -> Result<()> {
|
fn conv1d(dev: &Device) -> Result<()> {
|
||||||
let t = Tensor::new(
|
let t = Tensor::new(
|
||||||
@ -59,6 +62,17 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||||||
4.7076, -5.9745, -0.8276, 1.621
|
4.7076, -5.9745, -0.8276, 1.621
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 2)?;
|
||||||
|
assert_eq!(res.dims(), [1, 4, 7]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
|
||||||
|
[
|
||||||
|
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
|
||||||
|
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
|
||||||
|
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
|
||||||
|
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
|
||||||
|
]
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user