Add some batcher variants that handle errors. (#294)

This commit is contained in:
Laurent Mazare
2023-08-01 09:40:34 +01:00
committed by GitHub
parent e1e8127f15
commit 614f911e9e
2 changed files with 79 additions and 4 deletions

View File

@ -324,12 +324,12 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
None
} else {
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
let inputs = Tensor::new(&tokens[..seq_len], &device).ok();
let targets = Tensor::new(&tokens[1..], &device).ok();
inputs.zip(targets)
let inputs = Tensor::new(&tokens[..seq_len], &device);
let targets = Tensor::new(&tokens[1..], &device);
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
}
});
let batch_iter = candle_nn::dataset::Batcher::new2(iter).batch_size(args.batch_size);
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?;