Musicgen text embeddings. (#726)

* Musicgen text embeddings.

* Bugfix for layer norm.

* Proper position bias.

* Expose the weights.
This commit is contained in:
Laurent Mazare
2023-09-03 19:27:48 +02:00
committed by GitHub
parent bbec527bb9
commit 26cd266e65
4 changed files with 85 additions and 15 deletions

View File

@ -96,10 +96,9 @@ impl T5LayerNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let dtype = xs.dtype();
let xs_f32 = xs.to_dtype(DType::F32)?;
let xs2_f32 = (&xs_f32 * &xs_f32)?;
let sum_xs2_f32 = xs2_f32.sum_keepdim(D::Minus1)?;
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs.to_dtype(dtype)?;
let xs = xs.broadcast_mul(&self.weight)?;
Ok(xs)
@ -167,6 +166,9 @@ struct T5Attention {
n_heads: usize,
d_kv: usize,
relative_attention_bias: Option<Embedding>,
relative_attention_num_buckets: usize,
relative_attention_max_distance: usize,
inner_dim: usize,
}
impl T5Attention {
@ -194,6 +196,9 @@ impl T5Attention {
n_heads: cfg.num_heads,
d_kv: cfg.d_kv,
relative_attention_bias,
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
relative_attention_max_distance: cfg.relative_attention_max_distance,
inner_dim,
})
}
@ -206,17 +211,53 @@ impl T5Attention {
let v = self.v.forward(xs)?;
let q = q
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let v = v
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let scores = q.matmul(&k.t()?)?;
// TODO: position_bias_masked
let scores = match &self.relative_attention_bias {
None => scores,
Some(relative_attention_bias) => {
let query_length = seq_len;
let key_length = seq_len;
// This only handles the bidirectional case.
let num_buckets = self.relative_attention_num_buckets / 2;
let relative_position = (0..query_length as u32)
.map(|i| {
(0..key_length as u32)
.map(|j| {
if i < j {
j - i + num_buckets as u32
} else {
i - j
}
})
.collect::<Vec<u32>>()
})
.collect::<Vec<Vec<_>>>();
let relative_buckets = Tensor::new(relative_position, q.device())?;
let position_bias = relative_attention_bias
.forward(&relative_buckets)?
.permute((2, 0, 1))?
.unsqueeze(0)?;
(scores + position_bias)?
// TODO: position_bias_masked?
}
};
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output
.transpose(1, 2)?
.reshape((b_sz, seq_len, self.inner_dim))?;
let attn_output = self.o.forward(&attn_output)?;
Ok(attn_output)
}
@ -324,7 +365,7 @@ impl T5Stack {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let input_embeds = self.shared.as_ref().forward(input_ids)?;
let (_b_sz, _seq_len) = input_embeds.dims2()?;
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
let mut hidden_states = input_embeds;
for block in self.block.iter() {