mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Support dilation in conv-transpose2d. (#671)
This commit is contained in:
@ -1186,12 +1186,6 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
|||||||
const OP: &'static str = "conv_transpose2d";
|
const OP: &'static str = "conv_transpose2d";
|
||||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||||
let p = self.0;
|
let p = self.0;
|
||||||
if p.dilation != 1 {
|
|
||||||
crate::bail!(
|
|
||||||
"dilation {} is not supported for conv-transpose2d",
|
|
||||||
p.dilation
|
|
||||||
)
|
|
||||||
}
|
|
||||||
let inp = &inp[inp_l.start_offset()..];
|
let inp = &inp[inp_l.start_offset()..];
|
||||||
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
|
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
|
||||||
let k = &k[k_l.start_offset()..];
|
let k = &k[k_l.start_offset()..];
|
||||||
@ -1235,8 +1229,8 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
|||||||
for b_idx in 0..p.b_size {
|
for b_idx in 0..p.b_size {
|
||||||
for inp_y in 0..p.i_h {
|
for inp_y in 0..p.i_h {
|
||||||
for inp_x in 0..p.i_w {
|
for inp_x in 0..p.i_w {
|
||||||
let out_x = inp_x * p.stride + k_x;
|
let out_x = inp_x * p.stride + k_x * p.dilation;
|
||||||
let out_y = inp_y * p.stride + k_y;
|
let out_y = inp_y * p.stride + k_y * p.dilation;
|
||||||
if out_x < p.padding || out_y < p.padding {
|
if out_x < p.padding || out_y < p.padding {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -1046,12 +1046,6 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
|||||||
// Kernel shape: (c_in_k, c_out, h_k, w_k)
|
// Kernel shape: (c_in_k, c_out, h_k, w_k)
|
||||||
// Input shape: (b_size, c_in, h_in, w_in)
|
// Input shape: (b_size, c_in, h_in, w_in)
|
||||||
let p = &self.0;
|
let p = &self.0;
|
||||||
if p.dilation != 1 {
|
|
||||||
crate::bail!(
|
|
||||||
"dilation {} is not supported for conv-transpose2d",
|
|
||||||
p.dilation
|
|
||||||
)
|
|
||||||
}
|
|
||||||
let (out_w, out_h) = (p.out_w(), p.out_h());
|
let (out_w, out_h) = (p.out_w(), p.out_h());
|
||||||
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
|
@ -85,6 +85,10 @@ print(res)
|
|||||||
res = torch.nn.functional.conv2d(t, w, dilation=2)
|
res = torch.nn.functional.conv2d(t, w, dilation=2)
|
||||||
print(res.shape)
|
print(res.shape)
|
||||||
print(res[0])
|
print(res[0])
|
||||||
|
|
||||||
|
res = torch.nn.functional.conv_transpose2d(t, w_t, dilation=2)
|
||||||
|
print(res.shape)
|
||||||
|
print(res)
|
||||||
*/
|
*/
|
||||||
fn conv2d(dev: &Device) -> Result<()> {
|
fn conv2d(dev: &Device) -> Result<()> {
|
||||||
let t = Tensor::new(
|
let t = Tensor::new(
|
||||||
@ -158,6 +162,37 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[2.45, -2.3504],
|
[2.45, -2.3504],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Transpose and dilations.
|
||||||
|
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
|
||||||
|
assert_eq!(res.dims(), [1, 2, 9, 9]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
|
||||||
|
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
|
||||||
|
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
|
||||||
|
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
|
||||||
|
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
|
||||||
|
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
|
||||||
|
[-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
|
||||||
|
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
|
||||||
|
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
|
||||||
|
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
|
||||||
|
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
|
||||||
|
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
|
||||||
|
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
|
||||||
|
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
|
||||||
|
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
|
||||||
|
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
|
||||||
|
[-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,15 +155,15 @@ __device__ void conv_transpose2d(
|
|||||||
const size_t src_idx0 = b_idx * src_s[0];
|
const size_t src_idx0 = b_idx * src_s[0];
|
||||||
A d = 0;
|
A d = 0;
|
||||||
for (int k_x = 0; k_x < (int)w_k; ++k_x) {
|
for (int k_x = 0; k_x < (int)w_k; ++k_x) {
|
||||||
// let out_x = inp_x * p.stride + k_x - p.padding;
|
// let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
|
||||||
int inp_x_stride = (int)(out_x + padding) - k_x;
|
int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
|
||||||
if (inp_x_stride < 0 || inp_x_stride % stride) {
|
if (inp_x_stride < 0 || inp_x_stride % stride) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
int inp_x = inp_x_stride / stride;
|
int inp_x = inp_x_stride / stride;
|
||||||
if (inp_x >= w_in) continue;
|
if (inp_x >= w_in) continue;
|
||||||
for (int k_y = 0; k_y < (int)h_k; ++k_y) {
|
for (int k_y = 0; k_y < (int)h_k; ++k_y) {
|
||||||
int inp_y_stride = (int)(out_y + padding) - k_y;
|
int inp_y_stride = (int)(out_y + padding) - k_y * dilation;
|
||||||
if (inp_y_stride < 0 || inp_y_stride % stride) {
|
if (inp_y_stride < 0 || inp_y_stride % stride) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user