Add Efficientnet (#572)

* EfficientNet.

* Complete the efficientnet implementation.

* Improve group handling.

* Get the efficientnet to work.
This commit is contained in:
Laurent Mazare
2023-08-23 18:02:58 +01:00
committed by GitHub
parent eedd85ffa7
commit 431051cc32
4 changed files with 448 additions and 11 deletions

View File

@ -93,8 +93,8 @@ impl Tensor {
let params = ParamsConv1D {
b_size,
l_in,
c_out,
c_in,
c_out: c_out / groups,
c_in: c_in / groups,
k_size,
padding,
stride,
@ -103,9 +103,11 @@ impl Tensor {
self.conv1d_single_group(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
let kernel = kernel.chunk(groups, 0)?;
let blocks = blocks
.iter()
.map(|block| block.conv1d_single_group(kernel, &params))
.zip(&kernel)
.map(|(block, kernel)| block.conv1d_single_group(kernel, &params))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&blocks, 1)
}
@ -146,8 +148,8 @@ impl Tensor {
i_w,
k_h,
k_w,
c_out,
c_in,
c_out: c_out / groups,
c_in: c_in / groups,
padding,
stride,
};
@ -155,9 +157,11 @@ impl Tensor {
self.conv2d_single_group(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
let kernel = kernel.chunk(groups, 0)?;
let blocks = blocks
.iter()
.map(|block| block.conv2d_single_group(kernel, &params))
.zip(&kernel)
.map(|(block, kernel)| block.conv2d_single_group(kernel, &params))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&blocks, 1)
}