mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add Efficientnet (#572)
* EfficientNet. * Complete the efficientnet implementation. * Improve group handling. * Get the efficientnet to work.
This commit is contained in:
@ -93,8 +93,8 @@ impl Tensor {
|
||||
let params = ParamsConv1D {
|
||||
b_size,
|
||||
l_in,
|
||||
c_out,
|
||||
c_in,
|
||||
c_out: c_out / groups,
|
||||
c_in: c_in / groups,
|
||||
k_size,
|
||||
padding,
|
||||
stride,
|
||||
@ -103,9 +103,11 @@ impl Tensor {
|
||||
self.conv1d_single_group(kernel, ¶ms)
|
||||
} else {
|
||||
let blocks = self.chunk(groups, 1)?;
|
||||
let kernel = kernel.chunk(groups, 0)?;
|
||||
let blocks = blocks
|
||||
.iter()
|
||||
.map(|block| block.conv1d_single_group(kernel, ¶ms))
|
||||
.zip(&kernel)
|
||||
.map(|(block, kernel)| block.conv1d_single_group(kernel, ¶ms))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Tensor::cat(&blocks, 1)
|
||||
}
|
||||
@ -146,8 +148,8 @@ impl Tensor {
|
||||
i_w,
|
||||
k_h,
|
||||
k_w,
|
||||
c_out,
|
||||
c_in,
|
||||
c_out: c_out / groups,
|
||||
c_in: c_in / groups,
|
||||
padding,
|
||||
stride,
|
||||
};
|
||||
@ -155,9 +157,11 @@ impl Tensor {
|
||||
self.conv2d_single_group(kernel, ¶ms)
|
||||
} else {
|
||||
let blocks = self.chunk(groups, 1)?;
|
||||
let kernel = kernel.chunk(groups, 0)?;
|
||||
let blocks = blocks
|
||||
.iter()
|
||||
.map(|block| block.conv2d_single_group(kernel, ¶ms))
|
||||
.zip(&kernel)
|
||||
.map(|(block, kernel)| block.conv2d_single_group(kernel, ¶ms))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Tensor::cat(&blocks, 1)
|
||||
}
|
||||
|
Reference in New Issue
Block a user