diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs index 8dcef6d2..3794c22d 100644 --- a/candle-examples/examples/musicgen/main.rs +++ b/candle-examples/examples/musicgen/main.rs @@ -18,7 +18,7 @@ mod t5_model; use musicgen_model::{GenConfig, MusicgenForConditionalGeneration}; use anyhow::{Error as E, Result}; -use candle::DType; +use candle::{DType, Tensor}; use candle_nn::VarBuilder; use clap::Parser; use hf_hub::{api::sync::Api, Repo, RepoType}; @@ -39,6 +39,12 @@ struct Args { /// The tokenizer config. #[arg(long)] tokenizer: Option, + + #[arg( + long, + default_value = "90s rock song with loud guitars and heavy drums" + )] + prompt: String, } fn main() -> Result<()> { @@ -53,7 +59,10 @@ fn main() -> Result<()> { .get("tokenizer.json")?, }; let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; - let _tokenizer = tokenizer.with_padding(None).with_truncation(None); + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; let model = match args.model { Some(model) => std::path::PathBuf::from(model), @@ -69,6 +78,18 @@ fn main() -> Result<()> { let model = model.deserialize()?; let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device); let config = GenConfig::small(); - let _model = MusicgenForConditionalGeneration::load(vb, config)?; + let model = MusicgenForConditionalGeneration::load(vb, config)?; + + let tokens = tokenizer + .encode(args.prompt.as_str(), true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + println!("tokens: {tokens:?}"); + let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; + println!("{tokens:?}"); + let embeds = model.text_encoder.forward(&tokens)?; + println!("{embeds}"); + Ok(()) } diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index 751e0226..7e272fd7 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -370,9 +370,9 @@ impl MusicgenForCausalLM { #[derive(Debug)] pub struct MusicgenForConditionalGeneration { - text_encoder: crate::t5_model::T5EncoderModel, - audio_encoder: crate::encodec_model::EncodecModel, - decoder: MusicgenForCausalLM, + pub text_encoder: crate::t5_model::T5EncoderModel, + pub audio_encoder: crate::encodec_model::EncodecModel, + pub decoder: MusicgenForCausalLM, cfg: GenConfig, } diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index 33b11b95..607b5c93 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -96,10 +96,9 @@ impl T5LayerNorm { fn forward(&self, xs: &Tensor) -> Result { let dtype = xs.dtype(); let xs_f32 = xs.to_dtype(DType::F32)?; - let xs2_f32 = (&xs_f32 * &xs_f32)?; - let sum_xs2_f32 = xs2_f32.sum_keepdim(D::Minus1)?; - let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?; - let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?; + // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?; + let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; let xs = xs.to_dtype(dtype)?; let xs = xs.broadcast_mul(&self.weight)?; Ok(xs) @@ -167,6 +166,9 @@ struct T5Attention { n_heads: usize, d_kv: usize, relative_attention_bias: Option, + relative_attention_num_buckets: usize, + relative_attention_max_distance: usize, + inner_dim: usize, } impl T5Attention { @@ -194,6 +196,9 @@ impl T5Attention { n_heads: cfg.num_heads, d_kv: cfg.d_kv, relative_attention_bias, + relative_attention_num_buckets: cfg.relative_attention_num_buckets, + relative_attention_max_distance: cfg.relative_attention_max_distance, + inner_dim, }) } @@ -206,17 +211,53 @@ impl T5Attention { let v = self.v.forward(xs)?; let q = q .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let k = k .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let v = v .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let scores = q.matmul(&k.t()?)?; - // TODO: position_bias_masked + + let scores = match &self.relative_attention_bias { + None => scores, + Some(relative_attention_bias) => { + let query_length = seq_len; + let key_length = seq_len; + // This only handles the bidirectional case. + let num_buckets = self.relative_attention_num_buckets / 2; + let relative_position = (0..query_length as u32) + .map(|i| { + (0..key_length as u32) + .map(|j| { + if i < j { + j - i + num_buckets as u32 + } else { + i - j + } + }) + .collect::>() + }) + .collect::>>(); + let relative_buckets = Tensor::new(relative_position, q.device())?; + let position_bias = relative_attention_bias + .forward(&relative_buckets)? + .permute((2, 0, 1))? + .unsqueeze(0)?; + (scores + position_bias)? + // TODO: position_bias_masked? + } + }; + let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?; let attn_output = attn_weights.matmul(&v)?; + let attn_output = attn_output + .transpose(1, 2)? + .reshape((b_sz, seq_len, self.inner_dim))?; let attn_output = self.o.forward(&attn_output)?; Ok(attn_output) } @@ -324,7 +365,7 @@ impl T5Stack { fn forward(&self, input_ids: &Tensor) -> Result { let input_embeds = self.shared.as_ref().forward(input_ids)?; - let (_b_sz, _seq_len) = input_embeds.dims2()?; + let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?); let mut hidden_states = input_embeds; for block in self.block.iter() { diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index 14250ed2..7028f68c 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -29,6 +29,14 @@ impl Linear { pub fn new(weight: Tensor, bias: Option) -> Self { Self { weight, bias } } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } } impl super::Module for Linear {