Sketch the candle-transformers crate. (#147)

* Sketch the candle-transformers crate.

* Format the empty files.
This commit is contained in:
Laurent Mazare
2023-07-12 13:49:31 +01:00
committed by GitHub
parent eae646d322
commit ba35d895e7
12 changed files with 90 additions and 40 deletions

1
.gitignore vendored
View File

@ -19,3 +19,4 @@ Cargo.lock
perf.data perf.data
flamegraph.svg flamegraph.svg
*.so *.so
*.swp

View File

@ -6,6 +6,7 @@ members = [
"candle-hub", "candle-hub",
"candle-nn", "candle-nn",
"candle-pyo3", "candle-pyo3",
"candle-transformers",
] ]
[profile.release-with-debug] [profile.release-with-debug]

View File

@ -157,6 +157,15 @@ pub enum Error {
#[error("unsupported safetensor dtype {0:?}")] #[error("unsupported safetensor dtype {0:?}")]
UnsupportedSafeTensorDtype(safetensors::Dtype), UnsupportedSafeTensorDtype(safetensors::Dtype),
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
} }
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
impl Error {
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::Wrapped(Box::new(err))
}
}

View File

@ -13,6 +13,7 @@ readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../candle-core", default-features=false } candle = { path = "../candle-core", default-features=false }
candle-nn = { path = "../candle-nn", default-features=false } candle-nn = { path = "../candle-nn", default-features=false }
candle-transformers = { path = "../candle-transformers", default-features=false }
serde = { version = "1.0.166", features = ["derive"] } serde = { version = "1.0.166", features = ["derive"] }
serde_json = "1.0.99" serde_json = "1.0.99"
num-traits = "0.2.15" num-traits = "0.2.15"
@ -28,5 +29,5 @@ wav = "1.0.0"
[features] [features]
default = ["cuda"] default = ["cuda"]
cuda = ["candle/cuda", "candle-nn/cuda"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

View File

@ -4,11 +4,11 @@
extern crate intel_mkl_src; extern crate intel_mkl_src;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor, D}; use candle::{DType, Device, Tensor};
use candle_hub::{api::sync::Api, Repo, RepoType}; use candle_hub::{api::sync::Api, Repo, RepoType};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use clap::Parser; use clap::Parser;
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
mod model; mod model;
@ -21,10 +21,9 @@ const DTYPE: DType = DType::BF16;
struct TextGeneration { struct TextGeneration {
model: Falcon, model: Falcon,
rng: rand::rngs::StdRng,
device: Device, device: Device,
temperature: Option<f64>,
tokenizer: Tokenizer, tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
} }
impl TextGeneration { impl TextGeneration {
@ -32,14 +31,14 @@ impl TextGeneration {
model: Falcon, model: Falcon,
tokenizer: Tokenizer, tokenizer: Tokenizer,
seed: u64, seed: u64,
temperature: Option<f64>, temp: Option<f64>,
device: &Device, device: &Device,
) -> Self { ) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp);
Self { Self {
model, model,
tokenizer, tokenizer,
rng: rand::rngs::StdRng::seed_from_u64(seed), logits_processor,
temperature,
device: device.clone(), device: device.clone(),
} }
} }
@ -67,20 +66,7 @@ impl TextGeneration {
let logits = self.model.forward(&input)?; let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let next_token = if let Some(temperature) = self.temperature { let next_token = self.logits_processor.sample(&logits)?;
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v
.iter()
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.unwrap()
};
tokens.push(next_token); tokens.push(next_token);
new_tokens.push(next_token); new_tokens.push(next_token);
println!("> {:?}", start_gen.elapsed()); println!("> {:?}", start_gen.elapsed());

View File

@ -14,11 +14,11 @@ extern crate intel_mkl_src;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use clap::Parser; use clap::Parser;
use rand::{distributions::Distribution, SeedableRng};
use candle::{DType, Device, Tensor, D}; use candle::{DType, Device, Tensor, D};
use candle_hub::{api::sync::Api, Repo, RepoType}; use candle_hub::{api::sync::Api, Repo, RepoType};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
mod model; mod model;
use model::{Config, Llama}; use model::{Config, Llama};
@ -185,8 +185,8 @@ fn main() -> Result<()> {
println!("pre-computing the positional embeddings"); println!("pre-computing the positional embeddings");
let freqs_cis = precompute_freqs_cis(&config, &device)?; let freqs_cis = precompute_freqs_cis(&config, &device)?;
println!("starting the inference loop"); println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
let mut new_tokens = vec![]; let mut new_tokens = vec![];
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed);
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
let mut index_pos = 0; let mut index_pos = 0;
for index in 0..args.sample_len { for index in 0..args.sample_len {
@ -207,21 +207,7 @@ fn main() -> Result<()> {
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
index_pos += ctxt.len(); index_pos += ctxt.len();
let next_token = if let Some(temperature) = args.temperature { let next_token = logits_processor.sample(&logits)?;
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v
.iter()
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.unwrap()
};
tokens.push(next_token); tokens.push(next_token);
new_tokens.push(next_token); new_tokens.push(next_token);
println!("> {:?}", start_gen.elapsed()); println!("> {:?}", start_gen.elapsed());

View File

@ -0,0 +1,25 @@
[package]
name = "candle-transformers"
version = "0.1.0"
edition = "2021"
description = "Pretrained models and inference API for the candle ML framework."
repository = "https://github.com/LaurentMazare/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT/Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", default-features=false }
candle-hub = { path = "../candle-hub" }
candle-nn = { path = "../candle-nn", default-features=false }
intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]}
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
rand = "0.8.5"
wav = "1.0.0"
[features]
default = ["cuda"]
cuda = ["candle/cuda", "candle-nn/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]

View File

@ -0,0 +1,35 @@
use candle::{DType, Error, Result, Tensor, D};
use rand::{distributions::Distribution, SeedableRng};
pub struct LogitsProcessor {
rng: rand::rngs::StdRng,
temperature: Option<f64>,
}
impl LogitsProcessor {
pub fn new(seed: u64, temperature: Option<f64>) -> Self {
Self {
rng: rand::rngs::StdRng::seed_from_u64(seed),
temperature,
}
}
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
let next_token = if let Some(temperature) = self.temperature {
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
let prs: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v
.iter()
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.unwrap()
};
Ok(next_token)
}
}

View File

@ -0,0 +1,3 @@
pub mod generation;
pub mod models;
pub mod pipelines;

View File

@ -0,0 +1 @@

View File

@ -0,0 +1 @@
pub mod text_generation;

View File

@ -0,0 +1 @@