mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
F16 support for stable diffusion (#488)
* F16 support for stable diffusion. * Keep the attention bits in F32. * Keep more of the attention bits in F32. * More mixed precision support.
This commit is contained in:
@ -1,5 +1,5 @@
|
|||||||
//! Attention Based Building Blocks
|
//! Attention Based Building Blocks
|
||||||
use candle::{IndexOp, Result, Tensor, D};
|
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -147,6 +147,10 @@ impl CrossAttention {
|
|||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let batch_size_attention = query.dim(0)?;
|
let batch_size_attention = query.dim(0)?;
|
||||||
let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
|
let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
|
||||||
|
let in_dtype = query.dtype();
|
||||||
|
let query = query.to_dtype(DType::F32)?;
|
||||||
|
let key = key.to_dtype(DType::F32)?;
|
||||||
|
let value = value.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
for i in 0..batch_size_attention / slice_size {
|
for i in 0..batch_size_attention / slice_size {
|
||||||
let start_idx = i * slice_size;
|
let start_idx = i * slice_size;
|
||||||
@ -158,7 +162,7 @@ impl CrossAttention {
|
|||||||
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
|
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
|
||||||
hidden_states.push(xs)
|
hidden_states.push(xs)
|
||||||
}
|
}
|
||||||
let hidden_states = Tensor::stack(&hidden_states, 0)?;
|
let hidden_states = Tensor::stack(&hidden_states, 0)?.to_dtype(in_dtype)?;
|
||||||
self.reshape_batch_dim_to_heads(&hidden_states)
|
self.reshape_batch_dim_to_heads(&hidden_states)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -183,8 +187,14 @@ impl CrossAttention {
|
|||||||
.squeeze(0)?
|
.squeeze(0)?
|
||||||
.to_dtype(init_dtype)?
|
.to_dtype(init_dtype)?
|
||||||
} else {
|
} else {
|
||||||
|
let in_dtype = query.dtype();
|
||||||
|
let query = query.to_dtype(DType::F32)?;
|
||||||
|
let key = key.to_dtype(DType::F32)?;
|
||||||
|
let value = value.to_dtype(DType::F32)?;
|
||||||
let xs = query.matmul(&(key.t()? * self.scale)?)?;
|
let xs = query.matmul(&(key.t()? * self.scale)?)?;
|
||||||
nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?
|
nn::ops::softmax(&xs, D::Minus1)?
|
||||||
|
.matmul(&value)?
|
||||||
|
.to_dtype(in_dtype)?
|
||||||
};
|
};
|
||||||
self.reshape_batch_dim_to_heads(&xs)
|
self.reshape_batch_dim_to_heads(&xs)
|
||||||
}
|
}
|
||||||
@ -457,10 +467,15 @@ impl AttentionBlock {
|
|||||||
let num_heads = channels / num_head_channels;
|
let num_heads = channels / num_head_channels;
|
||||||
let group_norm =
|
let group_norm =
|
||||||
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
|
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
|
||||||
let query = nn::linear(channels, channels, vs.pp("query"))?;
|
let (q_path, k_path, v_path, out_path) = if vs.dtype() == DType::F16 {
|
||||||
let key = nn::linear(channels, channels, vs.pp("key"))?;
|
("to_q", "to_k", "to_v", "to_out.0")
|
||||||
let value = nn::linear(channels, channels, vs.pp("value"))?;
|
} else {
|
||||||
let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
|
("query", "key", "value", "proj_attn")
|
||||||
|
};
|
||||||
|
let query = nn::linear(channels, channels, vs.pp(q_path))?;
|
||||||
|
let key = nn::linear(channels, channels, vs.pp(k_path))?;
|
||||||
|
let value = nn::linear(channels, channels, vs.pp(v_path))?;
|
||||||
|
let proj_attn = nn::linear(channels, channels, vs.pp(out_path))?;
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
|
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
group_norm,
|
group_norm,
|
||||||
@ -483,6 +498,7 @@ impl AttentionBlock {
|
|||||||
|
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
|
let in_dtype = xs.dtype();
|
||||||
let residual = xs;
|
let residual = xs;
|
||||||
let (batch, channel, height, width) = xs.dims4()?;
|
let (batch, channel, height, width) = xs.dims4()?;
|
||||||
let xs = self
|
let xs = self
|
||||||
@ -495,9 +511,13 @@ impl AttentionBlock {
|
|||||||
let key_proj = self.key.forward(&xs)?;
|
let key_proj = self.key.forward(&xs)?;
|
||||||
let value_proj = self.value.forward(&xs)?;
|
let value_proj = self.value.forward(&xs)?;
|
||||||
|
|
||||||
let query_states = self.transpose_for_scores(query_proj)?;
|
let query_states = self
|
||||||
let key_states = self.transpose_for_scores(key_proj)?;
|
.transpose_for_scores(query_proj)?
|
||||||
let value_states = self.transpose_for_scores(value_proj)?;
|
.to_dtype(DType::F32)?;
|
||||||
|
let key_states = self.transpose_for_scores(key_proj)?.to_dtype(DType::F32)?;
|
||||||
|
let value_states = self
|
||||||
|
.transpose_for_scores(value_proj)?
|
||||||
|
.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
|
let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
|
||||||
let attention_scores =
|
let attention_scores =
|
||||||
@ -506,6 +526,7 @@ impl AttentionBlock {
|
|||||||
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
|
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
|
||||||
|
|
||||||
let xs = attention_probs.matmul(&value_states.contiguous()?)?;
|
let xs = attention_probs.matmul(&value_states.contiguous()?)?;
|
||||||
|
let xs = xs.to_dtype(in_dtype)?;
|
||||||
let xs = xs.transpose(1, 2)?.contiguous()?;
|
let xs = xs.transpose(1, 2)?.contiguous()?;
|
||||||
let xs = xs.flatten_from(D::Minus2)?;
|
let xs = xs.flatten_from(D::Minus2)?;
|
||||||
let xs = self
|
let xs = self
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
//! pairs of images with related texts.
|
//! pairs of images with related texts.
|
||||||
//!
|
//!
|
||||||
//! https://github.com/openai/CLIP
|
//! https://github.com/openai/CLIP
|
||||||
use candle::{Device, Result, Tensor, D};
|
use candle::{DType, Device, Result, Tensor, D};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
@ -146,18 +146,22 @@ impl ClipAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
|
||||||
|
let in_dtype = xs.dtype();
|
||||||
let (bsz, seq_len, embed_dim) = xs.dims3()?;
|
let (bsz, seq_len, embed_dim) = xs.dims3()?;
|
||||||
let query_states = (self.q_proj.forward(xs)? * self.scale)?;
|
let query_states = (self.q_proj.forward(xs)? * self.scale)?;
|
||||||
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
|
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
|
||||||
let query_states = self
|
let query_states = self
|
||||||
.shape(&query_states, seq_len, bsz)?
|
.shape(&query_states, seq_len, bsz)?
|
||||||
.reshape(proj_shape)?;
|
.reshape(proj_shape)?
|
||||||
|
.to_dtype(DType::F32)?;
|
||||||
let key_states = self
|
let key_states = self
|
||||||
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
|
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
|
||||||
.reshape(proj_shape)?;
|
.reshape(proj_shape)?
|
||||||
|
.to_dtype(DType::F32)?;
|
||||||
let value_states = self
|
let value_states = self
|
||||||
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
|
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
|
||||||
.reshape(proj_shape)?;
|
.reshape(proj_shape)?
|
||||||
|
.to_dtype(DType::F32)?;
|
||||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||||
|
|
||||||
let src_len = key_states.dim(1)?;
|
let src_len = key_states.dim(1)?;
|
||||||
@ -168,7 +172,7 @@ impl ClipAttention {
|
|||||||
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
|
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
|
||||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||||
|
|
||||||
let attn_output = attn_weights.matmul(&value_states)?;
|
let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
|
||||||
let attn_output = attn_output
|
let attn_output = attn_output
|
||||||
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
|
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
|
@ -44,10 +44,10 @@ impl Timesteps {
|
|||||||
impl Timesteps {
|
impl Timesteps {
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let half_dim = (self.num_channels / 2) as u32;
|
let half_dim = (self.num_channels / 2) as u32;
|
||||||
let exponent =
|
let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)?
|
||||||
(Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?;
|
* -f64::ln(10000.))?;
|
||||||
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
|
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
|
||||||
let emb = exponent.exp()?;
|
let emb = exponent.exp()?.to_dtype(xs.dtype())?;
|
||||||
// emb = timesteps[:, None].float() * emb[None, :]
|
// emb = timesteps[:, None].float() * emb[None, :]
|
||||||
let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
|
let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
|
||||||
let (cos, sin) = (emb.cos()?, emb.sin()?);
|
let (cos, sin) = (emb.cos()?, emb.sin()?);
|
||||||
|
@ -93,6 +93,9 @@ struct Args {
|
|||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_f16: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
|
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
|
||||||
@ -117,21 +120,39 @@ impl StableDiffusionVersion {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unet_file(&self) -> &'static str {
|
fn unet_file(&self, use_f16: bool) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::V1_5 | Self::V2_1 => "unet/diffusion_pytorch_model.safetensors",
|
Self::V1_5 | Self::V2_1 => {
|
||||||
|
if use_f16 {
|
||||||
|
"unet/diffusion_pytorch_model.fp16.safetensors"
|
||||||
|
} else {
|
||||||
|
"unet/diffusion_pytorch_model.safetensors"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vae_file(&self) -> &'static str {
|
fn vae_file(&self, use_f16: bool) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::V1_5 | Self::V2_1 => "vae/diffusion_pytorch_model.safetensors",
|
Self::V1_5 | Self::V2_1 => {
|
||||||
|
if use_f16 {
|
||||||
|
"vae/diffusion_pytorch_model.fp16.safetensors"
|
||||||
|
} else {
|
||||||
|
"vae/diffusion_pytorch_model.safetensors"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clip_file(&self) -> &'static str {
|
fn clip_file(&self, use_f16: bool) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::V1_5 | Self::V2_1 => "text_encoder/model.safetensors",
|
Self::V1_5 | Self::V2_1 => {
|
||||||
|
if use_f16 {
|
||||||
|
"text_encoder/model.fp16.safetensors"
|
||||||
|
} else {
|
||||||
|
"text_encoder/model.safetensors"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -144,6 +165,7 @@ impl ModelFile {
|
|||||||
&self,
|
&self,
|
||||||
filename: Option<String>,
|
filename: Option<String>,
|
||||||
version: StableDiffusionVersion,
|
version: StableDiffusionVersion,
|
||||||
|
use_f16: bool,
|
||||||
) -> Result<std::path::PathBuf> {
|
) -> Result<std::path::PathBuf> {
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
match filename {
|
match filename {
|
||||||
@ -151,9 +173,9 @@ impl ModelFile {
|
|||||||
None => {
|
None => {
|
||||||
let (repo, path) = match self {
|
let (repo, path) = match self {
|
||||||
Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH),
|
Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH),
|
||||||
Self::Clip => (version.repo(), version.clip_file()),
|
Self::Clip => (version.repo(), version.clip_file(use_f16)),
|
||||||
Self::Unet => (version.repo(), version.unet_file()),
|
Self::Unet => (version.repo(), version.unet_file(use_f16)),
|
||||||
Self::Vae => (version.repo(), version.vae_file()),
|
Self::Vae => (version.repo(), version.vae_file(use_f16)),
|
||||||
};
|
};
|
||||||
let filename = Api::new()?.model(repo.to_string()).get(path)?;
|
let filename = Api::new()?.model(repo.to_string()).get(path)?;
|
||||||
Ok(filename)
|
Ok(filename)
|
||||||
@ -209,6 +231,8 @@ fn run(args: Args) -> Result<()> {
|
|||||||
vae_weights,
|
vae_weights,
|
||||||
unet_weights,
|
unet_weights,
|
||||||
tracing,
|
tracing,
|
||||||
|
use_f16,
|
||||||
|
use_flash_attn,
|
||||||
..
|
..
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
@ -220,6 +244,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let dtype = if use_f16 { DType::F16 } else { DType::F32 };
|
||||||
let sd_config = match sd_version {
|
let sd_config = match sd_version {
|
||||||
StableDiffusionVersion::V1_5 => {
|
StableDiffusionVersion::V1_5 => {
|
||||||
stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width)
|
stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width)
|
||||||
@ -232,7 +257,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let scheduler = sd_config.build_scheduler(n_steps)?;
|
let scheduler = sd_config.build_scheduler(n_steps)?;
|
||||||
let device = candle_examples::device(cpu)?;
|
let device = candle_examples::device(cpu)?;
|
||||||
|
|
||||||
let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version)?;
|
let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version, use_f16)?;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
let pad_id = match &sd_config.clip.pad_with {
|
let pad_id = match &sd_config.clip.pad_with {
|
||||||
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
|
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
|
||||||
@ -260,18 +285,20 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
println!("Building the Clip transformer.");
|
println!("Building the Clip transformer.");
|
||||||
let clip_weights = ModelFile::Clip.get(clip_weights, sd_version)?;
|
let text_embeddings = {
|
||||||
let text_model = sd_config.build_clip_transformer(&clip_weights, &device)?;
|
let clip_weights = ModelFile::Clip.get(clip_weights, sd_version, false)?;
|
||||||
let text_embeddings = text_model.forward(&tokens)?;
|
let text_model = sd_config.build_clip_transformer(&clip_weights, &device, DType::F32)?;
|
||||||
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
let text_embeddings = text_model.forward(&tokens)?;
|
||||||
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
|
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
||||||
|
Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
|
||||||
|
};
|
||||||
|
|
||||||
println!("Building the autoencoder.");
|
println!("Building the autoencoder.");
|
||||||
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version)?;
|
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
||||||
let vae = sd_config.build_vae(&vae_weights, &device)?;
|
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
|
||||||
println!("Building the unet.");
|
println!("Building the unet.");
|
||||||
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version)?;
|
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
|
||||||
let unet = sd_config.build_unet(&unet_weights, &device, 4, args.use_flash_attn)?;
|
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
|
||||||
|
|
||||||
let bsize = 1;
|
let bsize = 1;
|
||||||
for idx in 0..num_samples {
|
for idx in 0..num_samples {
|
||||||
@ -280,7 +307,8 @@ fn run(args: Args) -> Result<()> {
|
|||||||
1f32,
|
1f32,
|
||||||
(bsize, 4, sd_config.height / 8, sd_config.width / 8),
|
(bsize, 4, sd_config.height / 8, sd_config.width / 8),
|
||||||
&device,
|
&device,
|
||||||
)?;
|
)?
|
||||||
|
.to_dtype(dtype)?;
|
||||||
|
|
||||||
// scale the initial noise by the standard deviation required by the scheduler
|
// scale the initial noise by the standard deviation required by the scheduler
|
||||||
latents = (latents * scheduler.init_noise_sigma())?;
|
latents = (latents * scheduler.init_noise_sigma())?;
|
||||||
|
@ -159,10 +159,11 @@ impl StableDiffusionConfig {
|
|||||||
&self,
|
&self,
|
||||||
vae_weights: P,
|
vae_weights: P,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
|
dtype: DType,
|
||||||
) -> Result<vae::AutoEncoderKL> {
|
) -> Result<vae::AutoEncoderKL> {
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
|
||||||
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
|
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
|
||||||
let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
|
let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
|
||||||
Ok(autoencoder)
|
Ok(autoencoder)
|
||||||
@ -174,10 +175,11 @@ impl StableDiffusionConfig {
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
in_channels: usize,
|
in_channels: usize,
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
|
dtype: DType,
|
||||||
) -> Result<unet_2d::UNet2DConditionModel> {
|
) -> Result<unet_2d::UNet2DConditionModel> {
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
|
||||||
let unet = unet_2d::UNet2DConditionModel::new(
|
let unet = unet_2d::UNet2DConditionModel::new(
|
||||||
vs_unet,
|
vs_unet,
|
||||||
in_channels,
|
in_channels,
|
||||||
@ -196,10 +198,11 @@ impl StableDiffusionConfig {
|
|||||||
&self,
|
&self,
|
||||||
clip_weights: P,
|
clip_weights: P,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
|
dtype: DType,
|
||||||
) -> Result<clip::ClipTextTransformer> {
|
) -> Result<clip::ClipTextTransformer> {
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vs = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
|
||||||
let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
|
let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
|
||||||
Ok(text_model)
|
Ok(text_model)
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
use crate::embeddings::{TimestepEmbedding, Timesteps};
|
use crate::embeddings::{TimestepEmbedding, Timesteps};
|
||||||
use crate::unet_2d_blocks::*;
|
use crate::unet_2d_blocks::*;
|
||||||
use crate::utils::{conv2d, Conv2d};
|
use crate::utils::{conv2d, Conv2d};
|
||||||
use candle::{DType, Result, Tensor};
|
use candle::{Result, Tensor};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
@ -316,7 +316,7 @@ impl UNet2DConditionModel {
|
|||||||
xs.clone()
|
xs.clone()
|
||||||
};
|
};
|
||||||
// 1. time
|
// 1. time
|
||||||
let emb = (Tensor::ones(bsize, DType::F32, device)? * timestep)?;
|
let emb = (Tensor::ones(bsize, xs.dtype(), device)? * timestep)?;
|
||||||
let emb = self.time_proj.forward(&emb)?;
|
let emb = self.time_proj.forward(&emb)?;
|
||||||
let emb = self.time_embedding.forward(&emb)?;
|
let emb = self.time_embedding.forward(&emb)?;
|
||||||
// 2. pre-process
|
// 2. pre-process
|
||||||
|
Reference in New Issue
Block a user