mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Dilated convolutions (#657)
* Add the dilation parameter. * Restore the basic optimizer example. * Dilation support in cudnn. * Use the dilation parameter in the cpu backend. * More dilation support. * No support for dilation in transposed convolutions. * Add dilation to a test. * Remove a print. * Helper function.
This commit is contained in:
@ -11,7 +11,7 @@ fn main() -> Result<()> {
|
||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
||||
let start = std::time::Instant::now();
|
||||
let res = inp.conv2d(&w, 0, 1, 1)?;
|
||||
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
|
||||
println!("{:?}", start.elapsed());
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
|
@ -40,7 +40,7 @@ impl Benchmark for Conv1d {
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.conv1d(&d.1, 0, 1, 1)
|
||||
d.0.conv1d(&d.1, 0, 1, 1, 1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 5;
|
||||
@ -59,7 +59,7 @@ impl Benchmark for Conv2d {
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
d.0.conv2d(&d.1, 0, 1, 1)
|
||||
d.0.conv2d(&d.1, 0, 1, 1, 1)
|
||||
}
|
||||
|
||||
const ITERS: usize = 1;
|
||||
|
@ -11,11 +11,11 @@ fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
|
||||
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
|
||||
let out_t = in_t.conv2d(&k_t, 0, 1, 1)?;
|
||||
let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
println!("{out_t}");
|
||||
let in_t = in_t.to_device(&Device::Cpu)?;
|
||||
let k_t = k_t.to_device(&Device::Cpu)?;
|
||||
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1)?;
|
||||
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
|
||||
.sqr()?
|
||||
.sum_all()?;
|
||||
@ -23,7 +23,7 @@ fn main() -> Result<()> {
|
||||
|
||||
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
|
||||
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
|
||||
let res = t.conv2d(&w, 1, 1, 1)?;
|
||||
let res = t.conv2d(&w, 1, 1, 1, 1)?;
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -197,21 +197,28 @@ impl Tensor {
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
} => {
|
||||
// The output height for conv_transpose2d is:
|
||||
// (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
|
||||
let grad_h = grad.dim(2)?;
|
||||
let k_h = kernel.dim(2)?;
|
||||
let out_size = (grad_h - 1) * stride + (k_h - 1) + 1 - 2 * padding;
|
||||
let out_size =
|
||||
(grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
|
||||
let out_padding = arg.dim(2)? - out_size;
|
||||
let grad_arg =
|
||||
grad.conv_transpose2d(kernel, *padding, out_padding, *stride)?;
|
||||
let grad_arg = grad.conv_transpose2d(
|
||||
kernel,
|
||||
*padding,
|
||||
out_padding,
|
||||
*stride,
|
||||
*dilation,
|
||||
)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
|
||||
let grad_kernel = arg
|
||||
.transpose(0, 1)?
|
||||
.conv2d(&grad.transpose(0, 1)?, *padding, *stride, 1)?
|
||||
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
|
@ -11,12 +11,12 @@ pub struct ParamsConv1D {
|
||||
pub(crate) k_size: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
}
|
||||
|
||||
impl ParamsConv1D {
|
||||
pub(crate) fn l_out(&self) -> usize {
|
||||
let dilation = 1;
|
||||
(self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
|
||||
(self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
@ -36,17 +36,16 @@ pub struct ParamsConv2D {
|
||||
pub(crate) c_in: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
}
|
||||
|
||||
impl ParamsConv2D {
|
||||
pub(crate) fn out_h(&self) -> usize {
|
||||
let dilation = 1;
|
||||
(self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
|
||||
(self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_w(&self) -> usize {
|
||||
let dilation = 1;
|
||||
(self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
|
||||
(self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
@ -66,18 +65,17 @@ pub struct ParamsConvTranspose2D {
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) output_padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
}
|
||||
|
||||
impl ParamsConvTranspose2D {
|
||||
pub(crate) fn out_h(&self) -> usize {
|
||||
let dilation = 1;
|
||||
(self.i_h - 1) * self.stride + dilation * (self.k_h - 1) + self.output_padding + 1
|
||||
(self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
|
||||
- 2 * self.padding
|
||||
}
|
||||
|
||||
pub(crate) fn out_w(&self) -> usize {
|
||||
let dilation = 1;
|
||||
(self.i_w - 1) * self.stride + dilation * (self.k_w - 1) + self.output_padding + 1
|
||||
(self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
|
||||
- 2 * self.padding
|
||||
}
|
||||
|
||||
@ -96,6 +94,7 @@ impl Tensor {
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
@ -107,6 +106,7 @@ impl Tensor {
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||
@ -130,6 +130,7 @@ impl Tensor {
|
||||
k_size,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv1d_single_group(kernel, ¶ms)
|
||||
@ -154,6 +155,7 @@ impl Tensor {
|
||||
kernel,
|
||||
padding: params.padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
@ -165,6 +167,7 @@ impl Tensor {
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
@ -184,6 +187,7 @@ impl Tensor {
|
||||
c_in: c_in / groups,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv2d_single_group(kernel, ¶ms)
|
||||
@ -206,6 +210,7 @@ impl Tensor {
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
|
||||
@ -223,6 +228,7 @@ impl Tensor {
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
let storage = self.storage().conv_transpose2d(
|
||||
self.layout(),
|
||||
@ -236,6 +242,7 @@ impl Tensor {
|
||||
padding: params.padding,
|
||||
output_padding: params.output_padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
|
@ -1064,7 +1064,7 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
let dst_idx = dst_idx + b_idx * p.c_out * l_out;
|
||||
for dst_l in 0..l_out {
|
||||
let dst_idx = dst_idx + dst_l;
|
||||
let src_l = p.stride * dst_l + offset;
|
||||
let src_l = (p.stride * dst_l + offset) * p.dilation;
|
||||
if src_l < p.padding || src_l >= p.padding + p.l_in {
|
||||
continue;
|
||||
}
|
||||
@ -1141,14 +1141,14 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
|
||||
for dst_h in 0..out_h {
|
||||
let dst_idx = dst_idx + dst_h * out_w;
|
||||
let src_h = p.stride * dst_h + offset_h;
|
||||
let src_h = (p.stride * dst_h + offset_h) * p.dilation;
|
||||
if src_h < p.padding || src_h >= p.i_h + p.padding {
|
||||
continue;
|
||||
}
|
||||
let src_h = src_h - p.padding;
|
||||
for dst_w in 0..out_w {
|
||||
let dst_idx = dst_idx + dst_w;
|
||||
let src_w = p.stride * dst_w + offset_w;
|
||||
let src_w = (p.stride * dst_w + offset_w) * p.dilation;
|
||||
if src_w < p.padding || src_w >= p.i_w + p.padding {
|
||||
continue;
|
||||
}
|
||||
@ -1186,6 +1186,12 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
const OP: &'static str = "conv_transpose2d";
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
if p.dilation != 1 {
|
||||
crate::bail!(
|
||||
"dilation {} is not supported for conv-transpose2d",
|
||||
p.dilation
|
||||
)
|
||||
}
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
|
||||
let k = &k[k_l.start_offset()..];
|
||||
|
@ -960,7 +960,9 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
crate::bail!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (el, l_out, p.stride, p.padding, &ds, inp, k, &out);
|
||||
let params = (
|
||||
el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
@ -998,7 +1000,9 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
crate::bail!("unexpected input shape for conv2d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (el, out_w, out_h, p.stride, p.padding, &ds, inp, k, &out);
|
||||
let params = (
|
||||
el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
@ -1018,6 +1022,12 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
// Kernel shape: (c_in_k, c_out, h_k, w_k)
|
||||
// Input shape: (b_size, c_in, h_in, w_in)
|
||||
let p = &self.0;
|
||||
if p.dilation != 1 {
|
||||
crate::bail!(
|
||||
"dilation {} is not supported for conv-transpose2d",
|
||||
p.dilation
|
||||
)
|
||||
}
|
||||
let (out_w, out_h) = (p.out_w(), p.out_h());
|
||||
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
@ -1043,6 +1053,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
p.stride,
|
||||
p.padding,
|
||||
p.output_padding,
|
||||
p.dilation,
|
||||
&ds,
|
||||
inp,
|
||||
k,
|
||||
|
@ -48,7 +48,7 @@ pub(crate) fn launch_conv2d<
|
||||
let conv = cudnn.create_conv2d::<T>(
|
||||
/* pad */ [params.padding as i32, params.padding as i32],
|
||||
/* stride */ [params.stride as i32, params.stride as i32],
|
||||
/* dilation */ [1, 1],
|
||||
/* dilation */ [params.dilation as i32, params.dilation as i32],
|
||||
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
|
||||
)?;
|
||||
let x_shape = [
|
||||
|
@ -81,6 +81,26 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
|
||||
for &[[[[S; N4]; N3]; N2]; N1]
|
||||
{
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from((N1, N2, N3, N4)))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
|
||||
for i1 in 0..N1 {
|
||||
for i2 in 0..N2 {
|
||||
for i3 in 0..N3 {
|
||||
vec.extend(self[i1][i2][i3])
|
||||
}
|
||||
}
|
||||
}
|
||||
S::to_cpu_storage_owned(vec)
|
||||
}
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub fn new_cuda(ordinal: usize) -> Result<Self> {
|
||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||
|
@ -81,6 +81,7 @@ pub enum Op {
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
@ -89,6 +90,7 @@ pub enum Op {
|
||||
kernel: Tensor,
|
||||
padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
#[allow(dead_code)]
|
||||
@ -98,6 +100,7 @@ pub enum Op {
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
},
|
||||
|
||||
AvgPool2D {
|
||||
|
@ -32,13 +32,13 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
dev,
|
||||
)?
|
||||
.reshape((2, 4, 3))?;
|
||||
let res = t.conv1d(&w, 0, 1, 1)?;
|
||||
let res = t.conv1d(&w, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
|
||||
);
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?;
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 5]);
|
||||
// Same as pytorch default padding: use zeros.
|
||||
assert_eq!(
|
||||
@ -51,13 +51,13 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
fn conv1d_small(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
|
||||
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
|
||||
let res = t.conv1d(&w, 0, 1, 1)?;
|
||||
let res = t.conv1d(&w, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 2]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.4056, -0.8689]
|
||||
);
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?;
|
||||
let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 4]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
@ -81,6 +81,10 @@ w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose2d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
|
||||
res = torch.nn.functional.conv2d(t, w, dilation=2)
|
||||
print(res.shape)
|
||||
print(res[0])
|
||||
*/
|
||||
fn conv2d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -113,7 +117,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
)?;
|
||||
let t = t.reshape((1, 4, 5, 5))?;
|
||||
let w = w.reshape((2, 4, 3, 3))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1)?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
@ -122,7 +126,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
@ -147,6 +151,13 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
]
|
||||
]
|
||||
);
|
||||
// Dilations.
|
||||
let res = t.conv2d(&w, 0, 1, 2, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 1, 1]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.45, -2.3504],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -182,13 +193,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
|
||||
let t = t.reshape((1, 2, 3, 3))?;
|
||||
let w = w.reshape((1, 2, 1, 1))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1)?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539]
|
||||
);
|
||||
let res = t.conv2d(&w, 2, 1, 1)?;
|
||||
let res = t.conv2d(&w, 2, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
@ -200,13 +211,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
|
||||
]
|
||||
);
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
|
||||
);
|
||||
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1)?;
|
||||
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [2, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
@ -230,7 +241,7 @@ fn conv2d_smaller(dev: &Device) -> Result<()> {
|
||||
let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
|
||||
let t = t.reshape((1, 1, 3, 3))?;
|
||||
let w = w.reshape((1, 1, 3, 3))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1)?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 1, 1]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
@ -261,7 +272,7 @@ fn conv2d_non_square(dev: &Device) -> Result<()> {
|
||||
let w = Tensor::new(&[-1.1351f32, 1.3841], dev)?;
|
||||
let t = t.reshape((1, 2, 4, 2))?;
|
||||
let w = w.reshape((1, 2, 1, 1))?;
|
||||
let res = t.conv2d(&w, 0, 1, 1)?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 4, 2]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
@ -270,6 +281,36 @@ fn conv2d_non_square(dev: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/*
|
||||
import torch
|
||||
torch.manual_seed(4242)
|
||||
|
||||
t = torch.randn((1, 4, 5, 5), requires_grad=True)
|
||||
w = torch.randn((2, 4, 3, 3), requires_grad=True)
|
||||
print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
loss = (res ** 2).sum()
|
||||
print(loss)
|
||||
loss.backward()
|
||||
print(t.grad.shape)
|
||||
print(t.grad.flatten())
|
||||
print(w.grad.shape)
|
||||
print(w.grad.flatten())
|
||||
|
||||
t.grad.zero_()
|
||||
w.grad.zero_()
|
||||
res = torch.nn.functional.conv2d(t, w, stride=2)
|
||||
print(res.flatten())
|
||||
loss = (res ** 2).sum()
|
||||
print(loss)
|
||||
loss.backward()
|
||||
print(t.grad.shape)
|
||||
print(t.grad[0])
|
||||
print(w.grad.shape)
|
||||
print(w.grad[0])
|
||||
*/
|
||||
fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
use candle_core::Var;
|
||||
let t = Var::from_slice(
|
||||
@ -302,7 +343,7 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
(2, 4, 3, 3),
|
||||
dev,
|
||||
)?;
|
||||
let res = t.conv2d(&w, 0, 1, 1)?;
|
||||
let res = t.conv2d(&w, 0, 1, 1, 1)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32);
|
||||
let grads = loss.backward()?;
|
||||
@ -335,6 +376,74 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
-34.24, -16.36, -12.31, 1.92, -33.62, -14.1, -49.23, -7.39, 11.5, -9.98, 9.66, 29.6
|
||||
]
|
||||
);
|
||||
|
||||
// Same as before but with stride.
|
||||
let res = t.conv2d(&w, 0, 2, 1, 1)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 277.16f32);
|
||||
let grads = loss.backward()?;
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[9.29, -7.03, 0.94, 3.49, -7.71],
|
||||
[-1.8, -7.82, 8.9, 8.46, 7.43],
|
||||
[-25.84, 22.09, -19.27, -0.22, 1.69],
|
||||
[4.02, 18.53, -18.37, 2.3, -24.51],
|
||||
[7.72, -9.68, -12.34, 5.6, -20.22]
|
||||
],
|
||||
[
|
||||
[21.73, 3.39, -18.27, 3.86, -3.65],
|
||||
[8.25, 3.73, 30.73, -8.61, -11.93],
|
||||
[-72.15, -15.36, -17.53, -12.32, -1.61],
|
||||
[-22.32, -7.79, -91.82, 6.44, -37.69],
|
||||
[52.88, 14.44, 42.75, 9.88, 2.01]
|
||||
],
|
||||
[
|
||||
[-8.98, 9.91, 6.75, -4.68, 15.38],
|
||||
[4.93, -0.33, 9.94, -1.46, 14.78],
|
||||
[13.62, -30.63, 3.96, -3.58, -4.48],
|
||||
[-14.13, 1.19, -34.43, 3.08, -33.83],
|
||||
[17.28, 12.94, 31.83, -3.35, 6.81]
|
||||
],
|
||||
[
|
||||
[23.54, 6.98, -24.52, 0.52, 4.87],
|
||||
[9.65, 6.18, 1.71, -25.23, -4.93],
|
||||
[-54.99, -23.66, 3.19, -3.73, 18.58],
|
||||
[-21.35, -10.39, -39.88, 28.73, -30.76],
|
||||
[-9.13, 11.12, -14.0, -8.23, -11.25]
|
||||
]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[28.34, -45.75, 7.32],
|
||||
[0.72, -35.28, 19.23],
|
||||
[-28.29, 20.89, -5.18]
|
||||
],
|
||||
[
|
||||
[-16.04, -16.38, 32.12],
|
||||
[57.5, 25.81, 11.96],
|
||||
[-18.66, 8.48, -9.92]
|
||||
],
|
||||
[
|
||||
[2.93, 1.57, -23.76],
|
||||
[12.74, -26.2, -17.88],
|
||||
[-14.98, -9.35, 12.2]
|
||||
],
|
||||
[
|
||||
[-0.18, -6.82, 20.79],
|
||||
[-2.54, 27.11, -10.11],
|
||||
[-0.41, -3.18, -0.07]
|
||||
]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -278,6 +278,7 @@ impl EncodecConv1d {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
@ -289,6 +290,7 @@ impl EncodecConv1d {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
|
@ -66,6 +66,7 @@ impl ResnetBlock2D {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
|
||||
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
|
||||
@ -80,6 +81,7 @@ impl ResnetBlock2D {
|
||||
stride: 1,
|
||||
padding: 0,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
Some(conv2d(
|
||||
in_channels,
|
||||
|
@ -281,11 +281,13 @@ impl AudioEncoder {
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
|
@ -132,6 +132,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
||||
stride,
|
||||
padding,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let conv = if bias {
|
||||
conv2d(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
|
||||
|
@ -93,6 +93,7 @@ impl ConvBlock {
|
||||
padding,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||
|
@ -8,6 +8,7 @@ __device__ void conv1d(
|
||||
const size_t l_out,
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t dilation,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
const T *kernel,
|
||||
@ -36,7 +37,7 @@ __device__ void conv1d(
|
||||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
A d = 0;
|
||||
for (size_t offset = 0; offset < k_size; ++offset) {
|
||||
size_t src_l = stride * dst_l + offset;
|
||||
size_t src_l = (stride * dst_l + offset) * dilation;
|
||||
if (src_l < padding || src_l >= padding + l_in) {
|
||||
continue;
|
||||
}
|
||||
@ -58,6 +59,7 @@ __device__ void conv2d(
|
||||
const size_t h_out,
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t dilation,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
const T *kernel,
|
||||
@ -90,13 +92,13 @@ __device__ void conv2d(
|
||||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
A d = 0;
|
||||
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
|
||||
size_t src_w = stride * dst_w + w_offset;
|
||||
size_t src_w = (stride * dst_w + w_offset) * dilation;
|
||||
if (src_w < padding || src_w >= w_in + padding) {
|
||||
continue;
|
||||
}
|
||||
src_w -= padding;
|
||||
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
|
||||
size_t src_h = stride * dst_h + h_offset;
|
||||
size_t src_h = (stride * dst_h + h_offset) * dilation;
|
||||
if (src_h < padding || src_h >= h_in + padding) {
|
||||
continue;
|
||||
}
|
||||
@ -120,6 +122,7 @@ __device__ void conv_transpose2d(
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t out_padding,
|
||||
const size_t dilation,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
const T *kernel,
|
||||
@ -335,12 +338,13 @@ extern "C" __global__ void FN_NAME( \
|
||||
const size_t num_dims, \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t dilation, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, padding, info, src, kernel, dst); \
|
||||
conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, padding, dilation, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define CONV2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
@ -350,12 +354,13 @@ extern "C" __global__ void FN_NAME( \
|
||||
const size_t h_out, \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t dilation, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \
|
||||
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
@ -366,12 +371,13 @@ extern "C" __global__ void FN_NAME( \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t out_padding, \
|
||||
const size_t dilation, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, info, src, kernel, dst); \
|
||||
conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
|
@ -5,6 +5,7 @@ use candle::{Result, Tensor};
|
||||
pub struct Conv1dConfig {
|
||||
pub padding: usize,
|
||||
pub stride: usize,
|
||||
pub dilation: usize,
|
||||
pub groups: usize,
|
||||
}
|
||||
|
||||
@ -13,6 +14,7 @@ impl Default for Conv1dConfig {
|
||||
Self {
|
||||
padding: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
}
|
||||
}
|
||||
@ -45,6 +47,7 @@ impl crate::Module for Conv1d {
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.stride,
|
||||
self.config.dilation,
|
||||
self.config.groups,
|
||||
)?;
|
||||
match &self.bias {
|
||||
@ -62,6 +65,7 @@ impl crate::Module for Conv1d {
|
||||
pub struct Conv2dConfig {
|
||||
pub padding: usize,
|
||||
pub stride: usize,
|
||||
pub dilation: usize,
|
||||
pub groups: usize,
|
||||
}
|
||||
|
||||
@ -70,6 +74,7 @@ impl Default for Conv2dConfig {
|
||||
Self {
|
||||
padding: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
}
|
||||
}
|
||||
@ -103,6 +108,7 @@ impl crate::Module for Conv2d {
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.stride,
|
||||
self.config.dilation,
|
||||
self.config.groups,
|
||||
)?;
|
||||
match &self.bias {
|
||||
|
@ -269,11 +269,13 @@ impl AudioEncoder {
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
|
@ -97,6 +97,7 @@ impl ConvBlock {
|
||||
padding,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||
|
Reference in New Issue
Block a user