mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Get the sampling to work.
This commit is contained in:
@ -161,7 +161,7 @@ fn main() -> Result<()> {
|
||||
}
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (_model, device) = {
|
||||
let (mut model, device) = {
|
||||
let dtype = device.bf16_default_to_f32();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
@ -176,8 +176,22 @@ fn main() -> Result<()> {
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?;
|
||||
println!("{prompt:?}");
|
||||
if args.prompt.ends_with(".safetensors") {
|
||||
let prompt = candle::safetensors::load(args.prompt, &device)?;
|
||||
let tokens = prompt
|
||||
.get("tokens")
|
||||
.expect("no tokens in prompt")
|
||||
.to_dtype(DType::U32)?;
|
||||
let mask = prompt.get("mask").expect("no mask in prompt").clone();
|
||||
println!("tokens:\n{tokens:?}");
|
||||
println!("mask:\n{mask:?}");
|
||||
let mut lp = candle_transformers::generation::LogitsProcessor::new(42, Some(0.8), None);
|
||||
let frame = model.generate_frame(&tokens, &mask, 0, &mut lp)?;
|
||||
println!("frame:\n{frame:?}");
|
||||
} else {
|
||||
let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?;
|
||||
println!("{prompt:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user