mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add a KV cache to T5. (#873)
* Add a KV cache to T5. * Suggest using release mode. * Use the kv cache in decoding. * Add a comment.
This commit is contained in:
@ -77,7 +77,7 @@ fn main() -> Result<()> {
|
||||
let model = model.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
||||
let config = GenConfig::small();
|
||||
let model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||
let mut model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt.as_str(), true)
|
||||
|
@ -3,7 +3,7 @@
|
||||
## Encoder-decoder example:
|
||||
|
||||
```bash
|
||||
$ cargo run --example t5 -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
|
||||
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
|
||||
...
|
||||
Running on CPU, to run on GPU, build this example with `--features cuda`
|
||||
Eine schöne Kerze.
|
||||
@ -13,7 +13,7 @@ Running on CPU, to run on GPU, build this example with `--features cuda`
|
||||
## Sentence embedding example:
|
||||
|
||||
```bash
|
||||
$ cargo run --example t5 -- --model-id "t5-small" --prompt "A beautiful candle."
|
||||
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
|
||||
...
|
||||
[[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265],
|
||||
[-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164],
|
||||
|
@ -48,10 +48,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
@ -131,6 +127,7 @@ impl T5ModelBuilder {
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||
let device = &builder.device;
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
@ -142,32 +139,32 @@ fn main() -> Result<()> {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let input_token_ids = Tensor::new(&tokens[..], &builder.device)?.unsqueeze(0)?;
|
||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
if !args.decode {
|
||||
let model = builder.build_encoder()?;
|
||||
for idx in 0..args.n {
|
||||
let mut model = builder.build_encoder()?;
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&input_token_ids)?;
|
||||
if idx == 0 {
|
||||
println!("{ys}");
|
||||
}
|
||||
println!("Took {:?}", start.elapsed());
|
||||
}
|
||||
} else {
|
||||
let model = builder.build_conditional_generation()?;
|
||||
let mut model = builder.build_conditional_generation()?;
|
||||
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, None, None);
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for _index in 0.. {
|
||||
for index in 0.. {
|
||||
if output_token_ids.len() > 512 {
|
||||
break;
|
||||
}
|
||||
let decoder_token_ids =
|
||||
Tensor::new(&output_token_ids[..], &builder.device)?.unsqueeze(0)?;
|
||||
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
|
||||
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
|
||||
} else {
|
||||
let last_token = *output_token_ids.last().unwrap();
|
||||
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
||||
};
|
||||
let logits = model.forward(&input_token_ids, &decoder_token_ids)?;
|
||||
let next_token_id = logits_processor.sample(&logits.flatten_to(1)?)?;
|
||||
if (next_token_id as usize) == builder.config.eos_token_id {
|
||||
if next_token_id as usize == builder.config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
output_token_ids.push(next_token_id);
|
||||
@ -186,7 +183,7 @@ fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let model = builder.build_encoder()?;
|
||||
let mut model = builder.build_encoder()?;
|
||||
let sentences = [
|
||||
"The cat sits outside",
|
||||
"A man is playing guitar",
|
||||
|
499
candle-examples/examples/wuerstchen/main.rs
Normal file
499
candle-examples/examples/wuerstchen/main.rs
Normal file
@ -0,0 +1,499 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use candle_transformers::models::stable_diffusion;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
||||
use clap::Parser;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const GUIDANCE_SCALE: f64 = 7.5;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The prompt to be used for image generation.
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "A very realistic photo of a rusty robot walking on a sandy beach"
|
||||
)]
|
||||
prompt: String,
|
||||
|
||||
#[arg(long, default_value = "")]
|
||||
uncond_prompt: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The height in pixels of the generated image.
|
||||
#[arg(long)]
|
||||
height: Option<usize>,
|
||||
|
||||
/// The width in pixels of the generated image.
|
||||
#[arg(long)]
|
||||
width: Option<usize>,
|
||||
|
||||
/// The UNet weight file, in .safetensors format.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
unet_weights: Option<String>,
|
||||
|
||||
/// The CLIP weight file, in .safetensors format.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
clip_weights: Option<String>,
|
||||
|
||||
/// The VAE weight file, in .safetensors format.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
vae_weights: Option<String>,
|
||||
|
||||
#[arg(long, value_name = "FILE")]
|
||||
/// The file specifying the tokenizer to used for tokenization.
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
|
||||
#[arg(long)]
|
||||
sliced_attention_size: Option<usize>,
|
||||
|
||||
/// The number of steps to run the diffusion for.
|
||||
#[arg(long, default_value_t = 30)]
|
||||
n_steps: usize,
|
||||
|
||||
/// The number of samples to generate.
|
||||
#[arg(long, default_value_t = 1)]
|
||||
num_samples: i64,
|
||||
|
||||
/// The name of the final image to generate.
|
||||
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
|
||||
final_image: String,
|
||||
|
||||
#[arg(long, value_enum, default_value = "v2-1")]
|
||||
sd_version: StableDiffusionVersion,
|
||||
|
||||
/// Generate intermediary images at each step.
|
||||
#[arg(long, action)]
|
||||
intermediary_images: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_f16: bool,
|
||||
|
||||
#[arg(long, value_name = "FILE")]
|
||||
img2img: Option<String>,
|
||||
|
||||
/// The strength, indicates how much to transform the initial image. The
|
||||
/// value must be between 0 and 1, a value of 1 discards the initial image
|
||||
/// information.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
img2img_strength: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
|
||||
enum StableDiffusionVersion {
|
||||
V1_5,
|
||||
V2_1,
|
||||
Xl,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ModelFile {
|
||||
Tokenizer,
|
||||
Tokenizer2,
|
||||
Clip,
|
||||
Clip2,
|
||||
Unet,
|
||||
Vae,
|
||||
}
|
||||
|
||||
impl StableDiffusionVersion {
|
||||
fn repo(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
|
||||
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
|
||||
}
|
||||
}
|
||||
|
||||
fn unet_file(&self, use_f16: bool) -> &'static str {
|
||||
match self {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||
if use_f16 {
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors"
|
||||
} else {
|
||||
"unet/diffusion_pytorch_model.safetensors"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn vae_file(&self, use_f16: bool) -> &'static str {
|
||||
match self {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||
if use_f16 {
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors"
|
||||
} else {
|
||||
"vae/diffusion_pytorch_model.safetensors"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn clip_file(&self, use_f16: bool) -> &'static str {
|
||||
match self {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||
if use_f16 {
|
||||
"text_encoder/model.fp16.safetensors"
|
||||
} else {
|
||||
"text_encoder/model.safetensors"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn clip2_file(&self, use_f16: bool) -> &'static str {
|
||||
match self {
|
||||
Self::V1_5 | Self::V2_1 | Self::Xl => {
|
||||
if use_f16 {
|
||||
"text_encoder_2/model.fp16.safetensors"
|
||||
} else {
|
||||
"text_encoder_2/model.safetensors"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelFile {
|
||||
fn get(
|
||||
&self,
|
||||
filename: Option<String>,
|
||||
version: StableDiffusionVersion,
|
||||
use_f16: bool,
|
||||
) -> Result<std::path::PathBuf> {
|
||||
use hf_hub::api::sync::Api;
|
||||
match filename {
|
||||
Some(filename) => Ok(std::path::PathBuf::from(filename)),
|
||||
None => {
|
||||
let (repo, path) = match self {
|
||||
Self::Tokenizer => {
|
||||
let tokenizer_repo = match version {
|
||||
StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
|
||||
"openai/clip-vit-base-patch32"
|
||||
}
|
||||
StableDiffusionVersion::Xl => {
|
||||
// This seems similar to the patch32 version except some very small
|
||||
// difference in the split regex.
|
||||
"openai/clip-vit-large-patch14"
|
||||
}
|
||||
};
|
||||
(tokenizer_repo, "tokenizer.json")
|
||||
}
|
||||
Self::Tokenizer2 => {
|
||||
("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json")
|
||||
}
|
||||
Self::Clip => (version.repo(), version.clip_file(use_f16)),
|
||||
Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
|
||||
Self::Unet => (version.repo(), version.unet_file(use_f16)),
|
||||
Self::Vae => (version.repo(), version.vae_file(use_f16)),
|
||||
};
|
||||
let filename = Api::new()?.model(repo.to_string()).get(path)?;
|
||||
Ok(filename)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn output_filename(
|
||||
basename: &str,
|
||||
sample_idx: i64,
|
||||
num_samples: i64,
|
||||
timestep_idx: Option<usize>,
|
||||
) -> String {
|
||||
let filename = if num_samples > 1 {
|
||||
match basename.rsplit_once('.') {
|
||||
None => format!("{basename}.{sample_idx}.png"),
|
||||
Some((filename_no_extension, extension)) => {
|
||||
format!("{filename_no_extension}.{sample_idx}.{extension}")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
basename.to_string()
|
||||
};
|
||||
match timestep_idx {
|
||||
None => filename,
|
||||
Some(timestep_idx) => match filename.rsplit_once('.') {
|
||||
None => format!("{filename}-{timestep_idx}.png"),
|
||||
Some((filename_no_extension, extension)) => {
|
||||
format!("{filename_no_extension}-{timestep_idx}.{extension}")
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn text_embeddings(
|
||||
prompt: &str,
|
||||
uncond_prompt: &str,
|
||||
tokenizer: Option<String>,
|
||||
clip_weights: Option<String>,
|
||||
sd_version: StableDiffusionVersion,
|
||||
sd_config: &stable_diffusion::StableDiffusionConfig,
|
||||
use_f16: bool,
|
||||
device: &Device,
|
||||
dtype: DType,
|
||||
first: bool,
|
||||
) -> Result<Tensor> {
|
||||
let tokenizer_file = if first {
|
||||
ModelFile::Tokenizer
|
||||
} else {
|
||||
ModelFile::Tokenizer2
|
||||
};
|
||||
let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let pad_id = match &sd_config.clip.pad_with {
|
||||
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
|
||||
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
|
||||
};
|
||||
println!("Running with prompt \"{prompt}\".");
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
while tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
tokens.push(pad_id)
|
||||
}
|
||||
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
|
||||
let mut uncond_tokens = tokenizer
|
||||
.encode(uncond_prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
uncond_tokens.push(pad_id)
|
||||
}
|
||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
|
||||
println!("Building the Clip transformer.");
|
||||
let clip_weights_file = if first {
|
||||
ModelFile::Clip
|
||||
} else {
|
||||
ModelFile::Clip2
|
||||
};
|
||||
let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?;
|
||||
let clip_config = if first {
|
||||
&sd_config.clip
|
||||
} else {
|
||||
sd_config.clip2.as_ref().unwrap()
|
||||
};
|
||||
let text_model =
|
||||
stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
|
||||
let text_embeddings = text_model.forward(&tokens)?;
|
||||
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
||||
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
|
||||
Ok(text_embeddings)
|
||||
}
|
||||
|
||||
fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> anyhow::Result<Tensor> {
|
||||
let img = image::io::Reader::open(path)?.decode()?;
|
||||
let (height, width) = (img.height() as usize, img.width() as usize);
|
||||
let height = height - height % 32;
|
||||
let width = width - width % 32;
|
||||
let img = img.resize_to_fill(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
let img = img.to_rgb8();
|
||||
let img = img.into_raw();
|
||||
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
|
||||
.permute((2, 0, 1))?
|
||||
.to_dtype(DType::F32)?
|
||||
.affine(2. / 255., -1.)?
|
||||
.unsqueeze(0)?;
|
||||
Ok(img)
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let Args {
|
||||
prompt,
|
||||
uncond_prompt,
|
||||
cpu,
|
||||
height,
|
||||
width,
|
||||
n_steps,
|
||||
tokenizer,
|
||||
final_image,
|
||||
sliced_attention_size,
|
||||
num_samples,
|
||||
sd_version,
|
||||
clip_weights,
|
||||
vae_weights,
|
||||
unet_weights,
|
||||
tracing,
|
||||
use_f16,
|
||||
use_flash_attn,
|
||||
img2img,
|
||||
img2img_strength,
|
||||
..
|
||||
} = args;
|
||||
|
||||
if !(0. ..=1.).contains(&img2img_strength) {
|
||||
anyhow::bail!("img2img-strength should be between 0 and 1, got {img2img_strength}")
|
||||
}
|
||||
|
||||
let _guard = if tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let dtype = if use_f16 { DType::F16 } else { DType::F32 };
|
||||
let sd_config = match sd_version {
|
||||
StableDiffusionVersion::V1_5 => {
|
||||
stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width)
|
||||
}
|
||||
StableDiffusionVersion::V2_1 => {
|
||||
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
|
||||
}
|
||||
StableDiffusionVersion::Xl => {
|
||||
stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
|
||||
}
|
||||
};
|
||||
|
||||
let scheduler = sd_config.build_scheduler(n_steps)?;
|
||||
let device = candle_examples::device(cpu)?;
|
||||
|
||||
let which = match sd_version {
|
||||
StableDiffusionVersion::Xl => vec![true, false],
|
||||
_ => vec![true],
|
||||
};
|
||||
let text_embeddings = which
|
||||
.iter()
|
||||
.map(|first| {
|
||||
text_embeddings(
|
||||
&prompt,
|
||||
&uncond_prompt,
|
||||
tokenizer.clone(),
|
||||
clip_weights.clone(),
|
||||
sd_version,
|
||||
&sd_config,
|
||||
use_f16,
|
||||
&device,
|
||||
dtype,
|
||||
*first,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
|
||||
println!("{text_embeddings:?}");
|
||||
|
||||
println!("Building the autoencoder.");
|
||||
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
||||
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
|
||||
let init_latent_dist = match &img2img {
|
||||
None => None,
|
||||
Some(image) => {
|
||||
let image = image_preprocess(image)?.to_device(&device)?;
|
||||
Some(vae.encode(&image)?)
|
||||
}
|
||||
};
|
||||
println!("Building the unet.");
|
||||
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
|
||||
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
|
||||
|
||||
let t_start = if img2img.is_some() {
|
||||
n_steps - (n_steps as f64 * img2img_strength) as usize
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let bsize = 1;
|
||||
for idx in 0..num_samples {
|
||||
let timesteps = scheduler.timesteps();
|
||||
let latents = match &init_latent_dist {
|
||||
Some(init_latent_dist) => {
|
||||
let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?;
|
||||
if t_start < timesteps.len() {
|
||||
let noise = latents.randn_like(0f64, 1f64)?;
|
||||
scheduler.add_noise(&latents, noise, timesteps[t_start])?
|
||||
} else {
|
||||
latents
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(bsize, 4, sd_config.height / 8, sd_config.width / 8),
|
||||
&device,
|
||||
)?;
|
||||
// scale the initial noise by the standard deviation required by the scheduler
|
||||
(latents * scheduler.init_noise_sigma())?
|
||||
}
|
||||
};
|
||||
let mut latents = latents.to_dtype(dtype)?;
|
||||
|
||||
println!("starting sampling");
|
||||
for (timestep_index, ×tep) in timesteps.iter().enumerate() {
|
||||
if timestep_index < t_start {
|
||||
continue;
|
||||
}
|
||||
let start_time = std::time::Instant::now();
|
||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||
|
||||
let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
|
||||
let noise_pred =
|
||||
unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
|
||||
let noise_pred = noise_pred.chunk(2, 0)?;
|
||||
let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
|
||||
let noise_pred =
|
||||
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?;
|
||||
latents = scheduler.step(&noise_pred, timestep, &latents)?;
|
||||
let dt = start_time.elapsed().as_secs_f32();
|
||||
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
|
||||
|
||||
if args.intermediary_images {
|
||||
let image = vae.decode(&(&latents / 0.18215)?)?;
|
||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image_filename =
|
||||
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
|
||||
candle_examples::save_image(&image, image_filename)?
|
||||
}
|
||||
}
|
||||
|
||||
println!(
|
||||
"Generating the final image for sample {}/{}.",
|
||||
idx + 1,
|
||||
num_samples
|
||||
);
|
||||
let image = vae.decode(&(&latents / 0.18215)?)?;
|
||||
// TODO: Add the clamping between 0 and 1.
|
||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
||||
candle_examples::save_image(&image, image_filename)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
run(args)
|
||||
}
|
@ -54,7 +54,7 @@ pub struct Config {
|
||||
is_decoder: bool,
|
||||
is_encoder_decoder: bool,
|
||||
#[serde(default = "default_use_cache")]
|
||||
use_cache: bool,
|
||||
pub use_cache: bool,
|
||||
pub pad_token_id: usize,
|
||||
pub eos_token_id: usize,
|
||||
}
|
||||
@ -245,10 +245,17 @@ struct T5Attention {
|
||||
relative_attention_num_buckets: usize,
|
||||
relative_attention_max_distance: usize,
|
||||
inner_dim: usize,
|
||||
use_cache: bool,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl T5Attention {
|
||||
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(
|
||||
has_relative_attention_bias: bool,
|
||||
decoder: bool,
|
||||
vb: VarBuilder,
|
||||
cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
|
||||
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
|
||||
@ -275,11 +282,13 @@ impl T5Attention {
|
||||
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
|
||||
relative_attention_max_distance: cfg.relative_attention_max_distance,
|
||||
inner_dim,
|
||||
use_cache: cfg.use_cache && decoder,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
key_value_states: Option<&Tensor>,
|
||||
@ -287,7 +296,6 @@ impl T5Attention {
|
||||
) -> Result<(Tensor, Option<Tensor>)> {
|
||||
// Performs Self-attention (if key_value_states is None) or attention
|
||||
// over source sentence (provided by key_value_states).
|
||||
// TODO: kv caching.
|
||||
let kv_input = match key_value_states {
|
||||
None => xs,
|
||||
Some(key_value_states) => key_value_states,
|
||||
@ -301,14 +309,22 @@ impl T5Attention {
|
||||
.reshape((b_sz, q_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let k = k
|
||||
let mut k = k
|
||||
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let v = v
|
||||
let mut v = v
|
||||
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
|
||||
if self.use_cache {
|
||||
if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
|
||||
k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?;
|
||||
v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?;
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
};
|
||||
// TODO: Use flash_attn.
|
||||
let scores = q.matmul(&k.t()?)?;
|
||||
let scores = match mask {
|
||||
@ -394,8 +410,8 @@ struct T5LayerSelfAttention {
|
||||
}
|
||||
|
||||
impl T5LayerSelfAttention {
|
||||
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
|
||||
fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?;
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
Ok(Self {
|
||||
@ -405,7 +421,7 @@ impl T5LayerSelfAttention {
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
@ -426,8 +442,8 @@ struct T5LayerCrossAttention {
|
||||
}
|
||||
|
||||
impl T5LayerCrossAttention {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let cross_attention = T5Attention::load(false, vb.pp("EncDecAttention"), cfg)?;
|
||||
fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?;
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
Ok(Self {
|
||||
@ -437,7 +453,7 @@ impl T5LayerCrossAttention {
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
hidden_states: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
key_value_states: &Tensor,
|
||||
@ -462,11 +478,17 @@ struct T5Block {
|
||||
}
|
||||
|
||||
impl T5Block {
|
||||
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(
|
||||
has_relative_attention_bias: bool,
|
||||
decoder: bool,
|
||||
vb: VarBuilder,
|
||||
cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let vb = vb.pp("layer");
|
||||
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
|
||||
let self_attn =
|
||||
T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?;
|
||||
let cross_attn = if cfg.is_decoder {
|
||||
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
|
||||
Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@ -480,19 +502,28 @@ impl T5Block {
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Option<Tensor>)> {
|
||||
// TODO: Cache masks
|
||||
let mask = match self.cross_attn.is_some() {
|
||||
true => Some(get_mask(xs.dim(1)?, xs.device())?),
|
||||
true => {
|
||||
let mask_len = xs.dim(1)?;
|
||||
// If the input seq length is 1, no need for a mask, this is also helpful to avoid shape
|
||||
// issues when using the KV cache in the decoder.
|
||||
if mask_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(get_mask(mask_len, xs.device())?)
|
||||
}
|
||||
}
|
||||
false => None,
|
||||
};
|
||||
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
|
||||
// TODO: clamp for f16?
|
||||
if let Some(cross_attn) = &self.cross_attn {
|
||||
if let Some(cross_attn) = &mut self.cross_attn {
|
||||
(xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
|
||||
// TODO: clamp for f16?
|
||||
}
|
||||
@ -510,9 +541,9 @@ struct T5Stack {
|
||||
}
|
||||
|
||||
impl T5Stack {
|
||||
fn load(vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
|
||||
fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
|
||||
let block = (0..cfg.num_layers)
|
||||
.map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg))
|
||||
.map(|i| T5Block::load(i == 0, decoder, vb.pp(&format!("block.{i}")), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let final_layer_norm = T5LayerNorm::load(
|
||||
cfg.d_model,
|
||||
@ -527,14 +558,14 @@ impl T5Stack {
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||
let mut hidden_states = input_embeds;
|
||||
let mut position_bias = None;
|
||||
for block in self.block.iter() {
|
||||
for block in self.block.iter_mut() {
|
||||
(hidden_states, position_bias) = block.forward(
|
||||
&hidden_states,
|
||||
position_bias.as_ref(),
|
||||
@ -555,14 +586,14 @@ impl T5EncoderModel {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let shared = Arc::new(shared);
|
||||
let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?;
|
||||
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
device: vb.device().clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
self.encoder.forward(input_ids, None)
|
||||
}
|
||||
|
||||
@ -589,13 +620,13 @@ impl T5ForConditionalGeneration {
|
||||
encoder_cfg.is_decoder = false;
|
||||
encoder_cfg.use_cache = false;
|
||||
encoder_cfg.is_encoder_decoder = false;
|
||||
let encoder = T5Stack::load(vb.pp("encoder"), &shared, &encoder_cfg)?;
|
||||
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?;
|
||||
|
||||
let mut decoder_cfg = cfg.clone();
|
||||
decoder_cfg.is_decoder = true;
|
||||
decoder_cfg.is_encoder_decoder = false;
|
||||
decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
|
||||
let decoder = T5Stack::load(vb.pp("decoder"), &shared, &decoder_cfg)?;
|
||||
let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
|
||||
|
||||
Ok(Self {
|
||||
encoder,
|
||||
@ -605,7 +636,7 @@ impl T5ForConditionalGeneration {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
|
||||
let encoder_output = self.encoder.forward(input_ids, None)?;
|
||||
let decoder_output = self
|
||||
.decoder
|
||||
|
Reference in New Issue
Block a user