Compare commits

...

4 Commits

Author SHA1 Message Date
b2796ce6ef Reduce the number of threads. 2023-09-09 21:10:30 +01:00
c22f23e568 Bugfix. 2023-09-09 19:11:59 +01:00
b283b3e181 Merge branch 'main' into conv-transpose-groups 2023-09-09 18:50:13 +01:00
b4fe316fa1 Group support in conv-transpose2d. 2023-09-09 18:11:41 +01:00
8 changed files with 54 additions and 26 deletions

View File

@ -213,6 +213,7 @@ impl Tensor {
out_padding,
*stride,
*dilation,
1,
)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;

View File

@ -60,12 +60,13 @@ pub struct ParamsConvTranspose2D {
pub(crate) i_w: usize,
pub(crate) k_h: usize,
pub(crate) k_w: usize,
pub(crate) c_out: usize,
pub(crate) c_out_per_group: usize,
pub(crate) c_in: usize,
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
pub(crate) groups: usize,
}
impl ParamsConvTranspose2D {
@ -80,7 +81,8 @@ impl ParamsConvTranspose2D {
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
let c_out = self.c_out_per_group * self.groups;
vec![self.b_size, c_out, self.out_h(), self.out_w()]
}
}
@ -211,24 +213,29 @@ impl Tensor {
output_padding: usize,
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
let (c_in_k, c_out_per_group, k_h, k_w) = kernel.dims4()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
if c_in % groups != 0 {
crate::bail!("in_channel {c_in} must be divisible by groups {groups}")
}
let params = ParamsConvTranspose2D {
b_size,
i_h,
i_w,
k_h,
k_w,
c_out,
c_out_per_group,
c_in,
padding,
output_padding,
stride,
dilation,
groups,
};
let storage = self.storage().conv_transpose2d(
self.layout(),
@ -243,6 +250,7 @@ impl Tensor {
output_padding: params.output_padding,
stride: params.stride,
dilation: params.dilation,
groups: params.groups,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))

View File

@ -1190,8 +1190,9 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
let (out_h, out_w) = (p.out_h(), p.out_w());
// Output shape: [b_size, c_out, out_h, out_w].
let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
let dst_s0 = p.c_out * out_h * out_w;
let c_out = p.groups * p.c_out_per_group;
let dst = vec![T::zero(); p.b_size * c_out * out_h * out_w];
let dst_s0 = c_out * out_h * out_w;
let dst_s1 = out_h * out_w;
let dst_s2 = out_w;
let dst_s3 = 1;
@ -1214,12 +1215,16 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
}
}
let c_in_per_group = p.c_in / p.groups;
for k_y in 0..p.k_h {
for k_x in 0..p.k_w {
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
let k_cont = (0..p.c_in)
(0..c_out).into_par_iter().for_each(|dst_c_idx| {
let (group_idx, dst_c_idx_in_group) =
(dst_c_idx / p.c_out_per_group, dst_c_idx % p.c_out_per_group);
let k_cont = (0..c_in_per_group)
.map(|c_in_idx| {
k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
let c_in_idx = group_idx * c_in_per_group + c_in_idx;
k[c_in_idx * k_s0 + dst_c_idx_in_group * k_s1 + k_y * k_s2 + k_x * k_s3]
})
.collect::<Vec<_>>();
for b_idx in 0..p.b_size {
@ -1245,7 +1250,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
inp_cont.as_ptr(),
k_cont.as_ptr(),
&mut d,
p.c_in,
c_in_per_group,
)
}
let dst_p = dst.as_ptr();

View File

@ -1050,11 +1050,11 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
k_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
// Kernel shape: (c_in_k, c_out, h_k, w_k)
// Kernel shape: (c_in_k, c_out / groups, h_k, w_k)
// Input shape: (b_size, c_in, h_in, w_in)
let p = &self.0;
let (out_w, out_h) = (p.out_w(), p.out_h());
let dst_el = p.c_out * out_w * out_h * p.b_size;
let dst_el = p.c_out_per_group * p.groups * out_w * out_h * p.b_size;
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(k_l.start_offset()..);
let shape = inp_l.shape();
@ -1063,7 +1063,13 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
const NUM_THREADS: u32 = 512;
let num_blocks = (dst_el as u32 + NUM_THREADS - 1) / NUM_THREADS;
let mut cfg = LaunchConfig {
grid_dim: (num_blocks, 1, 1),
block_dim: (NUM_THREADS, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), kernels::CONV)?;
let ds = if dims.len() == 4 {
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
@ -1072,13 +1078,13 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
};
let ds = dev.htod_copy(ds).w()?;
let params = (
el,
out_w,
out_h,
p.stride,
p.padding,
p.output_padding,
p.dilation,
p.groups,
&ds,
inp,
k,

View File

@ -102,6 +102,7 @@ pub enum Op {
output_padding: usize,
stride: usize,
dilation: usize,
groups: usize,
},
AvgPool2D {

View File

@ -130,7 +130,7 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7, 7]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
@ -164,7 +164,7 @@ fn conv2d(dev: &Device) -> Result<()> {
);
// Transpose and dilations.
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2, 1)?;
assert_eq!(res.dims(), [1, 2, 9, 9]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
@ -246,13 +246,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
);
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1)?;
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [2, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,

View File

@ -116,13 +116,13 @@ __device__ void conv2d(
// Naive implementation of conv_transpose2d.
template <typename T, typename A>
__device__ void conv_transpose2d(
const size_t src_numel,
const size_t w_out,
const size_t h_out,
const size_t stride,
const size_t padding,
const size_t out_padding,
const size_t dilation,
const size_t groups,
const size_t *info,
const T *src,
const T *kernel,
@ -130,17 +130,18 @@ __device__ void conv_transpose2d(
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
// src: (b_size, c_in, h_in, w_in)
// k: (c_in, c_out, h_k, w_k)
// k: (c_in, c_out / groups, h_k, w_k)
const size_t *src_dims = info;
const size_t *src_s = info + 4;
const size_t *k_dims = info + 8;
const size_t *k_s = info + 12;
const size_t h_k = k_dims[2];
const size_t w_k = k_dims[3];
const size_t c_out = k_dims[1];
const size_t c_out_per_group = k_dims[1];
const size_t c_in = src_dims[1];
const size_t h_in = src_dims[2];
const size_t w_in = src_dims[3];
const size_t c_out = c_out_per_group * groups;
if (dst_i >= src_dims[0] * c_out * w_out * h_out) {
return;
}
@ -148,6 +149,10 @@ __device__ void conv_transpose2d(
// TODO
const size_t b_idx = dst_i / (w_out * h_out * c_out);
const size_t dst_c_idx = (dst_i / (w_out * h_out)) % c_out;
const size_t c_idx_in_group = dst_c_idx % c_out_per_group;
const size_t c_in_per_group = c_in / groups;
const size_t group_idx = dst_c_idx / c_out_per_group;
// const size_t c_in_per_group = c_in;
// NCHW layout.
const size_t out_y = (dst_i / w_out) % h_out;
const size_t out_x = dst_i % w_out;
@ -169,9 +174,9 @@ __device__ void conv_transpose2d(
}
int inp_y = inp_y_stride / stride;
if (inp_y >= h_in) continue;
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
for (size_t src_c_idx = group_idx * c_in_per_group; src_c_idx < (group_idx + 1) * c_in_per_group; ++src_c_idx) {
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_y * src_s[2] + inp_x * src_s[3];
const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_y * k_s[2] + k_x * k_s[3];
const size_t k_idx = src_c_idx * k_s[0] + c_idx_in_group * k_s[1] + k_y * k_s[2] + k_x * k_s[3];
d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
}
}
@ -365,19 +370,19 @@ extern "C" __global__ void FN_NAME( \
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t src_numel, \
const size_t w_out, \
const size_t h_out, \
const size_t stride, \
const size_t padding, \
const size_t out_padding, \
const size_t dilation, \
const size_t groups, \
const size_t *info, \
const TYPENAME *src, \
const TYPENAME *kernel, \
TYPENAME *dst \
) { \
conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
conv_transpose2d<TYPENAME, TYPEACC>(w_out, h_out, stride, padding, out_padding, dilation, groups, info, src, kernel, dst); \
} \
#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \

View File

@ -127,7 +127,7 @@ pub struct ConvTranspose2dConfig {
pub output_padding: usize,
pub stride: usize,
pub dilation: usize,
// TODO: support groups.
pub groups: usize,
}
impl Default for ConvTranspose2dConfig {
@ -137,6 +137,7 @@ impl Default for ConvTranspose2dConfig {
output_padding: 0,
stride: 1,
dilation: 1,
groups: 1,
}
}
}
@ -170,6 +171,7 @@ impl crate::Module for ConvTranspose2d {
self.config.output_padding,
self.config.stride,
self.config.dilation,
self.config.groups,
)?;
match &self.bias {
None => Ok(x),