Only use classifier free guidance for the prior. (#896)

* Only use classifier free guidance for the prior.

* Add another specific layer-norm structure.

* Tweaks.

* Fix the latent shape.

* Print the prior shape.

* More shape fixes.

* Remove some debugging continue.
This commit is contained in:
Laurent Mazare
2023-09-19 14:13:05 +01:00
committed by GitHub
parent 9cf26c5cff
commit 06e46d7c3b
4 changed files with 125 additions and 74 deletions

View File

@ -34,6 +34,34 @@ impl Module for WLayerNorm {
}
}
#[derive(Debug)]
pub struct LayerNormNoWeights {
eps: f64,
}
impl LayerNormNoWeights {
pub fn new(_size: usize) -> Result<Self> {
Ok(Self { eps: 1e-6 })
}
}
impl Module for LayerNormNoWeights {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let x_dtype = xs.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = xs.dim(D::Minus1)?;
let xs = xs.to_dtype(internal_dtype)?;
let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let xs = xs.broadcast_sub(&mean_x)?;
let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
.to_dtype(x_dtype)
}
}
#[derive(Debug)]
pub struct TimestepBlock {
mapper: candle_nn::Linear,