mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Sketching the musicgen model. (#66)
* Skeleton files for musicgen. * Add a musicgen model module. * Sketch the model loading. * Start adding the forward pass. * More forward pass. * Positional embeddings. * Forward for the decoder layers. * Add an empty function. * Fix the musicgen weight names. * More musicgen modeling. * Add the T5 loading bits. * Add the encodec config. * Add the encodec module hierarchy. * More Encodec modeling. * Encodec modeling. * Encodec modeling. * Add more to the encodec modeling. * Load the weights. * Populate the resnet blocks. * Also load the conv transpose weights. * Split musicgen in multiple files.
This commit is contained in:
59
candle-examples/examples/musicgen/main.rs
Normal file
59
candle-examples/examples/musicgen/main.rs
Normal file
@ -0,0 +1,59 @@
|
||||
#![allow(dead_code)]
|
||||
// https://huggingface.co/facebook/musicgen-small/tree/main
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/modeling_musicgen.py
|
||||
// TODO: Add an offline mode.
|
||||
// TODO: Add a KV cache.
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
mod encodec_model;
|
||||
mod musicgen_model;
|
||||
mod nn;
|
||||
mod t5_model;
|
||||
|
||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||
use nn::VarBuilder;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device};
|
||||
use clap::Parser;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The model weight file, in safetensor format.
|
||||
#[arg(long)]
|
||||
model: String,
|
||||
|
||||
/// The tokenizer config.
|
||||
#[arg(long)]
|
||||
tokenizer: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
let device = if args.cpu {
|
||||
Device::Cpu
|
||||
} else {
|
||||
Device::new_cuda(0)?
|
||||
};
|
||||
|
||||
let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
|
||||
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||
|
||||
let model = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
|
||||
let model = model.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
||||
let config = GenConfig::small();
|
||||
let _model = MusicgenForConditionalGeneration::load(&vb, config)?;
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user