mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Add conv-transpose. (#635)
* Add conv-transpose. * Return zeros for now. * Naive CPU implementation. * Add a conv-transpose test + fix the cpu implementation. * Add a second test.
This commit is contained in:
@ -130,6 +130,16 @@ print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
|
||||
w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose2d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res.flatten())
|
||||
|
||||
t_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose2d(t_t, w)
|
||||
print(res.shape)
|
||||
print(res.flatten())
|
||||
*/
|
||||
fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -160,6 +170,26 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
|
||||
]
|
||||
);
|
||||
// TODO: enable the test for cuda once we have the proper implementation in place.
|
||||
if dev.is_cpu() {
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
|
||||
);
|
||||
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1)?;
|
||||
assert_eq!(res.dims(), [2, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
-0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728,
|
||||
0.528, -1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838,
|
||||
0.5802, -0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396,
|
||||
-0.8156, 0.4594, 2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267
|
||||
]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user