Add some batcher variants that handle errors. (#294)

This commit is contained in:
Laurent Mazare
2023-08-01 09:40:34 +01:00
committed by GitHub
parent e1e8127f15
commit 614f911e9e
2 changed files with 79 additions and 4 deletions

View File

@ -324,12 +324,12 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
None
} else {
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
let inputs = Tensor::new(&tokens[..seq_len], &device).ok();
let targets = Tensor::new(&tokens[1..], &device).ok();
inputs.zip(targets)
let inputs = Tensor::new(&tokens[..seq_len], &device);
let targets = Tensor::new(&tokens[1..], &device);
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
}
});
let batch_iter = candle_nn::dataset::Batcher::new2(iter).batch_size(args.batch_size);
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?;

View File

@ -46,6 +46,26 @@ impl<I: Iterator<Item = (Tensor, Tensor)>> Batcher<Iter2<I>> {
}
}
pub struct IterResult1<I: Iterator<Item = Result<Tensor>>> {
inner: I,
}
pub struct IterResult2<I: Iterator<Item = Result<(Tensor, Tensor)>>> {
inner: I,
}
impl<I: Iterator<Item = Result<Tensor>>> Batcher<IterResult1<I>> {
pub fn new_r1(inner: I) -> Self {
Self::new(IterResult1 { inner })
}
}
impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Batcher<IterResult2<I>> {
pub fn new_r2(inner: I) -> Self {
Self::new(IterResult2 { inner })
}
}
impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> {
type Item = Result<Tensor>;
@ -94,3 +114,58 @@ impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> {
Some(xs.and_then(|xs| ys.map(|ys| (xs, ys))))
}
}
impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> {
type Item = Result<Tensor>;
fn next(&mut self) -> Option<Self::Item> {
let mut items = Vec::with_capacity(self.batch_size);
for _i in 0..self.batch_size {
// We have two levels of inner here so that we can have two implementations of the
// Iterator trait that are different for Iter1 and Iter2. If rust gets better
// specialization at some point we can get rid of this.
match self.inner.inner.next() {
Some(item) => items.push(item),
None => {
if self.return_last_incomplete_batch {
break;
}
return None;
}
}
}
let items = items.into_iter().collect::<Result<Vec<Tensor>>>();
Some(items.and_then(|items| Tensor::stack(&items, 0)))
}
}
impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResult2<I>> {
type Item = Result<(Tensor, Tensor)>;
fn next(&mut self) -> Option<Self::Item> {
let mut xs = Vec::with_capacity(self.batch_size);
let mut ys = Vec::with_capacity(self.batch_size);
let mut errs = vec![];
for _i in 0..self.batch_size {
match self.inner.inner.next() {
Some(Ok((x, y))) => {
xs.push(x);
ys.push(y)
}
Some(Err(err)) => errs.push(err),
None => {
if self.return_last_incomplete_batch {
break;
}
return None;
}
}
}
if !errs.is_empty() {
return Some(Err(errs.swap_remove(0)));
}
let xs = Tensor::stack(&xs, 0);
let ys = Tensor::stack(&ys, 0);
Some(xs.and_then(|xs| ys.map(|ys| (xs, ys))))
}
}