Add the batcher. (#293)

This commit is contained in:
Laurent Mazare
2023-08-01 09:16:10 +01:00
committed by GitHub
parent fa98ca0c35
commit e1e8127f15
3 changed files with 111 additions and 18 deletions

View File

@ -319,26 +319,22 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
println!("dataset loaded and encoded: {} tokens", tokens.len());
let seq_len = model.config.seq_len;
let mut inputs = vec![];
let mut targets = vec![];
for start_idx in (0..tokens.len()).step_by(seq_len) {
let iter = (0..tokens.len()).step_by(seq_len).flat_map(|start_idx| {
if start_idx + seq_len + 1 > tokens.len() {
break;
}
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
let inputs_ = Tensor::new(&tokens[..seq_len], &device)?;
let targets_ = Tensor::new(&tokens[1..], &device)?;
inputs.push(inputs_);
targets.push(targets_);
if inputs.len() >= args.batch_size {
let inp = Tensor::stack(&inputs, 0)?;
let tgt = Tensor::stack(&targets, 0)?;
let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
println!("{}", loss.to_vec0::<f32>()?);
inputs.clear();
targets.clear();
None
} else {
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
let inputs = Tensor::new(&tokens[..seq_len], &device).ok();
let targets = Tensor::new(&tokens[1..], &device).ok();
inputs.zip(targets)
}
});
let batch_iter = candle_nn::dataset::Batcher::new2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
println!("{}", loss.to_vec0::<f32>()?);
}
Ok(())
}