Add the candle-datasets crate (#322)

* Move the vision datasets to a separate crate.

* Move the batcher bits.

* Update the readme.

* Move the tiny-stories bits.

---------

Co-authored-by: Jane Doe <jane.doe@example.org>
This commit is contained in:
Laurent Mazare
2023-08-05 08:56:50 +01:00
committed by GitHub
parent f7b2a0391d
commit 620f83cf66
15 changed files with 163 additions and 125 deletions

View File

@ -0,0 +1,171 @@
use candle::{Result, Tensor};
pub struct Batcher<I> {
inner: I,
batch_size: usize,
return_last_incomplete_batch: bool,
}
impl<I> Batcher<I> {
fn new(inner: I) -> Self {
Self {
inner,
batch_size: 16,
return_last_incomplete_batch: false,
}
}
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
pub fn return_last_incomplete_batch(mut self, r: bool) -> Self {
self.return_last_incomplete_batch = r;
self
}
}
pub struct Iter1<I: Iterator<Item = Tensor>> {
inner: I,
}
pub struct Iter2<I: Iterator<Item = (Tensor, Tensor)>> {
inner: I,
}
impl<I: Iterator<Item = Tensor>> Batcher<Iter1<I>> {
pub fn new1(inner: I) -> Self {
Self::new(Iter1 { inner })
}
}
impl<I: Iterator<Item = (Tensor, Tensor)>> Batcher<Iter2<I>> {
pub fn new2(inner: I) -> Self {
Self::new(Iter2 { inner })
}
}
pub struct IterResult1<I: Iterator<Item = Result<Tensor>>> {
inner: I,
}
pub struct IterResult2<I: Iterator<Item = Result<(Tensor, Tensor)>>> {
inner: I,
}
impl<I: Iterator<Item = Result<Tensor>>> Batcher<IterResult1<I>> {
pub fn new_r1(inner: I) -> Self {
Self::new(IterResult1 { inner })
}
}
impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Batcher<IterResult2<I>> {
pub fn new_r2(inner: I) -> Self {
Self::new(IterResult2 { inner })
}
}
impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> {
type Item = Result<Tensor>;
fn next(&mut self) -> Option<Self::Item> {
let mut items = Vec::with_capacity(self.batch_size);
for _i in 0..self.batch_size {
// We have two levels of inner here so that we can have two implementations of the
// Iterator trait that are different for Iter1 and Iter2. If rust gets better
// specialization at some point we can get rid of this.
match self.inner.inner.next() {
Some(item) => items.push(item),
None => {
if self.return_last_incomplete_batch {
break;
}
return None;
}
}
}
Some(Tensor::stack(&items, 0))
}
}
impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> {
type Item = Result<(Tensor, Tensor)>;
fn next(&mut self) -> Option<Self::Item> {
let mut xs = Vec::with_capacity(self.batch_size);
let mut ys = Vec::with_capacity(self.batch_size);
for _i in 0..self.batch_size {
match self.inner.inner.next() {
Some((x, y)) => {
xs.push(x);
ys.push(y)
}
None => {
if self.return_last_incomplete_batch {
break;
}
return None;
}
}
}
let xs = Tensor::stack(&xs, 0);
let ys = Tensor::stack(&ys, 0);
Some(xs.and_then(|xs| ys.map(|ys| (xs, ys))))
}
}
impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> {
type Item = Result<Tensor>;
fn next(&mut self) -> Option<Self::Item> {
let mut items = Vec::with_capacity(self.batch_size);
for _i in 0..self.batch_size {
// We have two levels of inner here so that we can have two implementations of the
// Iterator trait that are different for Iter1 and Iter2. If rust gets better
// specialization at some point we can get rid of this.
match self.inner.inner.next() {
Some(item) => items.push(item),
None => {
if self.return_last_incomplete_batch {
break;
}
return None;
}
}
}
let items = items.into_iter().collect::<Result<Vec<Tensor>>>();
Some(items.and_then(|items| Tensor::stack(&items, 0)))
}
}
impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResult2<I>> {
type Item = Result<(Tensor, Tensor)>;
fn next(&mut self) -> Option<Self::Item> {
let mut xs = Vec::with_capacity(self.batch_size);
let mut ys = Vec::with_capacity(self.batch_size);
let mut errs = vec![];
for _i in 0..self.batch_size {
match self.inner.inner.next() {
Some(Ok((x, y))) => {
xs.push(x);
ys.push(y)
}
Some(Err(err)) => errs.push(err),
None => {
if self.return_last_incomplete_batch {
break;
}
return None;
}
}
}
if !errs.is_empty() {
return Some(Err(errs.swap_remove(0)));
}
let xs = Tensor::stack(&xs, 0);
let ys = Tensor::stack(&ys, 0);
Some(xs.and_then(|xs| ys.map(|ys| (xs, ys))))
}
}

View File

@ -0,0 +1,6 @@
//! Datasets & Dataloaders for Candle
pub mod batcher;
pub mod nlp;
pub mod vision;
pub use batcher::Batcher;

View File

@ -0,0 +1 @@
pub mod tinystories;

View File

@ -0,0 +1,122 @@
//! Helper functions for the tinystories dataset. This uses the pre-tokenized version as generated
//! by the tools from https://github.com/karpathy/llama2.c
use candle::{Device, Result, Tensor};
pub struct Dataset {
valid_tokens: Vec<memmap2::Mmap>,
train_tokens: Vec<memmap2::Mmap>,
}
fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
let file = std::fs::File::open(p)?;
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
Ok(mmap)
}
impl Dataset {
pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
let dir = dir.as_ref();
let mut bin_files = vec![];
for file in std::fs::read_dir(dir)?.flatten() {
let file = file.path();
if let Some(extension) = file.extension() {
if extension == "bin" {
bin_files.push(file)
}
}
}
if bin_files.len() < 2 {
candle::bail!("found less than two bin files in {:?}", dir)
}
bin_files.sort();
let valid_tokens = mmap_file(&bin_files[0])?;
let train_tokens = bin_files[1..]
.iter()
.map(mmap_file)
.collect::<Result<Vec<_>>>()?;
Ok(Self {
valid_tokens: vec![valid_tokens],
train_tokens,
})
}
pub fn train_tokens(&self) -> usize {
self.train_tokens.len()
}
pub fn valid_tokens(&self) -> usize {
self.valid_tokens.len()
}
}
pub struct DatasetRandomIter<'a> {
all_tokens: &'a [memmap2::Mmap],
tokens: Vec<&'a memmap2::Mmap>,
current_tokens: &'a memmap2::Mmap,
indexes_in_bytes: Vec<usize>,
seq_len: usize,
device: Device,
}
impl<'a> DatasetRandomIter<'a> {
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
use rand::seq::SliceRandom;
use rand::thread_rng;
let all_tokens = if valid {
&ds.valid_tokens
} else {
&ds.train_tokens
};
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
tokens.shuffle(&mut thread_rng());
let current_tokens = tokens.pop().unwrap();
let seq_len_in_bytes = seq_len * 2;
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
.step_by(seq_len_in_bytes)
.collect::<Vec<_>>();
indexes_in_bytes.shuffle(&mut thread_rng());
Self {
all_tokens,
tokens,
current_tokens,
indexes_in_bytes,
seq_len,
device,
}
}
}
impl<'a> Iterator for DatasetRandomIter<'a> {
type Item = Result<(Tensor, Tensor)>;
fn next(&mut self) -> Option<Self::Item> {
use byteorder::{LittleEndian, ReadBytesExt};
use rand::seq::SliceRandom;
use rand::thread_rng;
let seq_len = self.seq_len;
if self.indexes_in_bytes.is_empty() {
if self.tokens.is_empty() {
self.tokens = self.all_tokens.iter().collect();
self.tokens.shuffle(&mut thread_rng());
}
self.current_tokens = self.tokens.pop().unwrap();
let seq_len_in_bytes = self.seq_len * 2;
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
.step_by(seq_len_in_bytes)
.collect::<Vec<_>>();
self.indexes_in_bytes.shuffle(&mut thread_rng());
}
let start_idx = self.indexes_in_bytes.pop().unwrap();
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
let mut tokens = vec![0u16; bytes.len() / 2];
if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
return Some(Err(err.into()));
}
let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
let inputs = Tensor::new(&tokens[..seq_len], &self.device);
let targets = Tensor::new(&tokens[1..], &self.device);
Some(candle::error::zip(inputs, targets))
}
}

View File

@ -0,0 +1,62 @@
//! The CIFAR-10 dataset.
//!
//! The files can be downloaded from the following page:
//! <https://www.cs.toronto.edu/~kriz/cifar.html>
//! The binary version of the dataset is used.
use crate::vision::Dataset;
use candle::{DType, Device, Result, Tensor};
use std::fs::File;
use std::io::{BufReader, Read};
const W: usize = 32;
const H: usize = 32;
const C: usize = 3;
const BYTES_PER_IMAGE: usize = W * H * C + 1;
const SAMPLES_PER_FILE: usize = 10000;
fn read_file(filename: &std::path::Path) -> Result<(Tensor, Tensor)> {
let mut buf_reader = BufReader::new(File::open(filename)?);
let mut data = vec![0u8; SAMPLES_PER_FILE * BYTES_PER_IMAGE];
buf_reader.read_exact(&mut data)?;
let mut images = vec![];
let mut labels = vec![];
for index in 0..SAMPLES_PER_FILE {
let content_offset = BYTES_PER_IMAGE * index;
labels.push(data[content_offset]);
images.push(&data[1 + content_offset..content_offset + BYTES_PER_IMAGE]);
}
let images: Vec<u8> = images
.iter()
.copied()
.flatten()
.copied()
.collect::<Vec<_>>();
let labels = Tensor::from_vec(labels, SAMPLES_PER_FILE, &Device::Cpu)?;
let images = Tensor::from_vec(images, (SAMPLES_PER_FILE, C, H, W), &Device::Cpu)?;
let images = (images.to_dtype(DType::F32)? / 255.)?;
Ok((images, labels))
}
pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<Dataset> {
let dir = dir.as_ref();
let (test_images, test_labels) = read_file(&dir.join("test_batch.bin"))?;
let train_images_and_labels = [
"data_batch_1.bin",
"data_batch_2.bin",
"data_batch_3.bin",
"data_batch_4.bin",
"data_batch_5.bin",
]
.iter()
.map(|x| read_file(&dir.join(x)))
.collect::<Result<Vec<_>>>()?;
let (train_images, train_labels): (Vec<_>, Vec<_>) =
train_images_and_labels.into_iter().unzip();
Ok(Dataset {
train_images: Tensor::cat(&train_images, 0)?,
train_labels: Tensor::cat(&train_labels, 0)?,
test_images,
test_labels,
labels: 10,
})
}

View File

@ -0,0 +1,65 @@
//! The MNIST hand-written digit dataset.
//!
//! The files can be obtained from the following link:
//! <http://yann.lecun.com/exdb/mnist/>
use candle::{DType, Device, Result, Tensor};
use std::fs::File;
use std::io::{self, BufReader, Read};
fn read_u32<T: Read>(reader: &mut T) -> Result<u32> {
let mut b = vec![0u8; 4];
reader.read_exact(&mut b)?;
let (result, _) = b.iter().rev().fold((0u64, 1u64), |(s, basis), &x| {
(s + basis * u64::from(x), basis * 256)
});
Ok(result as u32)
}
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {
let magic_number = read_u32(reader)?;
if magic_number != expected {
Err(io::Error::new(
io::ErrorKind::Other,
format!("incorrect magic number {magic_number} != {expected}"),
))?;
}
Ok(())
}
fn read_labels(filename: &std::path::Path) -> Result<Tensor> {
let mut buf_reader = BufReader::new(File::open(filename)?);
check_magic_number(&mut buf_reader, 2049)?;
let samples = read_u32(&mut buf_reader)?;
let mut data = vec![0u8; samples as usize];
buf_reader.read_exact(&mut data)?;
let samples = data.len();
Tensor::from_vec(data, samples, &Device::Cpu)
}
fn read_images(filename: &std::path::Path) -> Result<Tensor> {
let mut buf_reader = BufReader::new(File::open(filename)?);
check_magic_number(&mut buf_reader, 2051)?;
let samples = read_u32(&mut buf_reader)? as usize;
let rows = read_u32(&mut buf_reader)? as usize;
let cols = read_u32(&mut buf_reader)? as usize;
let data_len = samples * rows * cols;
let mut data = vec![0u8; data_len];
buf_reader.read_exact(&mut data)?;
let tensor = Tensor::from_vec(data, (samples, rows * cols), &Device::Cpu)?;
tensor.to_dtype(DType::F32)? / 255.
}
pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<crate::vision::Dataset> {
let dir = dir.as_ref();
let train_images = read_images(&dir.join("train-images-idx3-ubyte"))?;
let train_labels = read_labels(&dir.join("train-labels-idx1-ubyte"))?;
let test_images = read_images(&dir.join("t10k-images-idx3-ubyte"))?;
let test_labels = read_labels(&dir.join("t10k-labels-idx1-ubyte"))?;
Ok(crate::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
})
}

View File

@ -0,0 +1,12 @@
use candle::Tensor;
pub struct Dataset {
pub train_images: Tensor,
pub train_labels: Tensor,
pub test_images: Tensor,
pub test_labels: Tensor,
pub labels: usize,
}
pub mod cifar;
pub mod mnist;