diff --git a/.gitignore b/.gitignore index 2997bb61..c433e74b 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ Cargo.lock perf.data flamegraph.svg *.so +*.swp diff --git a/Cargo.toml b/Cargo.toml index d52bf3e3..218e717b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "candle-hub", "candle-nn", "candle-pyo3", + "candle-transformers", ] [profile.release-with-debug] diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index b9131356..2ab2ec1d 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -157,6 +157,15 @@ pub enum Error { #[error("unsupported safetensor dtype {0:?}")] UnsupportedSafeTensorDtype(safetensors::Dtype), + + #[error(transparent)] + Wrapped(Box), } pub type Result = std::result::Result; + +impl Error { + pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self { + Self::Wrapped(Box::new(err)) + } +} diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 889e0051..2cbef233 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -13,6 +13,7 @@ readme = "README.md" [dependencies] candle = { path = "../candle-core", default-features=false } candle-nn = { path = "../candle-nn", default-features=false } +candle-transformers = { path = "../candle-transformers", default-features=false } serde = { version = "1.0.166", features = ["derive"] } serde_json = "1.0.99" num-traits = "0.2.15" @@ -28,5 +29,5 @@ wav = "1.0.0" [features] default = ["cuda"] -cuda = ["candle/cuda", "candle-nn/cuda"] -mkl = ["dep:intel-mkl-src", "candle/mkl"] +cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] +mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 3cd1d1f8..4757d2b1 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -4,11 +4,11 @@ extern crate intel_mkl_src; use anyhow::{Error as E, Result}; -use candle::{DType, Device, Tensor, D}; +use candle::{DType, Device, Tensor}; use candle_hub::{api::sync::Api, Repo, RepoType}; use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; use clap::Parser; -use rand::{distributions::Distribution, SeedableRng}; use tokenizers::Tokenizer; mod model; @@ -21,10 +21,9 @@ const DTYPE: DType = DType::BF16; struct TextGeneration { model: Falcon, - rng: rand::rngs::StdRng, device: Device, - temperature: Option, tokenizer: Tokenizer, + logits_processor: LogitsProcessor, } impl TextGeneration { @@ -32,14 +31,14 @@ impl TextGeneration { model: Falcon, tokenizer: Tokenizer, seed: u64, - temperature: Option, + temp: Option, device: &Device, ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp); Self { model, tokenizer, - rng: rand::rngs::StdRng::seed_from_u64(seed), - temperature, + logits_processor, device: device.clone(), } } @@ -67,20 +66,7 @@ impl TextGeneration { let logits = self.model.forward(&input)?; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; - let next_token = if let Some(temperature) = self.temperature { - let prs = (&logits / temperature)?.softmax(D::Minus1)?; - let logits_v: Vec = prs.to_vec1()?; - let distr = rand::distributions::WeightedIndex::new(&logits_v)?; - distr.sample(&mut self.rng) as u32 - } else { - let logits_v: Vec = logits.to_vec1()?; - logits_v - .iter() - .enumerate() - .max_by(|(_, u), (_, v)| u.total_cmp(v)) - .map(|(i, _)| i as u32) - .unwrap() - }; + let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); new_tokens.push(next_token); println!("> {:?}", start_gen.elapsed()); diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index d21094a4..301b870a 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -14,11 +14,11 @@ extern crate intel_mkl_src; use anyhow::{Error as E, Result}; use clap::Parser; -use rand::{distributions::Distribution, SeedableRng}; use candle::{DType, Device, Tensor, D}; use candle_hub::{api::sync::Api, Repo, RepoType}; use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; mod model; use model::{Config, Llama}; @@ -185,8 +185,8 @@ fn main() -> Result<()> { println!("pre-computing the positional embeddings"); let freqs_cis = precompute_freqs_cis(&config, &device)?; println!("starting the inference loop"); + let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); let mut new_tokens = vec![]; - let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed); let start_gen = std::time::Instant::now(); let mut index_pos = 0; for index in 0..args.sample_len { @@ -207,21 +207,7 @@ fn main() -> Result<()> { let logits = logits.squeeze(0)?; index_pos += ctxt.len(); - let next_token = if let Some(temperature) = args.temperature { - let prs = (&logits / temperature)?.softmax(D::Minus1)?; - let logits_v: Vec = prs.to_vec1()?; - let distr = rand::distributions::WeightedIndex::new(&logits_v)?; - - distr.sample(&mut rng) as u32 - } else { - let logits_v: Vec = logits.to_vec1()?; - logits_v - .iter() - .enumerate() - .max_by(|(_, u), (_, v)| u.total_cmp(v)) - .map(|(i, _)| i as u32) - .unwrap() - }; + let next_token = logits_processor.sample(&logits)?; tokens.push(next_token); new_tokens.push(next_token); println!("> {:?}", start_gen.elapsed()); diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml new file mode 100644 index 00000000..048e0f6b --- /dev/null +++ b/candle-transformers/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "candle-transformers" +version = "0.1.0" +edition = "2021" + +description = "Pretrained models and inference API for the candle ML framework." +repository = "https://github.com/LaurentMazare/candle" +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT/Apache-2.0" +readme = "README.md" + +[dependencies] +candle = { path = "../candle-core", default-features=false } +candle-hub = { path = "../candle-hub" } +candle-nn = { path = "../candle-nn", default-features=false } +intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]} +tokenizers = { version = "0.13.3", default-features=false, features=["onig"] } +rand = "0.8.5" +wav = "1.0.0" + +[features] +default = ["cuda"] +cuda = ["candle/cuda", "candle-nn/cuda"] +mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"] diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs new file mode 100644 index 00000000..f954f322 --- /dev/null +++ b/candle-transformers/src/generation/mod.rs @@ -0,0 +1,35 @@ +use candle::{DType, Error, Result, Tensor, D}; +use rand::{distributions::Distribution, SeedableRng}; + +pub struct LogitsProcessor { + rng: rand::rngs::StdRng, + temperature: Option, +} + +impl LogitsProcessor { + pub fn new(seed: u64, temperature: Option) -> Self { + Self { + rng: rand::rngs::StdRng::seed_from_u64(seed), + temperature, + } + } + + pub fn sample(&mut self, logits: &Tensor) -> Result { + let logits = logits.to_dtype(DType::F32)?; + let next_token = if let Some(temperature) = self.temperature { + let prs = (&logits / temperature)?.softmax(D::Minus1)?; + let prs: Vec = prs.to_vec1()?; + let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; + distr.sample(&mut self.rng) as u32 + } else { + let logits_v: Vec = logits.to_vec1()?; + logits_v + .iter() + .enumerate() + .max_by(|(_, u), (_, v)| u.total_cmp(v)) + .map(|(i, _)| i as u32) + .unwrap() + }; + Ok(next_token) + } +} diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs new file mode 100644 index 00000000..86cb904e --- /dev/null +++ b/candle-transformers/src/lib.rs @@ -0,0 +1,3 @@ +pub mod generation; +pub mod models; +pub mod pipelines; diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/candle-transformers/src/models/mod.rs @@ -0,0 +1 @@ + diff --git a/candle-transformers/src/pipelines/mod.rs b/candle-transformers/src/pipelines/mod.rs new file mode 100644 index 00000000..d1bc14e2 --- /dev/null +++ b/candle-transformers/src/pipelines/mod.rs @@ -0,0 +1 @@ +pub mod text_generation; diff --git a/candle-transformers/src/pipelines/text_generation.rs b/candle-transformers/src/pipelines/text_generation.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/candle-transformers/src/pipelines/text_generation.rs @@ -0,0 +1 @@ +