mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
fix rwkv example eos token (#1785)
This commit is contained in:
@ -17,6 +17,8 @@ use candle_nn::VarBuilder;
|
|||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
|
||||||
|
const EOS_TOKEN_ID: u32 = 261;
|
||||||
|
|
||||||
enum Model {
|
enum Model {
|
||||||
M5(M5),
|
M5(M5),
|
||||||
Q5(Q5),
|
Q5(Q5),
|
||||||
@ -104,6 +106,9 @@ impl TextGeneration {
|
|||||||
let next_token = self.logits_processor.sample(&logits)?;
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
generated_tokens += 1;
|
generated_tokens += 1;
|
||||||
|
if next_token == EOS_TOKEN_ID || next_token == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
print!("{}", self.tokenizer.decode(&[next_token])?);
|
print!("{}", self.tokenizer.decode(&[next_token])?);
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user