mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Expose the fields from batch-norm. (#1176)
This commit is contained in:
@ -100,9 +100,19 @@ impl BatchNorm {
|
|||||||
num_features,
|
num_features,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl BatchNorm {
|
pub fn running_mean(&self) -> &Tensor {
|
||||||
|
&self.running_mean
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn running_var(&self) -> &Tensor {
|
||||||
|
&self.running_var
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn weight_and_bias(&self) -> Option<(&Tensor, &Tensor)> {
|
||||||
|
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
|
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x_dtype = x.dtype();
|
let x_dtype = x.dtype();
|
||||||
let internal_dtype = match x_dtype {
|
let internal_dtype = match x_dtype {
|
||||||
|
Reference in New Issue
Block a user