From 614f911e9e91eefafb55c7701fea712413625d4b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 1 Aug 2023 09:40:34 +0100 Subject: [PATCH] Add some batcher variants that handle errors. (#294) --- candle-examples/examples/llama2-c/main.rs | 8 +-- candle-nn/src/dataset.rs | 75 +++++++++++++++++++++++ 2 files changed, 79 insertions(+), 4 deletions(-) diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index ff2a53fe..2cf71bb5 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -324,12 +324,12 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> { None } else { let tokens = &tokens[start_idx..start_idx + seq_len + 1]; - let inputs = Tensor::new(&tokens[..seq_len], &device).ok(); - let targets = Tensor::new(&tokens[1..], &device).ok(); - inputs.zip(targets) + let inputs = Tensor::new(&tokens[..seq_len], &device); + let targets = Tensor::new(&tokens[1..], &device); + Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets)))) } }); - let batch_iter = candle_nn::dataset::Batcher::new2(iter).batch_size(args.batch_size); + let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size); for inp_tgt in batch_iter { let (inp, tgt) = inp_tgt?; let logits = model.forward(&inp, 0)?; diff --git a/candle-nn/src/dataset.rs b/candle-nn/src/dataset.rs index affe7b48..b74f1417 100644 --- a/candle-nn/src/dataset.rs +++ b/candle-nn/src/dataset.rs @@ -46,6 +46,26 @@ impl> Batcher> { } } +pub struct IterResult1>> { + inner: I, +} + +pub struct IterResult2>> { + inner: I, +} + +impl>> Batcher> { + pub fn new_r1(inner: I) -> Self { + Self::new(IterResult1 { inner }) + } +} + +impl>> Batcher> { + pub fn new_r2(inner: I) -> Self { + Self::new(IterResult2 { inner }) + } +} + impl> Iterator for Batcher> { type Item = Result; @@ -94,3 +114,58 @@ impl> Iterator for Batcher> { Some(xs.and_then(|xs| ys.map(|ys| (xs, ys)))) } } + +impl>> Iterator for Batcher> { + type Item = Result; + + fn next(&mut self) -> Option { + let mut items = Vec::with_capacity(self.batch_size); + for _i in 0..self.batch_size { + // We have two levels of inner here so that we can have two implementations of the + // Iterator trait that are different for Iter1 and Iter2. If rust gets better + // specialization at some point we can get rid of this. + match self.inner.inner.next() { + Some(item) => items.push(item), + None => { + if self.return_last_incomplete_batch { + break; + } + return None; + } + } + } + let items = items.into_iter().collect::>>(); + Some(items.and_then(|items| Tensor::stack(&items, 0))) + } +} + +impl>> Iterator for Batcher> { + type Item = Result<(Tensor, Tensor)>; + + fn next(&mut self) -> Option { + let mut xs = Vec::with_capacity(self.batch_size); + let mut ys = Vec::with_capacity(self.batch_size); + let mut errs = vec![]; + for _i in 0..self.batch_size { + match self.inner.inner.next() { + Some(Ok((x, y))) => { + xs.push(x); + ys.push(y) + } + Some(Err(err)) => errs.push(err), + None => { + if self.return_last_incomplete_batch { + break; + } + return None; + } + } + } + if !errs.is_empty() { + return Some(Err(errs.swap_remove(0))); + } + let xs = Tensor::stack(&xs, 0); + let ys = Tensor::stack(&ys, 0); + Some(xs.and_then(|xs| ys.map(|ys| (xs, ys)))) + } +}