mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
F16 support for stable diffusion (#488)
* F16 support for stable diffusion. * Keep the attention bits in F32. * Keep more of the attention bits in F32. * More mixed precision support.
This commit is contained in:
@ -5,7 +5,7 @@
|
||||
use crate::embeddings::{TimestepEmbedding, Timesteps};
|
||||
use crate::unet_2d_blocks::*;
|
||||
use crate::utils::{conv2d, Conv2d};
|
||||
use candle::{DType, Result, Tensor};
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
@ -316,7 +316,7 @@ impl UNet2DConditionModel {
|
||||
xs.clone()
|
||||
};
|
||||
// 1. time
|
||||
let emb = (Tensor::ones(bsize, DType::F32, device)? * timestep)?;
|
||||
let emb = (Tensor::ones(bsize, xs.dtype(), device)? * timestep)?;
|
||||
let emb = self.time_proj.forward(&emb)?;
|
||||
let emb = self.time_embedding.forward(&emb)?;
|
||||
// 2. pre-process
|
||||
|
Reference in New Issue
Block a user