mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Proper log buckets for t5. (#727)
* Proper log buckets for t5. * Properly pass the position bias.
This commit is contained in:
@ -202,7 +202,11 @@ impl T5Attention {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Option<Tensor>)> {
|
||||
// TODO: Apply the mask(s)?
|
||||
// TODO: kv caching.
|
||||
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
|
||||
@ -223,21 +227,43 @@ impl T5Attention {
|
||||
.contiguous()?;
|
||||
let scores = q.matmul(&k.t()?)?;
|
||||
|
||||
let scores = match &self.relative_attention_bias {
|
||||
None => scores,
|
||||
let (scores, position_bias) = match position_bias {
|
||||
Some(position_bias) => ((scores + position_bias)?, Some(position_bias.clone())),
|
||||
None => match &self.relative_attention_bias {
|
||||
None => (scores, None),
|
||||
Some(relative_attention_bias) => {
|
||||
let query_length = seq_len;
|
||||
let key_length = seq_len;
|
||||
// This only handles the bidirectional case.
|
||||
let num_buckets = self.relative_attention_num_buckets / 2;
|
||||
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
|
||||
let max_exact = num_buckets / 2;
|
||||
let relative_position = (0..query_length as u32)
|
||||
.map(|i| {
|
||||
(0..key_length as u32)
|
||||
.map(|j| {
|
||||
if i < j {
|
||||
j - i + num_buckets as u32
|
||||
if j - i < max_exact {
|
||||
j - i + num_buckets
|
||||
} else {
|
||||
let b = f32::log(
|
||||
(j - i) as f32 / max_exact as f32,
|
||||
self.relative_attention_max_distance as f32
|
||||
/ max_exact as f32,
|
||||
) * (num_buckets - max_exact) as f32;
|
||||
u32::min(
|
||||
max_exact + num_buckets + b as u32,
|
||||
self.relative_attention_num_buckets as u32 - 1,
|
||||
)
|
||||
}
|
||||
} else if i - j < max_exact {
|
||||
i - j
|
||||
} else {
|
||||
let b = f32::log(
|
||||
(i - j) as f32 / max_exact as f32,
|
||||
self.relative_attention_max_distance as f32
|
||||
/ max_exact as f32,
|
||||
) * (num_buckets - max_exact) as f32;
|
||||
max_exact + b as u32
|
||||
}
|
||||
})
|
||||
.collect::<Vec<u32>>()
|
||||
@ -248,9 +274,10 @@ impl T5Attention {
|
||||
.forward(&relative_buckets)?
|
||||
.permute((2, 0, 1))?
|
||||
.unsqueeze(0)?;
|
||||
(scores + position_bias)?
|
||||
((scores + &position_bias)?, Some(position_bias))
|
||||
// TODO: position_bias_masked?
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
|
||||
@ -259,7 +286,7 @@ impl T5Attention {
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, seq_len, self.inner_dim))?;
|
||||
let attn_output = self.o.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
Ok((attn_output, position_bias))
|
||||
}
|
||||
}
|
||||
|
||||
@ -280,11 +307,15 @@ impl T5LayerSelfAttention {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Option<Tensor>)> {
|
||||
let normed_xs = self.layer_norm.forward(xs)?;
|
||||
let ys = self.self_attention.forward(&normed_xs)?;
|
||||
let (ys, position_bias) = self.self_attention.forward(&normed_xs, position_bias)?;
|
||||
let ys = (xs + ys)?;
|
||||
Ok(ys)
|
||||
Ok((ys, position_bias))
|
||||
}
|
||||
}
|
||||
|
||||
@ -326,8 +357,12 @@ impl T5Block {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.self_attn.forward(xs)?;
|
||||
fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
position_bias: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Option<Tensor>)> {
|
||||
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias)?;
|
||||
// TODO: clamp for f16?
|
||||
if let Some(cross_attn) = &self.cross_attn {
|
||||
xs = cross_attn.forward(&xs)?;
|
||||
@ -335,7 +370,7 @@ impl T5Block {
|
||||
}
|
||||
let xs = self.ff.forward(&xs)?;
|
||||
// TODO: clamp for f16?
|
||||
Ok(xs)
|
||||
Ok((xs, position_bias))
|
||||
}
|
||||
}
|
||||
|
||||
@ -368,8 +403,10 @@ impl T5Stack {
|
||||
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
|
||||
|
||||
let mut hidden_states = input_embeds;
|
||||
let mut position_bias = None;
|
||||
for block in self.block.iter() {
|
||||
hidden_states = block.forward(&hidden_states)?
|
||||
(hidden_states, position_bias) =
|
||||
block.forward(&hidden_states, position_bias.as_ref())?
|
||||
}
|
||||
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
|
||||
Ok(hidden_states)
|
||||
|
Reference in New Issue
Block a user