mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Fix build issue in EOS Token in llama-multiprocess (#2420)
This commit is contained in:
@ -14,12 +14,14 @@ use clap::{Parser, ValueEnum};
|
|||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use candle_transformers::models::llama::LlamaEosToks;
|
||||||
use cudarc::driver::safe::CudaDevice;
|
use cudarc::driver::safe::CudaDevice;
|
||||||
use cudarc::nccl::safe::{Comm, Id};
|
use cudarc::nccl::safe::{Comm, Id};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
|
||||||
|
|
||||||
mod model;
|
mod model;
|
||||||
use model::{Config, Llama};
|
use model::{Config, Llama};
|
||||||
|
|
||||||
@ -219,9 +221,16 @@ fn main() -> Result<()> {
|
|||||||
let next_token = logits_processor.sample(&logits)?;
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
new_tokens.push(next_token);
|
new_tokens.push(next_token);
|
||||||
if Some(next_token) == config.eos_token_id {
|
match config.eos_token_id {
|
||||||
|
Some(LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
Some(LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
|
||||||
if rank == 0 {
|
if rank == 0 {
|
||||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||||
print!("{t}");
|
print!("{t}");
|
||||||
|
Reference in New Issue
Block a user