mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Cuda fix for starcoder. (#266)
* Cuda fix for starcoder. * Nicer output.
This commit is contained in:
@ -38,7 +38,10 @@ impl TextGeneration {
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.encode(prompt, true)
|
||||
@ -49,7 +52,6 @@ impl TextGeneration {
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let (context_size, past_len) = if self.model.config().use_cache && index > 0 {
|
||||
(1, tokens.len().saturating_sub(1))
|
||||
} else {
|
||||
@ -63,21 +65,17 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
println!("> {:?}", start_gen.elapsed());
|
||||
println!(
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
self.tokenizer
|
||||
.decode(vec![next_token], true)
|
||||
.map_err(E::msg)?
|
||||
);
|
||||
let token = self
|
||||
.tokenizer
|
||||
.decode(vec![next_token], true)
|
||||
.map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"{sample_len} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
"{sample_len} tokens generated ({:.3} token/s)",
|
||||
sample_len as f64 / dt.as_secs_f64(),
|
||||
self.tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user