Proper log buckets for t5. (#727)

* Proper log buckets for t5.

* Properly pass the position bias.
This commit is contained in:
Laurent Mazare
2023-09-03 21:33:50 +02:00
committed by GitHub
parent 26cd266e65
commit 9c61b0fc9b

View File

@ -202,7 +202,11 @@ impl T5Attention {
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
fn forward(
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
// TODO: Apply the mask(s)?
// TODO: kv caching.
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
@ -223,34 +227,57 @@ impl T5Attention {
.contiguous()?;
let scores = q.matmul(&k.t()?)?;
let scores = match &self.relative_attention_bias {
None => scores,
Some(relative_attention_bias) => {
let query_length = seq_len;
let key_length = seq_len;
// This only handles the bidirectional case.
let num_buckets = self.relative_attention_num_buckets / 2;
let relative_position = (0..query_length as u32)
.map(|i| {
(0..key_length as u32)
.map(|j| {
if i < j {
j - i + num_buckets as u32
} else {
i - j
}
})
.collect::<Vec<u32>>()
})
.collect::<Vec<Vec<_>>>();
let relative_buckets = Tensor::new(relative_position, q.device())?;
let position_bias = relative_attention_bias
.forward(&relative_buckets)?
.permute((2, 0, 1))?
.unsqueeze(0)?;
(scores + position_bias)?
// TODO: position_bias_masked?
}
let (scores, position_bias) = match position_bias {
Some(position_bias) => ((scores + position_bias)?, Some(position_bias.clone())),
None => match &self.relative_attention_bias {
None => (scores, None),
Some(relative_attention_bias) => {
let query_length = seq_len;
let key_length = seq_len;
// This only handles the bidirectional case.
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
let max_exact = num_buckets / 2;
let relative_position = (0..query_length as u32)
.map(|i| {
(0..key_length as u32)
.map(|j| {
if i < j {
if j - i < max_exact {
j - i + num_buckets
} else {
let b = f32::log(
(j - i) as f32 / max_exact as f32,
self.relative_attention_max_distance as f32
/ max_exact as f32,
) * (num_buckets - max_exact) as f32;
u32::min(
max_exact + num_buckets + b as u32,
self.relative_attention_num_buckets as u32 - 1,
)
}
} else if i - j < max_exact {
i - j
} else {
let b = f32::log(
(i - j) as f32 / max_exact as f32,
self.relative_attention_max_distance as f32
/ max_exact as f32,
) * (num_buckets - max_exact) as f32;
max_exact + b as u32
}
})
.collect::<Vec<u32>>()
})
.collect::<Vec<Vec<_>>>();
let relative_buckets = Tensor::new(relative_position, q.device())?;
let position_bias = relative_attention_bias
.forward(&relative_buckets)?
.permute((2, 0, 1))?
.unsqueeze(0)?;
((scores + &position_bias)?, Some(position_bias))
// TODO: position_bias_masked?
}
},
};
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
@ -259,7 +286,7 @@ impl T5Attention {
.transpose(1, 2)?
.reshape((b_sz, seq_len, self.inner_dim))?;
let attn_output = self.o.forward(&attn_output)?;
Ok(attn_output)
Ok((attn_output, position_bias))
}
}
@ -280,11 +307,15 @@ impl T5LayerSelfAttention {
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
fn forward(
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
let normed_xs = self.layer_norm.forward(xs)?;
let ys = self.self_attention.forward(&normed_xs)?;
let (ys, position_bias) = self.self_attention.forward(&normed_xs, position_bias)?;
let ys = (xs + ys)?;
Ok(ys)
Ok((ys, position_bias))
}
}
@ -326,8 +357,12 @@ impl T5Block {
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.self_attn.forward(xs)?;
fn forward(
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias)?;
// TODO: clamp for f16?
if let Some(cross_attn) = &self.cross_attn {
xs = cross_attn.forward(&xs)?;
@ -335,7 +370,7 @@ impl T5Block {
}
let xs = self.ff.forward(&xs)?;
// TODO: clamp for f16?
Ok(xs)
Ok((xs, position_bias))
}
}
@ -368,8 +403,10 @@ impl T5Stack {
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
let mut hidden_states = input_embeds;
let mut position_bias = None;
for block in self.block.iter() {
hidden_states = block.forward(&hidden_states)?
(hidden_states, position_bias) =
block.forward(&hidden_states, position_bias.as_ref())?
}
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
Ok(hidden_states)