#![allow(dead_code)] // https://huggingface.co/facebook/musicgen-small/tree/main // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/modeling_musicgen.py // TODO: Add an offline mode. // TODO: Add a KV cache. #[cfg(feature = "mkl")] extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; mod encodec_model; mod musicgen_model; mod nn; use musicgen_model::{GenConfig, MusicgenForConditionalGeneration}; use anyhow::{Error as E, Result}; use candle::{DType, Tensor}; use candle_nn::VarBuilder; use clap::Parser; use hf_hub::{api::sync::Api, Repo, RepoType}; const DTYPE: DType = DType::F32; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, /// The model weight file, in safetensor format. #[arg(long)] model: Option, /// The tokenizer config. #[arg(long)] tokenizer: Option, #[arg( long, default_value = "90s rock song with loud guitars and heavy drums" )] prompt: String, } fn main() -> Result<()> { use tokenizers::Tokenizer; let args = Args::parse(); let device = candle_examples::device(args.cpu)?; let tokenizer = match args.tokenizer { Some(tokenizer) => std::path::PathBuf::from(tokenizer), None => Api::new()? .model("facebook/musicgen-small".to_string()) .get("tokenizer.json")?, }; let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; let tokenizer = tokenizer .with_padding(None) .with_truncation(None) .map_err(E::msg)?; let model = match args.model { Some(model) => std::path::PathBuf::from(model), None => Api::new()? .repo(Repo::with_revision( "facebook/musicgen-small".to_string(), RepoType::Model, "refs/pr/13".to_string(), )) .get("model.safetensors")?, }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DTYPE, &device)? }; let config = GenConfig::small(); let mut model = MusicgenForConditionalGeneration::load(vb, config)?; let tokens = tokenizer .encode(args.prompt.as_str(), true) .map_err(E::msg)? .get_ids() .to_vec(); println!("tokens: {tokens:?}"); let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; println!("{tokens:?}"); let embeds = model.text_encoder.forward(&tokens)?; println!("{embeds}"); Ok(()) }