Fixed TP sharded version.

This commit is contained in:
Nicolas Patry
2023-07-25 19:29:03 +00:00
parent 1735e4831e
commit ed58de7551
4 changed files with 63 additions and 37 deletions

View File

@ -247,20 +247,24 @@ fn main() -> Result<()> {
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,
next_token,
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
);
if rank == 0 {
println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,
next_token,
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
);
}
}
let dt = start_gen.elapsed();
println!(
"{} tokens generated ({} token/s)\n----\n{}\n----",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
tokenizer.decode(new_tokens, true).map_err(E::msg)?
);
if rank == 0 {
println!(
"{} tokens generated ({} token/s)\n----\n{}\n----",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
tokenizer.decode(new_tokens, true).map_err(E::msg)?
);
}
Ok(())
}