mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add frame generation.
This commit is contained in:
@ -34,7 +34,7 @@ struct Args {
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
#[arg(long, default_value = "[0]Hey how are you doing?")]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
@ -76,6 +76,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
weights: Option<String>,
|
||||
|
||||
/// The mimi model weight file, in safetensor format.
|
||||
#[arg(long)]
|
||||
mimi_weights: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
@ -139,9 +143,14 @@ fn main() -> Result<()> {
|
||||
.model("meta-llama/Llama-3.2-1B".to_string())
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
|
||||
let mimi_filename = match args.mimi_weights {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => Api::new()?
|
||||
.model("kyutai/mimi".to_string())
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let _tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = match args.config {
|
||||
@ -152,14 +161,23 @@ fn main() -> Result<()> {
|
||||
}
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (_model, _device) = {
|
||||
let dtype = DType::F32;
|
||||
let (_model, device) = {
|
||||
let dtype = device.bf16_default_to_f32();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
(model, device)
|
||||
};
|
||||
let _mimi_model = {
|
||||
use candle_transformers::models::mimi;
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[mimi_filename], DType::F32, &device)? };
|
||||
let config = mimi::Config::v0_1(None);
|
||||
mimi::Model::new(config, vb)?
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?;
|
||||
println!("{prompt:?}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user