Cuda fix for starcoder. (#266)

* Cuda fix for starcoder.

* Nicer output.
This commit is contained in:
Laurent Mazare
2023-07-28 12:13:41 +01:00
committed by GitHub
parent 54ccf94472
commit 68eab38de6
2 changed files with 13 additions and 15 deletions

View File

@ -38,7 +38,10 @@ impl TextGeneration {
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
print!("{prompt}");
std::io::stdout().flush()?;
let mut tokens = self
.tokenizer
.encode(prompt, true)
@ -49,7 +52,6 @@ impl TextGeneration {
let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let start_gen = std::time::Instant::now();
let (context_size, past_len) = if self.model.config().use_cache && index > 0 {
(1, tokens.len().saturating_sub(1))
} else {
@ -63,21 +65,17 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,
next_token,
self.tokenizer
.decode(vec![next_token], true)
.map_err(E::msg)?
);
let token = self
.tokenizer
.decode(vec![next_token], true)
.map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"{sample_len} tokens generated ({} token/s)\n----\n{}\n----",
"{sample_len} tokens generated ({:.3} token/s)",
sample_len as f64 / dt.as_secs_f64(),
self.tokenizer.decode(new_tokens, true).map_err(E::msg)?
);
Ok(())
}