mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Implement T5 decoding (#864)
* Load t5 decoder * Run enc, dec, and lm head, but no cross attn * Cross-attention over key_value_states * New arg for decoder input ids * Add mask, don't forward position biases through decoder * Update t5 examples * Clippy + rustfmt
This commit is contained in:
@ -1,17 +1,25 @@
|
|||||||
# candle-t5
|
# candle-t5
|
||||||
|
|
||||||
Generates embeddings using a T5 model. It doesn't support generation yet.
|
## Encoder-decoder example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example t5 -- --model-id t5-large --prompt 'how tall is obama' --n 1
|
$ cargo run --example t5 -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
|
||||||
Loaded and encoded 2.014244792s
|
...
|
||||||
[[[-0.3174, -0.1462, 0.0065, ..., -0.0579, -0.0581, 0.1387],
|
Running on CPU, to run on GPU, build this example with `--features cuda`
|
||||||
[-0.2905, -0.1945, -0.0685, ..., -0.2457, -0.5137, -0.1760],
|
Eine schöne Kerze.
|
||||||
[-0.0591, -0.0213, -0.0241, ..., -0.0210, 0.0491, -0.0300],
|
9 tokens generated (2.42 token/s)
|
||||||
...
|
```
|
||||||
[-0.4333, 0.0027, -0.0609, ..., 0.3069, -0.2252, 0.3306],
|
|
||||||
[-0.1458, 0.1323, -0.0138, ..., 0.3000, -0.4550, -0.0384],
|
## Sentence embedding example:
|
||||||
[ 0.0397, 0.0485, -0.2373, ..., 0.2578, -0.2650, -0.4356]]]
|
|
||||||
Tensor[[1, 9, 1024], f32]
|
```bash
|
||||||
Took 2.1363425s
|
$ cargo run --example t5 -- --model-id "t5-small" --prompt "A beautiful candle."
|
||||||
```
|
...
|
||||||
|
[[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265],
|
||||||
|
[-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164],
|
||||||
|
[ 0.0624, -0.1024, 0.0430, ..., -0.1388, 0.0564, -0.2962],
|
||||||
|
[-0.0389, -0.1173, 0.0026, ..., 0.1064, -0.1065, 0.0990],
|
||||||
|
[ 0.1300, 0.0027, -0.0326, ..., 0.0026, -0.0317, 0.0851]]]
|
||||||
|
Tensor[[1, 5, 512], f32]
|
||||||
|
Took 303.766583ms
|
||||||
|
```
|
||||||
|
@ -3,18 +3,22 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
use std::io::Write;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
use candle_transformers::models::t5;
|
use candle_transformers::models::t5;
|
||||||
|
|
||||||
use anyhow::{anyhow, Error as E, Result};
|
use anyhow::{anyhow, Error as E, Result};
|
||||||
use candle::{DType, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
const DTYPE: DType = DType::F32;
|
const DTYPE: DType = DType::F32;
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
/// Run on CPU rather than on GPU.
|
/// Run on CPU rather than on GPU.
|
||||||
@ -36,7 +40,11 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
/// Compute embeddings for this prompt, otherwise compute sentence similarities.
|
/// Enable decoding.
|
||||||
|
#[arg(long)]
|
||||||
|
decode: bool,
|
||||||
|
|
||||||
|
/// Use this prompt, otherwise compute sentence similarities.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: Option<String>,
|
prompt: Option<String>,
|
||||||
|
|
||||||
@ -49,12 +57,18 @@ struct Args {
|
|||||||
normalize_embeddings: bool,
|
normalize_embeddings: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
struct T5ModelBuilder {
|
||||||
fn build_model_and_tokenizer(&self) -> Result<(t5::T5EncoderModel, Tokenizer)> {
|
device: Device,
|
||||||
let device = candle_examples::device(self.cpu)?;
|
config: t5::Config,
|
||||||
|
weights_filename: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl T5ModelBuilder {
|
||||||
|
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let default_model = "t5-small".to_string();
|
let default_model = "t5-small".to_string();
|
||||||
let default_revision = "refs/pr/15".to_string();
|
let default_revision = "refs/pr/15".to_string();
|
||||||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
||||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||||
(None, Some(revision)) => (default_model, revision),
|
(None, Some(revision)) => (default_model, revision),
|
||||||
@ -62,7 +76,7 @@ impl Args {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||||
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
|
let (config_filename, tokenizer_filename, weights_filename) = if args.offline {
|
||||||
let cache = Cache::default().repo(repo);
|
let cache = Cache::default().repo(repo);
|
||||||
(
|
(
|
||||||
cache
|
cache
|
||||||
@ -87,18 +101,36 @@ impl Args {
|
|||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
let config: t5::Config = serde_json::from_str(&config)?;
|
let config: t5::Config = serde_json::from_str(&config)?;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
Ok((
|
||||||
|
Self {
|
||||||
|
device,
|
||||||
|
config,
|
||||||
|
weights_filename,
|
||||||
|
},
|
||||||
|
tokenizer,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
|
||||||
|
let weights =
|
||||||
|
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
|
||||||
let model = t5::T5EncoderModel::load(vb, &config)?;
|
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
|
||||||
Ok((model, tokenizer))
|
}
|
||||||
|
|
||||||
|
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||||
|
let weights =
|
||||||
|
unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
|
||||||
|
let weights = weights.deserialize()?;
|
||||||
|
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
|
||||||
|
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
|
||||||
let tokenizer = tokenizer
|
let tokenizer = tokenizer
|
||||||
.with_padding(None)
|
.with_padding(None)
|
||||||
.with_truncation(None)
|
.with_truncation(None)
|
||||||
@ -110,17 +142,51 @@ fn main() -> Result<()> {
|
|||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
|
let input_token_ids = Tensor::new(&tokens[..], &builder.device)?.unsqueeze(0)?;
|
||||||
for idx in 0..args.n {
|
if !args.decode {
|
||||||
let start = std::time::Instant::now();
|
let model = builder.build_encoder()?;
|
||||||
let ys = model.forward(&token_ids)?;
|
for idx in 0..args.n {
|
||||||
if idx == 0 {
|
let start = std::time::Instant::now();
|
||||||
println!("{ys}");
|
let ys = model.forward(&input_token_ids)?;
|
||||||
|
if idx == 0 {
|
||||||
|
println!("{ys}");
|
||||||
|
}
|
||||||
|
println!("Took {:?}", start.elapsed());
|
||||||
}
|
}
|
||||||
println!("Took {:?}", start.elapsed());
|
} else {
|
||||||
|
let model = builder.build_conditional_generation()?;
|
||||||
|
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
|
||||||
|
let mut logits_processor = LogitsProcessor::new(299792458, None, None);
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
|
for _index in 0.. {
|
||||||
|
if output_token_ids.len() > 512 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let decoder_token_ids =
|
||||||
|
Tensor::new(&output_token_ids[..], &builder.device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input_token_ids, &decoder_token_ids)?;
|
||||||
|
let next_token_id = logits_processor.sample(&logits.flatten_to(1)?)?;
|
||||||
|
if (next_token_id as usize) == builder.config.eos_token_id {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
output_token_ids.push(next_token_id);
|
||||||
|
if let Some(text) = tokenizer.id_to_token(next_token_id) {
|
||||||
|
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
|
print!("{text}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n{} tokens generated ({:.2} token/s)\n",
|
||||||
|
tokens.len(),
|
||||||
|
tokens.len() as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
|
let model = builder.build_encoder()?;
|
||||||
let sentences = [
|
let sentences = [
|
||||||
"The cat sits outside",
|
"The cat sits outside",
|
||||||
"A man is playing guitar",
|
"A man is playing guitar",
|
||||||
|
@ -18,6 +18,21 @@ fn default_use_cache() -> bool {
|
|||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||||
|
let mask: Vec<_> = (0..size)
|
||||||
|
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||||
|
.collect();
|
||||||
|
let result = Tensor::from_slice(&mask, (size, size), device)?;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||||
|
let shape = mask.shape();
|
||||||
|
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||||
|
let m = mask.where_cond(&on_true, on_false)?;
|
||||||
|
Ok(m)
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
@ -40,8 +55,8 @@ pub struct Config {
|
|||||||
is_encoder_decoder: bool,
|
is_encoder_decoder: bool,
|
||||||
#[serde(default = "default_use_cache")]
|
#[serde(default = "default_use_cache")]
|
||||||
use_cache: bool,
|
use_cache: bool,
|
||||||
pad_token_id: usize,
|
pub pad_token_id: usize,
|
||||||
eos_token_id: usize,
|
pub eos_token_id: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Config {
|
impl Default for Config {
|
||||||
@ -233,13 +248,13 @@ struct T5Attention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl T5Attention {
|
impl T5Attention {
|
||||||
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||||
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
|
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
|
||||||
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
|
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
|
||||||
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
|
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
|
||||||
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
|
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
|
||||||
let relative_attention_bias = if h {
|
let relative_attention_bias = if has_relative_attention_bias {
|
||||||
let emb = embedding(
|
let emb = embedding(
|
||||||
cfg.relative_attention_num_buckets,
|
cfg.relative_attention_num_buckets,
|
||||||
cfg.num_heads,
|
cfg.num_heads,
|
||||||
@ -267,26 +282,46 @@ impl T5Attention {
|
|||||||
&self,
|
&self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
position_bias: Option<&Tensor>,
|
position_bias: Option<&Tensor>,
|
||||||
|
key_value_states: Option<&Tensor>,
|
||||||
|
mask: Option<&Tensor>,
|
||||||
) -> Result<(Tensor, Option<Tensor>)> {
|
) -> Result<(Tensor, Option<Tensor>)> {
|
||||||
// TODO: Apply the mask(s)?
|
// Performs Self-attention (if key_value_states is None) or attention
|
||||||
|
// over source sentence (provided by key_value_states).
|
||||||
// TODO: kv caching.
|
// TODO: kv caching.
|
||||||
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
|
let kv_input = match key_value_states {
|
||||||
|
None => xs,
|
||||||
|
Some(key_value_states) => key_value_states,
|
||||||
|
};
|
||||||
|
let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?);
|
||||||
|
let kv_len = kv_input.dim(1)?;
|
||||||
let q = self.q.forward(xs)?;
|
let q = self.q.forward(xs)?;
|
||||||
let k = self.k.forward(xs)?;
|
let k = self.k.forward(kv_input)?;
|
||||||
let v = self.v.forward(xs)?;
|
let v = self.v.forward(kv_input)?;
|
||||||
let q = q
|
let q = q
|
||||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
.reshape((b_sz, q_len, self.n_heads, self.d_kv))?
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.contiguous()?;
|
.contiguous()?;
|
||||||
let k = k
|
let k = k
|
||||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.contiguous()?;
|
.contiguous()?;
|
||||||
let v = v
|
let v = v
|
||||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.contiguous()?;
|
.contiguous()?;
|
||||||
|
// TODO: Use flash_attn.
|
||||||
let scores = q.matmul(&k.t()?)?;
|
let scores = q.matmul(&k.t()?)?;
|
||||||
|
let scores = match mask {
|
||||||
|
None => scores,
|
||||||
|
Some(mask) => masked_fill(
|
||||||
|
&scores,
|
||||||
|
&mask
|
||||||
|
.unsqueeze(0)?
|
||||||
|
.unsqueeze(0)?
|
||||||
|
.repeat((b_sz, self.n_heads))?,
|
||||||
|
f32::NEG_INFINITY,
|
||||||
|
)?,
|
||||||
|
};
|
||||||
|
|
||||||
let (scores, position_bias) = match position_bias {
|
let (scores, position_bias) = match position_bias {
|
||||||
Some(position_bias) => (
|
Some(position_bias) => (
|
||||||
@ -296,14 +331,12 @@ impl T5Attention {
|
|||||||
None => match &self.relative_attention_bias {
|
None => match &self.relative_attention_bias {
|
||||||
None => (scores, None),
|
None => (scores, None),
|
||||||
Some(relative_attention_bias) => {
|
Some(relative_attention_bias) => {
|
||||||
let query_length = seq_len;
|
|
||||||
let key_length = seq_len;
|
|
||||||
// This only handles the bidirectional case.
|
// This only handles the bidirectional case.
|
||||||
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
|
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
|
||||||
let max_exact = num_buckets / 2;
|
let max_exact = num_buckets / 2;
|
||||||
let relative_position = (0..query_length as u32)
|
let relative_position = (0..q_len as u32)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
(0..key_length as u32)
|
(0..kv_len as u32)
|
||||||
.map(|j| {
|
.map(|j| {
|
||||||
if i < j {
|
if i < j {
|
||||||
if j - i < max_exact {
|
if j - i < max_exact {
|
||||||
@ -348,7 +381,7 @@ impl T5Attention {
|
|||||||
let attn_output = attn_weights.matmul(&v)?;
|
let attn_output = attn_weights.matmul(&v)?;
|
||||||
let attn_output = attn_output
|
let attn_output = attn_output
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.reshape((b_sz, seq_len, self.inner_dim))?;
|
.reshape((b_sz, q_len, self.inner_dim))?;
|
||||||
let attn_output = self.o.forward(&attn_output)?;
|
let attn_output = self.o.forward(&attn_output)?;
|
||||||
Ok((attn_output, position_bias))
|
Ok((attn_output, position_bias))
|
||||||
}
|
}
|
||||||
@ -375,24 +408,49 @@ impl T5LayerSelfAttention {
|
|||||||
&self,
|
&self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
position_bias: Option<&Tensor>,
|
position_bias: Option<&Tensor>,
|
||||||
|
mask: Option<&Tensor>,
|
||||||
) -> Result<(Tensor, Option<Tensor>)> {
|
) -> Result<(Tensor, Option<Tensor>)> {
|
||||||
let normed_xs = self.layer_norm.forward(xs)?;
|
let normed_xs = self.layer_norm.forward(xs)?;
|
||||||
let (ys, position_bias) = self.self_attention.forward(&normed_xs, position_bias)?;
|
let (ys, position_bias) =
|
||||||
|
self.self_attention
|
||||||
|
.forward(&normed_xs, position_bias, None, mask)?;
|
||||||
let ys = (xs + ys)?;
|
let ys = (xs + ys)?;
|
||||||
Ok((ys, position_bias))
|
Ok((ys, position_bias))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct T5LayerCrossAttention {}
|
struct T5LayerCrossAttention {
|
||||||
|
cross_attention: T5Attention,
|
||||||
|
layer_norm: T5LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
impl T5LayerCrossAttention {
|
impl T5LayerCrossAttention {
|
||||||
fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
todo!()
|
let cross_attention = T5Attention::load(false, vb.pp("EncDecAttention"), cfg)?;
|
||||||
|
let layer_norm =
|
||||||
|
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
cross_attention,
|
||||||
|
layer_norm,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
fn forward(
|
||||||
todo!()
|
&self,
|
||||||
|
hidden_states: &Tensor,
|
||||||
|
position_bias: Option<&Tensor>,
|
||||||
|
key_value_states: &Tensor,
|
||||||
|
) -> Result<(Tensor, Option<Tensor>)> {
|
||||||
|
let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
|
||||||
|
let (ys, position_bias) = self.cross_attention.forward(
|
||||||
|
&normed_hidden_states,
|
||||||
|
position_bias,
|
||||||
|
Some(key_value_states),
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
let ys = (hidden_states + ys)?;
|
||||||
|
Ok((ys, position_bias))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -425,11 +483,17 @@ impl T5Block {
|
|||||||
&self,
|
&self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
position_bias: Option<&Tensor>,
|
position_bias: Option<&Tensor>,
|
||||||
|
encoder_hidden_states: Option<&Tensor>,
|
||||||
) -> Result<(Tensor, Option<Tensor>)> {
|
) -> Result<(Tensor, Option<Tensor>)> {
|
||||||
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias)?;
|
// TODO: Cache masks
|
||||||
|
let mask = match self.cross_attn.is_some() {
|
||||||
|
true => Some(get_mask(xs.dim(1)?, xs.device())?),
|
||||||
|
false => None,
|
||||||
|
};
|
||||||
|
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
|
||||||
// TODO: clamp for f16?
|
// TODO: clamp for f16?
|
||||||
if let Some(cross_attn) = &self.cross_attn {
|
if let Some(cross_attn) = &self.cross_attn {
|
||||||
xs = cross_attn.forward(&xs)?;
|
(xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
|
||||||
// TODO: clamp for f16?
|
// TODO: clamp for f16?
|
||||||
}
|
}
|
||||||
let xs = self.ff.forward(&xs)?;
|
let xs = self.ff.forward(&xs)?;
|
||||||
@ -462,13 +526,20 @@ impl T5Stack {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
fn forward(
|
||||||
|
&self,
|
||||||
|
input_ids: &Tensor,
|
||||||
|
encoder_hidden_states: Option<&Tensor>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||||
let mut hidden_states = input_embeds;
|
let mut hidden_states = input_embeds;
|
||||||
let mut position_bias = None;
|
let mut position_bias = None;
|
||||||
for block in self.block.iter() {
|
for block in self.block.iter() {
|
||||||
(hidden_states, position_bias) =
|
(hidden_states, position_bias) = block.forward(
|
||||||
block.forward(&hidden_states, position_bias.as_ref())?
|
&hidden_states,
|
||||||
|
position_bias.as_ref(),
|
||||||
|
encoder_hidden_states,
|
||||||
|
)?
|
||||||
}
|
}
|
||||||
self.final_layer_norm.forward(&hidden_states)
|
self.final_layer_norm.forward(&hidden_states)
|
||||||
}
|
}
|
||||||
@ -492,7 +563,61 @@ impl T5EncoderModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
self.encoder.forward(input_ids)
|
self.encoder.forward(input_ids, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &Device {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct T5ForConditionalGeneration {
|
||||||
|
encoder: T5Stack,
|
||||||
|
decoder: T5Stack,
|
||||||
|
shared: Arc<Embedding>,
|
||||||
|
device: Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl T5ForConditionalGeneration {
|
||||||
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
assert!(cfg.is_encoder_decoder);
|
||||||
|
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||||
|
let shared = Arc::new(shared);
|
||||||
|
|
||||||
|
let mut encoder_cfg = cfg.clone();
|
||||||
|
encoder_cfg.is_decoder = false;
|
||||||
|
encoder_cfg.use_cache = false;
|
||||||
|
encoder_cfg.is_encoder_decoder = false;
|
||||||
|
let encoder = T5Stack::load(vb.pp("encoder"), &shared, &encoder_cfg)?;
|
||||||
|
|
||||||
|
let mut decoder_cfg = cfg.clone();
|
||||||
|
decoder_cfg.is_decoder = true;
|
||||||
|
decoder_cfg.is_encoder_decoder = false;
|
||||||
|
decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
|
||||||
|
let decoder = T5Stack::load(vb.pp("decoder"), &shared, &decoder_cfg)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
encoder,
|
||||||
|
decoder,
|
||||||
|
shared,
|
||||||
|
device: vb.device().clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let encoder_output = self.encoder.forward(input_ids, None)?;
|
||||||
|
let decoder_output = self
|
||||||
|
.decoder
|
||||||
|
.forward(decoder_input_ids, Some(&encoder_output))?;
|
||||||
|
let sequence_output = decoder_output
|
||||||
|
.narrow(1, decoder_output.dim(1)? - 1, 1)?
|
||||||
|
.squeeze(1)?;
|
||||||
|
// TODO: check cfg.tie_word_embeddings to load from model instead.
|
||||||
|
let lm_head_weights = self.shared.embeddings().t()?;
|
||||||
|
let output = sequence_output.matmul(&lm_head_weights)?;
|
||||||
|
// TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)
|
||||||
|
Ok(output)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn device(&self) -> &Device {
|
pub fn device(&self) -> &Device {
|
||||||
|
Reference in New Issue
Block a user