mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Merge remote-tracking branch 'origin/main' into faster-gemv
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"candle-core",
|
"candle-core",
|
||||||
|
"candle-datasets",
|
||||||
"candle-examples",
|
"candle-examples",
|
||||||
"candle-nn",
|
"candle-nn",
|
||||||
"candle-pyo3",
|
"candle-pyo3",
|
||||||
|
@ -98,8 +98,9 @@ Cheatsheet:
|
|||||||
- [candle-nn](./candle-nn/): Facilities to build real models
|
- [candle-nn](./candle-nn/): Facilities to build real models
|
||||||
- [candle-examples](./candle-examples/): Real-world like examples on how to use the library in real settings
|
- [candle-examples](./candle-examples/): Real-world like examples on how to use the library in real settings
|
||||||
- [candle-kernels](./candle-kernels/): CUDA custom kernels
|
- [candle-kernels](./candle-kernels/): CUDA custom kernels
|
||||||
|
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
|
||||||
|
- [candle-transformers](./candle-transformers): Transformer related utilities.
|
||||||
|
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
|
||||||
|
|
||||||
## FAQ
|
## FAQ
|
||||||
|
|
||||||
|
20
candle-datasets/Cargo.toml
Normal file
20
candle-datasets/Cargo.toml
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
[package]
|
||||||
|
name = "candle-datasets"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
description.workspace = true
|
||||||
|
repository.workspace = true
|
||||||
|
keywords.workspace = true
|
||||||
|
categories.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
byteorder = { workspace = true }
|
||||||
|
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
||||||
|
candle-nn = { path = "../candle-nn", version = "0.1.0" }
|
||||||
|
hf-hub = { workspace = true}
|
||||||
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
|
memmap2 = { workspace = true }
|
||||||
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
|
rand = { workspace = true }
|
6
candle-datasets/src/lib.rs
Normal file
6
candle-datasets/src/lib.rs
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
//! Datasets & Dataloaders for Candle
|
||||||
|
pub mod batcher;
|
||||||
|
pub mod nlp;
|
||||||
|
pub mod vision;
|
||||||
|
|
||||||
|
pub use batcher::Batcher;
|
1
candle-datasets/src/nlp/mod.rs
Normal file
1
candle-datasets/src/nlp/mod.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
pub mod tinystories;
|
122
candle-datasets/src/nlp/tinystories.rs
Normal file
122
candle-datasets/src/nlp/tinystories.rs
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
//! Helper functions for the tinystories dataset. This uses the pre-tokenized version as generated
|
||||||
|
//! by the tools from https://github.com/karpathy/llama2.c
|
||||||
|
use candle::{Device, Result, Tensor};
|
||||||
|
|
||||||
|
pub struct Dataset {
|
||||||
|
valid_tokens: Vec<memmap2::Mmap>,
|
||||||
|
train_tokens: Vec<memmap2::Mmap>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
|
||||||
|
let file = std::fs::File::open(p)?;
|
||||||
|
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
|
||||||
|
Ok(mmap)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Dataset {
|
||||||
|
pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
|
||||||
|
let dir = dir.as_ref();
|
||||||
|
let mut bin_files = vec![];
|
||||||
|
for file in std::fs::read_dir(dir)?.flatten() {
|
||||||
|
let file = file.path();
|
||||||
|
if let Some(extension) = file.extension() {
|
||||||
|
if extension == "bin" {
|
||||||
|
bin_files.push(file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if bin_files.len() < 2 {
|
||||||
|
candle::bail!("found less than two bin files in {:?}", dir)
|
||||||
|
}
|
||||||
|
bin_files.sort();
|
||||||
|
let valid_tokens = mmap_file(&bin_files[0])?;
|
||||||
|
let train_tokens = bin_files[1..]
|
||||||
|
.iter()
|
||||||
|
.map(mmap_file)
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
Ok(Self {
|
||||||
|
valid_tokens: vec![valid_tokens],
|
||||||
|
train_tokens,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train_tokens(&self) -> usize {
|
||||||
|
self.train_tokens.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn valid_tokens(&self) -> usize {
|
||||||
|
self.valid_tokens.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DatasetRandomIter<'a> {
|
||||||
|
all_tokens: &'a [memmap2::Mmap],
|
||||||
|
tokens: Vec<&'a memmap2::Mmap>,
|
||||||
|
current_tokens: &'a memmap2::Mmap,
|
||||||
|
indexes_in_bytes: Vec<usize>,
|
||||||
|
seq_len: usize,
|
||||||
|
device: Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> DatasetRandomIter<'a> {
|
||||||
|
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
|
||||||
|
use rand::seq::SliceRandom;
|
||||||
|
use rand::thread_rng;
|
||||||
|
|
||||||
|
let all_tokens = if valid {
|
||||||
|
&ds.valid_tokens
|
||||||
|
} else {
|
||||||
|
&ds.train_tokens
|
||||||
|
};
|
||||||
|
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
|
||||||
|
tokens.shuffle(&mut thread_rng());
|
||||||
|
let current_tokens = tokens.pop().unwrap();
|
||||||
|
let seq_len_in_bytes = seq_len * 2;
|
||||||
|
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
|
||||||
|
.step_by(seq_len_in_bytes)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
indexes_in_bytes.shuffle(&mut thread_rng());
|
||||||
|
Self {
|
||||||
|
all_tokens,
|
||||||
|
tokens,
|
||||||
|
current_tokens,
|
||||||
|
indexes_in_bytes,
|
||||||
|
seq_len,
|
||||||
|
device,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Iterator for DatasetRandomIter<'a> {
|
||||||
|
type Item = Result<(Tensor, Tensor)>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
|
use rand::seq::SliceRandom;
|
||||||
|
use rand::thread_rng;
|
||||||
|
|
||||||
|
let seq_len = self.seq_len;
|
||||||
|
if self.indexes_in_bytes.is_empty() {
|
||||||
|
if self.tokens.is_empty() {
|
||||||
|
self.tokens = self.all_tokens.iter().collect();
|
||||||
|
self.tokens.shuffle(&mut thread_rng());
|
||||||
|
}
|
||||||
|
self.current_tokens = self.tokens.pop().unwrap();
|
||||||
|
let seq_len_in_bytes = self.seq_len * 2;
|
||||||
|
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
|
||||||
|
.step_by(seq_len_in_bytes)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
self.indexes_in_bytes.shuffle(&mut thread_rng());
|
||||||
|
}
|
||||||
|
let start_idx = self.indexes_in_bytes.pop().unwrap();
|
||||||
|
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
|
||||||
|
let mut tokens = vec![0u16; bytes.len() / 2];
|
||||||
|
if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
|
||||||
|
return Some(Err(err.into()));
|
||||||
|
}
|
||||||
|
let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
|
||||||
|
let inputs = Tensor::new(&tokens[..seq_len], &self.device);
|
||||||
|
let targets = Tensor::new(&tokens[1..], &self.device);
|
||||||
|
Some(candle::error::zip(inputs, targets))
|
||||||
|
}
|
||||||
|
}
|
@ -11,6 +11,7 @@ readme = "README.md"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
||||||
|
candle-datasets = { path = "../candle-datasets", version = "0.1.0" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.1.0" }
|
candle-nn = { path = "../candle-nn", version = "0.1.0" }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.1.0" }
|
candle-transformers = { path = "../candle-transformers", version = "0.1.0" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.0", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.0", optional = true }
|
||||||
|
@ -200,7 +200,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
|||||||
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
|
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
for inp_tgt in batch_iter {
|
for inp_tgt in batch_iter {
|
||||||
let (inp, tgt) = inp_tgt?;
|
let (inp, tgt) = inp_tgt?;
|
||||||
let logits = model.forward(&inp, 0)?;
|
let logits = model.forward(&inp, 0)?;
|
||||||
|
@ -1,118 +1,6 @@
|
|||||||
#![allow(dead_code)]
|
|
||||||
#![allow(unused)]
|
|
||||||
use crate::model::{Cache, Config, Llama};
|
use crate::model::{Cache, Config, Llama};
|
||||||
use candle::{DType, Device, Result, Tensor};
|
use candle::{DType, Device, Result};
|
||||||
|
use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter};
|
||||||
pub struct Dataset {
|
|
||||||
valid_tokens: Vec<memmap2::Mmap>,
|
|
||||||
train_tokens: Vec<memmap2::Mmap>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
|
|
||||||
let file = std::fs::File::open(p)?;
|
|
||||||
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
|
|
||||||
Ok(mmap)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Dataset {
|
|
||||||
pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
|
|
||||||
let dir = dir.as_ref();
|
|
||||||
let mut bin_files = vec![];
|
|
||||||
for file in std::fs::read_dir(dir)?.flatten() {
|
|
||||||
let file = file.path();
|
|
||||||
if let Some(extension) = file.extension() {
|
|
||||||
if extension == "bin" {
|
|
||||||
bin_files.push(file)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if bin_files.len() < 2 {
|
|
||||||
candle::bail!("found less than two bin files in {:?}", dir)
|
|
||||||
}
|
|
||||||
bin_files.sort();
|
|
||||||
let valid_tokens = mmap_file(&bin_files[0])?;
|
|
||||||
let train_tokens = bin_files[1..]
|
|
||||||
.iter()
|
|
||||||
.map(mmap_file)
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
Ok(Self {
|
|
||||||
valid_tokens: vec![valid_tokens],
|
|
||||||
train_tokens,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct DatasetRandomIter<'a> {
|
|
||||||
all_tokens: &'a [memmap2::Mmap],
|
|
||||||
tokens: Vec<&'a memmap2::Mmap>,
|
|
||||||
current_tokens: &'a memmap2::Mmap,
|
|
||||||
indexes_in_bytes: Vec<usize>,
|
|
||||||
seq_len: usize,
|
|
||||||
device: Device,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> DatasetRandomIter<'a> {
|
|
||||||
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
|
|
||||||
use rand::seq::SliceRandom;
|
|
||||||
use rand::thread_rng;
|
|
||||||
|
|
||||||
let all_tokens = if valid {
|
|
||||||
&ds.valid_tokens
|
|
||||||
} else {
|
|
||||||
&ds.train_tokens
|
|
||||||
};
|
|
||||||
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
|
|
||||||
tokens.shuffle(&mut thread_rng());
|
|
||||||
let current_tokens = tokens.pop().unwrap();
|
|
||||||
let seq_len_in_bytes = seq_len * 2;
|
|
||||||
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
|
|
||||||
.step_by(seq_len_in_bytes)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
indexes_in_bytes.shuffle(&mut thread_rng());
|
|
||||||
Self {
|
|
||||||
all_tokens,
|
|
||||||
tokens,
|
|
||||||
current_tokens,
|
|
||||||
indexes_in_bytes,
|
|
||||||
seq_len,
|
|
||||||
device,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> Iterator for DatasetRandomIter<'a> {
|
|
||||||
type Item = Result<(Tensor, Tensor)>;
|
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
|
||||||
use rand::seq::SliceRandom;
|
|
||||||
use rand::thread_rng;
|
|
||||||
|
|
||||||
let seq_len = self.seq_len;
|
|
||||||
if self.indexes_in_bytes.is_empty() {
|
|
||||||
if self.tokens.is_empty() {
|
|
||||||
self.tokens = self.all_tokens.iter().collect();
|
|
||||||
self.tokens.shuffle(&mut thread_rng());
|
|
||||||
}
|
|
||||||
self.current_tokens = self.tokens.pop().unwrap();
|
|
||||||
let seq_len_in_bytes = self.seq_len * 2;
|
|
||||||
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
|
|
||||||
.step_by(seq_len_in_bytes)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
self.indexes_in_bytes.shuffle(&mut thread_rng());
|
|
||||||
}
|
|
||||||
let start_idx = self.indexes_in_bytes.pop().unwrap();
|
|
||||||
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
|
|
||||||
let mut tokens = vec![0u16; bytes.len() / 2];
|
|
||||||
if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
|
|
||||||
return Some(Err(err.into()));
|
|
||||||
}
|
|
||||||
let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
|
|
||||||
let inputs = Tensor::new(&tokens[..seq_len], &self.device);
|
|
||||||
let targets = Tensor::new(&tokens[1..], &self.device);
|
|
||||||
Some(candle::error::zip(inputs, targets))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn valid_loss(
|
fn valid_loss(
|
||||||
dataset: &Dataset,
|
dataset: &Dataset,
|
||||||
@ -121,7 +9,7 @@ fn valid_loss(
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<f64> {
|
) -> Result<f64> {
|
||||||
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
||||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
let mut sum_ce = 0f64;
|
let mut sum_ce = 0f64;
|
||||||
let mut cnt = 0usize;
|
let mut cnt = 0usize;
|
||||||
for inp_tgt in batch_iter.take(50) {
|
for inp_tgt in batch_iter.take(50) {
|
||||||
@ -139,14 +27,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
|||||||
let dataset = Dataset::new(&args.pretokenized_dir)?;
|
let dataset = Dataset::new(&args.pretokenized_dir)?;
|
||||||
println!(
|
println!(
|
||||||
"loaded dataset, train: {} files, valid: {} files",
|
"loaded dataset, train: {} files, valid: {} files",
|
||||||
dataset.train_tokens.len(),
|
dataset.train_tokens(),
|
||||||
dataset.valid_tokens.len()
|
dataset.valid_tokens()
|
||||||
);
|
);
|
||||||
let varmap = candle_nn::VarMap::new();
|
let varmap = candle_nn::VarMap::new();
|
||||||
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
||||||
let config = Config::tiny();
|
let config = Config::tiny();
|
||||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||||
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
|
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||||
|
|
||||||
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||||
let model = Llama::load(vb, &cache, config)?;
|
let model = Llama::load(vb, &cache, config)?;
|
||||||
|
@ -63,7 +63,7 @@ struct TrainingArgs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn training_loop<M: Model>(
|
fn training_loop<M: Model>(
|
||||||
m: candle_nn::vision::Dataset,
|
m: candle_datasets::vision::Dataset,
|
||||||
args: &TrainingArgs,
|
args: &TrainingArgs,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let dev = candle::Device::cuda_if_available(0)?;
|
let dev = candle::Device::cuda_if_available(0)?;
|
||||||
@ -140,7 +140,7 @@ struct Args {
|
|||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
// Load the dataset
|
// Load the dataset
|
||||||
let m = candle_nn::vision::mnist::load_dir("data")?;
|
let m = candle_datasets::vision::mnist::load_dir("data")?;
|
||||||
println!("train-images: {:?}", m.train_images.shape());
|
println!("train-images: {:?}", m.train_images.shape());
|
||||||
println!("train-labels: {:?}", m.train_labels.shape());
|
println!("train-labels: {:?}", m.train_labels.shape());
|
||||||
println!("test-images: {:?}", m.test_images.shape());
|
println!("test-images: {:?}", m.test_images.shape());
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
// error type if needed or add some specialized cases on the candle-core side.
|
// error type if needed or add some specialized cases on the candle-core side.
|
||||||
pub mod activation;
|
pub mod activation;
|
||||||
pub mod conv;
|
pub mod conv;
|
||||||
pub mod dataset;
|
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod init;
|
pub mod init;
|
||||||
pub mod layer_norm;
|
pub mod layer_norm;
|
||||||
@ -11,7 +10,6 @@ pub mod loss;
|
|||||||
pub mod ops;
|
pub mod ops;
|
||||||
pub mod optim;
|
pub mod optim;
|
||||||
pub mod var_builder;
|
pub mod var_builder;
|
||||||
pub mod vision;
|
|
||||||
|
|
||||||
pub use activation::Activation;
|
pub use activation::Activation;
|
||||||
pub use conv::{Conv1d, Conv1dConfig};
|
pub use conv::{Conv1d, Conv1dConfig};
|
||||||
|
Reference in New Issue
Block a user