Add Efficientnet (#572)

* EfficientNet.

* Complete the efficientnet implementation.

* Improve group handling.

* Get the efficientnet to work.
This commit is contained in:
Laurent Mazare
2023-08-23 18:02:58 +01:00
committed by GitHub
parent eedd85ffa7
commit 431051cc32
4 changed files with 448 additions and 11 deletions

View File

@ -124,7 +124,11 @@ pub fn conv1d(
vs: crate::VarBuilder,
) -> Result<Conv1d> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get_or_init((out_channels, in_channels, kernel_size), "weight", init_ws)?;
let ws = vs.get_or_init(
(out_channels, in_channels / cfg.groups, kernel_size),
"weight",
init_ws,
)?;
let bound = 1. / (in_channels as f64).sqrt();
let init_bs = crate::Init::Uniform {
lo: -bound,
@ -143,7 +147,12 @@ pub fn conv2d(
) -> Result<Conv2d> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get_or_init(
(out_channels, in_channels, kernel_size, kernel_size),
(
out_channels,
in_channels / cfg.groups,
kernel_size,
kernel_size,
),
"weight",
init_ws,
)?;
@ -165,7 +174,12 @@ pub fn conv2d_no_bias(
) -> Result<Conv2d> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get_or_init(
(out_channels, in_channels, kernel_size, kernel_size),
(
out_channels,
in_channels / cfg.groups,
kernel_size,
kernel_size,
),
"weight",
init_ws,
)?;

View File

@ -129,7 +129,7 @@ impl<'a> VarBuilder<'a> {
})
}
pub fn push_prefix(&self, s: &str) -> Self {
pub fn push_prefix<S: ToString>(&self, s: S) -> Self {
let mut path = self.path.clone();
path.push(s.to_string());
Self {
@ -139,7 +139,7 @@ impl<'a> VarBuilder<'a> {
}
/// Short alias for `push_prefix`.
pub fn pp(&self, s: &str) -> Self {
pub fn pp<S: ToString>(&self, s: S) -> Self {
self.push_prefix(s)
}