From 6980774a914a9f8d012b1528336c62c6ff411827 Mon Sep 17 00:00:00 2001 From: Jack Shih Date: Fri, 1 Mar 2024 17:22:28 +0800 Subject: [PATCH] fix rwkv example eos token (#1785) --- candle-examples/examples/rwkv/main.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/candle-examples/examples/rwkv/main.rs b/candle-examples/examples/rwkv/main.rs index a2717170..e971a1cc 100644 --- a/candle-examples/examples/rwkv/main.rs +++ b/candle-examples/examples/rwkv/main.rs @@ -17,6 +17,8 @@ use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; +const EOS_TOKEN_ID: u32 = 261; + enum Model { M5(M5), Q5(Q5), @@ -104,6 +106,9 @@ impl TextGeneration { let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); generated_tokens += 1; + if next_token == EOS_TOKEN_ID || next_token == 0 { + break; + } print!("{}", self.tokenizer.decode(&[next_token])?); std::io::stdout().flush()?;