mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Musicgen text embeddings. (#726)
* Musicgen text embeddings. * Bugfix for layer norm. * Proper position bias. * Expose the weights.
This commit is contained in:
@ -18,7 +18,7 @@ mod t5_model;
|
||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::DType;
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
@ -39,6 +39,12 @@ struct Args {
|
||||
/// The tokenizer config.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "90s rock song with loud guitars and heavy drums"
|
||||
)]
|
||||
prompt: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -53,7 +59,10 @@ fn main() -> Result<()> {
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
@ -69,6 +78,18 @@ fn main() -> Result<()> {
|
||||
let model = model.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
||||
let config = GenConfig::small();
|
||||
let _model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||
let model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt.as_str(), true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
println!("tokens: {tokens:?}");
|
||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
println!("{tokens:?}");
|
||||
let embeds = model.text_encoder.forward(&tokens)?;
|
||||
println!("{embeds}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -370,9 +370,9 @@ impl MusicgenForCausalLM {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MusicgenForConditionalGeneration {
|
||||
text_encoder: crate::t5_model::T5EncoderModel,
|
||||
audio_encoder: crate::encodec_model::EncodecModel,
|
||||
decoder: MusicgenForCausalLM,
|
||||
pub text_encoder: crate::t5_model::T5EncoderModel,
|
||||
pub audio_encoder: crate::encodec_model::EncodecModel,
|
||||
pub decoder: MusicgenForCausalLM,
|
||||
cfg: GenConfig,
|
||||
}
|
||||
|
||||
|
@ -96,10 +96,9 @@ impl T5LayerNorm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let dtype = xs.dtype();
|
||||
let xs_f32 = xs.to_dtype(DType::F32)?;
|
||||
let xs2_f32 = (&xs_f32 * &xs_f32)?;
|
||||
let sum_xs2_f32 = xs2_f32.sum_keepdim(D::Minus1)?;
|
||||
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
|
||||
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
|
||||
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
||||
let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
|
||||
let xs = xs.to_dtype(dtype)?;
|
||||
let xs = xs.broadcast_mul(&self.weight)?;
|
||||
Ok(xs)
|
||||
@ -167,6 +166,9 @@ struct T5Attention {
|
||||
n_heads: usize,
|
||||
d_kv: usize,
|
||||
relative_attention_bias: Option<Embedding>,
|
||||
relative_attention_num_buckets: usize,
|
||||
relative_attention_max_distance: usize,
|
||||
inner_dim: usize,
|
||||
}
|
||||
|
||||
impl T5Attention {
|
||||
@ -194,6 +196,9 @@ impl T5Attention {
|
||||
n_heads: cfg.num_heads,
|
||||
d_kv: cfg.d_kv,
|
||||
relative_attention_bias,
|
||||
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
|
||||
relative_attention_max_distance: cfg.relative_attention_max_distance,
|
||||
inner_dim,
|
||||
})
|
||||
}
|
||||
|
||||
@ -206,17 +211,53 @@ impl T5Attention {
|
||||
let v = self.v.forward(xs)?;
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?;
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?;
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let v = v
|
||||
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
|
||||
.transpose(1, 2)?;
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let scores = q.matmul(&k.t()?)?;
|
||||
// TODO: position_bias_masked
|
||||
|
||||
let scores = match &self.relative_attention_bias {
|
||||
None => scores,
|
||||
Some(relative_attention_bias) => {
|
||||
let query_length = seq_len;
|
||||
let key_length = seq_len;
|
||||
// This only handles the bidirectional case.
|
||||
let num_buckets = self.relative_attention_num_buckets / 2;
|
||||
let relative_position = (0..query_length as u32)
|
||||
.map(|i| {
|
||||
(0..key_length as u32)
|
||||
.map(|j| {
|
||||
if i < j {
|
||||
j - i + num_buckets as u32
|
||||
} else {
|
||||
i - j
|
||||
}
|
||||
})
|
||||
.collect::<Vec<u32>>()
|
||||
})
|
||||
.collect::<Vec<Vec<_>>>();
|
||||
let relative_buckets = Tensor::new(relative_position, q.device())?;
|
||||
let position_bias = relative_attention_bias
|
||||
.forward(&relative_buckets)?
|
||||
.permute((2, 0, 1))?
|
||||
.unsqueeze(0)?;
|
||||
(scores + position_bias)?
|
||||
// TODO: position_bias_masked?
|
||||
}
|
||||
};
|
||||
|
||||
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
let attn_output = attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, seq_len, self.inner_dim))?;
|
||||
let attn_output = self.o.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
}
|
||||
@ -324,7 +365,7 @@ impl T5Stack {
|
||||
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||
let (_b_sz, _seq_len) = input_embeds.dims2()?;
|
||||
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
|
||||
|
||||
let mut hidden_states = input_embeds;
|
||||
for block in self.block.iter() {
|
||||
|
Reference in New Issue
Block a user