mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Musicgen text embeddings. (#726)
* Musicgen text embeddings. * Bugfix for layer norm. * Proper position bias. * Expose the weights.
This commit is contained in:
@ -18,7 +18,7 @@ mod t5_model;
|
||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::DType;
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
@ -39,6 +39,12 @@ struct Args {
|
||||
/// The tokenizer config.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "90s rock song with loud guitars and heavy drums"
|
||||
)]
|
||||
prompt: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -53,7 +59,10 @@ fn main() -> Result<()> {
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
@ -69,6 +78,18 @@ fn main() -> Result<()> {
|
||||
let model = model.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
||||
let config = GenConfig::small();
|
||||
let _model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||
let model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt.as_str(), true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
println!("tokens: {tokens:?}");
|
||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
println!("{tokens:?}");
|
||||
let embeds = model.text_encoder.forward(&tokens)?;
|
||||
println!("{embeds}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user