Add the candle-datasets crate (#322)

* Move the vision datasets to a separate crate.

* Move the batcher bits.

* Update the readme.

* Move the tiny-stories bits.

---------

Co-authored-by: Jane Doe <jane.doe@example.org>
This commit is contained in:
Laurent Mazare
2023-08-05 08:56:50 +01:00
committed by GitHub
parent f7b2a0391d
commit 620f83cf66
15 changed files with 163 additions and 125 deletions

View File

@ -200,7 +200,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
}
});
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?;

View File

@ -1,118 +1,6 @@
#![allow(dead_code)]
#![allow(unused)]
use crate::model::{Cache, Config, Llama};
use candle::{DType, Device, Result, Tensor};
pub struct Dataset {
valid_tokens: Vec<memmap2::Mmap>,
train_tokens: Vec<memmap2::Mmap>,
}
fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
let file = std::fs::File::open(p)?;
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
Ok(mmap)
}
impl Dataset {
pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
let dir = dir.as_ref();
let mut bin_files = vec![];
for file in std::fs::read_dir(dir)?.flatten() {
let file = file.path();
if let Some(extension) = file.extension() {
if extension == "bin" {
bin_files.push(file)
}
}
}
if bin_files.len() < 2 {
candle::bail!("found less than two bin files in {:?}", dir)
}
bin_files.sort();
let valid_tokens = mmap_file(&bin_files[0])?;
let train_tokens = bin_files[1..]
.iter()
.map(mmap_file)
.collect::<Result<Vec<_>>>()?;
Ok(Self {
valid_tokens: vec![valid_tokens],
train_tokens,
})
}
}
struct DatasetRandomIter<'a> {
all_tokens: &'a [memmap2::Mmap],
tokens: Vec<&'a memmap2::Mmap>,
current_tokens: &'a memmap2::Mmap,
indexes_in_bytes: Vec<usize>,
seq_len: usize,
device: Device,
}
impl<'a> DatasetRandomIter<'a> {
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
use rand::seq::SliceRandom;
use rand::thread_rng;
let all_tokens = if valid {
&ds.valid_tokens
} else {
&ds.train_tokens
};
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
tokens.shuffle(&mut thread_rng());
let current_tokens = tokens.pop().unwrap();
let seq_len_in_bytes = seq_len * 2;
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
.step_by(seq_len_in_bytes)
.collect::<Vec<_>>();
indexes_in_bytes.shuffle(&mut thread_rng());
Self {
all_tokens,
tokens,
current_tokens,
indexes_in_bytes,
seq_len,
device,
}
}
}
impl<'a> Iterator for DatasetRandomIter<'a> {
type Item = Result<(Tensor, Tensor)>;
fn next(&mut self) -> Option<Self::Item> {
use byteorder::{LittleEndian, ReadBytesExt};
use rand::seq::SliceRandom;
use rand::thread_rng;
let seq_len = self.seq_len;
if self.indexes_in_bytes.is_empty() {
if self.tokens.is_empty() {
self.tokens = self.all_tokens.iter().collect();
self.tokens.shuffle(&mut thread_rng());
}
self.current_tokens = self.tokens.pop().unwrap();
let seq_len_in_bytes = self.seq_len * 2;
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
.step_by(seq_len_in_bytes)
.collect::<Vec<_>>();
self.indexes_in_bytes.shuffle(&mut thread_rng());
}
let start_idx = self.indexes_in_bytes.pop().unwrap();
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
let mut tokens = vec![0u16; bytes.len() / 2];
if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
return Some(Err(err.into()));
}
let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
let inputs = Tensor::new(&tokens[..seq_len], &self.device);
let targets = Tensor::new(&tokens[1..], &self.device);
Some(candle::error::zip(inputs, targets))
}
}
use candle::{DType, Device, Result};
use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter};
fn valid_loss(
dataset: &Dataset,
@ -121,7 +9,7 @@ fn valid_loss(
device: &Device,
) -> Result<f64> {
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let mut sum_ce = 0f64;
let mut cnt = 0usize;
for inp_tgt in batch_iter.take(50) {
@ -139,14 +27,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let dataset = Dataset::new(&args.pretokenized_dir)?;
println!(
"loaded dataset, train: {} files, valid: {} files",
dataset.train_tokens.len(),
dataset.valid_tokens.len()
dataset.train_tokens(),
dataset.valid_tokens()
);
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let config = Config::tiny();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;