mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
ConvTranspose1d cuda support. (#1697)
* ConvTranspose1d cuda support. * Add the conv-transpose1d kernel. * Remove some unused variables.
This commit is contained in:
@ -50,17 +50,15 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
if dev.is_cpu() {
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
}
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user