mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Tmp state.
This commit is contained in:
@ -329,14 +329,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
println!("{tokens:?}");
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0.. {
|
||||
for index in 0..1 {
|
||||
if tokens.len() >= config.seq_len {
|
||||
break;
|
||||
}
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
// println!("Input {}", input);
|
||||
// println!("Input {}", input.to_device(&candle::Device::Cpu)?);
|
||||
let logits = model.forward(&input, index_pos)?;
|
||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||
|
Reference in New Issue
Block a user