F16 support for stable diffusion (#488)

* F16 support for stable diffusion.

* Keep the attention bits in F32.

* Keep more of the attention bits in F32.

* More mixed precision support.
This commit is contained in:
Laurent Mazare
2023-08-17 13:48:56 +01:00
committed by GitHub
parent c3176f0dfb
commit 5d99026fd2
6 changed files with 99 additions and 43 deletions

View File

@ -5,7 +5,7 @@
use crate::embeddings::{TimestepEmbedding, Timesteps};
use crate::unet_2d_blocks::*;
use crate::utils::{conv2d, Conv2d};
use candle::{DType, Result, Tensor};
use candle::{Result, Tensor};
use candle_nn as nn;
#[derive(Debug, Clone, Copy)]
@ -316,7 +316,7 @@ impl UNet2DConditionModel {
xs.clone()
};
// 1. time
let emb = (Tensor::ones(bsize, DType::F32, device)? * timestep)?;
let emb = (Tensor::ones(bsize, xs.dtype(), device)? * timestep)?;
let emb = self.time_proj.forward(&emb)?;
let emb = self.time_embedding.forward(&emb)?;
// 2. pre-process