fix rwkv example eos token (#1785)

This commit is contained in:
Jack Shih
2024-03-01 17:22:28 +08:00
committed by GitHub
parent 64d4038e4f
commit 6980774a91

View File

@ -17,6 +17,8 @@ use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
const EOS_TOKEN_ID: u32 = 261;
enum Model {
M5(M5),
Q5(Q5),
@ -104,6 +106,9 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == EOS_TOKEN_ID || next_token == 0 {
break;
}
print!("{}", self.tokenizer.decode(&[next_token])?);
std::io::stdout().flush()?;