mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Sketch the candle-transformers crate. (#147)
* Sketch the candle-transformers crate. * Format the empty files.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -19,3 +19,4 @@ Cargo.lock
|
|||||||
perf.data
|
perf.data
|
||||||
flamegraph.svg
|
flamegraph.svg
|
||||||
*.so
|
*.so
|
||||||
|
*.swp
|
||||||
|
@ -6,6 +6,7 @@ members = [
|
|||||||
"candle-hub",
|
"candle-hub",
|
||||||
"candle-nn",
|
"candle-nn",
|
||||||
"candle-pyo3",
|
"candle-pyo3",
|
||||||
|
"candle-transformers",
|
||||||
]
|
]
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
|
@ -157,6 +157,15 @@ pub enum Error {
|
|||||||
|
|
||||||
#[error("unsupported safetensor dtype {0:?}")]
|
#[error("unsupported safetensor dtype {0:?}")]
|
||||||
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Wrapped(Box<dyn std::error::Error + Send + Sync>),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Error>;
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
|
impl Error {
|
||||||
|
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||||
|
Self::Wrapped(Box::new(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -13,6 +13,7 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", default-features=false }
|
candle = { path = "../candle-core", default-features=false }
|
||||||
candle-nn = { path = "../candle-nn", default-features=false }
|
candle-nn = { path = "../candle-nn", default-features=false }
|
||||||
|
candle-transformers = { path = "../candle-transformers", default-features=false }
|
||||||
serde = { version = "1.0.166", features = ["derive"] }
|
serde = { version = "1.0.166", features = ["derive"] }
|
||||||
serde_json = "1.0.99"
|
serde_json = "1.0.99"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
@ -28,5 +29,5 @@ wav = "1.0.0"
|
|||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["cuda"]
|
default = ["cuda"]
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
|
@ -4,11 +4,11 @@
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{DType, Device, Tensor, D};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_hub::{api::sync::Api, Repo, RepoType};
|
use candle_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
mod model;
|
mod model;
|
||||||
@ -21,10 +21,9 @@ const DTYPE: DType = DType::BF16;
|
|||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
model: Falcon,
|
model: Falcon,
|
||||||
rng: rand::rngs::StdRng,
|
|
||||||
device: Device,
|
device: Device,
|
||||||
temperature: Option<f64>,
|
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TextGeneration {
|
impl TextGeneration {
|
||||||
@ -32,14 +31,14 @@ impl TextGeneration {
|
|||||||
model: Falcon,
|
model: Falcon,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
seed: u64,
|
seed: u64,
|
||||||
temperature: Option<f64>,
|
temp: Option<f64>,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp);
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
logits_processor,
|
||||||
temperature,
|
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -67,20 +66,7 @@ impl TextGeneration {
|
|||||||
let logits = self.model.forward(&input)?;
|
let logits = self.model.forward(&input)?;
|
||||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
let next_token = if let Some(temperature) = self.temperature {
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
|
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
|
||||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
|
||||||
distr.sample(&mut self.rng) as u32
|
|
||||||
} else {
|
|
||||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
|
||||||
logits_v
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.max_by(|(_, u), (_, v)| u.total_cmp(v))
|
|
||||||
.map(|(i, _)| i as u32)
|
|
||||||
.unwrap()
|
|
||||||
};
|
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
new_tokens.push(next_token);
|
new_tokens.push(next_token);
|
||||||
println!("> {:?}", start_gen.elapsed());
|
println!("> {:?}", start_gen.elapsed());
|
||||||
|
@ -14,11 +14,11 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor, D};
|
use candle::{DType, Device, Tensor, D};
|
||||||
use candle_hub::{api::sync::Api, Repo, RepoType};
|
use candle_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
|
||||||
mod model;
|
mod model;
|
||||||
use model::{Config, Llama};
|
use model::{Config, Llama};
|
||||||
@ -185,8 +185,8 @@ fn main() -> Result<()> {
|
|||||||
println!("pre-computing the positional embeddings");
|
println!("pre-computing the positional embeddings");
|
||||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
|
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||||
let mut new_tokens = vec![];
|
let mut new_tokens = vec![];
|
||||||
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed);
|
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
@ -207,21 +207,7 @@ fn main() -> Result<()> {
|
|||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
index_pos += ctxt.len();
|
index_pos += ctxt.len();
|
||||||
|
|
||||||
let next_token = if let Some(temperature) = args.temperature {
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
|
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
|
||||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
|
||||||
|
|
||||||
distr.sample(&mut rng) as u32
|
|
||||||
} else {
|
|
||||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
|
||||||
logits_v
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.max_by(|(_, u), (_, v)| u.total_cmp(v))
|
|
||||||
.map(|(i, _)| i as u32)
|
|
||||||
.unwrap()
|
|
||||||
};
|
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
new_tokens.push(next_token);
|
new_tokens.push(next_token);
|
||||||
println!("> {:?}", start_gen.elapsed());
|
println!("> {:?}", start_gen.elapsed());
|
||||||
|
25
candle-transformers/Cargo.toml
Normal file
25
candle-transformers/Cargo.toml
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
[package]
|
||||||
|
name = "candle-transformers"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
description = "Pretrained models and inference API for the candle ML framework."
|
||||||
|
repository = "https://github.com/LaurentMazare/candle"
|
||||||
|
keywords = ["blas", "tensor", "machine-learning"]
|
||||||
|
categories = ["science"]
|
||||||
|
license = "MIT/Apache-2.0"
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
candle = { path = "../candle-core", default-features=false }
|
||||||
|
candle-hub = { path = "../candle-hub" }
|
||||||
|
candle-nn = { path = "../candle-nn", default-features=false }
|
||||||
|
intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]}
|
||||||
|
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
|
||||||
|
rand = "0.8.5"
|
||||||
|
wav = "1.0.0"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["cuda"]
|
||||||
|
cuda = ["candle/cuda", "candle-nn/cuda"]
|
||||||
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
35
candle-transformers/src/generation/mod.rs
Normal file
35
candle-transformers/src/generation/mod.rs
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
use candle::{DType, Error, Result, Tensor, D};
|
||||||
|
use rand::{distributions::Distribution, SeedableRng};
|
||||||
|
|
||||||
|
pub struct LogitsProcessor {
|
||||||
|
rng: rand::rngs::StdRng,
|
||||||
|
temperature: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LogitsProcessor {
|
||||||
|
pub fn new(seed: u64, temperature: Option<f64>) -> Self {
|
||||||
|
Self {
|
||||||
|
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||||
|
temperature,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
||||||
|
let logits = logits.to_dtype(DType::F32)?;
|
||||||
|
let next_token = if let Some(temperature) = self.temperature {
|
||||||
|
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
|
||||||
|
let prs: Vec<f32> = prs.to_vec1()?;
|
||||||
|
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||||
|
distr.sample(&mut self.rng) as u32
|
||||||
|
} else {
|
||||||
|
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||||
|
logits_v
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.max_by(|(_, u), (_, v)| u.total_cmp(v))
|
||||||
|
.map(|(i, _)| i as u32)
|
||||||
|
.unwrap()
|
||||||
|
};
|
||||||
|
Ok(next_token)
|
||||||
|
}
|
||||||
|
}
|
3
candle-transformers/src/lib.rs
Normal file
3
candle-transformers/src/lib.rs
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
pub mod generation;
|
||||||
|
pub mod models;
|
||||||
|
pub mod pipelines;
|
1
candle-transformers/src/models/mod.rs
Normal file
1
candle-transformers/src/models/mod.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
1
candle-transformers/src/pipelines/mod.rs
Normal file
1
candle-transformers/src/pipelines/mod.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
pub mod text_generation;
|
1
candle-transformers/src/pipelines/text_generation.rs
Normal file
1
candle-transformers/src/pipelines/text_generation.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
Reference in New Issue
Block a user