mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Implement T5 decoding (#864)
* Load t5 decoder * Run enc, dec, and lm head, but no cross attn * Cross-attention over key_value_states * New arg for decoder input ids * Add mask, don't forward position biases through decoder * Update t5 examples * Clippy + rustfmt
This commit is contained in:
@ -18,6 +18,21 @@ fn default_use_cache() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let result = Tensor::from_slice(&mask, (size, size), device)?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
@ -40,8 +55,8 @@ pub struct Config {
|
||||
is_encoder_decoder: bool,
|
||||
#[serde(default = "default_use_cache")]
|
||||
use_cache: bool,
|
||||
pad_token_id: usize,
|
||||
eos_token_id: usize,
|
||||
pub pad_token_id: usize,
|
||||
pub eos_token_id: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
@ -233,13 +248,13 @@ struct T5Attention {
|
||||
}
|
||||
|
||||
impl T5Attention {
|
||||
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
|
||||
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
|
||||
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
|
||||
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
|
||||
let relative_attention_bias = if h {
|
||||
let relative_attention_bias = if has_relative_attention_bias {
|
||||
let emb = embedding(
|
||||
cfg.relative_attention_num_buckets,
|
||||
cfg.num_heads,
|
||||
@ -267,26 +282,46 @@ impl T5Attention {
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
key_value_states: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Option<Tensor>)> {
|
||||
// TODO: Apply the mask(s)?
|
||||
// Performs Self-attention (if key_value_states is None) or attention
|
||||
// over source sentence (provided by key_value_states).
|
||||
// TODO: kv caching.
|
||||
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
|
||||
let kv_input = match key_value_states {
|
||||
None => xs,
|
||||
Some(key_value_states) => key_value_states,
|
||||
};
|
||||
let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?);
|
||||
let kv_len = kv_input.dim(1)?;
|
||||
let q = self.q.forward(xs)?;
|
||||
let k = self.k.forward(xs)?;
|
||||
let v = self.v.forward(xs)?;
|
||||
let k = self.k.forward(kv_input)?;
|
||||
let v = self.v.forward(kv_input)?;
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.reshape((b_sz, q_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let v = v
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
// TODO: Use flash_attn.
|
||||
let scores = q.matmul(&k.t()?)?;
|
||||
let scores = match mask {
|
||||
None => scores,
|
||||
Some(mask) => masked_fill(
|
||||
&scores,
|
||||
&mask
|
||||
.unsqueeze(0)?
|
||||
.unsqueeze(0)?
|
||||
.repeat((b_sz, self.n_heads))?,
|
||||
f32::NEG_INFINITY,
|
||||
)?,
|
||||
};
|
||||
|
||||
let (scores, position_bias) = match position_bias {
|
||||
Some(position_bias) => (
|
||||
@ -296,14 +331,12 @@ impl T5Attention {
|
||||
None => match &self.relative_attention_bias {
|
||||
None => (scores, None),
|
||||
Some(relative_attention_bias) => {
|
||||
let query_length = seq_len;
|
||||
let key_length = seq_len;
|
||||
// This only handles the bidirectional case.
|
||||
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
|
||||
let max_exact = num_buckets / 2;
|
||||
let relative_position = (0..query_length as u32)
|
||||
let relative_position = (0..q_len as u32)
|
||||
.map(|i| {
|
||||
(0..key_length as u32)
|
||||
(0..kv_len as u32)
|
||||
.map(|j| {
|
||||
if i < j {
|
||||
if j - i < max_exact {
|
||||
@ -348,7 +381,7 @@ impl T5Attention {
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
let attn_output = attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, seq_len, self.inner_dim))?;
|
||||
.reshape((b_sz, q_len, self.inner_dim))?;
|
||||
let attn_output = self.o.forward(&attn_output)?;
|
||||
Ok((attn_output, position_bias))
|
||||
}
|
||||
@ -375,24 +408,49 @@ impl T5LayerSelfAttention {
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Option<Tensor>)> {
|
||||
let normed_xs = self.layer_norm.forward(xs)?;
|
||||
let (ys, position_bias) = self.self_attention.forward(&normed_xs, position_bias)?;
|
||||
let (ys, position_bias) =
|
||||
self.self_attention
|
||||
.forward(&normed_xs, position_bias, None, mask)?;
|
||||
let ys = (xs + ys)?;
|
||||
Ok((ys, position_bias))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct T5LayerCrossAttention {}
|
||||
struct T5LayerCrossAttention {
|
||||
cross_attention: T5Attention,
|
||||
layer_norm: T5LayerNorm,
|
||||
}
|
||||
|
||||
impl T5LayerCrossAttention {
|
||||
fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
|
||||
todo!()
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let cross_attention = T5Attention::load(false, vb.pp("EncDecAttention"), cfg)?;
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
Ok(Self {
|
||||
cross_attention,
|
||||
layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn forward(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
key_value_states: &Tensor,
|
||||
) -> Result<(Tensor, Option<Tensor>)> {
|
||||
let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
|
||||
let (ys, position_bias) = self.cross_attention.forward(
|
||||
&normed_hidden_states,
|
||||
position_bias,
|
||||
Some(key_value_states),
|
||||
None,
|
||||
)?;
|
||||
let ys = (hidden_states + ys)?;
|
||||
Ok((ys, position_bias))
|
||||
}
|
||||
}
|
||||
|
||||
@ -425,11 +483,17 @@ impl T5Block {
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Option<Tensor>)> {
|
||||
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias)?;
|
||||
// TODO: Cache masks
|
||||
let mask = match self.cross_attn.is_some() {
|
||||
true => Some(get_mask(xs.dim(1)?, xs.device())?),
|
||||
false => None,
|
||||
};
|
||||
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
|
||||
// TODO: clamp for f16?
|
||||
if let Some(cross_attn) = &self.cross_attn {
|
||||
xs = cross_attn.forward(&xs)?;
|
||||
(xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
|
||||
// TODO: clamp for f16?
|
||||
}
|
||||
let xs = self.ff.forward(&xs)?;
|
||||
@ -462,13 +526,20 @@ impl T5Stack {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
fn forward(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||
let mut hidden_states = input_embeds;
|
||||
let mut position_bias = None;
|
||||
for block in self.block.iter() {
|
||||
(hidden_states, position_bias) =
|
||||
block.forward(&hidden_states, position_bias.as_ref())?
|
||||
(hidden_states, position_bias) = block.forward(
|
||||
&hidden_states,
|
||||
position_bias.as_ref(),
|
||||
encoder_hidden_states,
|
||||
)?
|
||||
}
|
||||
self.final_layer_norm.forward(&hidden_states)
|
||||
}
|
||||
@ -492,7 +563,61 @@ impl T5EncoderModel {
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
self.encoder.forward(input_ids)
|
||||
self.encoder.forward(input_ids, None)
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct T5ForConditionalGeneration {
|
||||
encoder: T5Stack,
|
||||
decoder: T5Stack,
|
||||
shared: Arc<Embedding>,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl T5ForConditionalGeneration {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
assert!(cfg.is_encoder_decoder);
|
||||
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let shared = Arc::new(shared);
|
||||
|
||||
let mut encoder_cfg = cfg.clone();
|
||||
encoder_cfg.is_decoder = false;
|
||||
encoder_cfg.use_cache = false;
|
||||
encoder_cfg.is_encoder_decoder = false;
|
||||
let encoder = T5Stack::load(vb.pp("encoder"), &shared, &encoder_cfg)?;
|
||||
|
||||
let mut decoder_cfg = cfg.clone();
|
||||
decoder_cfg.is_decoder = true;
|
||||
decoder_cfg.is_encoder_decoder = false;
|
||||
decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
|
||||
let decoder = T5Stack::load(vb.pp("decoder"), &shared, &decoder_cfg)?;
|
||||
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
shared,
|
||||
device: vb.device().clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
|
||||
let encoder_output = self.encoder.forward(input_ids, None)?;
|
||||
let decoder_output = self
|
||||
.decoder
|
||||
.forward(decoder_input_ids, Some(&encoder_output))?;
|
||||
let sequence_output = decoder_output
|
||||
.narrow(1, decoder_output.dim(1)? - 1, 1)?
|
||||
.squeeze(1)?;
|
||||
// TODO: check cfg.tie_word_embeddings to load from model instead.
|
||||
let lm_head_weights = self.shared.embeddings().t()?;
|
||||
let output = sequence_output.matmul(&lm_head_weights)?;
|
||||
// TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
|
Reference in New Issue
Block a user