mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Expose the fields from batch-norm. (#1176)
This commit is contained in:
@ -100,9 +100,19 @@ impl BatchNorm {
|
||||
num_features,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl BatchNorm {
|
||||
pub fn running_mean(&self) -> &Tensor {
|
||||
&self.running_mean
|
||||
}
|
||||
|
||||
pub fn running_var(&self) -> &Tensor {
|
||||
&self.running_var
|
||||
}
|
||||
|
||||
pub fn weight_and_bias(&self) -> Option<(&Tensor, &Tensor)> {
|
||||
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
|
||||
}
|
||||
|
||||
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
|
Reference in New Issue
Block a user