diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs index f540e084..5942f9a6 100644 --- a/candle-examples/examples/llama_multiprocess/main.rs +++ b/candle-examples/examples/llama_multiprocess/main.rs @@ -14,12 +14,14 @@ use clap::{Parser, ValueEnum}; use candle::{DType, Device, Tensor}; use candle_transformers::generation::LogitsProcessor; +use candle_transformers::models::llama::LlamaEosToks; use cudarc::driver::safe::CudaDevice; use cudarc::nccl::safe::{Comm, Id}; use hf_hub::{api::sync::Api, Repo, RepoType}; use std::io::Write; use std::rc::Rc; + mod model; use model::{Config, Llama}; @@ -219,9 +221,16 @@ fn main() -> Result<()> { let next_token = logits_processor.sample(&logits)?; tokens.push(next_token); new_tokens.push(next_token); - if Some(next_token) == config.eos_token_id { - break; + match config.eos_token_id { + Some(LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => { + break; + } + Some(LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => { + break; + } + _ => (), } + if rank == 0 { if let Some(t) = tokenizer.next_token(next_token)? { print!("{t}");