mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Preliminary support for SDXL. (#647)
* Preliminary support for SDXL. * More SDXL support. * More SDXL. * Use the proper clip config. * Querying for existing tensors. * More robust test.
This commit is contained in:
@ -473,7 +473,7 @@ impl AttentionBlock {
|
|||||||
let num_heads = channels / num_head_channels;
|
let num_heads = channels / num_head_channels;
|
||||||
let group_norm =
|
let group_norm =
|
||||||
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
|
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
|
||||||
let (q_path, k_path, v_path, out_path) = if vs.dtype() == DType::F16 {
|
let (q_path, k_path, v_path, out_path) = if vs.contains_tensor("to_q.weight") {
|
||||||
("to_q", "to_k", "to_v", "to_out.0")
|
("to_q", "to_k", "to_v", "to_out.0")
|
||||||
} else {
|
} else {
|
||||||
("query", "key", "value", "proj_attn")
|
("query", "key", "value", "proj_attn")
|
||||||
|
@ -69,6 +69,36 @@ impl Config {
|
|||||||
activation: Activation::Gelu,
|
activation: Activation::Gelu,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder/config.json
|
||||||
|
pub fn sdxl() -> Self {
|
||||||
|
Self {
|
||||||
|
vocab_size: 49408,
|
||||||
|
embed_dim: 768,
|
||||||
|
intermediate_size: 3072,
|
||||||
|
max_position_embeddings: 77,
|
||||||
|
pad_with: Some("!".to_string()),
|
||||||
|
num_hidden_layers: 12,
|
||||||
|
num_attention_heads: 12,
|
||||||
|
projection_dim: 768,
|
||||||
|
activation: Activation::QuickGelu,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder_2/config.json
|
||||||
|
pub fn sdxl2() -> Self {
|
||||||
|
Self {
|
||||||
|
vocab_size: 49408,
|
||||||
|
embed_dim: 1280,
|
||||||
|
intermediate_size: 5120,
|
||||||
|
max_position_embeddings: 77,
|
||||||
|
pad_with: Some("!".to_string()),
|
||||||
|
num_hidden_layers: 32,
|
||||||
|
num_attention_heads: 20,
|
||||||
|
projection_dim: 1280,
|
||||||
|
activation: Activation::Gelu,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CLIP Text Model
|
// CLIP Text Model
|
||||||
|
@ -17,7 +17,7 @@ mod utils;
|
|||||||
mod vae;
|
mod vae;
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{DType, Device, IndexOp, Tensor};
|
use candle::{DType, Device, IndexOp, Tensor, D};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
@ -102,12 +102,16 @@ struct Args {
|
|||||||
enum StableDiffusionVersion {
|
enum StableDiffusionVersion {
|
||||||
V1_5,
|
V1_5,
|
||||||
V2_1,
|
V2_1,
|
||||||
|
Xl,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
enum ModelFile {
|
enum ModelFile {
|
||||||
Tokenizer,
|
Tokenizer,
|
||||||
|
Tokenizer2,
|
||||||
Clip,
|
Clip,
|
||||||
|
Clip2,
|
||||||
Unet,
|
Unet,
|
||||||
Vae,
|
Vae,
|
||||||
}
|
}
|
||||||
@ -115,6 +119,7 @@ enum ModelFile {
|
|||||||
impl StableDiffusionVersion {
|
impl StableDiffusionVersion {
|
||||||
fn repo(&self) -> &'static str {
|
fn repo(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
|
Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
|
||||||
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
|
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
|
||||||
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
|
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
|
||||||
}
|
}
|
||||||
@ -122,7 +127,7 @@ impl StableDiffusionVersion {
|
|||||||
|
|
||||||
fn unet_file(&self, use_f16: bool) -> &'static str {
|
fn unet_file(&self, use_f16: bool) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::V1_5 | Self::V2_1 => {
|
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||||
if use_f16 {
|
if use_f16 {
|
||||||
"unet/diffusion_pytorch_model.fp16.safetensors"
|
"unet/diffusion_pytorch_model.fp16.safetensors"
|
||||||
} else {
|
} else {
|
||||||
@ -134,7 +139,7 @@ impl StableDiffusionVersion {
|
|||||||
|
|
||||||
fn vae_file(&self, use_f16: bool) -> &'static str {
|
fn vae_file(&self, use_f16: bool) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::V1_5 | Self::V2_1 => {
|
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||||
if use_f16 {
|
if use_f16 {
|
||||||
"vae/diffusion_pytorch_model.fp16.safetensors"
|
"vae/diffusion_pytorch_model.fp16.safetensors"
|
||||||
} else {
|
} else {
|
||||||
@ -146,7 +151,7 @@ impl StableDiffusionVersion {
|
|||||||
|
|
||||||
fn clip_file(&self, use_f16: bool) -> &'static str {
|
fn clip_file(&self, use_f16: bool) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::V1_5 | Self::V2_1 => {
|
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||||
if use_f16 {
|
if use_f16 {
|
||||||
"text_encoder/model.fp16.safetensors"
|
"text_encoder/model.fp16.safetensors"
|
||||||
} else {
|
} else {
|
||||||
@ -155,12 +160,21 @@ impl StableDiffusionVersion {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clip2_file(&self, use_f16: bool) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||||
|
if use_f16 {
|
||||||
|
"text_encoder_2/model.fp16.safetensors"
|
||||||
|
} else {
|
||||||
|
"text_encoder_2/model.safetensors"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ModelFile {
|
impl ModelFile {
|
||||||
const TOKENIZER_REPO: &str = "openai/clip-vit-base-patch32";
|
|
||||||
const TOKENIZER_PATH: &str = "tokenizer.json";
|
|
||||||
|
|
||||||
fn get(
|
fn get(
|
||||||
&self,
|
&self,
|
||||||
filename: Option<String>,
|
filename: Option<String>,
|
||||||
@ -172,8 +186,24 @@ impl ModelFile {
|
|||||||
Some(filename) => Ok(std::path::PathBuf::from(filename)),
|
Some(filename) => Ok(std::path::PathBuf::from(filename)),
|
||||||
None => {
|
None => {
|
||||||
let (repo, path) = match self {
|
let (repo, path) = match self {
|
||||||
Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH),
|
Self::Tokenizer => {
|
||||||
|
let tokenizer_repo = match version {
|
||||||
|
StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
|
||||||
|
"openai/clip-vit-base-patch32"
|
||||||
|
}
|
||||||
|
StableDiffusionVersion::Xl => {
|
||||||
|
// This seems similar to the patch32 version except some very small
|
||||||
|
// difference in the split regex.
|
||||||
|
"openai/clip-vit-large-patch14"
|
||||||
|
}
|
||||||
|
};
|
||||||
|
(tokenizer_repo, "tokenizer.json")
|
||||||
|
}
|
||||||
|
Self::Tokenizer2 => {
|
||||||
|
("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json")
|
||||||
|
}
|
||||||
Self::Clip => (version.repo(), version.clip_file(use_f16)),
|
Self::Clip => (version.repo(), version.clip_file(use_f16)),
|
||||||
|
Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
|
||||||
Self::Unet => (version.repo(), version.unet_file(use_f16)),
|
Self::Unet => (version.repo(), version.unet_file(use_f16)),
|
||||||
Self::Vae => (version.repo(), version.vae_file(use_f16)),
|
Self::Vae => (version.repo(), version.vae_file(use_f16)),
|
||||||
};
|
};
|
||||||
@ -211,6 +241,71 @@ fn output_filename(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn text_embeddings(
|
||||||
|
prompt: &str,
|
||||||
|
uncond_prompt: &str,
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
clip_weights: Option<String>,
|
||||||
|
sd_version: StableDiffusionVersion,
|
||||||
|
sd_config: &stable_diffusion::StableDiffusionConfig,
|
||||||
|
use_f16: bool,
|
||||||
|
device: &Device,
|
||||||
|
dtype: DType,
|
||||||
|
first: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let tokenizer_file = if first {
|
||||||
|
ModelFile::Tokenizer
|
||||||
|
} else {
|
||||||
|
ModelFile::Tokenizer2
|
||||||
|
};
|
||||||
|
let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?;
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
|
let pad_id = match &sd_config.clip.pad_with {
|
||||||
|
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
|
||||||
|
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
|
||||||
|
};
|
||||||
|
println!("Running with prompt \"{prompt}\".");
|
||||||
|
let mut tokens = tokenizer
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
while tokens.len() < sd_config.clip.max_position_embeddings {
|
||||||
|
tokens.push(pad_id)
|
||||||
|
}
|
||||||
|
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
|
let mut uncond_tokens = tokenizer
|
||||||
|
.encode(uncond_prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
|
||||||
|
uncond_tokens.push(pad_id)
|
||||||
|
}
|
||||||
|
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
|
println!("Building the Clip transformer.");
|
||||||
|
let clip_weights_file = if first {
|
||||||
|
ModelFile::Clip
|
||||||
|
} else {
|
||||||
|
ModelFile::Clip2
|
||||||
|
};
|
||||||
|
let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?;
|
||||||
|
let clip_config = if first {
|
||||||
|
&sd_config.clip
|
||||||
|
} else {
|
||||||
|
sd_config.clip2.as_ref().unwrap()
|
||||||
|
};
|
||||||
|
let text_model =
|
||||||
|
stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
|
||||||
|
let text_embeddings = text_model.forward(&tokens)?;
|
||||||
|
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
||||||
|
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
|
||||||
|
Ok(text_embeddings)
|
||||||
|
}
|
||||||
|
|
||||||
fn run(args: Args) -> Result<()> {
|
fn run(args: Args) -> Result<()> {
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
@ -252,46 +347,37 @@ fn run(args: Args) -> Result<()> {
|
|||||||
StableDiffusionVersion::V2_1 => {
|
StableDiffusionVersion::V2_1 => {
|
||||||
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
|
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
|
||||||
}
|
}
|
||||||
|
StableDiffusionVersion::Xl => {
|
||||||
|
stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let scheduler = sd_config.build_scheduler(n_steps)?;
|
let scheduler = sd_config.build_scheduler(n_steps)?;
|
||||||
let device = candle_examples::device(cpu)?;
|
let device = candle_examples::device(cpu)?;
|
||||||
|
|
||||||
let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version, use_f16)?;
|
let which = match sd_version {
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
StableDiffusionVersion::Xl => vec![true, false],
|
||||||
let pad_id = match &sd_config.clip.pad_with {
|
_ => vec![true],
|
||||||
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
|
|
||||||
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
|
|
||||||
};
|
|
||||||
println!("Running with prompt \"{prompt}\".");
|
|
||||||
let mut tokens = tokenizer
|
|
||||||
.encode(prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
while tokens.len() < sd_config.clip.max_position_embeddings {
|
|
||||||
tokens.push(pad_id)
|
|
||||||
}
|
|
||||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
|
||||||
|
|
||||||
let mut uncond_tokens = tokenizer
|
|
||||||
.encode(uncond_prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
|
|
||||||
uncond_tokens.push(pad_id)
|
|
||||||
}
|
|
||||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
|
||||||
|
|
||||||
println!("Building the Clip transformer.");
|
|
||||||
let text_embeddings = {
|
|
||||||
let clip_weights = ModelFile::Clip.get(clip_weights, sd_version, false)?;
|
|
||||||
let text_model = sd_config.build_clip_transformer(&clip_weights, &device, DType::F32)?;
|
|
||||||
let text_embeddings = text_model.forward(&tokens)?;
|
|
||||||
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
|
||||||
Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
|
|
||||||
};
|
};
|
||||||
|
let text_embeddings = which
|
||||||
|
.iter()
|
||||||
|
.map(|first| {
|
||||||
|
text_embeddings(
|
||||||
|
&prompt,
|
||||||
|
&uncond_prompt,
|
||||||
|
tokenizer.clone(),
|
||||||
|
clip_weights.clone(),
|
||||||
|
sd_version,
|
||||||
|
&sd_config,
|
||||||
|
use_f16,
|
||||||
|
&device,
|
||||||
|
dtype,
|
||||||
|
*first,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
|
||||||
|
println!("{text_embeddings:?}");
|
||||||
|
|
||||||
println!("Building the autoencoder.");
|
println!("Building the autoencoder.");
|
||||||
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
||||||
|
@ -8,6 +8,7 @@ pub struct StableDiffusionConfig {
|
|||||||
pub width: usize,
|
pub width: usize,
|
||||||
pub height: usize,
|
pub height: usize,
|
||||||
pub clip: clip::Config,
|
pub clip: clip::Config,
|
||||||
|
pub clip2: Option<clip::Config>,
|
||||||
autoencoder: vae::AutoEncoderKLConfig,
|
autoencoder: vae::AutoEncoderKLConfig,
|
||||||
unet: unet_2d::UNet2DConditionModelConfig,
|
unet: unet_2d::UNet2DConditionModelConfig,
|
||||||
scheduler: ddim::DDIMSchedulerConfig,
|
scheduler: ddim::DDIMSchedulerConfig,
|
||||||
@ -51,7 +52,7 @@ impl StableDiffusionConfig {
|
|||||||
norm_num_groups: 32,
|
norm_num_groups: 32,
|
||||||
};
|
};
|
||||||
let height = if let Some(height) = height {
|
let height = if let Some(height) = height {
|
||||||
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
|
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||||
height
|
height
|
||||||
} else {
|
} else {
|
||||||
512
|
512
|
||||||
@ -68,6 +69,7 @@ impl StableDiffusionConfig {
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
clip: clip::Config::v1_5(),
|
clip: clip::Config::v1_5(),
|
||||||
|
clip2: None,
|
||||||
autoencoder,
|
autoencoder,
|
||||||
scheduler: Default::default(),
|
scheduler: Default::default(),
|
||||||
unet,
|
unet,
|
||||||
@ -118,7 +120,7 @@ impl StableDiffusionConfig {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let height = if let Some(height) = height {
|
let height = if let Some(height) = height {
|
||||||
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
|
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||||
height
|
height
|
||||||
} else {
|
} else {
|
||||||
768
|
768
|
||||||
@ -135,6 +137,7 @@ impl StableDiffusionConfig {
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
clip: clip::Config::v2_1(),
|
clip: clip::Config::v2_1(),
|
||||||
|
clip2: None,
|
||||||
autoencoder,
|
autoencoder,
|
||||||
scheduler,
|
scheduler,
|
||||||
unet,
|
unet,
|
||||||
@ -155,6 +158,83 @@ impl StableDiffusionConfig {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn sdxl_(
|
||||||
|
sliced_attention_size: Option<usize>,
|
||||||
|
height: Option<usize>,
|
||||||
|
width: Option<usize>,
|
||||||
|
prediction_type: PredictionType,
|
||||||
|
) -> Self {
|
||||||
|
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
|
||||||
|
out_channels,
|
||||||
|
use_cross_attn,
|
||||||
|
attention_head_dim,
|
||||||
|
};
|
||||||
|
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
|
||||||
|
let unet = unet_2d::UNet2DConditionModelConfig {
|
||||||
|
blocks: vec![bc(320, false, 5), bc(640, false, 10), bc(1280, true, 20)],
|
||||||
|
center_input_sample: false,
|
||||||
|
cross_attention_dim: 2048,
|
||||||
|
downsample_padding: 1,
|
||||||
|
flip_sin_to_cos: true,
|
||||||
|
freq_shift: 0.,
|
||||||
|
layers_per_block: 2,
|
||||||
|
mid_block_scale_factor: 1.,
|
||||||
|
norm_eps: 1e-5,
|
||||||
|
norm_num_groups: 32,
|
||||||
|
sliced_attention_size,
|
||||||
|
use_linear_projection: true,
|
||||||
|
};
|
||||||
|
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
|
||||||
|
let autoencoder = vae::AutoEncoderKLConfig {
|
||||||
|
block_out_channels: vec![128, 256, 512, 512],
|
||||||
|
layers_per_block: 2,
|
||||||
|
latent_channels: 4,
|
||||||
|
norm_num_groups: 32,
|
||||||
|
};
|
||||||
|
let scheduler = ddim::DDIMSchedulerConfig {
|
||||||
|
prediction_type,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let height = if let Some(height) = height {
|
||||||
|
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||||
|
height
|
||||||
|
} else {
|
||||||
|
1024
|
||||||
|
};
|
||||||
|
|
||||||
|
let width = if let Some(width) = width {
|
||||||
|
assert_eq!(width % 8, 0, "width has to be divisible by 8");
|
||||||
|
width
|
||||||
|
} else {
|
||||||
|
1024
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
clip: clip::Config::sdxl(),
|
||||||
|
clip2: Some(clip::Config::sdxl2()),
|
||||||
|
autoencoder,
|
||||||
|
scheduler,
|
||||||
|
unet,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sdxl(
|
||||||
|
sliced_attention_size: Option<usize>,
|
||||||
|
height: Option<usize>,
|
||||||
|
width: Option<usize>,
|
||||||
|
) -> Self {
|
||||||
|
Self::sdxl_(
|
||||||
|
sliced_attention_size,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
|
||||||
|
PredictionType::Epsilon,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn build_vae<P: AsRef<std::path::Path>>(
|
pub fn build_vae<P: AsRef<std::path::Path>>(
|
||||||
&self,
|
&self,
|
||||||
vae_weights: P,
|
vae_weights: P,
|
||||||
@ -193,17 +273,17 @@ impl StableDiffusionConfig {
|
|||||||
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
|
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
|
||||||
ddim::DDIMScheduler::new(n_steps, self.scheduler)
|
ddim::DDIMScheduler::new(n_steps, self.scheduler)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
|
|
||||||
&self,
|
pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
|
||||||
clip_weights: P,
|
clip: &clip::Config,
|
||||||
device: &Device,
|
clip_weights: P,
|
||||||
dtype: DType,
|
device: &Device,
|
||||||
) -> Result<clip::ClipTextTransformer> {
|
dtype: DType,
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
|
) -> Result<clip::ClipTextTransformer> {
|
||||||
let weights = weights.deserialize()?;
|
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
|
||||||
let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
|
let weights = weights.deserialize()?;
|
||||||
let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
|
let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
|
||||||
Ok(text_model)
|
let text_model = clip::ClipTextTransformer::new(vs, clip)?;
|
||||||
}
|
Ok(text_model)
|
||||||
}
|
}
|
||||||
|
@ -52,6 +52,8 @@ pub trait Backend {
|
|||||||
dtype: DType,
|
dtype: DType,
|
||||||
dev: &Device,
|
dev: &Device,
|
||||||
) -> Result<Tensor>;
|
) -> Result<Tensor>;
|
||||||
|
|
||||||
|
fn contains_tensor(&self, name: &str) -> bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait SimpleBackend {
|
pub trait SimpleBackend {
|
||||||
@ -64,6 +66,8 @@ pub trait SimpleBackend {
|
|||||||
dtype: DType,
|
dtype: DType,
|
||||||
dev: &Device,
|
dev: &Device,
|
||||||
) -> Result<Tensor>;
|
) -> Result<Tensor>;
|
||||||
|
|
||||||
|
fn contains_tensor(&self, name: &str) -> bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Backend for Box<dyn SimpleBackend + 'a> {
|
impl<'a> Backend for Box<dyn SimpleBackend + 'a> {
|
||||||
@ -78,6 +82,10 @@ impl<'a> Backend for Box<dyn SimpleBackend + 'a> {
|
|||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
self.as_ref().get(s, name, h, dtype, dev)
|
self.as_ref().get(s, name, h, dtype, dev)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn contains_tensor(&self, name: &str) -> bool {
|
||||||
|
self.as_ref().contains_tensor(name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
||||||
@ -94,6 +102,8 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return a new `VarBuilder` adding `s` to the current prefix. This can be think of as `cd`
|
||||||
|
/// into a directory.
|
||||||
pub fn push_prefix<S: ToString>(&self, s: S) -> Self {
|
pub fn push_prefix<S: ToString>(&self, s: S) -> Self {
|
||||||
let mut path = self.path.clone();
|
let mut path = self.path.clone();
|
||||||
path.push(s.to_string());
|
path.push(s.to_string());
|
||||||
@ -109,10 +119,12 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
|||||||
self.push_prefix(s)
|
self.push_prefix(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The device used by default.
|
||||||
pub fn device(&self) -> &Device {
|
pub fn device(&self) -> &Device {
|
||||||
&self.data.device
|
&self.data.device
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The dtype used by default.
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
self.data.dtype
|
self.data.dtype
|
||||||
}
|
}
|
||||||
@ -125,6 +137,14 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// This returns true only if a tensor with the passed in name is available. E.g. when passed
|
||||||
|
/// `a`, true is returned if `prefix.a` exists but false is returned if only `prefix.a.b`
|
||||||
|
/// exists.
|
||||||
|
pub fn contains_tensor(&self, tensor_name: &str) -> bool {
|
||||||
|
let path = self.path(tensor_name);
|
||||||
|
self.data.backend.contains_tensor(&path)
|
||||||
|
}
|
||||||
|
|
||||||
/// Retrieve the tensor associated with the given name at the current path.
|
/// Retrieve the tensor associated with the given name at the current path.
|
||||||
pub fn get_with_hints<S: Into<Shape>>(
|
pub fn get_with_hints<S: Into<Shape>>(
|
||||||
&self,
|
&self,
|
||||||
@ -149,6 +169,10 @@ impl SimpleBackend for Zeros {
|
|||||||
fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result<Tensor> {
|
fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result<Tensor> {
|
||||||
Tensor::zeros(s, dtype, dev)
|
Tensor::zeros(s, dtype, dev)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn contains_tensor(&self, _name: &str) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SimpleBackend for HashMap<String, Tensor> {
|
impl SimpleBackend for HashMap<String, Tensor> {
|
||||||
@ -179,6 +203,10 @@ impl SimpleBackend for HashMap<String, Tensor> {
|
|||||||
}
|
}
|
||||||
tensor.to_device(dev)?.to_dtype(dtype)
|
tensor.to_device(dev)?.to_dtype(dtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn contains_tensor(&self, name: &str) -> bool {
|
||||||
|
self.contains_key(name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SimpleBackend for VarMap {
|
impl SimpleBackend for VarMap {
|
||||||
@ -192,6 +220,10 @@ impl SimpleBackend for VarMap {
|
|||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
VarMap::get(self, s, name, h, dtype, dev)
|
VarMap::get(self, s, name, h, dtype, dev)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn contains_tensor(&self, name: &str) -> bool {
|
||||||
|
self.data().lock().unwrap().contains_key(name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SafeTensorWithRouting<'a> {
|
struct SafeTensorWithRouting<'a> {
|
||||||
@ -228,6 +260,10 @@ impl<'a> SimpleBackend for SafeTensorWithRouting<'a> {
|
|||||||
}
|
}
|
||||||
Ok(tensor)
|
Ok(tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn contains_tensor(&self, name: &str) -> bool {
|
||||||
|
self.routing.contains_key(name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SimpleBackend for candle::npy::NpzTensors {
|
impl SimpleBackend for candle::npy::NpzTensors {
|
||||||
@ -257,6 +293,10 @@ impl SimpleBackend for candle::npy::NpzTensors {
|
|||||||
}
|
}
|
||||||
Ok(tensor)
|
Ok(tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn contains_tensor(&self, name: &str) -> bool {
|
||||||
|
self.get(name).map_or(false, |v| v.is_some())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> VarBuilder<'a> {
|
impl<'a> VarBuilder<'a> {
|
||||||
@ -425,4 +465,8 @@ impl<'a> Backend for ShardedSafeTensors<'a> {
|
|||||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||||
Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)
|
Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn contains_tensor(&self, name: &str) -> bool {
|
||||||
|
self.0.routing.contains_key(name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user