Musicgen text embeddings. (#726)

* Musicgen text embeddings.

* Bugfix for layer norm.

* Proper position bias.

* Expose the weights.
This commit is contained in:
Laurent Mazare
2023-09-03 19:27:48 +02:00
committed by GitHub
parent bbec527bb9
commit 26cd266e65
4 changed files with 85 additions and 15 deletions

View File

@ -18,7 +18,7 @@ mod t5_model;
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
use anyhow::{Error as E, Result};
use candle::DType;
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
@ -39,6 +39,12 @@ struct Args {
/// The tokenizer config.
#[arg(long)]
tokenizer: Option<String>,
#[arg(
long,
default_value = "90s rock song with loud guitars and heavy drums"
)]
prompt: String,
}
fn main() -> Result<()> {
@ -53,7 +59,10 @@ fn main() -> Result<()> {
.get("tokenizer.json")?,
};
let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
@ -69,6 +78,18 @@ fn main() -> Result<()> {
let model = model.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
let config = GenConfig::small();
let _model = MusicgenForConditionalGeneration::load(vb, config)?;
let model = MusicgenForConditionalGeneration::load(vb, config)?;
let tokens = tokenizer
.encode(args.prompt.as_str(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
println!("tokens: {tokens:?}");
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
println!("{tokens:?}");
let embeds = model.text_encoder.forward(&tokens)?;
println!("{embeds}");
Ok(())
}

View File

@ -370,9 +370,9 @@ impl MusicgenForCausalLM {
#[derive(Debug)]
pub struct MusicgenForConditionalGeneration {
text_encoder: crate::t5_model::T5EncoderModel,
audio_encoder: crate::encodec_model::EncodecModel,
decoder: MusicgenForCausalLM,
pub text_encoder: crate::t5_model::T5EncoderModel,
pub audio_encoder: crate::encodec_model::EncodecModel,
pub decoder: MusicgenForCausalLM,
cfg: GenConfig,
}

View File

@ -96,10 +96,9 @@ impl T5LayerNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let dtype = xs.dtype();
let xs_f32 = xs.to_dtype(DType::F32)?;
let xs2_f32 = (&xs_f32 * &xs_f32)?;
let sum_xs2_f32 = xs2_f32.sum_keepdim(D::Minus1)?;
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs.to_dtype(dtype)?;
let xs = xs.broadcast_mul(&self.weight)?;
Ok(xs)
@ -167,6 +166,9 @@ struct T5Attention {
n_heads: usize,
d_kv: usize,
relative_attention_bias: Option<Embedding>,
relative_attention_num_buckets: usize,
relative_attention_max_distance: usize,
inner_dim: usize,
}
impl T5Attention {
@ -194,6 +196,9 @@ impl T5Attention {
n_heads: cfg.num_heads,
d_kv: cfg.d_kv,
relative_attention_bias,
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
relative_attention_max_distance: cfg.relative_attention_max_distance,
inner_dim,
})
}
@ -206,17 +211,53 @@ impl T5Attention {
let v = self.v.forward(xs)?;
let q = q
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let v = v
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let scores = q.matmul(&k.t()?)?;
// TODO: position_bias_masked
let scores = match &self.relative_attention_bias {
None => scores,
Some(relative_attention_bias) => {
let query_length = seq_len;
let key_length = seq_len;
// This only handles the bidirectional case.
let num_buckets = self.relative_attention_num_buckets / 2;
let relative_position = (0..query_length as u32)
.map(|i| {
(0..key_length as u32)
.map(|j| {
if i < j {
j - i + num_buckets as u32
} else {
i - j
}
})
.collect::<Vec<u32>>()
})
.collect::<Vec<Vec<_>>>();
let relative_buckets = Tensor::new(relative_position, q.device())?;
let position_bias = relative_attention_bias
.forward(&relative_buckets)?
.permute((2, 0, 1))?
.unsqueeze(0)?;
(scores + position_bias)?
// TODO: position_bias_masked?
}
};
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output
.transpose(1, 2)?
.reshape((b_sz, seq_len, self.inner_dim))?;
let attn_output = self.o.forward(&attn_output)?;
Ok(attn_output)
}
@ -324,7 +365,7 @@ impl T5Stack {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let input_embeds = self.shared.as_ref().forward(input_ids)?;
let (_b_sz, _seq_len) = input_embeds.dims2()?;
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
let mut hidden_states = input_embeds;
for block in self.block.iter() {

View File

@ -29,6 +29,14 @@ impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
Self { weight, bias }
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}
impl super::Module for Linear {