Dilated convolutions (#657)

* Add the dilation parameter.

* Restore the basic optimizer example.

* Dilation support in cudnn.

* Use the dilation parameter in the cpu backend.

* More dilation support.

* No support for dilation in transposed convolutions.

* Add dilation to a test.

* Remove a print.

* Helper function.
This commit is contained in:
Laurent Mazare
2023-08-29 16:12:11 +01:00
committed by GitHub
parent ee8bb1bde1
commit a044907ffc
20 changed files with 231 additions and 45 deletions

View File

@ -11,7 +11,7 @@ fn main() -> Result<()> {
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
let start = std::time::Instant::now();
let res = inp.conv2d(&w, 0, 1, 1)?;
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
println!("{:?}", start.elapsed());
println!("{res:?}");
Ok(())

View File

@ -40,7 +40,7 @@ impl Benchmark for Conv1d {
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.conv1d(&d.1, 0, 1, 1)
d.0.conv1d(&d.1, 0, 1, 1, 1)
}
const ITERS: usize = 5;
@ -59,7 +59,7 @@ impl Benchmark for Conv2d {
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.conv2d(&d.1, 0, 1, 1)
d.0.conv2d(&d.1, 0, 1, 1, 1)
}
const ITERS: usize = 1;

View File

@ -11,11 +11,11 @@ fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
let out_t = in_t.conv2d(&k_t, 0, 1, 1)?;
let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
println!("{out_t}");
let in_t = in_t.to_device(&Device::Cpu)?;
let k_t = k_t.to_device(&Device::Cpu)?;
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1)?;
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
.sqr()?
.sum_all()?;
@ -23,7 +23,7 @@ fn main() -> Result<()> {
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
let res = t.conv2d(&w, 1, 1, 1)?;
let res = t.conv2d(&w, 1, 1, 1, 1)?;
println!("{res:?}");
Ok(())
}

View File

@ -197,21 +197,28 @@ impl Tensor {
kernel,
padding,
stride,
dilation,
} => {
// The output height for conv_transpose2d is:
// (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
let grad_h = grad.dim(2)?;
let k_h = kernel.dim(2)?;
let out_size = (grad_h - 1) * stride + (k_h - 1) + 1 - 2 * padding;
let out_size =
(grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
let grad_arg =
grad.conv_transpose2d(kernel, *padding, out_padding, *stride)?;
let grad_arg = grad.conv_transpose2d(
kernel,
*padding,
out_padding,
*stride,
*dilation,
)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = arg
.transpose(0, 1)?
.conv2d(&grad.transpose(0, 1)?, *padding, *stride, 1)?
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
*sum_grad = sum_grad.add(&grad_kernel)?;

View File

@ -11,12 +11,12 @@ pub struct ParamsConv1D {
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConv1D {
pub(crate) fn l_out(&self) -> usize {
let dilation = 1;
(self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
(self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
@ -36,17 +36,16 @@ pub struct ParamsConv2D {
pub(crate) c_in: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConv2D {
pub(crate) fn out_h(&self) -> usize {
let dilation = 1;
(self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
(self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
}
pub(crate) fn out_w(&self) -> usize {
let dilation = 1;
(self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
(self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
@ -66,18 +65,17 @@ pub struct ParamsConvTranspose2D {
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConvTranspose2D {
pub(crate) fn out_h(&self) -> usize {
let dilation = 1;
(self.i_h - 1) * self.stride + dilation * (self.k_h - 1) + self.output_padding + 1
(self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
- 2 * self.padding
}
pub(crate) fn out_w(&self) -> usize {
let dilation = 1;
(self.i_w - 1) * self.stride + dilation * (self.k_w - 1) + self.output_padding + 1
(self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
- 2 * self.padding
}
@ -96,6 +94,7 @@ impl Tensor {
kernel,
padding: params.padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
@ -107,6 +106,7 @@ impl Tensor {
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
@ -130,6 +130,7 @@ impl Tensor {
k_size,
padding,
stride,
dilation,
};
if groups == 1 {
self.conv1d_single_group(kernel, &params)
@ -154,6 +155,7 @@ impl Tensor {
kernel,
padding: params.padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
@ -165,6 +167,7 @@ impl Tensor {
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
@ -184,6 +187,7 @@ impl Tensor {
c_in: c_in / groups,
padding,
stride,
dilation,
};
if groups == 1 {
self.conv2d_single_group(kernel, &params)
@ -206,6 +210,7 @@ impl Tensor {
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
@ -223,6 +228,7 @@ impl Tensor {
padding,
output_padding,
stride,
dilation,
};
let storage = self.storage().conv_transpose2d(
self.layout(),
@ -236,6 +242,7 @@ impl Tensor {
padding: params.padding,
output_padding: params.output_padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))

View File

@ -1064,7 +1064,7 @@ impl<'a> Map2 for Conv1D<'a> {
let dst_idx = dst_idx + b_idx * p.c_out * l_out;
for dst_l in 0..l_out {
let dst_idx = dst_idx + dst_l;
let src_l = p.stride * dst_l + offset;
let src_l = (p.stride * dst_l + offset) * p.dilation;
if src_l < p.padding || src_l >= p.padding + p.l_in {
continue;
}
@ -1141,14 +1141,14 @@ impl<'a> Map2 for Conv2D<'a> {
let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
for dst_h in 0..out_h {
let dst_idx = dst_idx + dst_h * out_w;
let src_h = p.stride * dst_h + offset_h;
let src_h = (p.stride * dst_h + offset_h) * p.dilation;
if src_h < p.padding || src_h >= p.i_h + p.padding {
continue;
}
let src_h = src_h - p.padding;
for dst_w in 0..out_w {
let dst_idx = dst_idx + dst_w;
let src_w = p.stride * dst_w + offset_w;
let src_w = (p.stride * dst_w + offset_w) * p.dilation;
if src_w < p.padding || src_w >= p.i_w + p.padding {
continue;
}
@ -1186,6 +1186,12 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
const OP: &'static str = "conv_transpose2d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
if p.dilation != 1 {
crate::bail!(
"dilation {} is not supported for conv-transpose2d",
p.dilation
)
}
let inp = &inp[inp_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
let k = &k[k_l.start_offset()..];

View File

@ -960,7 +960,9 @@ impl<'a> Map2 for Conv1D<'a> {
crate::bail!("unexpected input shape for conv1d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
let params = (el, l_out, p.stride, p.padding, &ds, inp, k, &out);
let params = (
el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
@ -998,7 +1000,9 @@ impl<'a> Map2 for Conv2D<'a> {
crate::bail!("unexpected input shape for conv2d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
let params = (el, out_w, out_h, p.stride, p.padding, &ds, inp, k, &out);
let params = (
el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
@ -1018,6 +1022,12 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
// Kernel shape: (c_in_k, c_out, h_k, w_k)
// Input shape: (b_size, c_in, h_in, w_in)
let p = &self.0;
if p.dilation != 1 {
crate::bail!(
"dilation {} is not supported for conv-transpose2d",
p.dilation
)
}
let (out_w, out_h) = (p.out_w(), p.out_h());
let dst_el = p.c_out * out_w * out_h * p.b_size;
let inp = &inp.slice(inp_l.start_offset()..);
@ -1043,6 +1053,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
p.stride,
p.padding,
p.output_padding,
p.dilation,
&ds,
inp,
k,

View File

@ -48,7 +48,7 @@ pub(crate) fn launch_conv2d<
let conv = cudnn.create_conv2d::<T>(
/* pad */ [params.padding as i32, params.padding as i32],
/* stride */ [params.stride as i32, params.stride as i32],
/* dilation */ [1, 1],
/* dilation */ [params.dilation as i32, params.dilation as i32],
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
)?;
let x_shape = [

View File

@ -81,6 +81,26 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
}
}
impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
for &[[[[S; N4]; N3]; N2]; N1]
{
fn shape(&self) -> Result<Shape> {
Ok(Shape::from((N1, N2, N3, N4)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
for i1 in 0..N1 {
for i2 in 0..N2 {
for i3 in 0..N3 {
vec.extend(self[i1][i2][i3])
}
}
}
S::to_cpu_storage_owned(vec)
}
}
impl Device {
pub fn new_cuda(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))

View File

@ -81,6 +81,7 @@ pub enum Op {
kernel: Tensor,
padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
@ -89,6 +90,7 @@ pub enum Op {
kernel: Tensor,
padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
@ -98,6 +100,7 @@ pub enum Op {
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
},
AvgPool2D {

View File

@ -32,13 +32,13 @@ fn conv1d(dev: &Device) -> Result<()> {
dev,
)?
.reshape((2, 4, 3))?;
let res = t.conv1d(&w, 0, 1, 1)?;
let res = t.conv1d(&w, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
);
let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?;
let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 5]);
// Same as pytorch default padding: use zeros.
assert_eq!(
@ -51,13 +51,13 @@ fn conv1d(dev: &Device) -> Result<()> {
fn conv1d_small(dev: &Device) -> Result<()> {
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
let res = t.conv1d(&w, 0, 1, 1)?;
let res = t.conv1d(&w, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 2]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.4056, -0.8689]
);
let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?;
let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 4]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -81,6 +81,10 @@ w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose2d(t, w_t)
print(res.shape)
print(res)
res = torch.nn.functional.conv2d(t, w, dilation=2)
print(res.shape)
print(res[0])
*/
fn conv2d(dev: &Device) -> Result<()> {
let t = Tensor::new(
@ -113,7 +117,7 @@ fn conv2d(dev: &Device) -> Result<()> {
)?;
let t = t.reshape((1, 4, 5, 5))?;
let w = w.reshape((2, 4, 3, 3))?;
let res = t.conv2d(&w, 0, 1, 1)?;
let res = t.conv2d(&w, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -122,7 +126,7 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7, 7]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
@ -147,6 +151,13 @@ fn conv2d(dev: &Device) -> Result<()> {
]
]
);
// Dilations.
let res = t.conv2d(&w, 0, 1, 2, 1)?;
assert_eq!(res.dims(), [1, 2, 1, 1]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.45, -2.3504],
);
Ok(())
}
@ -182,13 +193,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
let t = t.reshape((1, 2, 3, 3))?;
let w = w.reshape((1, 2, 1, 1))?;
let res = t.conv2d(&w, 0, 1, 1)?;
let res = t.conv2d(&w, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539]
);
let res = t.conv2d(&w, 2, 1, 1)?;
let res = t.conv2d(&w, 2, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 7, 7]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -200,13 +211,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
);
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1)?;
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [2, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -230,7 +241,7 @@ fn conv2d_smaller(dev: &Device) -> Result<()> {
let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
let t = t.reshape((1, 1, 3, 3))?;
let w = w.reshape((1, 1, 3, 3))?;
let res = t.conv2d(&w, 0, 1, 1)?;
let res = t.conv2d(&w, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 1, 1]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -261,7 +272,7 @@ fn conv2d_non_square(dev: &Device) -> Result<()> {
let w = Tensor::new(&[-1.1351f32, 1.3841], dev)?;
let t = t.reshape((1, 2, 4, 2))?;
let w = w.reshape((1, 2, 1, 1))?;
let res = t.conv2d(&w, 0, 1, 1)?;
let res = t.conv2d(&w, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 4, 2]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -270,6 +281,36 @@ fn conv2d_non_square(dev: &Device) -> Result<()> {
Ok(())
}
/*
import torch
torch.manual_seed(4242)
t = torch.randn((1, 4, 5, 5), requires_grad=True)
w = torch.randn((2, 4, 3, 3), requires_grad=True)
print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv2d(t, w)
print(res.flatten())
loss = (res ** 2).sum()
print(loss)
loss.backward()
print(t.grad.shape)
print(t.grad.flatten())
print(w.grad.shape)
print(w.grad.flatten())
t.grad.zero_()
w.grad.zero_()
res = torch.nn.functional.conv2d(t, w, stride=2)
print(res.flatten())
loss = (res ** 2).sum()
print(loss)
loss.backward()
print(t.grad.shape)
print(t.grad[0])
print(w.grad.shape)
print(w.grad[0])
*/
fn conv2d_grad(dev: &Device) -> Result<()> {
use candle_core::Var;
let t = Var::from_slice(
@ -302,7 +343,7 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
(2, 4, 3, 3),
dev,
)?;
let res = t.conv2d(&w, 0, 1, 1)?;
let res = t.conv2d(&w, 0, 1, 1, 1)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32);
let grads = loss.backward()?;
@ -335,6 +376,74 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
-34.24, -16.36, -12.31, 1.92, -33.62, -14.1, -49.23, -7.39, 11.5, -9.98, 9.66, 29.6
]
);
// Same as before but with stride.
let res = t.conv2d(&w, 0, 2, 1, 1)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 277.16f32);
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
assert_eq!(
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
[
[
[9.29, -7.03, 0.94, 3.49, -7.71],
[-1.8, -7.82, 8.9, 8.46, 7.43],
[-25.84, 22.09, -19.27, -0.22, 1.69],
[4.02, 18.53, -18.37, 2.3, -24.51],
[7.72, -9.68, -12.34, 5.6, -20.22]
],
[
[21.73, 3.39, -18.27, 3.86, -3.65],
[8.25, 3.73, 30.73, -8.61, -11.93],
[-72.15, -15.36, -17.53, -12.32, -1.61],
[-22.32, -7.79, -91.82, 6.44, -37.69],
[52.88, 14.44, 42.75, 9.88, 2.01]
],
[
[-8.98, 9.91, 6.75, -4.68, 15.38],
[4.93, -0.33, 9.94, -1.46, 14.78],
[13.62, -30.63, 3.96, -3.58, -4.48],
[-14.13, 1.19, -34.43, 3.08, -33.83],
[17.28, 12.94, 31.83, -3.35, 6.81]
],
[
[23.54, 6.98, -24.52, 0.52, 4.87],
[9.65, 6.18, 1.71, -25.23, -4.93],
[-54.99, -23.66, 3.19, -3.73, 18.58],
[-21.35, -10.39, -39.88, 28.73, -30.76],
[-9.13, 11.12, -14.0, -8.23, -11.25]
]
]
);
assert_eq!(
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
[
[
[28.34, -45.75, 7.32],
[0.72, -35.28, 19.23],
[-28.29, 20.89, -5.18]
],
[
[-16.04, -16.38, 32.12],
[57.5, 25.81, 11.96],
[-18.66, 8.48, -9.92]
],
[
[2.93, 1.57, -23.76],
[12.74, -26.2, -17.88],
[-14.98, -9.35, 12.2]
],
[
[-0.18, -6.82, 20.79],
[-2.54, 27.11, -10.11],
[-0.41, -3.18, -0.07]
]
]
);
Ok(())
}

View File

@ -278,6 +278,7 @@ impl EncodecConv1d {
padding: 0,
stride,
groups: 1,
dilation: 1,
},
vb.pp("conv"),
)?,
@ -289,6 +290,7 @@ impl EncodecConv1d {
padding: 0,
stride,
groups: 1,
dilation: 1,
},
vb.pp("conv"),
)?,

View File

@ -66,6 +66,7 @@ impl ResnetBlock2D {
stride: 1,
padding: 1,
groups: 1,
dilation: 1,
};
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
@ -80,6 +81,7 @@ impl ResnetBlock2D {
stride: 1,
padding: 0,
groups: 1,
dilation: 1,
};
Some(conv2d(
in_channels,

View File

@ -281,11 +281,13 @@ impl AudioEncoder {
padding: 1,
stride: 1,
groups: 1,
dilation: 1,
};
let cfg2 = Conv1dConfig {
padding: 1,
stride: 2,
groups: 1,
dilation: 1,
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;

View File

@ -132,6 +132,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
stride,
padding,
groups: 1,
dilation: 1,
};
let conv = if bias {
conv2d(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?

View File

@ -93,6 +93,7 @@ impl ConvBlock {
padding,
stride,
groups: 1,
dilation: 1,
};
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;

View File

@ -8,6 +8,7 @@ __device__ void conv1d(
const size_t l_out,
const size_t stride,
const size_t padding,
const size_t dilation,
const size_t *info,
const T *src,
const T *kernel,
@ -36,7 +37,7 @@ __device__ void conv1d(
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
for (size_t offset = 0; offset < k_size; ++offset) {
size_t src_l = stride * dst_l + offset;
size_t src_l = (stride * dst_l + offset) * dilation;
if (src_l < padding || src_l >= padding + l_in) {
continue;
}
@ -58,6 +59,7 @@ __device__ void conv2d(
const size_t h_out,
const size_t stride,
const size_t padding,
const size_t dilation,
const size_t *info,
const T *src,
const T *kernel,
@ -90,13 +92,13 @@ __device__ void conv2d(
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
size_t src_w = stride * dst_w + w_offset;
size_t src_w = (stride * dst_w + w_offset) * dilation;
if (src_w < padding || src_w >= w_in + padding) {
continue;
}
src_w -= padding;
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
size_t src_h = stride * dst_h + h_offset;
size_t src_h = (stride * dst_h + h_offset) * dilation;
if (src_h < padding || src_h >= h_in + padding) {
continue;
}
@ -120,6 +122,7 @@ __device__ void conv_transpose2d(
const size_t stride,
const size_t padding,
const size_t out_padding,
const size_t dilation,
const size_t *info,
const T *src,
const T *kernel,
@ -335,12 +338,13 @@ extern "C" __global__ void FN_NAME( \
const size_t num_dims, \
const size_t stride, \
const size_t padding, \
const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
const TYPENAME *kernel, \
TYPENAME *dst \
) { \
conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, padding, info, src, kernel, dst); \
conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, padding, dilation, info, src, kernel, dst); \
} \
#define CONV2D_OP(TYPENAME, TYPEACC, FN_NAME) \
@ -350,12 +354,13 @@ extern "C" __global__ void FN_NAME( \
const size_t h_out, \
const size_t stride, \
const size_t padding, \
const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
const TYPENAME *kernel, \
TYPENAME *dst \
) { \
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
} \
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
@ -366,12 +371,13 @@ extern "C" __global__ void FN_NAME( \
const size_t stride, \
const size_t padding, \
const size_t out_padding, \
const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
const TYPENAME *kernel, \
TYPENAME *dst \
) { \
conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, info, src, kernel, dst); \
conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
} \
#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \

View File

@ -5,6 +5,7 @@ use candle::{Result, Tensor};
pub struct Conv1dConfig {
pub padding: usize,
pub stride: usize,
pub dilation: usize,
pub groups: usize,
}
@ -13,6 +14,7 @@ impl Default for Conv1dConfig {
Self {
padding: 0,
stride: 1,
dilation: 1,
groups: 1,
}
}
@ -45,6 +47,7 @@ impl crate::Module for Conv1d {
&self.weight,
self.config.padding,
self.config.stride,
self.config.dilation,
self.config.groups,
)?;
match &self.bias {
@ -62,6 +65,7 @@ impl crate::Module for Conv1d {
pub struct Conv2dConfig {
pub padding: usize,
pub stride: usize,
pub dilation: usize,
pub groups: usize,
}
@ -70,6 +74,7 @@ impl Default for Conv2dConfig {
Self {
padding: 0,
stride: 1,
dilation: 1,
groups: 1,
}
}
@ -103,6 +108,7 @@ impl crate::Module for Conv2d {
&self.weight,
self.config.padding,
self.config.stride,
self.config.dilation,
self.config.groups,
)?;
match &self.bias {

View File

@ -269,11 +269,13 @@ impl AudioEncoder {
padding: 1,
stride: 1,
groups: 1,
dilation: 1,
};
let cfg2 = Conv1dConfig {
padding: 1,
stride: 2,
groups: 1,
dilation: 1,
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;

View File

@ -97,6 +97,7 @@ impl ConvBlock {
padding,
stride,
groups: 1,
dilation: 1,
};
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;