Get the sampling to work.

This commit is contained in:
laurent
2025-04-03 14:58:44 +02:00
parent 3fb67e0c2c
commit e319cd78d9
2 changed files with 45 additions and 10 deletions

View File

@ -161,7 +161,7 @@ fn main() -> Result<()> {
}
};
let device = candle_examples::device(args.cpu)?;
let (_model, device) = {
let (mut model, device) = {
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
@ -176,8 +176,22 @@ fn main() -> Result<()> {
};
println!("loaded the model in {:?}", start.elapsed());
let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?;
println!("{prompt:?}");
if args.prompt.ends_with(".safetensors") {
let prompt = candle::safetensors::load(args.prompt, &device)?;
let tokens = prompt
.get("tokens")
.expect("no tokens in prompt")
.to_dtype(DType::U32)?;
let mask = prompt.get("mask").expect("no mask in prompt").clone();
println!("tokens:\n{tokens:?}");
println!("mask:\n{mask:?}");
let mut lp = candle_transformers::generation::LogitsProcessor::new(42, Some(0.8), None);
let frame = model.generate_frame(&tokens, &mask, 0, &mut lp)?;
println!("frame:\n{frame:?}");
} else {
let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?;
println!("{prompt:?}");
}
Ok(())
}