mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Fix build issue in EOS Token in llama-multiprocess (#2420)
This commit is contained in:
@ -14,12 +14,14 @@ use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::models::llama::LlamaEosToks;
|
||||
use cudarc::driver::safe::CudaDevice;
|
||||
use cudarc::nccl::safe::{Comm, Id};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
use std::rc::Rc;
|
||||
|
||||
|
||||
mod model;
|
||||
use model::{Config, Llama};
|
||||
|
||||
@ -219,9 +221,16 @@ fn main() -> Result<()> {
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
if Some(next_token) == config.eos_token_id {
|
||||
break;
|
||||
match config.eos_token_id {
|
||||
Some(LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
|
||||
break;
|
||||
}
|
||||
Some(LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
|
||||
break;
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
|
||||
if rank == 0 {
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
|
Reference in New Issue
Block a user