mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Proper log buckets for t5. (#727)
* Proper log buckets for t5. * Properly pass the position bias.
This commit is contained in:
@ -202,7 +202,11 @@ impl T5Attention {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(
|
||||||
|
&self,
|
||||||
|
xs: &Tensor,
|
||||||
|
position_bias: Option<&Tensor>,
|
||||||
|
) -> Result<(Tensor, Option<Tensor>)> {
|
||||||
// TODO: Apply the mask(s)?
|
// TODO: Apply the mask(s)?
|
||||||
// TODO: kv caching.
|
// TODO: kv caching.
|
||||||
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
|
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
|
||||||
@ -223,21 +227,43 @@ impl T5Attention {
|
|||||||
.contiguous()?;
|
.contiguous()?;
|
||||||
let scores = q.matmul(&k.t()?)?;
|
let scores = q.matmul(&k.t()?)?;
|
||||||
|
|
||||||
let scores = match &self.relative_attention_bias {
|
let (scores, position_bias) = match position_bias {
|
||||||
None => scores,
|
Some(position_bias) => ((scores + position_bias)?, Some(position_bias.clone())),
|
||||||
|
None => match &self.relative_attention_bias {
|
||||||
|
None => (scores, None),
|
||||||
Some(relative_attention_bias) => {
|
Some(relative_attention_bias) => {
|
||||||
let query_length = seq_len;
|
let query_length = seq_len;
|
||||||
let key_length = seq_len;
|
let key_length = seq_len;
|
||||||
// This only handles the bidirectional case.
|
// This only handles the bidirectional case.
|
||||||
let num_buckets = self.relative_attention_num_buckets / 2;
|
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
|
||||||
|
let max_exact = num_buckets / 2;
|
||||||
let relative_position = (0..query_length as u32)
|
let relative_position = (0..query_length as u32)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
(0..key_length as u32)
|
(0..key_length as u32)
|
||||||
.map(|j| {
|
.map(|j| {
|
||||||
if i < j {
|
if i < j {
|
||||||
j - i + num_buckets as u32
|
if j - i < max_exact {
|
||||||
|
j - i + num_buckets
|
||||||
} else {
|
} else {
|
||||||
|
let b = f32::log(
|
||||||
|
(j - i) as f32 / max_exact as f32,
|
||||||
|
self.relative_attention_max_distance as f32
|
||||||
|
/ max_exact as f32,
|
||||||
|
) * (num_buckets - max_exact) as f32;
|
||||||
|
u32::min(
|
||||||
|
max_exact + num_buckets + b as u32,
|
||||||
|
self.relative_attention_num_buckets as u32 - 1,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else if i - j < max_exact {
|
||||||
i - j
|
i - j
|
||||||
|
} else {
|
||||||
|
let b = f32::log(
|
||||||
|
(i - j) as f32 / max_exact as f32,
|
||||||
|
self.relative_attention_max_distance as f32
|
||||||
|
/ max_exact as f32,
|
||||||
|
) * (num_buckets - max_exact) as f32;
|
||||||
|
max_exact + b as u32
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect::<Vec<u32>>()
|
.collect::<Vec<u32>>()
|
||||||
@ -248,9 +274,10 @@ impl T5Attention {
|
|||||||
.forward(&relative_buckets)?
|
.forward(&relative_buckets)?
|
||||||
.permute((2, 0, 1))?
|
.permute((2, 0, 1))?
|
||||||
.unsqueeze(0)?;
|
.unsqueeze(0)?;
|
||||||
(scores + position_bias)?
|
((scores + &position_bias)?, Some(position_bias))
|
||||||
// TODO: position_bias_masked?
|
// TODO: position_bias_masked?
|
||||||
}
|
}
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
|
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
|
||||||
@ -259,7 +286,7 @@ impl T5Attention {
|
|||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.reshape((b_sz, seq_len, self.inner_dim))?;
|
.reshape((b_sz, seq_len, self.inner_dim))?;
|
||||||
let attn_output = self.o.forward(&attn_output)?;
|
let attn_output = self.o.forward(&attn_output)?;
|
||||||
Ok(attn_output)
|
Ok((attn_output, position_bias))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -280,11 +307,15 @@ impl T5LayerSelfAttention {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(
|
||||||
|
&self,
|
||||||
|
xs: &Tensor,
|
||||||
|
position_bias: Option<&Tensor>,
|
||||||
|
) -> Result<(Tensor, Option<Tensor>)> {
|
||||||
let normed_xs = self.layer_norm.forward(xs)?;
|
let normed_xs = self.layer_norm.forward(xs)?;
|
||||||
let ys = self.self_attention.forward(&normed_xs)?;
|
let (ys, position_bias) = self.self_attention.forward(&normed_xs, position_bias)?;
|
||||||
let ys = (xs + ys)?;
|
let ys = (xs + ys)?;
|
||||||
Ok(ys)
|
Ok((ys, position_bias))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -326,8 +357,12 @@ impl T5Block {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(
|
||||||
let mut xs = self.self_attn.forward(xs)?;
|
&self,
|
||||||
|
xs: &Tensor,
|
||||||
|
position_bias: Option<&Tensor>,
|
||||||
|
) -> Result<(Tensor, Option<Tensor>)> {
|
||||||
|
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias)?;
|
||||||
// TODO: clamp for f16?
|
// TODO: clamp for f16?
|
||||||
if let Some(cross_attn) = &self.cross_attn {
|
if let Some(cross_attn) = &self.cross_attn {
|
||||||
xs = cross_attn.forward(&xs)?;
|
xs = cross_attn.forward(&xs)?;
|
||||||
@ -335,7 +370,7 @@ impl T5Block {
|
|||||||
}
|
}
|
||||||
let xs = self.ff.forward(&xs)?;
|
let xs = self.ff.forward(&xs)?;
|
||||||
// TODO: clamp for f16?
|
// TODO: clamp for f16?
|
||||||
Ok(xs)
|
Ok((xs, position_bias))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -368,8 +403,10 @@ impl T5Stack {
|
|||||||
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
|
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
|
||||||
|
|
||||||
let mut hidden_states = input_embeds;
|
let mut hidden_states = input_embeds;
|
||||||
|
let mut position_bias = None;
|
||||||
for block in self.block.iter() {
|
for block in self.block.iter() {
|
||||||
hidden_states = block.forward(&hidden_states)?
|
(hidden_states, position_bias) =
|
||||||
|
block.forward(&hidden_states, position_bias.as_ref())?
|
||||||
}
|
}
|
||||||
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
|
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
|
||||||
Ok(hidden_states)
|
Ok(hidden_states)
|
||||||
|
Reference in New Issue
Block a user