Fix the musicgen example. (#724)

* Fix the musicgen example.

* Retrieve the weights from the hub.
This commit is contained in:
Laurent Mazare
2023-09-03 15:50:39 +02:00
committed by GitHub
parent f7980e07e0
commit bbec527bb9
5 changed files with 62 additions and 134 deletions

View File

@ -16,11 +16,12 @@ mod nn;
mod t5_model;
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
use nn::VarBuilder;
use anyhow::{Error as E, Result};
use candle::DType;
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
const DTYPE: DType = DType::F32;
@ -33,11 +34,11 @@ struct Args {
/// The model weight file, in safetensor format.
#[arg(long)]
model: String,
model: Option<String>,
/// The tokenizer config.
#[arg(long)]
tokenizer: String,
tokenizer: Option<String>,
}
fn main() -> Result<()> {
@ -45,10 +46,26 @@ fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
let tokenizer = match args.tokenizer {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => Api::new()?
.model("facebook/musicgen-small".to_string())
.get("tokenizer.json")?,
};
let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
let model = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.repo(Repo::with_revision(
"facebook/musicgen-small".to_string(),
RepoType::Model,
"refs/pr/13".to_string(),
))
.get("model.safetensors")?,
};
let model = unsafe { candle::safetensors::MmapedFile::new(model)? };
let model = model.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
let config = GenConfig::small();