From 1a276b5da79a4bb2305dde7368b800d165599819 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 17 Sep 2023 09:00:45 +0200 Subject: [PATCH] Add a KV cache to T5. (#873) * Add a KV cache to T5. * Suggest using release mode. * Use the kv cache in decoding. * Add a comment. --- candle-examples/examples/musicgen/main.rs | 2 +- candle-examples/examples/t5/README.md | 4 +- candle-examples/examples/t5/main.rs | 37 +- candle-examples/examples/wuerstchen/main.rs | 499 ++++++++++++++++++++ candle-transformers/src/models/t5.rs | 85 ++-- 5 files changed, 577 insertions(+), 50 deletions(-) create mode 100644 candle-examples/examples/wuerstchen/main.rs diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs index df8c3135..0fae67b5 100644 --- a/candle-examples/examples/musicgen/main.rs +++ b/candle-examples/examples/musicgen/main.rs @@ -77,7 +77,7 @@ fn main() -> Result<()> { let model = model.deserialize()?; let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device); let config = GenConfig::small(); - let model = MusicgenForConditionalGeneration::load(vb, config)?; + let mut model = MusicgenForConditionalGeneration::load(vb, config)?; let tokens = tokenizer .encode(args.prompt.as_str(), true) diff --git a/candle-examples/examples/t5/README.md b/candle-examples/examples/t5/README.md index c6ea2125..6a406467 100644 --- a/candle-examples/examples/t5/README.md +++ b/candle-examples/examples/t5/README.md @@ -3,7 +3,7 @@ ## Encoder-decoder example: ```bash -$ cargo run --example t5 -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode +$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode ... Running on CPU, to run on GPU, build this example with `--features cuda` Eine schöne Kerze. @@ -13,7 +13,7 @@ Running on CPU, to run on GPU, build this example with `--features cuda` ## Sentence embedding example: ```bash -$ cargo run --example t5 -- --model-id "t5-small" --prompt "A beautiful candle." +$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle." ... [[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265], [-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164], diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 00291609..c432e004 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -48,10 +48,6 @@ struct Args { #[arg(long)] prompt: Option, - /// The number of times to run the prompt. - #[arg(long, default_value = "1")] - n: usize, - /// L2 normalization for embeddings. #[arg(long, default_value = "true")] normalize_embeddings: bool, @@ -131,6 +127,7 @@ impl T5ModelBuilder { fn main() -> Result<()> { let args = Args::parse(); let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?; + let device = &builder.device; let tokenizer = tokenizer .with_padding(None) .with_truncation(None) @@ -142,32 +139,32 @@ fn main() -> Result<()> { .map_err(E::msg)? .get_ids() .to_vec(); - let input_token_ids = Tensor::new(&tokens[..], &builder.device)?.unsqueeze(0)?; + let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; if !args.decode { - let model = builder.build_encoder()?; - for idx in 0..args.n { - let start = std::time::Instant::now(); - let ys = model.forward(&input_token_ids)?; - if idx == 0 { - println!("{ys}"); - } - println!("Took {:?}", start.elapsed()); - } + let mut model = builder.build_encoder()?; + let start = std::time::Instant::now(); + let ys = model.forward(&input_token_ids)?; + println!("{ys}"); + println!("Took {:?}", start.elapsed()); } else { - let model = builder.build_conditional_generation()?; + let mut model = builder.build_conditional_generation()?; let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec(); let mut logits_processor = LogitsProcessor::new(299792458, None, None); let start = std::time::Instant::now(); - for _index in 0.. { + for index in 0.. { if output_token_ids.len() > 512 { break; } - let decoder_token_ids = - Tensor::new(&output_token_ids[..], &builder.device)?.unsqueeze(0)?; + let decoder_token_ids = if index == 0 || !builder.config.use_cache { + Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)? + } else { + let last_token = *output_token_ids.last().unwrap(); + Tensor::new(&[last_token], device)?.unsqueeze(0)? + }; let logits = model.forward(&input_token_ids, &decoder_token_ids)?; let next_token_id = logits_processor.sample(&logits.flatten_to(1)?)?; - if (next_token_id as usize) == builder.config.eos_token_id { + if next_token_id as usize == builder.config.eos_token_id { break; } output_token_ids.push(next_token_id); @@ -186,7 +183,7 @@ fn main() -> Result<()> { } } None => { - let model = builder.build_encoder()?; + let mut model = builder.build_encoder()?; let sentences = [ "The cat sits outside", "A man is playing guitar", diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs new file mode 100644 index 00000000..c8b771a0 --- /dev/null +++ b/candle-examples/examples/wuerstchen/main.rs @@ -0,0 +1,499 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use candle_transformers::models::stable_diffusion; + +use anyhow::{Error as E, Result}; +use candle::{DType, Device, IndexOp, Module, Tensor, D}; +use clap::Parser; +use tokenizers::Tokenizer; + +const GUIDANCE_SCALE: f64 = 7.5; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Args { + /// The prompt to be used for image generation. + #[arg( + long, + default_value = "A very realistic photo of a rusty robot walking on a sandy beach" + )] + prompt: String, + + #[arg(long, default_value = "")] + uncond_prompt: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// The height in pixels of the generated image. + #[arg(long)] + height: Option, + + /// The width in pixels of the generated image. + #[arg(long)] + width: Option, + + /// The UNet weight file, in .safetensors format. + #[arg(long, value_name = "FILE")] + unet_weights: Option, + + /// The CLIP weight file, in .safetensors format. + #[arg(long, value_name = "FILE")] + clip_weights: Option, + + /// The VAE weight file, in .safetensors format. + #[arg(long, value_name = "FILE")] + vae_weights: Option, + + #[arg(long, value_name = "FILE")] + /// The file specifying the tokenizer to used for tokenization. + tokenizer: Option, + + /// The size of the sliced attention or 0 for automatic slicing (disabled by default) + #[arg(long)] + sliced_attention_size: Option, + + /// The number of steps to run the diffusion for. + #[arg(long, default_value_t = 30)] + n_steps: usize, + + /// The number of samples to generate. + #[arg(long, default_value_t = 1)] + num_samples: i64, + + /// The name of the final image to generate. + #[arg(long, value_name = "FILE", default_value = "sd_final.png")] + final_image: String, + + #[arg(long, value_enum, default_value = "v2-1")] + sd_version: StableDiffusionVersion, + + /// Generate intermediary images at each step. + #[arg(long, action)] + intermediary_images: bool, + + #[arg(long)] + use_flash_attn: bool, + + #[arg(long)] + use_f16: bool, + + #[arg(long, value_name = "FILE")] + img2img: Option, + + /// The strength, indicates how much to transform the initial image. The + /// value must be between 0 and 1, a value of 1 discards the initial image + /// information. + #[arg(long, default_value_t = 0.8)] + img2img_strength: f64, +} + +#[derive(Debug, Clone, Copy, clap::ValueEnum)] +enum StableDiffusionVersion { + V1_5, + V2_1, + Xl, +} + +#[allow(unused)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ModelFile { + Tokenizer, + Tokenizer2, + Clip, + Clip2, + Unet, + Vae, +} + +impl StableDiffusionVersion { + fn repo(&self) -> &'static str { + match self { + Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0", + Self::V2_1 => "stabilityai/stable-diffusion-2-1", + Self::V1_5 => "runwayml/stable-diffusion-v1-5", + } + } + + fn unet_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl => { + if use_f16 { + "unet/diffusion_pytorch_model.fp16.safetensors" + } else { + "unet/diffusion_pytorch_model.safetensors" + } + } + } + } + + fn vae_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl => { + if use_f16 { + "vae/diffusion_pytorch_model.fp16.safetensors" + } else { + "vae/diffusion_pytorch_model.safetensors" + } + } + } + } + + fn clip_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl => { + if use_f16 { + "text_encoder/model.fp16.safetensors" + } else { + "text_encoder/model.safetensors" + } + } + } + } + + fn clip2_file(&self, use_f16: bool) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 | Self::Xl => { + if use_f16 { + "text_encoder_2/model.fp16.safetensors" + } else { + "text_encoder_2/model.safetensors" + } + } + } + } +} + +impl ModelFile { + fn get( + &self, + filename: Option, + version: StableDiffusionVersion, + use_f16: bool, + ) -> Result { + use hf_hub::api::sync::Api; + match filename { + Some(filename) => Ok(std::path::PathBuf::from(filename)), + None => { + let (repo, path) = match self { + Self::Tokenizer => { + let tokenizer_repo = match version { + StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { + "openai/clip-vit-base-patch32" + } + StableDiffusionVersion::Xl => { + // This seems similar to the patch32 version except some very small + // difference in the split regex. + "openai/clip-vit-large-patch14" + } + }; + (tokenizer_repo, "tokenizer.json") + } + Self::Tokenizer2 => { + ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json") + } + Self::Clip => (version.repo(), version.clip_file(use_f16)), + Self::Clip2 => (version.repo(), version.clip2_file(use_f16)), + Self::Unet => (version.repo(), version.unet_file(use_f16)), + Self::Vae => (version.repo(), version.vae_file(use_f16)), + }; + let filename = Api::new()?.model(repo.to_string()).get(path)?; + Ok(filename) + } + } + } +} + +fn output_filename( + basename: &str, + sample_idx: i64, + num_samples: i64, + timestep_idx: Option, +) -> String { + let filename = if num_samples > 1 { + match basename.rsplit_once('.') { + None => format!("{basename}.{sample_idx}.png"), + Some((filename_no_extension, extension)) => { + format!("{filename_no_extension}.{sample_idx}.{extension}") + } + } + } else { + basename.to_string() + }; + match timestep_idx { + None => filename, + Some(timestep_idx) => match filename.rsplit_once('.') { + None => format!("{filename}-{timestep_idx}.png"), + Some((filename_no_extension, extension)) => { + format!("{filename_no_extension}-{timestep_idx}.{extension}") + } + }, + } +} + +#[allow(clippy::too_many_arguments)] +fn text_embeddings( + prompt: &str, + uncond_prompt: &str, + tokenizer: Option, + clip_weights: Option, + sd_version: StableDiffusionVersion, + sd_config: &stable_diffusion::StableDiffusionConfig, + use_f16: bool, + device: &Device, + dtype: DType, + first: bool, +) -> Result { + let tokenizer_file = if first { + ModelFile::Tokenizer + } else { + ModelFile::Tokenizer2 + }; + let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?; + let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; + let pad_id = match &sd_config.clip.pad_with { + Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(), + None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(), + }; + println!("Running with prompt \"{prompt}\"."); + let mut tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + while tokens.len() < sd_config.clip.max_position_embeddings { + tokens.push(pad_id) + } + let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; + + let mut uncond_tokens = tokenizer + .encode(uncond_prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + while uncond_tokens.len() < sd_config.clip.max_position_embeddings { + uncond_tokens.push(pad_id) + } + let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; + + println!("Building the Clip transformer."); + let clip_weights_file = if first { + ModelFile::Clip + } else { + ModelFile::Clip2 + }; + let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?; + let clip_config = if first { + &sd_config.clip + } else { + sd_config.clip2.as_ref().unwrap() + }; + let text_model = + stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?; + let text_embeddings = text_model.forward(&tokens)?; + let uncond_embeddings = text_model.forward(&uncond_tokens)?; + let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?; + Ok(text_embeddings) +} + +fn image_preprocess>(path: T) -> anyhow::Result { + let img = image::io::Reader::open(path)?.decode()?; + let (height, width) = (img.height() as usize, img.width() as usize); + let height = height - height % 32; + let width = width - width % 32; + let img = img.resize_to_fill( + width as u32, + height as u32, + image::imageops::FilterType::CatmullRom, + ); + let img = img.to_rgb8(); + let img = img.into_raw(); + let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? + .permute((2, 0, 1))? + .to_dtype(DType::F32)? + .affine(2. / 255., -1.)? + .unsqueeze(0)?; + Ok(img) +} + +fn run(args: Args) -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let Args { + prompt, + uncond_prompt, + cpu, + height, + width, + n_steps, + tokenizer, + final_image, + sliced_attention_size, + num_samples, + sd_version, + clip_weights, + vae_weights, + unet_weights, + tracing, + use_f16, + use_flash_attn, + img2img, + img2img_strength, + .. + } = args; + + if !(0. ..=1.).contains(&img2img_strength) { + anyhow::bail!("img2img-strength should be between 0 and 1, got {img2img_strength}") + } + + let _guard = if tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + let dtype = if use_f16 { DType::F16 } else { DType::F32 }; + let sd_config = match sd_version { + StableDiffusionVersion::V1_5 => { + stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width) + } + StableDiffusionVersion::V2_1 => { + stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width) + } + StableDiffusionVersion::Xl => { + stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width) + } + }; + + let scheduler = sd_config.build_scheduler(n_steps)?; + let device = candle_examples::device(cpu)?; + + let which = match sd_version { + StableDiffusionVersion::Xl => vec![true, false], + _ => vec![true], + }; + let text_embeddings = which + .iter() + .map(|first| { + text_embeddings( + &prompt, + &uncond_prompt, + tokenizer.clone(), + clip_weights.clone(), + sd_version, + &sd_config, + use_f16, + &device, + dtype, + *first, + ) + }) + .collect::>>()?; + let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?; + println!("{text_embeddings:?}"); + + println!("Building the autoencoder."); + let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?; + let vae = sd_config.build_vae(&vae_weights, &device, dtype)?; + let init_latent_dist = match &img2img { + None => None, + Some(image) => { + let image = image_preprocess(image)?.to_device(&device)?; + Some(vae.encode(&image)?) + } + }; + println!("Building the unet."); + let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?; + let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?; + + let t_start = if img2img.is_some() { + n_steps - (n_steps as f64 * img2img_strength) as usize + } else { + 0 + }; + let bsize = 1; + for idx in 0..num_samples { + let timesteps = scheduler.timesteps(); + let latents = match &init_latent_dist { + Some(init_latent_dist) => { + let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?; + if t_start < timesteps.len() { + let noise = latents.randn_like(0f64, 1f64)?; + scheduler.add_noise(&latents, noise, timesteps[t_start])? + } else { + latents + } + } + None => { + let latents = Tensor::randn( + 0f32, + 1f32, + (bsize, 4, sd_config.height / 8, sd_config.width / 8), + &device, + )?; + // scale the initial noise by the standard deviation required by the scheduler + (latents * scheduler.init_noise_sigma())? + } + }; + let mut latents = latents.to_dtype(dtype)?; + + println!("starting sampling"); + for (timestep_index, ×tep) in timesteps.iter().enumerate() { + if timestep_index < t_start { + continue; + } + let start_time = std::time::Instant::now(); + let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; + + let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; + let noise_pred = + unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?; + let noise_pred = noise_pred.chunk(2, 0)?; + let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); + let noise_pred = + (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?; + latents = scheduler.step(&noise_pred, timestep, &latents)?; + let dt = start_time.elapsed().as_secs_f32(); + println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt); + + if args.intermediary_images { + let image = vae.decode(&(&latents / 0.18215)?)?; + let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; + let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; + let image_filename = + output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1)); + candle_examples::save_image(&image, image_filename)? + } + } + + println!( + "Generating the final image for sample {}/{}.", + idx + 1, + num_samples + ); + let image = vae.decode(&(&latents / 0.18215)?)?; + // TODO: Add the clamping between 0 and 1. + let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; + let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; + let image_filename = output_filename(&final_image, idx + 1, num_samples, None); + candle_examples::save_image(&image, image_filename)? + } + Ok(()) +} + +fn main() -> Result<()> { + let args = Args::parse(); + run(args) +} diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index c35dea0b..8b621f64 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -54,7 +54,7 @@ pub struct Config { is_decoder: bool, is_encoder_decoder: bool, #[serde(default = "default_use_cache")] - use_cache: bool, + pub use_cache: bool, pub pad_token_id: usize, pub eos_token_id: usize, } @@ -245,10 +245,17 @@ struct T5Attention { relative_attention_num_buckets: usize, relative_attention_max_distance: usize, inner_dim: usize, + use_cache: bool, + kv_cache: Option<(Tensor, Tensor)>, } impl T5Attention { - fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result { + fn load( + has_relative_attention_bias: bool, + decoder: bool, + vb: VarBuilder, + cfg: &Config, + ) -> Result { let inner_dim = cfg.num_heads * cfg.d_kv; let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?; let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?; @@ -275,11 +282,13 @@ impl T5Attention { relative_attention_num_buckets: cfg.relative_attention_num_buckets, relative_attention_max_distance: cfg.relative_attention_max_distance, inner_dim, + use_cache: cfg.use_cache && decoder, + kv_cache: None, }) } fn forward( - &self, + &mut self, xs: &Tensor, position_bias: Option<&Tensor>, key_value_states: Option<&Tensor>, @@ -287,7 +296,6 @@ impl T5Attention { ) -> Result<(Tensor, Option)> { // Performs Self-attention (if key_value_states is None) or attention // over source sentence (provided by key_value_states). - // TODO: kv caching. let kv_input = match key_value_states { None => xs, Some(key_value_states) => key_value_states, @@ -301,14 +309,22 @@ impl T5Attention { .reshape((b_sz, q_len, self.n_heads, self.d_kv))? .transpose(1, 2)? .contiguous()?; - let k = k + let mut k = k .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? .transpose(1, 2)? .contiguous()?; - let v = v + let mut v = v .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? .transpose(1, 2)? .contiguous()?; + + if self.use_cache { + if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache { + k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?; + }; + self.kv_cache = Some((k.clone(), v.clone())); + }; // TODO: Use flash_attn. let scores = q.matmul(&k.t()?)?; let scores = match mask { @@ -394,8 +410,8 @@ struct T5LayerSelfAttention { } impl T5LayerSelfAttention { - fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result { - let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?; + fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result { + let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?; let layer_norm = T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; Ok(Self { @@ -405,7 +421,7 @@ impl T5LayerSelfAttention { } fn forward( - &self, + &mut self, xs: &Tensor, position_bias: Option<&Tensor>, mask: Option<&Tensor>, @@ -426,8 +442,8 @@ struct T5LayerCrossAttention { } impl T5LayerCrossAttention { - fn load(vb: VarBuilder, cfg: &Config) -> Result { - let cross_attention = T5Attention::load(false, vb.pp("EncDecAttention"), cfg)?; + fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result { + let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?; let layer_norm = T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; Ok(Self { @@ -437,7 +453,7 @@ impl T5LayerCrossAttention { } fn forward( - &self, + &mut self, hidden_states: &Tensor, position_bias: Option<&Tensor>, key_value_states: &Tensor, @@ -462,11 +478,17 @@ struct T5Block { } impl T5Block { - fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result { + fn load( + has_relative_attention_bias: bool, + decoder: bool, + vb: VarBuilder, + cfg: &Config, + ) -> Result { let vb = vb.pp("layer"); - let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?; + let self_attn = + T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?; let cross_attn = if cfg.is_decoder { - Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?) + Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?) } else { None }; @@ -480,19 +502,28 @@ impl T5Block { } fn forward( - &self, + &mut self, xs: &Tensor, position_bias: Option<&Tensor>, encoder_hidden_states: Option<&Tensor>, ) -> Result<(Tensor, Option)> { // TODO: Cache masks let mask = match self.cross_attn.is_some() { - true => Some(get_mask(xs.dim(1)?, xs.device())?), + true => { + let mask_len = xs.dim(1)?; + // If the input seq length is 1, no need for a mask, this is also helpful to avoid shape + // issues when using the KV cache in the decoder. + if mask_len <= 1 { + None + } else { + Some(get_mask(mask_len, xs.device())?) + } + } false => None, }; let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?; // TODO: clamp for f16? - if let Some(cross_attn) = &self.cross_attn { + if let Some(cross_attn) = &mut self.cross_attn { (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?; // TODO: clamp for f16? } @@ -510,9 +541,9 @@ struct T5Stack { } impl T5Stack { - fn load(vb: VarBuilder, shared: &Arc, cfg: &Config) -> Result { + fn load(decoder: bool, vb: VarBuilder, shared: &Arc, cfg: &Config) -> Result { let block = (0..cfg.num_layers) - .map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg)) + .map(|i| T5Block::load(i == 0, decoder, vb.pp(&format!("block.{i}")), cfg)) .collect::>>()?; let final_layer_norm = T5LayerNorm::load( cfg.d_model, @@ -527,14 +558,14 @@ impl T5Stack { } fn forward( - &self, + &mut self, input_ids: &Tensor, encoder_hidden_states: Option<&Tensor>, ) -> Result { let input_embeds = self.shared.as_ref().forward(input_ids)?; let mut hidden_states = input_embeds; let mut position_bias = None; - for block in self.block.iter() { + for block in self.block.iter_mut() { (hidden_states, position_bias) = block.forward( &hidden_states, position_bias.as_ref(), @@ -555,14 +586,14 @@ impl T5EncoderModel { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; let shared = Arc::new(shared); - let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?; + let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?; Ok(Self { encoder, device: vb.device().clone(), }) } - pub fn forward(&self, input_ids: &Tensor) -> Result { + pub fn forward(&mut self, input_ids: &Tensor) -> Result { self.encoder.forward(input_ids, None) } @@ -589,13 +620,13 @@ impl T5ForConditionalGeneration { encoder_cfg.is_decoder = false; encoder_cfg.use_cache = false; encoder_cfg.is_encoder_decoder = false; - let encoder = T5Stack::load(vb.pp("encoder"), &shared, &encoder_cfg)?; + let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?; let mut decoder_cfg = cfg.clone(); decoder_cfg.is_decoder = true; decoder_cfg.is_encoder_decoder = false; decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers); - let decoder = T5Stack::load(vb.pp("decoder"), &shared, &decoder_cfg)?; + let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?; Ok(Self { encoder, @@ -605,7 +636,7 @@ impl T5ForConditionalGeneration { }) } - pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result { + pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result { let encoder_output = self.encoder.forward(input_ids, None)?; let decoder_output = self .decoder