Cuda fix for starcoder. (#266)

* Cuda fix for starcoder.

* Nicer output.
This commit is contained in:
Laurent Mazare
2023-07-28 12:13:41 +01:00
committed by GitHub
parent 54ccf94472
commit 68eab38de6
2 changed files with 13 additions and 15 deletions

View File

@ -38,7 +38,10 @@ impl TextGeneration {
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
print!("{prompt}");
std::io::stdout().flush()?;
let mut tokens = self
.tokenizer
.encode(prompt, true)
@ -49,7 +52,6 @@ impl TextGeneration {
let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let start_gen = std::time::Instant::now();
let (context_size, past_len) = if self.model.config().use_cache && index > 0 {
(1, tokens.len().saturating_sub(1))
} else {
@ -63,21 +65,17 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,
next_token,
self.tokenizer
.decode(vec![next_token], true)
.map_err(E::msg)?
);
let token = self
.tokenizer
.decode(vec![next_token], true)
.map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"{sample_len} tokens generated ({} token/s)\n----\n{}\n----",
"{sample_len} tokens generated ({:.3} token/s)",
sample_len as f64 / dt.as_secs_f64(),
self.tokenizer.decode(new_tokens, true).map_err(E::msg)?
);
Ok(())
}

View File

@ -22,11 +22,11 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
Ok(LayerNorm::new(weight, bias, eps))
}
fn make_causal_mask(t: usize) -> Result<Tensor> {
fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u32::from(j <= i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
let mask = Tensor::from_slice(&mask, (t, t), device)?;
Ok(mask)
}
@ -327,7 +327,7 @@ impl GPTBigCode {
.collect::<Result<Vec<_>>>()?;
let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
let lm_head = linear(hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?;
let bias = make_causal_mask(cfg.max_position_embeddings)?;
let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?;
Ok(Self {
wte,
wpe,