Proper log buckets for t5. (#727)

* Proper log buckets for t5.

* Properly pass the position bias.
This commit is contained in:
Laurent Mazare
2023-09-03 21:33:50 +02:00
committed by GitHub
parent 26cd266e65
commit 9c61b0fc9b

View File

@ -202,7 +202,11 @@ impl T5Attention {
}) })
} }
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
// TODO: Apply the mask(s)? // TODO: Apply the mask(s)?
// TODO: kv caching. // TODO: kv caching.
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?); let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
@ -223,21 +227,43 @@ impl T5Attention {
.contiguous()?; .contiguous()?;
let scores = q.matmul(&k.t()?)?; let scores = q.matmul(&k.t()?)?;
let scores = match &self.relative_attention_bias { let (scores, position_bias) = match position_bias {
None => scores, Some(position_bias) => ((scores + position_bias)?, Some(position_bias.clone())),
None => match &self.relative_attention_bias {
None => (scores, None),
Some(relative_attention_bias) => { Some(relative_attention_bias) => {
let query_length = seq_len; let query_length = seq_len;
let key_length = seq_len; let key_length = seq_len;
// This only handles the bidirectional case. // This only handles the bidirectional case.
let num_buckets = self.relative_attention_num_buckets / 2; let num_buckets = self.relative_attention_num_buckets as u32 / 2;
let max_exact = num_buckets / 2;
let relative_position = (0..query_length as u32) let relative_position = (0..query_length as u32)
.map(|i| { .map(|i| {
(0..key_length as u32) (0..key_length as u32)
.map(|j| { .map(|j| {
if i < j { if i < j {
j - i + num_buckets as u32 if j - i < max_exact {
j - i + num_buckets
} else { } else {
let b = f32::log(
(j - i) as f32 / max_exact as f32,
self.relative_attention_max_distance as f32
/ max_exact as f32,
) * (num_buckets - max_exact) as f32;
u32::min(
max_exact + num_buckets + b as u32,
self.relative_attention_num_buckets as u32 - 1,
)
}
} else if i - j < max_exact {
i - j i - j
} else {
let b = f32::log(
(i - j) as f32 / max_exact as f32,
self.relative_attention_max_distance as f32
/ max_exact as f32,
) * (num_buckets - max_exact) as f32;
max_exact + b as u32
} }
}) })
.collect::<Vec<u32>>() .collect::<Vec<u32>>()
@ -248,9 +274,10 @@ impl T5Attention {
.forward(&relative_buckets)? .forward(&relative_buckets)?
.permute((2, 0, 1))? .permute((2, 0, 1))?
.unsqueeze(0)?; .unsqueeze(0)?;
(scores + position_bias)? ((scores + &position_bias)?, Some(position_bias))
// TODO: position_bias_masked? // TODO: position_bias_masked?
} }
},
}; };
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?; let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
@ -259,7 +286,7 @@ impl T5Attention {
.transpose(1, 2)? .transpose(1, 2)?
.reshape((b_sz, seq_len, self.inner_dim))?; .reshape((b_sz, seq_len, self.inner_dim))?;
let attn_output = self.o.forward(&attn_output)?; let attn_output = self.o.forward(&attn_output)?;
Ok(attn_output) Ok((attn_output, position_bias))
} }
} }
@ -280,11 +307,15 @@ impl T5LayerSelfAttention {
}) })
} }
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
let normed_xs = self.layer_norm.forward(xs)?; let normed_xs = self.layer_norm.forward(xs)?;
let ys = self.self_attention.forward(&normed_xs)?; let (ys, position_bias) = self.self_attention.forward(&normed_xs, position_bias)?;
let ys = (xs + ys)?; let ys = (xs + ys)?;
Ok(ys) Ok((ys, position_bias))
} }
} }
@ -326,8 +357,12 @@ impl T5Block {
}) })
} }
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(
let mut xs = self.self_attn.forward(xs)?; &self,
xs: &Tensor,
position_bias: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias)?;
// TODO: clamp for f16? // TODO: clamp for f16?
if let Some(cross_attn) = &self.cross_attn { if let Some(cross_attn) = &self.cross_attn {
xs = cross_attn.forward(&xs)?; xs = cross_attn.forward(&xs)?;
@ -335,7 +370,7 @@ impl T5Block {
} }
let xs = self.ff.forward(&xs)?; let xs = self.ff.forward(&xs)?;
// TODO: clamp for f16? // TODO: clamp for f16?
Ok(xs) Ok((xs, position_bias))
} }
} }
@ -368,8 +403,10 @@ impl T5Stack {
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?); let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
let mut hidden_states = input_embeds; let mut hidden_states = input_embeds;
let mut position_bias = None;
for block in self.block.iter() { for block in self.block.iter() {
hidden_states = block.forward(&hidden_states)? (hidden_states, position_bias) =
block.forward(&hidden_states, position_bias.as_ref())?
} }
let hidden_states = self.final_layer_norm.forward(&hidden_states)?; let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
Ok(hidden_states) Ok(hidden_states)