VarBuilder cleanup (#627)

* VarBuilder cleanup.

* Implement the basic varbuilders.

* Add the sharded code.

* Proper support for tensor sharding.
This commit is contained in:
Laurent Mazare
2023-08-27 18:03:26 +01:00
committed by GitHub
parent be471d50ab
commit 4c338b0cd9
12 changed files with 409 additions and 291 deletions

View File

@ -179,11 +179,11 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
if config.eps < 0. {
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
}
let running_mean = vb.get_or_init(num_features, "running_mean", crate::Init::Const(0.))?;
let running_var = vb.get_or_init(num_features, "running_var", crate::Init::Const(1.))?;
let running_mean = vb.get_with_hints(num_features, "running_mean", crate::Init::Const(0.))?;
let running_var = vb.get_with_hints(num_features, "running_var", crate::Init::Const(1.))?;
let weight_and_bias = if config.affine {
let weight = vb.get_or_init(num_features, "weight", crate::Init::Const(1.))?;
let bias = vb.get_or_init(num_features, "bias", crate::Init::Const(0.))?;
let weight = vb.get_with_hints(num_features, "weight", crate::Init::Const(1.))?;
let bias = vb.get_with_hints(num_features, "bias", crate::Init::Const(0.))?;
Some((weight, bias))
} else {
None