diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs index b105955c..340b8437 100644 --- a/candle-examples/examples/bigcode/main.rs +++ b/candle-examples/examples/bigcode/main.rs @@ -38,7 +38,10 @@ impl TextGeneration { } fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; println!("starting the inference loop"); + print!("{prompt}"); + std::io::stdout().flush()?; let mut tokens = self .tokenizer .encode(prompt, true) @@ -49,7 +52,6 @@ impl TextGeneration { let mut new_tokens = vec![]; let start_gen = std::time::Instant::now(); for index in 0..sample_len { - let start_gen = std::time::Instant::now(); let (context_size, past_len) = if self.model.config().use_cache && index > 0 { (1, tokens.len().saturating_sub(1)) } else { @@ -63,21 +65,17 @@ impl TextGeneration { let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); new_tokens.push(next_token); - println!("> {:?}", start_gen.elapsed()); - println!( - "{} token: {} '{}'", - index + 1, - next_token, - self.tokenizer - .decode(vec![next_token], true) - .map_err(E::msg)? - ); + let token = self + .tokenizer + .decode(vec![next_token], true) + .map_err(E::msg)?; + print!("{token}"); + std::io::stdout().flush()?; } let dt = start_gen.elapsed(); println!( - "{sample_len} tokens generated ({} token/s)\n----\n{}\n----", + "{sample_len} tokens generated ({:.3} token/s)", sample_len as f64 / dt.as_secs_f64(), - self.tokenizer.decode(new_tokens, true).map_err(E::msg)? ); Ok(()) } diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs index 3b8033bb..3f68a5be 100644 --- a/candle-examples/examples/bigcode/model.rs +++ b/candle-examples/examples/bigcode/model.rs @@ -22,11 +22,11 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { Ok(LayerNorm::new(weight, bias, eps)) } -fn make_causal_mask(t: usize) -> Result { +fn make_causal_mask(t: usize, device: &Device) -> Result { let mask: Vec<_> = (0..t) .flat_map(|i| (0..t).map(move |j| u32::from(j <= i))) .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; + let mask = Tensor::from_slice(&mask, (t, t), device)?; Ok(mask) } @@ -327,7 +327,7 @@ impl GPTBigCode { .collect::>>()?; let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?; let lm_head = linear(hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?; - let bias = make_causal_mask(cfg.max_position_embeddings)?; + let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?; Ok(Self { wte, wpe,