diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index f9bbe149..ff2a53fe 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -319,26 +319,22 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> { println!("dataset loaded and encoded: {} tokens", tokens.len()); let seq_len = model.config.seq_len; - let mut inputs = vec![]; - let mut targets = vec![]; - for start_idx in (0..tokens.len()).step_by(seq_len) { + let iter = (0..tokens.len()).step_by(seq_len).flat_map(|start_idx| { if start_idx + seq_len + 1 > tokens.len() { - break; - } - let tokens = &tokens[start_idx..start_idx + seq_len + 1]; - let inputs_ = Tensor::new(&tokens[..seq_len], &device)?; - let targets_ = Tensor::new(&tokens[1..], &device)?; - inputs.push(inputs_); - targets.push(targets_); - if inputs.len() >= args.batch_size { - let inp = Tensor::stack(&inputs, 0)?; - let tgt = Tensor::stack(&targets, 0)?; - let logits = model.forward(&inp, 0)?; - let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; - println!("{}", loss.to_vec0::()?); - inputs.clear(); - targets.clear(); + None + } else { + let tokens = &tokens[start_idx..start_idx + seq_len + 1]; + let inputs = Tensor::new(&tokens[..seq_len], &device).ok(); + let targets = Tensor::new(&tokens[1..], &device).ok(); + inputs.zip(targets) } + }); + let batch_iter = candle_nn::dataset::Batcher::new2(iter).batch_size(args.batch_size); + for inp_tgt in batch_iter { + let (inp, tgt) = inp_tgt?; + let logits = model.forward(&inp, 0)?; + let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; + println!("{}", loss.to_vec0::()?); } Ok(()) } diff --git a/candle-nn/src/dataset.rs b/candle-nn/src/dataset.rs new file mode 100644 index 00000000..affe7b48 --- /dev/null +++ b/candle-nn/src/dataset.rs @@ -0,0 +1,96 @@ +use candle::{Result, Tensor}; + +pub struct Batcher { + inner: I, + batch_size: usize, + return_last_incomplete_batch: bool, +} + +impl Batcher { + fn new(inner: I) -> Self { + Self { + inner, + batch_size: 16, + return_last_incomplete_batch: false, + } + } + + pub fn batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + pub fn return_last_incomplete_batch(mut self, r: bool) -> Self { + self.return_last_incomplete_batch = r; + self + } +} + +pub struct Iter1> { + inner: I, +} + +pub struct Iter2> { + inner: I, +} + +impl> Batcher> { + pub fn new1(inner: I) -> Self { + Self::new(Iter1 { inner }) + } +} + +impl> Batcher> { + pub fn new2(inner: I) -> Self { + Self::new(Iter2 { inner }) + } +} + +impl> Iterator for Batcher> { + type Item = Result; + + fn next(&mut self) -> Option { + let mut items = Vec::with_capacity(self.batch_size); + for _i in 0..self.batch_size { + // We have two levels of inner here so that we can have two implementations of the + // Iterator trait that are different for Iter1 and Iter2. If rust gets better + // specialization at some point we can get rid of this. + match self.inner.inner.next() { + Some(item) => items.push(item), + None => { + if self.return_last_incomplete_batch { + break; + } + return None; + } + } + } + Some(Tensor::stack(&items, 0)) + } +} + +impl> Iterator for Batcher> { + type Item = Result<(Tensor, Tensor)>; + + fn next(&mut self) -> Option { + let mut xs = Vec::with_capacity(self.batch_size); + let mut ys = Vec::with_capacity(self.batch_size); + for _i in 0..self.batch_size { + match self.inner.inner.next() { + Some((x, y)) => { + xs.push(x); + ys.push(y) + } + None => { + if self.return_last_incomplete_batch { + break; + } + return None; + } + } + } + let xs = Tensor::stack(&xs, 0); + let ys = Tensor::stack(&ys, 0); + Some(xs.and_then(|xs| ys.map(|ys| (xs, ys)))) + } +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index d0b62dbb..e8086238 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -2,6 +2,7 @@ // error type if needed or add some specialized cases on the candle-core side. pub mod activation; pub mod conv; +pub mod dataset; pub mod embedding; pub mod init; pub mod layer_norm;