Add a KV cache to T5. (#873)

* Add a KV cache to T5.

* Suggest using release mode.

* Use the kv cache in decoding.

* Add a comment.
This commit is contained in:
Laurent Mazare
2023-09-17 09:00:45 +02:00
committed by GitHub
parent 8658df3485
commit 1a276b5da7
5 changed files with 577 additions and 50 deletions

View File

@ -77,7 +77,7 @@ fn main() -> Result<()> {
let model = model.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
let config = GenConfig::small();
let model = MusicgenForConditionalGeneration::load(vb, config)?;
let mut model = MusicgenForConditionalGeneration::load(vb, config)?;
let tokens = tokenizer
.encode(args.prompt.as_str(), true)

View File

@ -3,7 +3,7 @@
## Encoder-decoder example:
```bash
$ cargo run --example t5 -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
...
Running on CPU, to run on GPU, build this example with `--features cuda`
Eine schöne Kerze.
@ -13,7 +13,7 @@ Running on CPU, to run on GPU, build this example with `--features cuda`
## Sentence embedding example:
```bash
$ cargo run --example t5 -- --model-id "t5-small" --prompt "A beautiful candle."
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
...
[[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265],
[-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164],

View File

@ -48,10 +48,6 @@ struct Args {
#[arg(long)]
prompt: Option<String>,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
@ -131,6 +127,7 @@ impl T5ModelBuilder {
fn main() -> Result<()> {
let args = Args::parse();
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
let device = &builder.device;
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
@ -142,32 +139,32 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], &builder.device)?.unsqueeze(0)?;
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
if !args.decode {
let model = builder.build_encoder()?;
for idx in 0..args.n {
let mut model = builder.build_encoder()?;
let start = std::time::Instant::now();
let ys = model.forward(&input_token_ids)?;
if idx == 0 {
println!("{ys}");
}
println!("Took {:?}", start.elapsed());
}
} else {
let model = builder.build_conditional_generation()?;
let mut model = builder.build_conditional_generation()?;
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
let mut logits_processor = LogitsProcessor::new(299792458, None, None);
let start = std::time::Instant::now();
for _index in 0.. {
for index in 0.. {
if output_token_ids.len() > 512 {
break;
}
let decoder_token_ids =
Tensor::new(&output_token_ids[..], &builder.device)?.unsqueeze(0)?;
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
} else {
let last_token = *output_token_ids.last().unwrap();
Tensor::new(&[last_token], device)?.unsqueeze(0)?
};
let logits = model.forward(&input_token_ids, &decoder_token_ids)?;
let next_token_id = logits_processor.sample(&logits.flatten_to(1)?)?;
if (next_token_id as usize) == builder.config.eos_token_id {
if next_token_id as usize == builder.config.eos_token_id {
break;
}
output_token_ids.push(next_token_id);
@ -186,7 +183,7 @@ fn main() -> Result<()> {
}
}
None => {
let model = builder.build_encoder()?;
let mut model = builder.build_encoder()?;
let sentences = [
"The cat sits outside",
"A man is playing guitar",

View File

@ -0,0 +1,499 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use candle_transformers::models::stable_diffusion;
use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;
const GUIDANCE_SCALE: f64 = 7.5;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The prompt to be used for image generation.
#[arg(
long,
default_value = "A very realistic photo of a rusty robot walking on a sandy beach"
)]
prompt: String,
#[arg(long, default_value = "")]
uncond_prompt: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The height in pixels of the generated image.
#[arg(long)]
height: Option<usize>,
/// The width in pixels of the generated image.
#[arg(long)]
width: Option<usize>,
/// The UNet weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
unet_weights: Option<String>,
/// The CLIP weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
clip_weights: Option<String>,
/// The VAE weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
vae_weights: Option<String>,
#[arg(long, value_name = "FILE")]
/// The file specifying the tokenizer to used for tokenization.
tokenizer: Option<String>,
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
#[arg(long)]
sliced_attention_size: Option<usize>,
/// The number of steps to run the diffusion for.
#[arg(long, default_value_t = 30)]
n_steps: usize,
/// The number of samples to generate.
#[arg(long, default_value_t = 1)]
num_samples: i64,
/// The name of the final image to generate.
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
final_image: String,
#[arg(long, value_enum, default_value = "v2-1")]
sd_version: StableDiffusionVersion,
/// Generate intermediary images at each step.
#[arg(long, action)]
intermediary_images: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
use_f16: bool,
#[arg(long, value_name = "FILE")]
img2img: Option<String>,
/// The strength, indicates how much to transform the initial image. The
/// value must be between 0 and 1, a value of 1 discards the initial image
/// information.
#[arg(long, default_value_t = 0.8)]
img2img_strength: f64,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
enum StableDiffusionVersion {
V1_5,
V2_1,
Xl,
}
#[allow(unused)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelFile {
Tokenizer,
Tokenizer2,
Clip,
Clip2,
Unet,
Vae,
}
impl StableDiffusionVersion {
fn repo(&self) -> &'static str {
match self {
Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
}
}
fn unet_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"unet/diffusion_pytorch_model.fp16.safetensors"
} else {
"unet/diffusion_pytorch_model.safetensors"
}
}
}
}
fn vae_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"vae/diffusion_pytorch_model.fp16.safetensors"
} else {
"vae/diffusion_pytorch_model.safetensors"
}
}
}
}
fn clip_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"text_encoder/model.fp16.safetensors"
} else {
"text_encoder/model.safetensors"
}
}
}
}
fn clip2_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"text_encoder_2/model.fp16.safetensors"
} else {
"text_encoder_2/model.safetensors"
}
}
}
}
}
impl ModelFile {
fn get(
&self,
filename: Option<String>,
version: StableDiffusionVersion,
use_f16: bool,
) -> Result<std::path::PathBuf> {
use hf_hub::api::sync::Api;
match filename {
Some(filename) => Ok(std::path::PathBuf::from(filename)),
None => {
let (repo, path) = match self {
Self::Tokenizer => {
let tokenizer_repo = match version {
StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
"openai/clip-vit-base-patch32"
}
StableDiffusionVersion::Xl => {
// This seems similar to the patch32 version except some very small
// difference in the split regex.
"openai/clip-vit-large-patch14"
}
};
(tokenizer_repo, "tokenizer.json")
}
Self::Tokenizer2 => {
("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json")
}
Self::Clip => (version.repo(), version.clip_file(use_f16)),
Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
Self::Unet => (version.repo(), version.unet_file(use_f16)),
Self::Vae => (version.repo(), version.vae_file(use_f16)),
};
let filename = Api::new()?.model(repo.to_string()).get(path)?;
Ok(filename)
}
}
}
}
fn output_filename(
basename: &str,
sample_idx: i64,
num_samples: i64,
timestep_idx: Option<usize>,
) -> String {
let filename = if num_samples > 1 {
match basename.rsplit_once('.') {
None => format!("{basename}.{sample_idx}.png"),
Some((filename_no_extension, extension)) => {
format!("{filename_no_extension}.{sample_idx}.{extension}")
}
}
} else {
basename.to_string()
};
match timestep_idx {
None => filename,
Some(timestep_idx) => match filename.rsplit_once('.') {
None => format!("{filename}-{timestep_idx}.png"),
Some((filename_no_extension, extension)) => {
format!("{filename_no_extension}-{timestep_idx}.{extension}")
}
},
}
}
#[allow(clippy::too_many_arguments)]
fn text_embeddings(
prompt: &str,
uncond_prompt: &str,
tokenizer: Option<String>,
clip_weights: Option<String>,
sd_version: StableDiffusionVersion,
sd_config: &stable_diffusion::StableDiffusionConfig,
use_f16: bool,
device: &Device,
dtype: DType,
first: bool,
) -> Result<Tensor> {
let tokenizer_file = if first {
ModelFile::Tokenizer
} else {
ModelFile::Tokenizer2
};
let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?;
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let pad_id = match &sd_config.clip.pad_with {
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
};
println!("Running with prompt \"{prompt}\".");
let mut tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while tokens.len() < sd_config.clip.max_position_embeddings {
tokens.push(pad_id)
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
println!("Building the Clip transformer.");
let clip_weights_file = if first {
ModelFile::Clip
} else {
ModelFile::Clip2
};
let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?;
let clip_config = if first {
&sd_config.clip
} else {
sd_config.clip2.as_ref().unwrap()
};
let text_model =
stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward(&tokens)?;
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
Ok(text_embeddings)
}
fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> anyhow::Result<Tensor> {
let img = image::io::Reader::open(path)?.decode()?;
let (height, width) = (img.height() as usize, img.width() as usize);
let height = height - height % 32;
let width = width - width % 32;
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::CatmullRom,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?
.unsqueeze(0)?;
Ok(img)
}
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let Args {
prompt,
uncond_prompt,
cpu,
height,
width,
n_steps,
tokenizer,
final_image,
sliced_attention_size,
num_samples,
sd_version,
clip_weights,
vae_weights,
unet_weights,
tracing,
use_f16,
use_flash_attn,
img2img,
img2img_strength,
..
} = args;
if !(0. ..=1.).contains(&img2img_strength) {
anyhow::bail!("img2img-strength should be between 0 and 1, got {img2img_strength}")
}
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let dtype = if use_f16 { DType::F16 } else { DType::F32 };
let sd_config = match sd_version {
StableDiffusionVersion::V1_5 => {
stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width)
}
StableDiffusionVersion::V2_1 => {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
}
StableDiffusionVersion::Xl => {
stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
}
};
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
let which = match sd_version {
StableDiffusionVersion::Xl => vec![true, false],
_ => vec![true],
};
let text_embeddings = which
.iter()
.map(|first| {
text_embeddings(
&prompt,
&uncond_prompt,
tokenizer.clone(),
clip_weights.clone(),
sd_version,
&sd_config,
use_f16,
&device,
dtype,
*first,
)
})
.collect::<Result<Vec<_>>>()?;
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
println!("{text_embeddings:?}");
println!("Building the autoencoder.");
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
let init_latent_dist = match &img2img {
None => None,
Some(image) => {
let image = image_preprocess(image)?.to_device(&device)?;
Some(vae.encode(&image)?)
}
};
println!("Building the unet.");
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
let t_start = if img2img.is_some() {
n_steps - (n_steps as f64 * img2img_strength) as usize
} else {
0
};
let bsize = 1;
for idx in 0..num_samples {
let timesteps = scheduler.timesteps();
let latents = match &init_latent_dist {
Some(init_latent_dist) => {
let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?;
if t_start < timesteps.len() {
let noise = latents.randn_like(0f64, 1f64)?;
scheduler.add_noise(&latents, noise, timesteps[t_start])?
} else {
latents
}
}
None => {
let latents = Tensor::randn(
0f32,
1f32,
(bsize, 4, sd_config.height / 8, sd_config.width / 8),
&device,
)?;
// scale the initial noise by the standard deviation required by the scheduler
(latents * scheduler.init_noise_sigma())?
}
};
let mut latents = latents.to_dtype(dtype)?;
println!("starting sampling");
for (timestep_index, &timestep) in timesteps.iter().enumerate() {
if timestep_index < t_start {
continue;
}
let start_time = std::time::Instant::now();
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
let noise_pred =
unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
let noise_pred = noise_pred.chunk(2, 0)?;
let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
let noise_pred =
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?;
latents = scheduler.step(&noise_pred, timestep, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
if args.intermediary_images {
let image = vae.decode(&(&latents / 0.18215)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename =
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
candle_examples::save_image(&image, image_filename)?
}
}
println!(
"Generating the final image for sample {}/{}.",
idx + 1,
num_samples
);
let image = vae.decode(&(&latents / 0.18215)?)?;
// TODO: Add the clamping between 0 and 1.
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
candle_examples::save_image(&image, image_filename)?
}
Ok(())
}
fn main() -> Result<()> {
let args = Args::parse();
run(args)
}

View File

@ -54,7 +54,7 @@ pub struct Config {
is_decoder: bool,
is_encoder_decoder: bool,
#[serde(default = "default_use_cache")]
use_cache: bool,
pub use_cache: bool,
pub pad_token_id: usize,
pub eos_token_id: usize,
}
@ -245,10 +245,17 @@ struct T5Attention {
relative_attention_num_buckets: usize,
relative_attention_max_distance: usize,
inner_dim: usize,
use_cache: bool,
kv_cache: Option<(Tensor, Tensor)>,
}
impl T5Attention {
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
fn load(
has_relative_attention_bias: bool,
decoder: bool,
vb: VarBuilder,
cfg: &Config,
) -> Result<Self> {
let inner_dim = cfg.num_heads * cfg.d_kv;
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
@ -275,11 +282,13 @@ impl T5Attention {
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
relative_attention_max_distance: cfg.relative_attention_max_distance,
inner_dim,
use_cache: cfg.use_cache && decoder,
kv_cache: None,
})
}
fn forward(
&self,
&mut self,
xs: &Tensor,
position_bias: Option<&Tensor>,
key_value_states: Option<&Tensor>,
@ -287,7 +296,6 @@ impl T5Attention {
) -> Result<(Tensor, Option<Tensor>)> {
// Performs Self-attention (if key_value_states is None) or attention
// over source sentence (provided by key_value_states).
// TODO: kv caching.
let kv_input = match key_value_states {
None => xs,
Some(key_value_states) => key_value_states,
@ -301,14 +309,22 @@ impl T5Attention {
.reshape((b_sz, q_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
let k = k
let mut k = k
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
let v = v
let mut v = v
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
if self.use_cache {
if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?;
v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?;
};
self.kv_cache = Some((k.clone(), v.clone()));
};
// TODO: Use flash_attn.
let scores = q.matmul(&k.t()?)?;
let scores = match mask {
@ -394,8 +410,8 @@ struct T5LayerSelfAttention {
}
impl T5LayerSelfAttention {
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?;
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
Ok(Self {
@ -405,7 +421,7 @@ impl T5LayerSelfAttention {
}
fn forward(
&self,
&mut self,
xs: &Tensor,
position_bias: Option<&Tensor>,
mask: Option<&Tensor>,
@ -426,8 +442,8 @@ struct T5LayerCrossAttention {
}
impl T5LayerCrossAttention {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let cross_attention = T5Attention::load(false, vb.pp("EncDecAttention"), cfg)?;
fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?;
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
Ok(Self {
@ -437,7 +453,7 @@ impl T5LayerCrossAttention {
}
fn forward(
&self,
&mut self,
hidden_states: &Tensor,
position_bias: Option<&Tensor>,
key_value_states: &Tensor,
@ -462,11 +478,17 @@ struct T5Block {
}
impl T5Block {
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
fn load(
has_relative_attention_bias: bool,
decoder: bool,
vb: VarBuilder,
cfg: &Config,
) -> Result<Self> {
let vb = vb.pp("layer");
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
let self_attn =
T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?;
let cross_attn = if cfg.is_decoder {
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?)
} else {
None
};
@ -480,19 +502,28 @@ impl T5Block {
}
fn forward(
&self,
&mut self,
xs: &Tensor,
position_bias: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
// TODO: Cache masks
let mask = match self.cross_attn.is_some() {
true => Some(get_mask(xs.dim(1)?, xs.device())?),
true => {
let mask_len = xs.dim(1)?;
// If the input seq length is 1, no need for a mask, this is also helpful to avoid shape
// issues when using the KV cache in the decoder.
if mask_len <= 1 {
None
} else {
Some(get_mask(mask_len, xs.device())?)
}
}
false => None,
};
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
// TODO: clamp for f16?
if let Some(cross_attn) = &self.cross_attn {
if let Some(cross_attn) = &mut self.cross_attn {
(xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
// TODO: clamp for f16?
}
@ -510,9 +541,9 @@ struct T5Stack {
}
impl T5Stack {
fn load(vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
let block = (0..cfg.num_layers)
.map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg))
.map(|i| T5Block::load(i == 0, decoder, vb.pp(&format!("block.{i}")), cfg))
.collect::<Result<Vec<_>>>()?;
let final_layer_norm = T5LayerNorm::load(
cfg.d_model,
@ -527,14 +558,14 @@ impl T5Stack {
}
fn forward(
&self,
&mut self,
input_ids: &Tensor,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let input_embeds = self.shared.as_ref().forward(input_ids)?;
let mut hidden_states = input_embeds;
let mut position_bias = None;
for block in self.block.iter() {
for block in self.block.iter_mut() {
(hidden_states, position_bias) = block.forward(
&hidden_states,
position_bias.as_ref(),
@ -555,14 +586,14 @@ impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Arc::new(shared);
let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?;
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
Ok(Self {
encoder,
device: vb.device().clone(),
})
}
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
self.encoder.forward(input_ids, None)
}
@ -589,13 +620,13 @@ impl T5ForConditionalGeneration {
encoder_cfg.is_decoder = false;
encoder_cfg.use_cache = false;
encoder_cfg.is_encoder_decoder = false;
let encoder = T5Stack::load(vb.pp("encoder"), &shared, &encoder_cfg)?;
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?;
let mut decoder_cfg = cfg.clone();
decoder_cfg.is_decoder = true;
decoder_cfg.is_encoder_decoder = false;
decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
let decoder = T5Stack::load(vb.pp("decoder"), &shared, &decoder_cfg)?;
let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
Ok(Self {
encoder,
@ -605,7 +636,7 @@ impl T5ForConditionalGeneration {
})
}
pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
let encoder_output = self.encoder.forward(input_ids, None)?;
let decoder_output = self
.decoder