Preliminary support for SDXL. (#647)

* Preliminary support for SDXL.

* More SDXL support.

* More SDXL.

* Use the proper clip config.

* Querying for existing tensors.

* More robust test.
This commit is contained in:
Laurent Mazare
2023-08-29 09:00:04 +01:00
committed by GitHub
parent 49326fb925
commit 33c23c19b6
5 changed files with 298 additions and 58 deletions

View File

@ -8,6 +8,7 @@ pub struct StableDiffusionConfig {
pub width: usize,
pub height: usize,
pub clip: clip::Config,
pub clip2: Option<clip::Config>,
autoencoder: vae::AutoEncoderKLConfig,
unet: unet_2d::UNet2DConditionModelConfig,
scheduler: ddim::DDIMSchedulerConfig,
@ -51,7 +52,7 @@ impl StableDiffusionConfig {
norm_num_groups: 32,
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
512
@ -68,6 +69,7 @@ impl StableDiffusionConfig {
width,
height,
clip: clip::Config::v1_5(),
clip2: None,
autoencoder,
scheduler: Default::default(),
unet,
@ -118,7 +120,7 @@ impl StableDiffusionConfig {
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
768
@ -135,6 +137,7 @@ impl StableDiffusionConfig {
width,
height,
clip: clip::Config::v2_1(),
clip2: None,
autoencoder,
scheduler,
unet,
@ -155,6 +158,83 @@ impl StableDiffusionConfig {
)
}
fn sdxl_(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
prediction_type: PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![bc(320, false, 5), bc(640, false, 10), bc(1280, true, 20)],
center_input_sample: false,
cross_attention_dim: 2048,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
1024
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
1024
};
Self {
width,
height,
clip: clip::Config::sdxl(),
clip2: Some(clip::Config::sdxl2()),
autoencoder,
scheduler,
unet,
}
}
pub fn sdxl(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
Self::sdxl_(
sliced_attention_size,
height,
width,
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
PredictionType::Epsilon,
)
}
pub fn build_vae<P: AsRef<std::path::Path>>(
&self,
vae_weights: P,
@ -193,17 +273,17 @@ impl StableDiffusionConfig {
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
ddim::DDIMScheduler::new(n_steps, self.scheduler)
}
pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
&self,
clip_weights: P,
device: &Device,
dtype: DType,
) -> Result<clip::ClipTextTransformer> {
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
let weights = weights.deserialize()?;
let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
Ok(text_model)
}
}
pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
clip: &clip::Config,
clip_weights: P,
device: &Device,
dtype: DType,
) -> Result<clip::ClipTextTransformer> {
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
let weights = weights.deserialize()?;
let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
let text_model = clip::ClipTextTransformer::new(vs, clip)?;
Ok(text_model)
}