mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add some batcher variants that handle errors. (#294)
This commit is contained in:
@ -324,12 +324,12 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
None
|
||||
} else {
|
||||
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
|
||||
let inputs = Tensor::new(&tokens[..seq_len], &device).ok();
|
||||
let targets = Tensor::new(&tokens[1..], &device).ok();
|
||||
inputs.zip(targets)
|
||||
let inputs = Tensor::new(&tokens[..seq_len], &device);
|
||||
let targets = Tensor::new(&tokens[1..], &device);
|
||||
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
|
||||
}
|
||||
});
|
||||
let batch_iter = candle_nn::dataset::Batcher::new2(iter).batch_size(args.batch_size);
|
||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
for inp_tgt in batch_iter {
|
||||
let (inp, tgt) = inp_tgt?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
|
Reference in New Issue
Block a user