Preliminary support for SDXL. (#647)

* Preliminary support for SDXL.

* More SDXL support.

* More SDXL.

* Use the proper clip config.

* Querying for existing tensors.

* More robust test.
This commit is contained in:
Laurent Mazare
2023-08-29 09:00:04 +01:00
committed by GitHub
parent 49326fb925
commit 33c23c19b6
5 changed files with 298 additions and 58 deletions

View File

@ -473,7 +473,7 @@ impl AttentionBlock {
let num_heads = channels / num_head_channels;
let group_norm =
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
let (q_path, k_path, v_path, out_path) = if vs.dtype() == DType::F16 {
let (q_path, k_path, v_path, out_path) = if vs.contains_tensor("to_q.weight") {
("to_q", "to_k", "to_v", "to_out.0")
} else {
("query", "key", "value", "proj_attn")

View File

@ -69,6 +69,36 @@ impl Config {
activation: Activation::Gelu,
}
}
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder/config.json
pub fn sdxl() -> Self {
Self {
vocab_size: 49408,
embed_dim: 768,
intermediate_size: 3072,
max_position_embeddings: 77,
pad_with: Some("!".to_string()),
num_hidden_layers: 12,
num_attention_heads: 12,
projection_dim: 768,
activation: Activation::QuickGelu,
}
}
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder_2/config.json
pub fn sdxl2() -> Self {
Self {
vocab_size: 49408,
embed_dim: 1280,
intermediate_size: 5120,
max_position_embeddings: 77,
pad_with: Some("!".to_string()),
num_hidden_layers: 32,
num_attention_heads: 20,
projection_dim: 1280,
activation: Activation::Gelu,
}
}
}
// CLIP Text Model

View File

@ -17,7 +17,7 @@ mod utils;
mod vae;
use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor};
use candle::{DType, Device, IndexOp, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;
@ -102,12 +102,16 @@ struct Args {
enum StableDiffusionVersion {
V1_5,
V2_1,
Xl,
}
#[allow(unused)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelFile {
Tokenizer,
Tokenizer2,
Clip,
Clip2,
Unet,
Vae,
}
@ -115,6 +119,7 @@ enum ModelFile {
impl StableDiffusionVersion {
fn repo(&self) -> &'static str {
match self {
Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
}
@ -122,7 +127,7 @@ impl StableDiffusionVersion {
fn unet_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 => {
Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"unet/diffusion_pytorch_model.fp16.safetensors"
} else {
@ -134,7 +139,7 @@ impl StableDiffusionVersion {
fn vae_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 => {
Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"vae/diffusion_pytorch_model.fp16.safetensors"
} else {
@ -146,7 +151,7 @@ impl StableDiffusionVersion {
fn clip_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 => {
Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"text_encoder/model.fp16.safetensors"
} else {
@ -155,12 +160,21 @@ impl StableDiffusionVersion {
}
}
}
fn clip2_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"text_encoder_2/model.fp16.safetensors"
} else {
"text_encoder_2/model.safetensors"
}
}
}
}
}
impl ModelFile {
const TOKENIZER_REPO: &str = "openai/clip-vit-base-patch32";
const TOKENIZER_PATH: &str = "tokenizer.json";
fn get(
&self,
filename: Option<String>,
@ -172,8 +186,24 @@ impl ModelFile {
Some(filename) => Ok(std::path::PathBuf::from(filename)),
None => {
let (repo, path) = match self {
Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH),
Self::Tokenizer => {
let tokenizer_repo = match version {
StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
"openai/clip-vit-base-patch32"
}
StableDiffusionVersion::Xl => {
// This seems similar to the patch32 version except some very small
// difference in the split regex.
"openai/clip-vit-large-patch14"
}
};
(tokenizer_repo, "tokenizer.json")
}
Self::Tokenizer2 => {
("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json")
}
Self::Clip => (version.repo(), version.clip_file(use_f16)),
Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
Self::Unet => (version.repo(), version.unet_file(use_f16)),
Self::Vae => (version.repo(), version.vae_file(use_f16)),
};
@ -211,6 +241,71 @@ fn output_filename(
}
}
#[allow(clippy::too_many_arguments)]
fn text_embeddings(
prompt: &str,
uncond_prompt: &str,
tokenizer: Option<String>,
clip_weights: Option<String>,
sd_version: StableDiffusionVersion,
sd_config: &stable_diffusion::StableDiffusionConfig,
use_f16: bool,
device: &Device,
dtype: DType,
first: bool,
) -> Result<Tensor> {
let tokenizer_file = if first {
ModelFile::Tokenizer
} else {
ModelFile::Tokenizer2
};
let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?;
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let pad_id = match &sd_config.clip.pad_with {
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
};
println!("Running with prompt \"{prompt}\".");
let mut tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while tokens.len() < sd_config.clip.max_position_embeddings {
tokens.push(pad_id)
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
println!("Building the Clip transformer.");
let clip_weights_file = if first {
ModelFile::Clip
} else {
ModelFile::Clip2
};
let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?;
let clip_config = if first {
&sd_config.clip
} else {
sd_config.clip2.as_ref().unwrap()
};
let text_model =
stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward(&tokens)?;
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
Ok(text_embeddings)
}
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
@ -252,46 +347,37 @@ fn run(args: Args) -> Result<()> {
StableDiffusionVersion::V2_1 => {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
}
StableDiffusionVersion::Xl => {
stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
}
};
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version, use_f16)?;
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let pad_id = match &sd_config.clip.pad_with {
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
};
println!("Running with prompt \"{prompt}\".");
let mut tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while tokens.len() < sd_config.clip.max_position_embeddings {
tokens.push(pad_id)
}
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
println!("Building the Clip transformer.");
let text_embeddings = {
let clip_weights = ModelFile::Clip.get(clip_weights, sd_version, false)?;
let text_model = sd_config.build_clip_transformer(&clip_weights, &device, DType::F32)?;
let text_embeddings = text_model.forward(&tokens)?;
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
let which = match sd_version {
StableDiffusionVersion::Xl => vec![true, false],
_ => vec![true],
};
let text_embeddings = which
.iter()
.map(|first| {
text_embeddings(
&prompt,
&uncond_prompt,
tokenizer.clone(),
clip_weights.clone(),
sd_version,
&sd_config,
use_f16,
&device,
dtype,
*first,
)
})
.collect::<Result<Vec<_>>>()?;
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
println!("{text_embeddings:?}");
println!("Building the autoencoder.");
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;

View File

@ -8,6 +8,7 @@ pub struct StableDiffusionConfig {
pub width: usize,
pub height: usize,
pub clip: clip::Config,
pub clip2: Option<clip::Config>,
autoencoder: vae::AutoEncoderKLConfig,
unet: unet_2d::UNet2DConditionModelConfig,
scheduler: ddim::DDIMSchedulerConfig,
@ -51,7 +52,7 @@ impl StableDiffusionConfig {
norm_num_groups: 32,
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
512
@ -68,6 +69,7 @@ impl StableDiffusionConfig {
width,
height,
clip: clip::Config::v1_5(),
clip2: None,
autoencoder,
scheduler: Default::default(),
unet,
@ -118,7 +120,7 @@ impl StableDiffusionConfig {
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
768
@ -135,6 +137,7 @@ impl StableDiffusionConfig {
width,
height,
clip: clip::Config::v2_1(),
clip2: None,
autoencoder,
scheduler,
unet,
@ -155,6 +158,83 @@ impl StableDiffusionConfig {
)
}
fn sdxl_(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
prediction_type: PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![bc(320, false, 5), bc(640, false, 10), bc(1280, true, 20)],
center_input_sample: false,
cross_attention_dim: 2048,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
1024
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
1024
};
Self {
width,
height,
clip: clip::Config::sdxl(),
clip2: Some(clip::Config::sdxl2()),
autoencoder,
scheduler,
unet,
}
}
pub fn sdxl(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
Self::sdxl_(
sliced_attention_size,
height,
width,
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
PredictionType::Epsilon,
)
}
pub fn build_vae<P: AsRef<std::path::Path>>(
&self,
vae_weights: P,
@ -193,17 +273,17 @@ impl StableDiffusionConfig {
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
ddim::DDIMScheduler::new(n_steps, self.scheduler)
}
pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
&self,
clip_weights: P,
device: &Device,
dtype: DType,
) -> Result<clip::ClipTextTransformer> {
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
let weights = weights.deserialize()?;
let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
Ok(text_model)
}
}
pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
clip: &clip::Config,
clip_weights: P,
device: &Device,
dtype: DType,
) -> Result<clip::ClipTextTransformer> {
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
let weights = weights.deserialize()?;
let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
let text_model = clip::ClipTextTransformer::new(vs, clip)?;
Ok(text_model)
}