mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add the batcher. (#293)
This commit is contained in:
@ -319,26 +319,22 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
|||||||
println!("dataset loaded and encoded: {} tokens", tokens.len());
|
println!("dataset loaded and encoded: {} tokens", tokens.len());
|
||||||
|
|
||||||
let seq_len = model.config.seq_len;
|
let seq_len = model.config.seq_len;
|
||||||
let mut inputs = vec![];
|
let iter = (0..tokens.len()).step_by(seq_len).flat_map(|start_idx| {
|
||||||
let mut targets = vec![];
|
|
||||||
for start_idx in (0..tokens.len()).step_by(seq_len) {
|
|
||||||
if start_idx + seq_len + 1 > tokens.len() {
|
if start_idx + seq_len + 1 > tokens.len() {
|
||||||
break;
|
None
|
||||||
}
|
} else {
|
||||||
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
|
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
|
||||||
let inputs_ = Tensor::new(&tokens[..seq_len], &device)?;
|
let inputs = Tensor::new(&tokens[..seq_len], &device).ok();
|
||||||
let targets_ = Tensor::new(&tokens[1..], &device)?;
|
let targets = Tensor::new(&tokens[1..], &device).ok();
|
||||||
inputs.push(inputs_);
|
inputs.zip(targets)
|
||||||
targets.push(targets_);
|
|
||||||
if inputs.len() >= args.batch_size {
|
|
||||||
let inp = Tensor::stack(&inputs, 0)?;
|
|
||||||
let tgt = Tensor::stack(&targets, 0)?;
|
|
||||||
let logits = model.forward(&inp, 0)?;
|
|
||||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
|
||||||
println!("{}", loss.to_vec0::<f32>()?);
|
|
||||||
inputs.clear();
|
|
||||||
targets.clear();
|
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
let batch_iter = candle_nn::dataset::Batcher::new2(iter).batch_size(args.batch_size);
|
||||||
|
for inp_tgt in batch_iter {
|
||||||
|
let (inp, tgt) = inp_tgt?;
|
||||||
|
let logits = model.forward(&inp, 0)?;
|
||||||
|
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||||
|
println!("{}", loss.to_vec0::<f32>()?);
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
96
candle-nn/src/dataset.rs
Normal file
96
candle-nn/src/dataset.rs
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
use candle::{Result, Tensor};
|
||||||
|
|
||||||
|
pub struct Batcher<I> {
|
||||||
|
inner: I,
|
||||||
|
batch_size: usize,
|
||||||
|
return_last_incomplete_batch: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> Batcher<I> {
|
||||||
|
fn new(inner: I) -> Self {
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
batch_size: 16,
|
||||||
|
return_last_incomplete_batch: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn batch_size(mut self, batch_size: usize) -> Self {
|
||||||
|
self.batch_size = batch_size;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn return_last_incomplete_batch(mut self, r: bool) -> Self {
|
||||||
|
self.return_last_incomplete_batch = r;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Iter1<I: Iterator<Item = Tensor>> {
|
||||||
|
inner: I,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Iter2<I: Iterator<Item = (Tensor, Tensor)>> {
|
||||||
|
inner: I,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I: Iterator<Item = Tensor>> Batcher<Iter1<I>> {
|
||||||
|
pub fn new1(inner: I) -> Self {
|
||||||
|
Self::new(Iter1 { inner })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I: Iterator<Item = (Tensor, Tensor)>> Batcher<Iter2<I>> {
|
||||||
|
pub fn new2(inner: I) -> Self {
|
||||||
|
Self::new(Iter2 { inner })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> {
|
||||||
|
type Item = Result<Tensor>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let mut items = Vec::with_capacity(self.batch_size);
|
||||||
|
for _i in 0..self.batch_size {
|
||||||
|
// We have two levels of inner here so that we can have two implementations of the
|
||||||
|
// Iterator trait that are different for Iter1 and Iter2. If rust gets better
|
||||||
|
// specialization at some point we can get rid of this.
|
||||||
|
match self.inner.inner.next() {
|
||||||
|
Some(item) => items.push(item),
|
||||||
|
None => {
|
||||||
|
if self.return_last_incomplete_batch {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(Tensor::stack(&items, 0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> {
|
||||||
|
type Item = Result<(Tensor, Tensor)>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let mut xs = Vec::with_capacity(self.batch_size);
|
||||||
|
let mut ys = Vec::with_capacity(self.batch_size);
|
||||||
|
for _i in 0..self.batch_size {
|
||||||
|
match self.inner.inner.next() {
|
||||||
|
Some((x, y)) => {
|
||||||
|
xs.push(x);
|
||||||
|
ys.push(y)
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
if self.return_last_incomplete_batch {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let xs = Tensor::stack(&xs, 0);
|
||||||
|
let ys = Tensor::stack(&ys, 0);
|
||||||
|
Some(xs.and_then(|xs| ys.map(|ys| (xs, ys))))
|
||||||
|
}
|
||||||
|
}
|
@ -2,6 +2,7 @@
|
|||||||
// error type if needed or add some specialized cases on the candle-core side.
|
// error type if needed or add some specialized cases on the candle-core side.
|
||||||
pub mod activation;
|
pub mod activation;
|
||||||
pub mod conv;
|
pub mod conv;
|
||||||
|
pub mod dataset;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod init;
|
pub mod init;
|
||||||
pub mod layer_norm;
|
pub mod layer_norm;
|
||||||
|
Reference in New Issue
Block a user