mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Fix the musicgen example. (#724)
* Fix the musicgen example. * Retrieve the weights from the hub.
This commit is contained in:
@ -16,11 +16,12 @@ mod nn;
|
||||
mod t5_model;
|
||||
|
||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||
use nn::VarBuilder;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::DType;
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
@ -33,11 +34,11 @@ struct Args {
|
||||
|
||||
/// The model weight file, in safetensor format.
|
||||
#[arg(long)]
|
||||
model: String,
|
||||
model: Option<String>,
|
||||
|
||||
/// The tokenizer config.
|
||||
#[arg(long)]
|
||||
tokenizer: String,
|
||||
tokenizer: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -45,10 +46,26 @@ fn main() -> Result<()> {
|
||||
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
|
||||
let tokenizer = match args.tokenizer {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => Api::new()?
|
||||
.model("facebook/musicgen-small".to_string())
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||
|
||||
let model = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => Api::new()?
|
||||
.repo(Repo::with_revision(
|
||||
"facebook/musicgen-small".to_string(),
|
||||
RepoType::Model,
|
||||
"refs/pr/13".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let model = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||
let model = model.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
||||
let config = GenConfig::small();
|
||||
|
Reference in New Issue
Block a user