mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Only use classifier free guidance for the prior. (#896)
* Only use classifier free guidance for the prior. * Add another specific layer-norm structure. * Tweaks. * Fix the latent shape. * Print the prior shape. * More shape fixes. * Remove some debugging continue.
This commit is contained in:
@ -34,6 +34,34 @@ impl Module for WLayerNorm {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LayerNormNoWeights {
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNormNoWeights {
|
||||
pub fn new(_size: usize) -> Result<Self> {
|
||||
Ok(Self { eps: 1e-6 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerNormNoWeights {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = xs.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = xs.dim(D::Minus1)?;
|
||||
let xs = xs.to_dtype(internal_dtype)?;
|
||||
let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let xs = xs.broadcast_sub(&mean_x)?;
|
||||
let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
|
||||
.to_dtype(x_dtype)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TimestepBlock {
|
||||
mapper: candle_nn::Linear,
|
||||
|
Reference in New Issue
Block a user