mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
4 Commits
0.8.3
...
conv-trans
Author | SHA1 | Date | |
---|---|---|---|
b2796ce6ef | |||
c22f23e568 | |||
b283b3e181 | |||
b4fe316fa1 |
@ -213,6 +213,7 @@ impl Tensor {
|
||||
out_padding,
|
||||
*stride,
|
||||
*dilation,
|
||||
1,
|
||||
)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
|
@ -60,12 +60,13 @@ pub struct ParamsConvTranspose2D {
|
||||
pub(crate) i_w: usize,
|
||||
pub(crate) k_h: usize,
|
||||
pub(crate) k_w: usize,
|
||||
pub(crate) c_out: usize,
|
||||
pub(crate) c_out_per_group: usize,
|
||||
pub(crate) c_in: usize,
|
||||
pub(crate) padding: usize,
|
||||
pub(crate) output_padding: usize,
|
||||
pub(crate) stride: usize,
|
||||
pub(crate) dilation: usize,
|
||||
pub(crate) groups: usize,
|
||||
}
|
||||
|
||||
impl ParamsConvTranspose2D {
|
||||
@ -80,7 +81,8 @@ impl ParamsConvTranspose2D {
|
||||
}
|
||||
|
||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
|
||||
let c_out = self.c_out_per_group * self.groups;
|
||||
vec![self.b_size, c_out, self.out_h(), self.out_w()]
|
||||
}
|
||||
}
|
||||
|
||||
@ -211,24 +213,29 @@ impl Tensor {
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
|
||||
let (c_in_k, c_out_per_group, k_h, k_w) = kernel.dims4()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
if c_in % groups != 0 {
|
||||
crate::bail!("in_channel {c_in} must be divisible by groups {groups}")
|
||||
}
|
||||
let params = ParamsConvTranspose2D {
|
||||
b_size,
|
||||
i_h,
|
||||
i_w,
|
||||
k_h,
|
||||
k_w,
|
||||
c_out,
|
||||
c_out_per_group,
|
||||
c_in,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
};
|
||||
let storage = self.storage().conv_transpose2d(
|
||||
self.layout(),
|
||||
@ -243,6 +250,7 @@ impl Tensor {
|
||||
output_padding: params.output_padding,
|
||||
stride: params.stride,
|
||||
dilation: params.dilation,
|
||||
groups: params.groups,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
|
@ -1190,8 +1190,9 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
let (out_h, out_w) = (p.out_h(), p.out_w());
|
||||
|
||||
// Output shape: [b_size, c_out, out_h, out_w].
|
||||
let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
|
||||
let dst_s0 = p.c_out * out_h * out_w;
|
||||
let c_out = p.groups * p.c_out_per_group;
|
||||
let dst = vec![T::zero(); p.b_size * c_out * out_h * out_w];
|
||||
let dst_s0 = c_out * out_h * out_w;
|
||||
let dst_s1 = out_h * out_w;
|
||||
let dst_s2 = out_w;
|
||||
let dst_s3 = 1;
|
||||
@ -1214,12 +1215,16 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
let c_in_per_group = p.c_in / p.groups;
|
||||
for k_y in 0..p.k_h {
|
||||
for k_x in 0..p.k_w {
|
||||
(0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
|
||||
let k_cont = (0..p.c_in)
|
||||
(0..c_out).into_par_iter().for_each(|dst_c_idx| {
|
||||
let (group_idx, dst_c_idx_in_group) =
|
||||
(dst_c_idx / p.c_out_per_group, dst_c_idx % p.c_out_per_group);
|
||||
let k_cont = (0..c_in_per_group)
|
||||
.map(|c_in_idx| {
|
||||
k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
|
||||
let c_in_idx = group_idx * c_in_per_group + c_in_idx;
|
||||
k[c_in_idx * k_s0 + dst_c_idx_in_group * k_s1 + k_y * k_s2 + k_x * k_s3]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
for b_idx in 0..p.b_size {
|
||||
@ -1245,7 +1250,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
inp_cont.as_ptr(),
|
||||
k_cont.as_ptr(),
|
||||
&mut d,
|
||||
p.c_in,
|
||||
c_in_per_group,
|
||||
)
|
||||
}
|
||||
let dst_p = dst.as_ptr();
|
||||
|
@ -1050,11 +1050,11 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_in_k, c_out, h_k, w_k)
|
||||
// Kernel shape: (c_in_k, c_out / groups, h_k, w_k)
|
||||
// Input shape: (b_size, c_in, h_in, w_in)
|
||||
let p = &self.0;
|
||||
let (out_w, out_h) = (p.out_w(), p.out_h());
|
||||
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
||||
let dst_el = p.c_out_per_group * p.groups * out_w * out_h * p.b_size;
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
@ -1063,7 +1063,13 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
const NUM_THREADS: u32 = 512;
|
||||
let num_blocks = (dst_el as u32 + NUM_THREADS - 1) / NUM_THREADS;
|
||||
let mut cfg = LaunchConfig {
|
||||
grid_dim: (num_blocks, 1, 1),
|
||||
block_dim: (NUM_THREADS, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), kernels::CONV)?;
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
@ -1072,13 +1078,13 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (
|
||||
el,
|
||||
out_w,
|
||||
out_h,
|
||||
p.stride,
|
||||
p.padding,
|
||||
p.output_padding,
|
||||
p.dilation,
|
||||
p.groups,
|
||||
&ds,
|
||||
inp,
|
||||
k,
|
||||
|
@ -102,6 +102,7 @@ pub enum Op {
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
},
|
||||
|
||||
AvgPool2D {
|
||||
|
@ -130,7 +130,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
@ -164,7 +164,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
);
|
||||
|
||||
// Transpose and dilations.
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 9, 9]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
@ -246,13 +246,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
|
||||
]
|
||||
);
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
|
||||
);
|
||||
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1)?;
|
||||
let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [2, 2, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
|
@ -116,13 +116,13 @@ __device__ void conv2d(
|
||||
// Naive implementation of conv_transpose2d.
|
||||
template <typename T, typename A>
|
||||
__device__ void conv_transpose2d(
|
||||
const size_t src_numel,
|
||||
const size_t w_out,
|
||||
const size_t h_out,
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t out_padding,
|
||||
const size_t dilation,
|
||||
const size_t groups,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
const T *kernel,
|
||||
@ -130,17 +130,18 @@ __device__ void conv_transpose2d(
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// src: (b_size, c_in, h_in, w_in)
|
||||
// k: (c_in, c_out, h_k, w_k)
|
||||
// k: (c_in, c_out / groups, h_k, w_k)
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 4;
|
||||
const size_t *k_dims = info + 8;
|
||||
const size_t *k_s = info + 12;
|
||||
const size_t h_k = k_dims[2];
|
||||
const size_t w_k = k_dims[3];
|
||||
const size_t c_out = k_dims[1];
|
||||
const size_t c_out_per_group = k_dims[1];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t h_in = src_dims[2];
|
||||
const size_t w_in = src_dims[3];
|
||||
const size_t c_out = c_out_per_group * groups;
|
||||
if (dst_i >= src_dims[0] * c_out * w_out * h_out) {
|
||||
return;
|
||||
}
|
||||
@ -148,6 +149,10 @@ __device__ void conv_transpose2d(
|
||||
// TODO
|
||||
const size_t b_idx = dst_i / (w_out * h_out * c_out);
|
||||
const size_t dst_c_idx = (dst_i / (w_out * h_out)) % c_out;
|
||||
const size_t c_idx_in_group = dst_c_idx % c_out_per_group;
|
||||
const size_t c_in_per_group = c_in / groups;
|
||||
const size_t group_idx = dst_c_idx / c_out_per_group;
|
||||
// const size_t c_in_per_group = c_in;
|
||||
// NCHW layout.
|
||||
const size_t out_y = (dst_i / w_out) % h_out;
|
||||
const size_t out_x = dst_i % w_out;
|
||||
@ -169,9 +174,9 @@ __device__ void conv_transpose2d(
|
||||
}
|
||||
int inp_y = inp_y_stride / stride;
|
||||
if (inp_y >= h_in) continue;
|
||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
||||
for (size_t src_c_idx = group_idx * c_in_per_group; src_c_idx < (group_idx + 1) * c_in_per_group; ++src_c_idx) {
|
||||
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_y * src_s[2] + inp_x * src_s[3];
|
||||
const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_y * k_s[2] + k_x * k_s[3];
|
||||
const size_t k_idx = src_c_idx * k_s[0] + c_idx_in_group * k_s[1] + k_y * k_s[2] + k_x * k_s[3];
|
||||
d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
|
||||
}
|
||||
}
|
||||
@ -365,19 +370,19 @@ extern "C" __global__ void FN_NAME( \
|
||||
|
||||
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
const size_t w_out, \
|
||||
const size_t h_out, \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t out_padding, \
|
||||
const size_t dilation, \
|
||||
const size_t groups, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
|
||||
conv_transpose2d<TYPENAME, TYPEACC>(w_out, h_out, stride, padding, out_padding, dilation, groups, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
|
@ -127,7 +127,7 @@ pub struct ConvTranspose2dConfig {
|
||||
pub output_padding: usize,
|
||||
pub stride: usize,
|
||||
pub dilation: usize,
|
||||
// TODO: support groups.
|
||||
pub groups: usize,
|
||||
}
|
||||
|
||||
impl Default for ConvTranspose2dConfig {
|
||||
@ -137,6 +137,7 @@ impl Default for ConvTranspose2dConfig {
|
||||
output_padding: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -170,6 +171,7 @@ impl crate::Module for ConvTranspose2d {
|
||||
self.config.output_padding,
|
||||
self.config.stride,
|
||||
self.config.dilation,
|
||||
self.config.groups,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
|
Reference in New Issue
Block a user