Implement group-norm. (#334)

* Implement group-norm.

* Add some testing for group-norm.
This commit is contained in:
Laurent Mazare
2023-08-07 07:53:05 +02:00
committed by GitHub
parent 2c9f605976
commit 5bb2fce998
5 changed files with 150 additions and 14 deletions

View File

@ -1,10 +1,9 @@
//! Group Normalization.
//!
//! This layer applies Group Normalization over a mini-batch of inputs.
use candle::{Result, Tensor};
use candle::{DType, Result, Tensor};
// This group norm version handles both weight and bias so removes the mean.
#[allow(dead_code)]
#[derive(Debug)]
pub struct GroupNorm {
weight: Tensor,
@ -21,18 +20,50 @@ impl GroupNorm {
num_channels: usize,
num_groups: usize,
eps: f64,
) -> Self {
Self {
) -> Result<Self> {
if num_channels % num_groups != 0 {
candle::bail!(
"GroupNorm: num_groups ({num_groups}) must divide num_channels ({num_channels})"
)
}
Ok(Self {
weight,
bias,
eps,
num_channels,
num_groups,
}
})
}
pub fn forward(&self, _: &Tensor) -> Result<Tensor> {
todo!()
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_shape = x.dims();
if x_shape.len() <= 2 {
candle::bail!("input rank for GroupNorm should be at least 3");
}
let (b_sz, n_channels) = (x_shape[0], x_shape[1]);
let hidden_size = x_shape[2..].iter().product::<usize>() * n_channels / self.num_groups;
if n_channels != self.num_channels {
candle::bail!(
"unexpected num-channels in GroupNorm ({n_channels} <> {}",
self.num_channels
)
}
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let x = x.reshape((b_sz, self.num_groups, hidden_size))?;
let x = x.to_dtype(internal_dtype)?;
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?
.reshape(x_shape)
}
}
@ -44,5 +75,5 @@ pub fn group_norm(
) -> Result<GroupNorm> {
let weight = vb.get_or_init(num_channels, "weight", crate::Init::Const(1.))?;
let bias = vb.get_or_init(num_channels, "bias", crate::Init::Const(0.))?;
Ok(GroupNorm::new(weight, bias, num_channels, num_groups, eps))
GroupNorm::new(weight, bias, num_channels, num_groups, eps)
}