F16 support for stable diffusion (#488)

* F16 support for stable diffusion.

* Keep the attention bits in F32.

* Keep more of the attention bits in F32.

* More mixed precision support.
This commit is contained in:
Laurent Mazare
2023-08-17 13:48:56 +01:00
committed by GitHub
parent c3176f0dfb
commit 5d99026fd2
6 changed files with 99 additions and 43 deletions

View File

@ -44,10 +44,10 @@ impl Timesteps {
impl Timesteps {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let half_dim = (self.num_channels / 2) as u32;
let exponent =
(Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?;
let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)?
* -f64::ln(10000.))?;
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
let emb = exponent.exp()?;
let emb = exponent.exp()?.to_dtype(xs.dtype())?;
// emb = timesteps[:, None].float() * emb[None, :]
let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
let (cos, sin) = (emb.cos()?, emb.sin()?);