diff --git a/candle-examples/examples/rwkv/main.rs b/candle-examples/examples/rwkv/main.rs index a2717170..e971a1cc 100644 --- a/candle-examples/examples/rwkv/main.rs +++ b/candle-examples/examples/rwkv/main.rs @@ -17,6 +17,8 @@ use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; +const EOS_TOKEN_ID: u32 = 261; + enum Model { M5(M5), Q5(Q5), @@ -104,6 +106,9 @@ impl TextGeneration { let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); generated_tokens += 1; + if next_token == EOS_TOKEN_ID || next_token == 0 { + break; + } print!("{}", self.tokenizer.decode(&[next_token])?); std::io::stdout().flush()?;