mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
F16 support for stable diffusion (#488)
* F16 support for stable diffusion. * Keep the attention bits in F32. * Keep more of the attention bits in F32. * More mixed precision support.
This commit is contained in:
@ -44,10 +44,10 @@ impl Timesteps {
|
||||
impl Timesteps {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let half_dim = (self.num_channels / 2) as u32;
|
||||
let exponent =
|
||||
(Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?;
|
||||
let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)?
|
||||
* -f64::ln(10000.))?;
|
||||
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
|
||||
let emb = exponent.exp()?;
|
||||
let emb = exponent.exp()?.to_dtype(xs.dtype())?;
|
||||
// emb = timesteps[:, None].float() * emb[None, :]
|
||||
let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
|
||||
let (cos, sin) = (emb.cos()?, emb.sin()?);
|
||||
|
Reference in New Issue
Block a user