Implement T5 decoding (#864)

* Load t5 decoder

* Run enc, dec, and lm head, but no cross attn

* Cross-attention over key_value_states

* New arg for decoder input ids

* Add mask, don't forward position biases through decoder

* Update t5 examples

* Clippy + rustfmt
This commit is contained in:
Juarez Bochi
2023-09-15 13:05:12 -07:00
committed by GitHub
parent c2007ac88f
commit 3e49f8fce5
3 changed files with 260 additions and 61 deletions

View File

@ -1,17 +1,25 @@
# candle-t5
Generates embeddings using a T5 model. It doesn't support generation yet.
## Encoder-decoder example:
```bash
$ cargo run --example t5 -- --model-id t5-large --prompt 'how tall is obama' --n 1
Loaded and encoded 2.014244792s
[[[-0.3174, -0.1462, 0.0065, ..., -0.0579, -0.0581, 0.1387],
[-0.2905, -0.1945, -0.0685, ..., -0.2457, -0.5137, -0.1760],
[-0.0591, -0.0213, -0.0241, ..., -0.0210, 0.0491, -0.0300],
...
[-0.4333, 0.0027, -0.0609, ..., 0.3069, -0.2252, 0.3306],
[-0.1458, 0.1323, -0.0138, ..., 0.3000, -0.4550, -0.0384],
[ 0.0397, 0.0485, -0.2373, ..., 0.2578, -0.2650, -0.4356]]]
Tensor[[1, 9, 1024], f32]
Took 2.1363425s
```
$ cargo run --example t5 -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
...
Running on CPU, to run on GPU, build this example with `--features cuda`
Eine schöne Kerze.
9 tokens generated (2.42 token/s)
```
## Sentence embedding example:
```bash
$ cargo run --example t5 -- --model-id "t5-small" --prompt "A beautiful candle."
...
[[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265],
[-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164],
[ 0.0624, -0.1024, 0.0430, ..., -0.1388, 0.0564, -0.2962],
[-0.0389, -0.1173, 0.0026, ..., 0.1064, -0.1065, 0.0990],
[ 0.1300, 0.0027, -0.0326, ..., 0.0026, -0.0317, 0.0851]]]
Tensor[[1, 5, 512], f32]
Took 303.766583ms
```

View File

@ -3,18 +3,22 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use std::io::Write;
use std::path::PathBuf;
use candle_transformers::models::t5;
use anyhow::{anyhow, Error as E, Result};
use candle::{DType, Tensor};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use clap::Parser;
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
use tokenizers::Tokenizer;
const DTYPE: DType = DType::F32;
#[derive(Parser, Debug)]
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
@ -36,7 +40,11 @@ struct Args {
#[arg(long)]
revision: Option<String>,
/// Compute embeddings for this prompt, otherwise compute sentence similarities.
/// Enable decoding.
#[arg(long)]
decode: bool,
/// Use this prompt, otherwise compute sentence similarities.
#[arg(long)]
prompt: Option<String>,
@ -49,12 +57,18 @@ struct Args {
normalize_embeddings: bool,
}
impl Args {
fn build_model_and_tokenizer(&self) -> Result<(t5::T5EncoderModel, Tokenizer)> {
let device = candle_examples::device(self.cpu)?;
struct T5ModelBuilder {
device: Device,
config: t5::Config,
weights_filename: PathBuf,
}
impl T5ModelBuilder {
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
let device = candle_examples::device(args.cpu)?;
let default_model = "t5-small".to_string();
let default_revision = "refs/pr/15".to_string();
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
@ -62,7 +76,7 @@ impl Args {
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
let (config_filename, tokenizer_filename, weights_filename) = if args.offline {
let cache = Cache::default().repo(repo);
(
cache
@ -87,18 +101,36 @@ impl Args {
let config = std::fs::read_to_string(config_filename)?;
let config: t5::Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
Ok((
Self {
device,
config,
weights_filename,
},
tokenizer,
))
}
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
let weights =
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let model = t5::T5EncoderModel::load(vb, &config)?;
Ok((model, tokenizer))
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
}
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
let weights =
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}
}
fn main() -> Result<()> {
let args = Args::parse();
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
@ -110,17 +142,51 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
for idx in 0..args.n {
let start = std::time::Instant::now();
let ys = model.forward(&token_ids)?;
if idx == 0 {
println!("{ys}");
let input_token_ids = Tensor::new(&tokens[..], &builder.device)?.unsqueeze(0)?;
if !args.decode {
let model = builder.build_encoder()?;
for idx in 0..args.n {
let start = std::time::Instant::now();
let ys = model.forward(&input_token_ids)?;
if idx == 0 {
println!("{ys}");
}
println!("Took {:?}", start.elapsed());
}
println!("Took {:?}", start.elapsed());
} else {
let model = builder.build_conditional_generation()?;
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
let mut logits_processor = LogitsProcessor::new(299792458, None, None);
let start = std::time::Instant::now();
for _index in 0.. {
if output_token_ids.len() > 512 {
break;
}
let decoder_token_ids =
Tensor::new(&output_token_ids[..], &builder.device)?.unsqueeze(0)?;
let logits = model.forward(&input_token_ids, &decoder_token_ids)?;
let next_token_id = logits_processor.sample(&logits.flatten_to(1)?)?;
if (next_token_id as usize) == builder.config.eos_token_id {
break;
}
output_token_ids.push(next_token_id);
if let Some(text) = tokenizer.id_to_token(next_token_id) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
}
let dt = start.elapsed();
println!(
"\n{} tokens generated ({:.2} token/s)\n",
tokens.len(),
tokens.len() as f64 / dt.as_secs_f64(),
);
}
}
None => {
let model = builder.build_encoder()?;
let sentences = [
"The cat sits outside",
"A man is playing guitar",

View File

@ -18,6 +18,21 @@ fn default_use_cache() -> bool {
true
}
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
let result = Tensor::from_slice(&mask, (size, size), device)?;
Ok(result)
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
vocab_size: usize,
@ -40,8 +55,8 @@ pub struct Config {
is_encoder_decoder: bool,
#[serde(default = "default_use_cache")]
use_cache: bool,
pad_token_id: usize,
eos_token_id: usize,
pub pad_token_id: usize,
pub eos_token_id: usize,
}
impl Default for Config {
@ -233,13 +248,13 @@ struct T5Attention {
}
impl T5Attention {
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let inner_dim = cfg.num_heads * cfg.d_kv;
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
let relative_attention_bias = if h {
let relative_attention_bias = if has_relative_attention_bias {
let emb = embedding(
cfg.relative_attention_num_buckets,
cfg.num_heads,
@ -267,26 +282,46 @@ impl T5Attention {
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
key_value_states: Option<&Tensor>,
mask: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
// TODO: Apply the mask(s)?
// Performs Self-attention (if key_value_states is None) or attention
// over source sentence (provided by key_value_states).
// TODO: kv caching.
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
let kv_input = match key_value_states {
None => xs,
Some(key_value_states) => key_value_states,
};
let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?);
let kv_len = kv_input.dim(1)?;
let q = self.q.forward(xs)?;
let k = self.k.forward(xs)?;
let v = self.v.forward(xs)?;
let k = self.k.forward(kv_input)?;
let v = self.v.forward(kv_input)?;
let q = q
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.reshape((b_sz, q_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
let v = v
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
// TODO: Use flash_attn.
let scores = q.matmul(&k.t()?)?;
let scores = match mask {
None => scores,
Some(mask) => masked_fill(
&scores,
&mask
.unsqueeze(0)?
.unsqueeze(0)?
.repeat((b_sz, self.n_heads))?,
f32::NEG_INFINITY,
)?,
};
let (scores, position_bias) = match position_bias {
Some(position_bias) => (
@ -296,14 +331,12 @@ impl T5Attention {
None => match &self.relative_attention_bias {
None => (scores, None),
Some(relative_attention_bias) => {
let query_length = seq_len;
let key_length = seq_len;
// This only handles the bidirectional case.
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
let max_exact = num_buckets / 2;
let relative_position = (0..query_length as u32)
let relative_position = (0..q_len as u32)
.map(|i| {
(0..key_length as u32)
(0..kv_len as u32)
.map(|j| {
if i < j {
if j - i < max_exact {
@ -348,7 +381,7 @@ impl T5Attention {
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output
.transpose(1, 2)?
.reshape((b_sz, seq_len, self.inner_dim))?;
.reshape((b_sz, q_len, self.inner_dim))?;
let attn_output = self.o.forward(&attn_output)?;
Ok((attn_output, position_bias))
}
@ -375,24 +408,49 @@ impl T5LayerSelfAttention {
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
mask: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
let normed_xs = self.layer_norm.forward(xs)?;
let (ys, position_bias) = self.self_attention.forward(&normed_xs, position_bias)?;
let (ys, position_bias) =
self.self_attention
.forward(&normed_xs, position_bias, None, mask)?;
let ys = (xs + ys)?;
Ok((ys, position_bias))
}
}
#[derive(Debug)]
struct T5LayerCrossAttention {}
struct T5LayerCrossAttention {
cross_attention: T5Attention,
layer_norm: T5LayerNorm,
}
impl T5LayerCrossAttention {
fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
todo!()
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let cross_attention = T5Attention::load(false, vb.pp("EncDecAttention"), cfg)?;
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
Ok(Self {
cross_attention,
layer_norm,
})
}
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
fn forward(
&self,
hidden_states: &Tensor,
position_bias: Option<&Tensor>,
key_value_states: &Tensor,
) -> Result<(Tensor, Option<Tensor>)> {
let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
let (ys, position_bias) = self.cross_attention.forward(
&normed_hidden_states,
position_bias,
Some(key_value_states),
None,
)?;
let ys = (hidden_states + ys)?;
Ok((ys, position_bias))
}
}
@ -425,11 +483,17 @@ impl T5Block {
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias)?;
// TODO: Cache masks
let mask = match self.cross_attn.is_some() {
true => Some(get_mask(xs.dim(1)?, xs.device())?),
false => None,
};
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
// TODO: clamp for f16?
if let Some(cross_attn) = &self.cross_attn {
xs = cross_attn.forward(&xs)?;
(xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
// TODO: clamp for f16?
}
let xs = self.ff.forward(&xs)?;
@ -462,13 +526,20 @@ impl T5Stack {
})
}
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
fn forward(
&self,
input_ids: &Tensor,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let input_embeds = self.shared.as_ref().forward(input_ids)?;
let mut hidden_states = input_embeds;
let mut position_bias = None;
for block in self.block.iter() {
(hidden_states, position_bias) =
block.forward(&hidden_states, position_bias.as_ref())?
(hidden_states, position_bias) = block.forward(
&hidden_states,
position_bias.as_ref(),
encoder_hidden_states,
)?
}
self.final_layer_norm.forward(&hidden_states)
}
@ -492,7 +563,61 @@ impl T5EncoderModel {
}
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
self.encoder.forward(input_ids)
self.encoder.forward(input_ids, None)
}
pub fn device(&self) -> &Device {
&self.device
}
}
#[derive(Debug)]
pub struct T5ForConditionalGeneration {
encoder: T5Stack,
decoder: T5Stack,
shared: Arc<Embedding>,
device: Device,
}
impl T5ForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
assert!(cfg.is_encoder_decoder);
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Arc::new(shared);
let mut encoder_cfg = cfg.clone();
encoder_cfg.is_decoder = false;
encoder_cfg.use_cache = false;
encoder_cfg.is_encoder_decoder = false;
let encoder = T5Stack::load(vb.pp("encoder"), &shared, &encoder_cfg)?;
let mut decoder_cfg = cfg.clone();
decoder_cfg.is_decoder = true;
decoder_cfg.is_encoder_decoder = false;
decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
let decoder = T5Stack::load(vb.pp("decoder"), &shared, &decoder_cfg)?;
Ok(Self {
encoder,
decoder,
shared,
device: vb.device().clone(),
})
}
pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
let encoder_output = self.encoder.forward(input_ids, None)?;
let decoder_output = self
.decoder
.forward(decoder_input_ids, Some(&encoder_output))?;
let sequence_output = decoder_output
.narrow(1, decoder_output.dim(1)? - 1, 1)?
.squeeze(1)?;
// TODO: check cfg.tie_word_embeddings to load from model instead.
let lm_head_weights = self.shared.embeddings().t()?;
let output = sequence_output.matmul(&lm_head_weights)?;
// TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)
Ok(output)
}
pub fn device(&self) -> &Device {