mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Test for the transposed conv1d. (#1254)
This commit is contained in:
@ -196,8 +196,8 @@ impl Tensor {
|
|||||||
stride: usize,
|
stride: usize,
|
||||||
dilation: usize,
|
dilation: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
|
||||||
let (b_size, c_in, l_in) = self.dims3()?;
|
let (b_size, c_in, l_in) = self.dims3()?;
|
||||||
|
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||||
if c_in != c_in_k {
|
if c_in != c_in_k {
|
||||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,11 @@ res = torch.nn.functional.conv1d(t, w)
|
|||||||
print(res.flatten())
|
print(res.flatten())
|
||||||
res = torch.nn.functional.conv1d(t, w, padding=1)
|
res = torch.nn.functional.conv1d(t, w, padding=1)
|
||||||
print(res.flatten())
|
print(res.flatten())
|
||||||
|
|
||||||
|
w_t = w.transpose(0, 1)
|
||||||
|
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
||||||
|
print(res.shape)
|
||||||
|
print(res)
|
||||||
*/
|
*/
|
||||||
fn conv1d(dev: &Device) -> Result<()> {
|
fn conv1d(dev: &Device) -> Result<()> {
|
||||||
let t = Tensor::new(
|
let t = Tensor::new(
|
||||||
@ -45,6 +50,17 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||||
);
|
);
|
||||||
|
if dev.is_cpu() {
|
||||||
|
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||||
|
assert_eq!(res.dims(), [1, 2, 7]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
|
[
|
||||||
|
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||||
|
4.7076, -5.9745, -0.8276, 1.621
|
||||||
|
],
|
||||||
|
);
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user