Fix build issue in EOS Token in llama-multiprocess (#2420)

This commit is contained in:
Hadi
2024-08-16 12:46:31 -04:00
committed by GitHub
parent 53ce65f706
commit 2b75dd9551

View File

@ -14,12 +14,14 @@ use clap::{Parser, ValueEnum};
use candle::{DType, Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::llama::LlamaEosToks;
use cudarc::driver::safe::CudaDevice;
use cudarc::nccl::safe::{Comm, Id};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
use std::rc::Rc;
mod model;
use model::{Config, Llama};
@ -219,9 +221,16 @@ fn main() -> Result<()> {
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
if Some(next_token) == config.eos_token_id {
match config.eos_token_id {
Some(LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
break;
}
Some(LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
break;
}
_ => (),
}
if rank == 0 {
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");