From 620f83cf66073f033d1fdc9846123c155422677e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 5 Aug 2023 08:56:50 +0100 Subject: [PATCH] Add the candle-datasets crate (#322) * Move the vision datasets to a separate crate. * Move the batcher bits. * Update the readme. * Move the tiny-stories bits. --------- Co-authored-by: Jane Doe --- Cargo.toml | 1 + README.md | 5 +- candle-datasets/Cargo.toml | 20 +++ .../src/batcher.rs | 0 candle-datasets/src/lib.rs | 6 + candle-datasets/src/nlp/mod.rs | 1 + candle-datasets/src/nlp/tinystories.rs | 122 +++++++++++++++++ .../src/vision/cifar.rs | 0 .../src/vision/mnist.rs | 0 .../src/vision/mod.rs | 0 candle-examples/Cargo.toml | 1 + candle-examples/examples/llama2-c/main.rs | 2 +- candle-examples/examples/llama2-c/training.rs | 124 +----------------- .../examples/mnist-training/main.rs | 4 +- candle-nn/src/lib.rs | 2 - 15 files changed, 163 insertions(+), 125 deletions(-) create mode 100644 candle-datasets/Cargo.toml rename candle-nn/src/dataset.rs => candle-datasets/src/batcher.rs (100%) create mode 100644 candle-datasets/src/lib.rs create mode 100644 candle-datasets/src/nlp/mod.rs create mode 100644 candle-datasets/src/nlp/tinystories.rs rename {candle-nn => candle-datasets}/src/vision/cifar.rs (100%) rename {candle-nn => candle-datasets}/src/vision/mnist.rs (100%) rename {candle-nn => candle-datasets}/src/vision/mod.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index 301451a0..ea008f00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "candle-core", + "candle-datasets", "candle-examples", "candle-nn", "candle-pyo3", diff --git a/README.md b/README.md index 3232fe7a..f0cdb332 100644 --- a/README.md +++ b/README.md @@ -98,8 +98,9 @@ Cheatsheet: - [candle-nn](./candle-nn/): Facilities to build real models - [candle-examples](./candle-examples/): Real-world like examples on how to use the library in real settings - [candle-kernels](./candle-kernels/): CUDA custom kernels - - +- [candle-datasets](./candle-datasets/): Datasets and data loaders. +- [candle-transformers](./candle-transformers): Transformer related utilities. +- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer. ## FAQ diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml new file mode 100644 index 00000000..12169daf --- /dev/null +++ b/candle-datasets/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "candle-datasets" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" + +[dependencies] +byteorder = { workspace = true } +candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" } +candle-nn = { path = "../candle-nn", version = "0.1.0" } +hf-hub = { workspace = true} +intel-mkl-src = { workspace = true, optional = true } +memmap2 = { workspace = true } +tokenizers = { workspace = true, features = ["onig"] } +rand = { workspace = true } diff --git a/candle-nn/src/dataset.rs b/candle-datasets/src/batcher.rs similarity index 100% rename from candle-nn/src/dataset.rs rename to candle-datasets/src/batcher.rs diff --git a/candle-datasets/src/lib.rs b/candle-datasets/src/lib.rs new file mode 100644 index 00000000..42ad5d62 --- /dev/null +++ b/candle-datasets/src/lib.rs @@ -0,0 +1,6 @@ +//! Datasets & Dataloaders for Candle +pub mod batcher; +pub mod nlp; +pub mod vision; + +pub use batcher::Batcher; diff --git a/candle-datasets/src/nlp/mod.rs b/candle-datasets/src/nlp/mod.rs new file mode 100644 index 00000000..42e9d288 --- /dev/null +++ b/candle-datasets/src/nlp/mod.rs @@ -0,0 +1 @@ +pub mod tinystories; diff --git a/candle-datasets/src/nlp/tinystories.rs b/candle-datasets/src/nlp/tinystories.rs new file mode 100644 index 00000000..c657c9eb --- /dev/null +++ b/candle-datasets/src/nlp/tinystories.rs @@ -0,0 +1,122 @@ +//! Helper functions for the tinystories dataset. This uses the pre-tokenized version as generated +//! by the tools from https://github.com/karpathy/llama2.c +use candle::{Device, Result, Tensor}; + +pub struct Dataset { + valid_tokens: Vec, + train_tokens: Vec, +} + +fn mmap_file(p: &std::path::PathBuf) -> Result { + let file = std::fs::File::open(p)?; + let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? }; + Ok(mmap) +} + +impl Dataset { + pub fn new>(dir: P) -> Result { + let dir = dir.as_ref(); + let mut bin_files = vec![]; + for file in std::fs::read_dir(dir)?.flatten() { + let file = file.path(); + if let Some(extension) = file.extension() { + if extension == "bin" { + bin_files.push(file) + } + } + } + if bin_files.len() < 2 { + candle::bail!("found less than two bin files in {:?}", dir) + } + bin_files.sort(); + let valid_tokens = mmap_file(&bin_files[0])?; + let train_tokens = bin_files[1..] + .iter() + .map(mmap_file) + .collect::>>()?; + Ok(Self { + valid_tokens: vec![valid_tokens], + train_tokens, + }) + } + + pub fn train_tokens(&self) -> usize { + self.train_tokens.len() + } + + pub fn valid_tokens(&self) -> usize { + self.valid_tokens.len() + } +} + +pub struct DatasetRandomIter<'a> { + all_tokens: &'a [memmap2::Mmap], + tokens: Vec<&'a memmap2::Mmap>, + current_tokens: &'a memmap2::Mmap, + indexes_in_bytes: Vec, + seq_len: usize, + device: Device, +} + +impl<'a> DatasetRandomIter<'a> { + pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self { + use rand::seq::SliceRandom; + use rand::thread_rng; + + let all_tokens = if valid { + &ds.valid_tokens + } else { + &ds.train_tokens + }; + let mut tokens = all_tokens.iter().collect::>(); + tokens.shuffle(&mut thread_rng()); + let current_tokens = tokens.pop().unwrap(); + let seq_len_in_bytes = seq_len * 2; + let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes) + .step_by(seq_len_in_bytes) + .collect::>(); + indexes_in_bytes.shuffle(&mut thread_rng()); + Self { + all_tokens, + tokens, + current_tokens, + indexes_in_bytes, + seq_len, + device, + } + } +} + +impl<'a> Iterator for DatasetRandomIter<'a> { + type Item = Result<(Tensor, Tensor)>; + + fn next(&mut self) -> Option { + use byteorder::{LittleEndian, ReadBytesExt}; + use rand::seq::SliceRandom; + use rand::thread_rng; + + let seq_len = self.seq_len; + if self.indexes_in_bytes.is_empty() { + if self.tokens.is_empty() { + self.tokens = self.all_tokens.iter().collect(); + self.tokens.shuffle(&mut thread_rng()); + } + self.current_tokens = self.tokens.pop().unwrap(); + let seq_len_in_bytes = self.seq_len * 2; + self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes) + .step_by(seq_len_in_bytes) + .collect::>(); + self.indexes_in_bytes.shuffle(&mut thread_rng()); + } + let start_idx = self.indexes_in_bytes.pop().unwrap(); + let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)]; + let mut tokens = vec![0u16; bytes.len() / 2]; + if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::(&mut tokens) { + return Some(Err(err.into())); + } + let tokens = tokens.into_iter().map(|v| v as u32).collect::>(); + let inputs = Tensor::new(&tokens[..seq_len], &self.device); + let targets = Tensor::new(&tokens[1..], &self.device); + Some(candle::error::zip(inputs, targets)) + } +} diff --git a/candle-nn/src/vision/cifar.rs b/candle-datasets/src/vision/cifar.rs similarity index 100% rename from candle-nn/src/vision/cifar.rs rename to candle-datasets/src/vision/cifar.rs diff --git a/candle-nn/src/vision/mnist.rs b/candle-datasets/src/vision/mnist.rs similarity index 100% rename from candle-nn/src/vision/mnist.rs rename to candle-datasets/src/vision/mnist.rs diff --git a/candle-nn/src/vision/mod.rs b/candle-datasets/src/vision/mod.rs similarity index 100% rename from candle-nn/src/vision/mod.rs rename to candle-datasets/src/vision/mod.rs diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 0db960ca..47490f42 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -11,6 +11,7 @@ readme = "README.md" [dependencies] candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" } +candle-datasets = { path = "../candle-datasets", version = "0.1.0" } candle-nn = { path = "../candle-nn", version = "0.1.0" } candle-transformers = { path = "../candle-transformers", version = "0.1.0" } candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.0", optional = true } diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 612dc358..c3b94df1 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -200,7 +200,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> { Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets)))) } }); - let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size); + let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); for inp_tgt in batch_iter { let (inp, tgt) = inp_tgt?; let logits = model.forward(&inp, 0)?; diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs index e55c686c..3e93c786 100644 --- a/candle-examples/examples/llama2-c/training.rs +++ b/candle-examples/examples/llama2-c/training.rs @@ -1,118 +1,6 @@ -#![allow(dead_code)] -#![allow(unused)] use crate::model::{Cache, Config, Llama}; -use candle::{DType, Device, Result, Tensor}; - -pub struct Dataset { - valid_tokens: Vec, - train_tokens: Vec, -} - -fn mmap_file(p: &std::path::PathBuf) -> Result { - let file = std::fs::File::open(p)?; - let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? }; - Ok(mmap) -} - -impl Dataset { - pub fn new>(dir: P) -> Result { - let dir = dir.as_ref(); - let mut bin_files = vec![]; - for file in std::fs::read_dir(dir)?.flatten() { - let file = file.path(); - if let Some(extension) = file.extension() { - if extension == "bin" { - bin_files.push(file) - } - } - } - if bin_files.len() < 2 { - candle::bail!("found less than two bin files in {:?}", dir) - } - bin_files.sort(); - let valid_tokens = mmap_file(&bin_files[0])?; - let train_tokens = bin_files[1..] - .iter() - .map(mmap_file) - .collect::>>()?; - Ok(Self { - valid_tokens: vec![valid_tokens], - train_tokens, - }) - } -} - -struct DatasetRandomIter<'a> { - all_tokens: &'a [memmap2::Mmap], - tokens: Vec<&'a memmap2::Mmap>, - current_tokens: &'a memmap2::Mmap, - indexes_in_bytes: Vec, - seq_len: usize, - device: Device, -} - -impl<'a> DatasetRandomIter<'a> { - pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self { - use rand::seq::SliceRandom; - use rand::thread_rng; - - let all_tokens = if valid { - &ds.valid_tokens - } else { - &ds.train_tokens - }; - let mut tokens = all_tokens.iter().collect::>(); - tokens.shuffle(&mut thread_rng()); - let current_tokens = tokens.pop().unwrap(); - let seq_len_in_bytes = seq_len * 2; - let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes) - .step_by(seq_len_in_bytes) - .collect::>(); - indexes_in_bytes.shuffle(&mut thread_rng()); - Self { - all_tokens, - tokens, - current_tokens, - indexes_in_bytes, - seq_len, - device, - } - } -} - -impl<'a> Iterator for DatasetRandomIter<'a> { - type Item = Result<(Tensor, Tensor)>; - - fn next(&mut self) -> Option { - use byteorder::{LittleEndian, ReadBytesExt}; - use rand::seq::SliceRandom; - use rand::thread_rng; - - let seq_len = self.seq_len; - if self.indexes_in_bytes.is_empty() { - if self.tokens.is_empty() { - self.tokens = self.all_tokens.iter().collect(); - self.tokens.shuffle(&mut thread_rng()); - } - self.current_tokens = self.tokens.pop().unwrap(); - let seq_len_in_bytes = self.seq_len * 2; - self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes) - .step_by(seq_len_in_bytes) - .collect::>(); - self.indexes_in_bytes.shuffle(&mut thread_rng()); - } - let start_idx = self.indexes_in_bytes.pop().unwrap(); - let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)]; - let mut tokens = vec![0u16; bytes.len() / 2]; - if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::(&mut tokens) { - return Some(Err(err.into())); - } - let tokens = tokens.into_iter().map(|v| v as u32).collect::>(); - let inputs = Tensor::new(&tokens[..seq_len], &self.device); - let targets = Tensor::new(&tokens[1..], &self.device); - Some(candle::error::zip(inputs, targets)) - } -} +use candle::{DType, Device, Result}; +use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter}; fn valid_loss( dataset: &Dataset, @@ -121,7 +9,7 @@ fn valid_loss( device: &Device, ) -> Result { let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone()); - let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size); + let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let mut sum_ce = 0f64; let mut cnt = 0usize; for inp_tgt in batch_iter.take(50) { @@ -139,14 +27,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> { let dataset = Dataset::new(&args.pretokenized_dir)?; println!( "loaded dataset, train: {} files, valid: {} files", - dataset.train_tokens.len(), - dataset.valid_tokens.len() + dataset.train_tokens(), + dataset.valid_tokens() ); let varmap = candle_nn::VarMap::new(); let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device); let config = Config::tiny(); let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone()); - let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size); + let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let cache = Cache::new(false, &config, vb.pp("rot"))?; let model = Llama::load(vb, &cache, config)?; diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index e251f6e9..d9e596ce 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -63,7 +63,7 @@ struct TrainingArgs { } fn training_loop( - m: candle_nn::vision::Dataset, + m: candle_datasets::vision::Dataset, args: &TrainingArgs, ) -> anyhow::Result<()> { let dev = candle::Device::cuda_if_available(0)?; @@ -140,7 +140,7 @@ struct Args { pub fn main() -> anyhow::Result<()> { let args = Args::parse(); // Load the dataset - let m = candle_nn::vision::mnist::load_dir("data")?; + let m = candle_datasets::vision::mnist::load_dir("data")?; println!("train-images: {:?}", m.train_images.shape()); println!("train-labels: {:?}", m.train_labels.shape()); println!("test-images: {:?}", m.test_images.shape()); diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 46a83800..3f54bd43 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -2,7 +2,6 @@ // error type if needed or add some specialized cases on the candle-core side. pub mod activation; pub mod conv; -pub mod dataset; pub mod embedding; pub mod init; pub mod layer_norm; @@ -11,7 +10,6 @@ pub mod loss; pub mod ops; pub mod optim; pub mod var_builder; -pub mod vision; pub use activation::Activation; pub use conv::{Conv1d, Conv1dConfig};