Compare commits

...

20 Commits

Author SHA1 Message Date
2e273ddf31 Fixing the mkl dependency hell. 2025-03-27 18:01:21 +01:00
cb02b389d5 Fix reinforcement learning example (#2837) 2025-03-26 16:27:45 +01:00
0d4097031c fixed rand import for mnist-training (#2833) 2025-03-26 08:10:03 +01:00
10853b803c fixed rand imports for whisper-microphone example (#2834) 2025-03-26 08:09:27 +01:00
f3d472952f fix: candle-flash-attn linux and msvc build (#2829)
* fix: candle-flash-attn linux and msvc build

* Missing newline at eof.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2025-03-25 08:45:12 +01:00
67b85f79f1 Pickle decoder fix and Long1 opcode addition. (#2824)
* Pickle decoder changes: added Long1 opcode, fixed tensor offset calculation

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-03-23 08:10:08 +01:00
0b24f7f0a4 Fix for whisper example. rand::distribution is now rand::distr (#2811) 2025-03-16 19:14:55 +01:00
3afb04925a Allow for growing the default KV cache when needed. (#2810) 2025-03-16 17:30:25 +01:00
cbf5fc80c2 Add Gemma 3 1b IT toe Gemma examples (#2809)
- Updates the Gemma example to include Gemma 3 1b instruction tuned.
2025-03-16 17:00:48 +01:00
468d1d525f Bump the crate version to 0.8.4. (#2808) 2025-03-15 07:42:24 +01:00
c930ab7e1a upgrade half library to fix rand (#2806)
fix lints
2025-03-14 09:01:54 +01:00
111edbc4ea Gemma 3 initial setup (text only). (#2802)
* Gemma 3 initial setup (text only).

* Use the rotating kv cache for the sliding window.
2025-03-14 07:56:02 +01:00
e286cf7cc9 Parse the json config for siglip models. (#2800)
* Parse the json config for siglip models.

* Bump the tokenizers dependency.

* Add a v2 model.

* Support more v2 model.s
2025-03-09 14:01:09 +01:00
e4ffb85228 Add ModernBert sentency classifier (#2796) 2025-03-08 14:48:22 +01:00
37db86ff79 Allow ModernBert to be used to generate embeddings. (#2791) 2025-03-03 12:39:04 +01:00
add3a714aa phi-4-mini (#2790) 2025-03-01 10:07:29 +01:00
26c16923b9 Make sorted_nodes pub function (#2780) 2025-02-22 10:23:45 +01:00
9e8bf70333 Avoid some clippy lints on 1.85. (#2778)
* Avoid some clippy lints on 1.85.

* Upload artifacts v4.
2025-02-22 10:23:22 +01:00
ac9cdbd448 Refactor From<Tuple> implementations by using macros, add tests (#2762) 2025-02-19 10:58:29 +01:00
e6cc76fc37 Implement DeepSeek V2 (#2744)
* Add deepseek v2

* Fix

* Remove unused

* Add kv cache

* Remove from cargo.toml

* Fix dtype selection logic

* Fix unnecessary u32->f32->gather->u32

* Remove fromstr impl

* Use local scopes for some clarity

* Typo

* Repeat k_pe

* Chain calls to remove mut

* Actually, remove all muts

* Update readme
2025-02-19 10:51:01 +01:00
36 changed files with 2326 additions and 175 deletions

Binary file not shown.

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.8.3"
version = "0.8.4"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,40 +33,40 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.8.3" }
candle-datasets = { path = "./candle-datasets", version = "0.8.3" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.3" }
candle-kernels = { path = "./candle-kernels", version = "0.8.3" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.3" }
candle-nn = { path = "./candle-nn", version = "0.8.3" }
candle-onnx = { path = "./candle-onnx", version = "0.8.3" }
candle-transformers = { path = "./candle-transformers", version = "0.8.3" }
candle = { path = "./candle-core", package = "candle-core", version = "0.8.4" }
candle-datasets = { path = "./candle-datasets", version = "0.8.4" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.4" }
candle-kernels = { path = "./candle-kernels", version = "0.8.4" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.4" }
candle-nn = { path = "./candle-nn", version = "0.8.4" }
candle-onnx = { path = "./candle-onnx", version = "0.8.4" }
candle-transformers = { path = "./candle-transformers", version = "0.8.4" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.13.5", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.4.1"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
intel-mkl-src = { version = "0.8.1" }
libc = { version = "0.2.147" }
log = "0.4"
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
num_cpus = "1.15.0"
num-traits = "0.2.15"
parquet = { version = "51.0.0" }
rand = "0.8.5"
rand_distr = "0.4.3"
rand = "0.9.0"
rand_distr = "0.5.1"
rayon = "1.7.0"
safetensors = "0.4.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99"
thiserror = "1"
tokenizers = { version = "0.19.1", default-features = false }
tokenizers = { version = "0.21.0", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"

View File

@ -32,7 +32,7 @@ impl Tensor {
/// elements having dependencies on the latter ones, e.g. the first element if any is the
/// argument.
/// This assumes that the op graph is a DAG.
fn sorted_nodes(&self) -> Vec<&Tensor> {
pub fn sorted_nodes(&self) -> Vec<&Tensor> {
// The vec of sorted nodes is passed as an owned value rather than a mutable reference
// to get around some lifetime limitations.
fn walk<'a>(

View File

@ -2482,15 +2482,15 @@ impl BackendDevice for CpuDevice {
use rand::prelude::*;
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
match dtype {
DType::U8 | DType::U32 | DType::I64 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
}
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
let uniform =
rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max));
let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max))
.map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::<bf16, _>(uniform))
}
@ -2498,8 +2498,8 @@ impl BackendDevice for CpuDevice {
}
DType::F16 => {
let mut data = Vec::with_capacity(elem_count);
let uniform =
rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max));
let uniform = rand::distr::Uniform::new(f16::from_f64(min), f16::from_f64(max))
.map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::<f16, _>(uniform))
}
@ -2507,7 +2507,8 @@ impl BackendDevice for CpuDevice {
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
let uniform = rand::distributions::Uniform::new(min as f32, max as f32);
let uniform =
rand::distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::<f32, _>(uniform))
}
@ -2515,7 +2516,7 @@ impl BackendDevice for CpuDevice {
}
DType::F64 => {
let mut data = Vec::with_capacity(elem_count);
let uniform = rand::distributions::Uniform::new(min, max);
let uniform = rand::distr::Uniform::new(min, max).map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(rng.sample::<f64, _>(uniform))
}
@ -2528,7 +2529,7 @@ impl BackendDevice for CpuDevice {
use rand::prelude::*;
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
match dtype {
DType::U8 | DType::U32 | DType::I64 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())

View File

@ -45,6 +45,7 @@ pub enum OpCode {
BinFloat = b'G',
Append = b'a',
Appends = b'e',
Long1 = 0x8a,
}
// Avoid using FromPrimitive so as not to drag another dependency.
@ -84,6 +85,7 @@ impl TryFrom<u8> for OpCode {
b'G' => Ok(Self::BinFloat),
b'a' => Ok(Self::Append),
b'e' => Ok(Self::Appends),
0x8a => Ok(Self::Long1),
value => Err(value),
}
}
@ -106,6 +108,7 @@ pub enum Object {
class_name: String,
},
Int(i32),
Long(i64),
Float(f64),
Unicode(String),
Bool(bool),
@ -170,6 +173,14 @@ impl Object {
}
}
pub fn int_or_long(self) -> OResult<i64> {
match self {
Self::Int(t) => Ok(t as i64),
Self::Long(t) => Ok(t),
_ => Err(self),
}
}
pub fn tuple(self) -> OResult<Vec<Self>> {
match self {
Self::Tuple(t) => Ok(t),
@ -590,6 +601,15 @@ impl Stack {
let obj = self.new_obj(class, args)?;
self.push(obj)
}
OpCode::Long1 => {
let n_bytes = r.read_u8()?;
let mut v = 0;
// Decode the next n bytes in little endian
for i in 0..n_bytes {
v |= (r.read_u8()? as i64) << (i * 8);
}
self.push(Object::Long(v))
}
}
Ok(false)
}
@ -607,10 +627,10 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
let mut args = args.tuple()?;
let stride = Vec::<usize>::try_from(args.remove(3))?;
let size = Vec::<usize>::try_from(args.remove(2))?;
let offset = args.remove(1).int()? as usize;
let offset = args.remove(1).int_or_long()? as usize;
let storage = args.remove(0).persistent_load()?;
let mut storage = storage.tuple()?;
let storage_size = storage.remove(4).int()? as usize;
let storage_size = storage.remove(4).int_or_long()? as usize;
let path = storage.remove(2).unicode()?;
let (_module_name, class_name) = storage.remove(1).class()?;
let dtype = match class_name.as_str() {
@ -624,7 +644,11 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
crate::bail!("unsupported storage type {other}")
}
};
let layout = Layout::new(crate::Shape::from(size), stride, offset);
let layout = Layout::new(
crate::Shape::from(size),
stride,
offset * dtype.size_in_bytes(),
);
Ok((layout, dtype, path, storage_size))
}

View File

@ -43,43 +43,22 @@ impl From<usize> for Shape {
}
}
impl From<(usize,)> for Shape {
fn from(d1: (usize,)) -> Self {
Self(vec![d1.0])
macro_rules! impl_from_tuple {
($tuple:ty, $($index:tt),+) => {
impl From<$tuple> for Shape {
fn from(d: $tuple) -> Self {
Self(vec![$(d.$index,)+])
}
}
}
}
impl From<(usize, usize)> for Shape {
fn from(d12: (usize, usize)) -> Self {
Self(vec![d12.0, d12.1])
}
}
impl From<(usize, usize, usize)> for Shape {
fn from(d123: (usize, usize, usize)) -> Self {
Self(vec![d123.0, d123.1, d123.2])
}
}
impl From<(usize, usize, usize, usize)> for Shape {
fn from(d1234: (usize, usize, usize, usize)) -> Self {
Self(vec![d1234.0, d1234.1, d1234.2, d1234.3])
}
}
impl From<(usize, usize, usize, usize, usize)> for Shape {
fn from(d12345: (usize, usize, usize, usize, usize)) -> Self {
Self(vec![d12345.0, d12345.1, d12345.2, d12345.3, d12345.4])
}
}
impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
Self(vec![
d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
])
}
}
impl_from_tuple!((usize,), 0);
impl_from_tuple!((usize, usize), 0, 1);
impl_from_tuple!((usize, usize, usize), 0, 1, 2);
impl_from_tuple!((usize, usize, usize, usize), 0, 1, 2, 3);
impl_from_tuple!((usize, usize, usize, usize, usize), 0, 1, 2, 3, 4);
impl_from_tuple!((usize, usize, usize, usize, usize, usize), 0, 1, 2, 3, 4, 5);
impl From<Vec<usize>> for Shape {
fn from(dims: Vec<usize>) -> Self {
@ -636,4 +615,20 @@ mod tests {
let shape = Shape::from((299, 792, 458));
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
#[test]
fn test_from_tuple() {
let shape = Shape::from((2,));
assert_eq!(shape.dims(), &[2]);
let shape = Shape::from((2, 3));
assert_eq!(shape.dims(), &[2, 3]);
let shape = Shape::from((2, 3, 4));
assert_eq!(shape.dims(), &[2, 3, 4]);
let shape = Shape::from((2, 3, 4, 5));
assert_eq!(shape.dims(), &[2, 3, 4, 5]);
let shape = Shape::from((2, 3, 4, 5, 6));
assert_eq!(shape.dims(), &[2, 3, 4, 5, 6]);
let shape = Shape::from((2, 3, 4, 5, 6, 7));
assert_eq!(shape.dims(), &[2, 3, 4, 5, 6, 7]);
}
}

View File

@ -880,10 +880,10 @@ fn get_random_tensors(
let mut rng = StdRng::seed_from_u64(314159265358979);
let lhs = (0..m * k)
.map(|_| rng.gen::<f32>() - 0.5)
.map(|_| rng.random::<f32>() - 0.5)
.collect::<Vec<_>>();
let rhs = (0..n * k)
.map(|_| rng.gen::<f32>() - 0.5)
.map(|_| rng.random::<f32>() - 0.5)
.collect::<Vec<_>>();
let lhs = Tensor::from_vec(lhs, (m, k), device)?;

View File

@ -60,8 +60,8 @@ pub struct DatasetRandomIter<'a> {
impl<'a> DatasetRandomIter<'a> {
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
use rand::rng;
use rand::seq::SliceRandom;
use rand::thread_rng;
let all_tokens = if valid {
&ds.valid_tokens
@ -69,13 +69,13 @@ impl<'a> DatasetRandomIter<'a> {
&ds.train_tokens
};
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
tokens.shuffle(&mut thread_rng());
tokens.shuffle(&mut rng());
let current_tokens = tokens.pop().unwrap();
let seq_len_in_bytes = seq_len * 2;
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
.step_by(seq_len_in_bytes)
.collect::<Vec<_>>();
indexes_in_bytes.shuffle(&mut thread_rng());
indexes_in_bytes.shuffle(&mut rng());
Self {
all_tokens,
tokens,
@ -92,21 +92,21 @@ impl Iterator for DatasetRandomIter<'_> {
fn next(&mut self) -> Option<Self::Item> {
use byteorder::{LittleEndian, ReadBytesExt};
use rand::rng;
use rand::seq::SliceRandom;
use rand::thread_rng;
let seq_len = self.seq_len;
if self.indexes_in_bytes.is_empty() {
if self.tokens.is_empty() {
self.tokens = self.all_tokens.iter().collect();
self.tokens.shuffle(&mut thread_rng());
self.tokens.shuffle(&mut rng());
}
self.current_tokens = self.tokens.pop().unwrap();
let seq_len_in_bytes = self.seq_len * 2;
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
.step_by(seq_len_in_bytes)
.collect::<Vec<_>>();
self.indexes_in_bytes.shuffle(&mut thread_rng());
self.indexes_in_bytes.shuffle(&mut rng());
}
let start_idx = self.indexes_in_bytes.pop().unwrap();
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];

View File

@ -0,0 +1,33 @@
# DeepSeek V2
DeepSeek V2 an MoE model featuring MLA (Multi-Latent Attention). There is a lite (16B) and a full (236B) model.
- Context length of **32k tokens** (Lite model), **128k tokens** (full model)
- 64 routed experts (Lite model), 160 routed experts (full model)
## Running the example
```bash
$ cargo run --example deepseekv2 --release --features metal -- --prompt "Recursive fibonacci code in Rust:" --which lite --sample-len 150
fn fibonacci(n: u32) -> u32 {
if n <= 1 {
return n;
} else {
return fibonacci(n - 1) + fibonacci(n - 2);
}
}
## Fibonacci code in Python:
def fibonacci(n):
if n <= 1:
return n
else:
return fibonacci(n-1) + fibonacci(n-2)
## Fibonacci code in JavaScript:
function fibonacci(n) {
if (n <= 1
```

View File

@ -0,0 +1,282 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::deepseek2::{DeepSeekV2, DeepSeekV2Config};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: DeepSeekV2,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: DeepSeekV2,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
top_k: Option<usize>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = {
let temperature = temp.unwrap_or(0.);
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (top_k, top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<end▁of▁sentence>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <end▁of▁sentence> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "lite")]
Lite,
#[value(name = "lite-chat")]
LiteChat,
#[value(name = "coder-lite-chat")]
CoderLiteChat,
#[value(name = "v2")]
V2,
#[value(name = "v2-chat")]
V2Chat,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
/// The model size to use.
#[arg(long, default_value = "lite")]
which: Which,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => match args.which {
Which::CoderLiteChat => "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct".to_string(),
Which::LiteChat => "deepseek-ai/DeepSeek-V2-Lite-Chat".to_string(),
Which::Lite => "deepseek-ai/DeepSeek-V2-Lite".to_string(),
Which::V2 => "deepseek-ai/DeepSeek-V2".to_string(),
Which::V2Chat => "deepseek-ai/DeepSeek-V2-Chat".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = repo.get("tokenizer.json")?;
let filenames = candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?;
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: DeepSeekV2Config = {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
};
let device = candle_examples::device(args.cpu)?;
let (model, device) = {
let dtype = if device.is_cpu() {
DType::F16
} else {
DType::BF16
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = DeepSeekV2::new(&config, vb)?;
(model, device)
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.top_k,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -9,6 +9,7 @@ use clap::Parser;
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -47,29 +48,16 @@ enum Which {
BaseV2_9B,
#[value(name = "2-9b-it")]
InstructV2_9B,
}
impl Which {
fn is_v1(&self) -> bool {
match self {
Self::Base2B
| Self::Base7B
| Self::Instruct2B
| Self::Instruct7B
| Self::InstructV1_1_2B
| Self::InstructV1_1_7B
| Self::CodeBase2B
| Self::CodeBase7B
| Self::CodeInstruct2B
| Self::CodeInstruct7B => true,
Self::BaseV2_2B | Self::InstructV2_2B | Self::BaseV2_9B | Self::InstructV2_9B => false,
}
}
#[value(name = "3-1b")]
BaseV3_1B,
#[value(name = "3-1b-it")]
InstructV3_1B,
}
enum Model {
V1(Model1),
V2(Model2),
V3(Model3),
}
impl Model {
@ -77,6 +65,7 @@ impl Model {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
Self::V3(m) => m.forward(input_ids, pos),
}
}
}
@ -284,6 +273,8 @@ fn main() -> Result<()> {
Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
@ -304,7 +295,10 @@ fn main() -> Result<()> {
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
None => match args.which {
Which::BaseV3_1B | Which::InstructV3_1B => vec![repo.get("model.safetensors")?],
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
@ -317,14 +311,31 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = if args.which.is_v1() {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(args.use_flash_attn, &config, vb)?;
Model::V1(model)
} else {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(args.use_flash_attn, &config, vb)?;
Model::V2(model)
let model = match args.which {
Which::Base2B
| Which::Base7B
| Which::Instruct2B
| Which::Instruct7B
| Which::InstructV1_1_2B
| Which::InstructV1_1_7B
| Which::CodeBase2B
| Which::CodeBase7B
| Which::CodeInstruct2B
| Which::CodeInstruct7B => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(args.use_flash_attn, &config, vb)?;
Model::V1(model)
}
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(args.use_flash_attn, &config, vb)?;
Model::V2(model)
}
Which::BaseV3_1B | Which::InstructV3_1B => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model3::new(args.use_flash_attn, &config, vb)?;
Model::V3(model)
}
};
println!("loaded the model in {:?}", start.elapsed());

View File

@ -16,7 +16,7 @@ use candle_transformers::models::quantized_metavoice::transformer as qtransforme
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use hf_hub::api::sync::Api;
use rand::{distributions::Distribution, SeedableRng};
use rand::{distr::Distribution, SeedableRng};
pub const ENCODEC_NTOKENS: u32 = 1024;
@ -250,7 +250,7 @@ fn main() -> Result<()> {
let logits = logits.i(step)?.to_dtype(DType::F32)?;
let logits = &(&logits / 1.0)?;
let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::<f32>()?;
let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?;
let distr = rand::distr::weighted::WeightedIndex::new(prs.as_slice())?;
let sample = distr.sample(&mut rng) as u32;
codes_.push(sample)
}

View File

@ -7,6 +7,7 @@ extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use rand::prelude::*;
use rand::rng;
use candle::{DType, Result, Tensor, D};
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
@ -138,7 +139,7 @@ fn training_loop_cnn(
let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
for epoch in 1..args.epochs {
let mut sum_loss = 0f32;
batch_idxs.shuffle(&mut thread_rng());
batch_idxs.shuffle(&mut rng());
for batch_idx in batch_idxs.iter() {
let train_images = train_images.narrow(0, batch_idx * BSIZE, BSIZE)?;
let train_labels = train_labels.narrow(0, batch_idx * BSIZE, BSIZE)?;

View File

@ -148,6 +148,8 @@ enum WhichModel {
#[value(name = "3-medium")]
V3Medium,
#[value(name = "2-old")]
V4Mini,
#[value(name = "4-mini")]
V2Old,
PuffinPhiV2,
PhiHermes,
@ -261,6 +263,7 @@ fn main() -> Result<()> {
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(),
WhichModel::V3Medium => "microsoft/Phi-3-medium-4k-instruct".to_string(),
WhichModel::V4Mini => "microsoft/Phi-4-mini-instruct".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string()
}
@ -281,6 +284,7 @@ fn main() -> Result<()> {
WhichModel::V2
| WhichModel::V3
| WhichModel::V3Medium
| WhichModel::V4Mini
| WhichModel::PuffinPhiV2
| WhichModel::PhiHermes => "main".to_string(),
}
@ -296,7 +300,8 @@ fn main() -> Result<()> {
| WhichModel::V2
| WhichModel::V2Old
| WhichModel::V3
| WhichModel::V3Medium => repo.get("tokenizer.json")?,
| WhichModel::V3Medium
| WhichModel::V4Mini => repo.get("tokenizer.json")?,
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")?
}
@ -312,19 +317,21 @@ fn main() -> Result<()> {
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
WhichModel::V3 | WhichModel::V3Medium => anyhow::bail!(
WhichModel::V3 | WhichModel::V3Medium | WhichModel::V4Mini => anyhow::bail!(
"use the quantized or quantized-phi examples for quantized phi-v3"
),
}
} else {
match args.model {
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 | WhichModel::V3Medium => {
candle_examples::hub_load_safetensors(
&repo,
"model.safetensors.index.json",
)?
}
WhichModel::V2
| WhichModel::V2Old
| WhichModel::V3
| WhichModel::V3Medium
| WhichModel::V4Mini => candle_examples::hub_load_safetensors(
&repo,
"model.safetensors.index.json",
)?,
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
}
@ -341,7 +348,7 @@ fn main() -> Result<()> {
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
WhichModel::V3 | WhichModel::V3Medium => {
WhichModel::V3 | WhichModel::V3Medium | WhichModel::V4Mini => {
panic!("use the quantized or quantized-phi examples for quantized phi-v3")
}
};
@ -361,7 +368,10 @@ fn main() -> Result<()> {
let dtype = match args.dtype {
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
None => {
if args.model == WhichModel::V3 || args.model == WhichModel::V3Medium {
if args.model == WhichModel::V3
|| args.model == WhichModel::V3Medium
|| args.model == WhichModel::V4Mini
{
device.bf16_default_to_f32()
} else {
DType::F32
@ -377,7 +387,7 @@ fn main() -> Result<()> {
let phi = Phi::new(&config, vb)?;
Model::Phi(phi)
}
WhichModel::V3 | WhichModel::V3Medium => {
WhichModel::V3 | WhichModel::V3Medium | WhichModel::V4Mini => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: Phi3Config = serde_json::from_str(&config)?;

View File

@ -5,7 +5,7 @@ use candle_nn::{
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
VarBuilder, VarMap,
};
use rand::{distributions::Uniform, thread_rng, Rng};
use rand::{distr::Uniform, rng, Rng};
use super::gym_env::GymEnv;
@ -103,8 +103,8 @@ impl ReplayBuffer {
if self.size < batch_size {
Ok(None)
} else {
let transitions: Vec<&Transition> = thread_rng()
.sample_iter(Uniform::from(0..self.size))
let transitions: Vec<&Transition> = rng()
.sample_iter(Uniform::try_from(0..self.size).map_err(Error::wrap)?)
.take(batch_size)
.map(|i| self.buffer.get(i).unwrap())
.collect();
@ -498,11 +498,11 @@ pub fn run() -> Result<()> {
OuNoise::new(MU, THETA, SIGMA, size_action)?,
)?;
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
for episode in 0..MAX_EPISODES {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut state = env.reset(rng.random::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
@ -538,7 +538,7 @@ pub fn run() -> Result<()> {
agent.train = false;
for episode in 0..10 {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut state = env.reset(rng.random::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;

View File

@ -1,9 +1,8 @@
use std::collections::VecDeque;
use rand::distributions::Uniform;
use rand::{thread_rng, Rng};
use rand::{distr::Uniform, rng, Rng};
use candle::{DType, Device, Module, Result, Tensor};
use candle::{DType, Device, Error, Module, Result, Tensor};
use candle_nn::loss::mse;
use candle_nn::{linear, seq, Activation, AdamW, Optimizer, VarBuilder, VarMap};
@ -65,8 +64,8 @@ pub fn run() -> Result<()> {
// fed to the model so that it performs a backward pass.
if memory.len() > BATCH_SIZE {
// Sample randomly from the memory.
let batch = thread_rng()
.sample_iter(Uniform::from(0..memory.len()))
let batch = rng()
.sample_iter(Uniform::try_from(0..memory.len()).map_err(Error::wrap)?)
.take(BATCH_SIZE)
.map(|i| memory.get(i).unwrap().clone())
.collect::<Vec<_>>();

View File

@ -4,7 +4,7 @@ use candle_nn::{
linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer,
ParamsAdamW, VarBuilder, VarMap,
};
use rand::{distributions::Distribution, rngs::ThreadRng, Rng};
use rand::{distr::Distribution, rngs::ThreadRng, Rng};
fn new_model(
input_shape: &[usize],
@ -39,7 +39,7 @@ fn accumulate_rewards(steps: &[Step<i64>]) -> Vec<f64> {
}
fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> {
let distribution = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?;
let distribution = rand::distr::weighted::WeightedIndex::new(probs).map_err(Error::wrap)?;
let mut rng = rng;
Ok(distribution.sample(&mut rng))
}
@ -65,10 +65,10 @@ pub fn run() -> Result<()> {
let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
for epoch_idx in 0..100 {
let mut state = env.reset(rng.gen::<u64>())?;
let mut state = env.reset(rng.random::<u64>())?;
let mut steps: Vec<Step<i64>> = vec![];
loop {
@ -84,7 +84,7 @@ pub fn run() -> Result<()> {
steps.push(step.copy_with_obs(&state));
if step.terminated || step.truncated {
state = env.reset(rng.gen::<u64>())?;
state = env.reset(rng.random::<u64>())?;
if steps.len() > 5000 {
break;
}

View File

@ -13,11 +13,40 @@ use candle_transformers::models::siglip;
use tokenizers::Tokenizer;
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
enum Which {
#[value(name = "v1-base-patch16-224")]
V1BasePatch16_224,
#[value(name = "v2-base-patch16-224")]
V2BasePatch16_224,
#[value(name = "v2-base-patch16-256")]
V2BasePatch16_256,
#[value(name = "v2-base-patch16-384")]
V2BasePatch16_384,
#[value(name = "v2-base-patch16-512")]
V2BasePatch16_512,
#[value(name = "v2-large-patch16-256")]
V2LargePatch16_256,
#[value(name = "v2-large-patch16-384")]
V2LargePatch16_384,
#[value(name = "v2-large-patch16-512")]
V2LargePatch16_512,
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
config: Option<String>,
#[arg(long)]
hf_repo: Option<String>,
#[arg(long, default_value = "v1-base-patch16-224")]
which: Which,
#[arg(long)]
tokenizer: Option<String>,
@ -66,16 +95,37 @@ fn load_images<T: AsRef<std::path::Path>>(
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let hf_repo = match args.hf_repo.as_ref() {
Some(hf_repo) => hf_repo,
None => match args.which {
Which::V1BasePatch16_224 => "google/siglip-base-patch16-224",
Which::V2BasePatch16_224 => "google/siglip2-base-patch16-224",
Which::V2BasePatch16_256 => "google/siglip2-base-patch16-256",
Which::V2BasePatch16_384 => "google/siglip2-base-patch16-384",
Which::V2BasePatch16_512 => "google/siglip2-base-patch16-512",
Which::V2LargePatch16_256 => "google/siglip2-large-patch16-256",
Which::V2LargePatch16_384 => "google/siglip2-large-patch16-384",
Which::V2LargePatch16_512 => "google/siglip2-large-patch16-512",
},
};
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
let api = api.model(hf_repo.to_string());
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let tokenizer = get_tokenizer(args.tokenizer)?;
let config = siglip::Config::base_patch16_224();
let config_file = match args.config {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(hf_repo.to_string());
api.get("config.json")?
}
Some(config) => config.into(),
};
let tokenizer = get_tokenizer(hf_repo, args.tokenizer)?;
let config: siglip::Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
@ -114,11 +164,11 @@ pub fn main() -> anyhow::Result<()> {
Ok(())
}
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
pub fn get_tokenizer(hf_repo: &str, tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
let tokenizer = match tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
let api = api.model(hf_repo.to_string());
api.get("tokenizer.json")?
}
Some(file) => file.into(),

View File

@ -617,7 +617,7 @@ fn run(args: Args) -> Result<()> {
let mut scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
// If a seed is not given, generate a random seed and print it
let seed = seed.unwrap_or(rand::thread_rng().gen_range(0u64..u64::MAX));
let seed = seed.unwrap_or(rand::rng().random_range(0u64..u64::MAX));
println!("Using seed {seed}");
device.set_seed(seed)?;
let use_guide_scale = guidance_scale > 1.0;

View File

@ -9,7 +9,7 @@ use candle::{Device, IndexOp, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use rand::{distr::Distribution, SeedableRng};
use tokenizers::Tokenizer;
mod multilingual;
@ -204,7 +204,7 @@ impl Decoder {
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
let distr = rand::distr::weighted::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;

View File

@ -14,7 +14,9 @@ use candle::{Device, IndexOp, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use rand::distr::weighted::WeightedIndex;
use rand::distr::Distribution;
use rand::SeedableRng;
use tokenizers::Tokenizer;
mod multilingual;
@ -208,7 +210,7 @@ impl Decoder {
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
let distr = WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.8.3"
version = "0.8.4"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.3" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.4" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -88,19 +88,26 @@ fn main() -> Result<()> {
.arg("--use_fast_math")
.arg("--verbose");
let mut is_target_msvc = false;
if let Ok(target) = std::env::var("TARGET") {
if target.contains("msvc") {
is_target_msvc = true;
builder = builder.arg("-D_USE_MATH_DEFINES");
}
}
if !is_target_msvc {
builder = builder.arg("-Xcompiler").arg("-fPIC");
}
let out_file = build_dir.join("libflashattention.a");
builder.build_lib(out_file);
println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=flashattention");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++");
if !is_target_msvc {
println!("cargo:rustc-link-lib=dylib=stdc++");
}
Ok(())
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.8.3"
version = "0.8.4"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.8.3"
version = "0.8.4"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -11,6 +11,7 @@ pub struct Cache {
all_data: Option<Tensor>,
dim: usize,
current_seq_len: usize,
grow_by: usize,
max_seq_len: usize,
}
@ -20,6 +21,7 @@ impl Cache {
all_data: None,
dim,
current_seq_len: 0,
grow_by: max_seq_len,
max_seq_len,
}
}
@ -65,11 +67,11 @@ impl Cache {
};
let ad = self.all_data.as_mut().unwrap();
if self.current_seq_len + seq_len > self.max_seq_len {
candle::bail!(
"kv-cache: above max-seq-len {}+{seq_len}>{}",
self.current_seq_len,
self.max_seq_len
)
let mut shape = src.dims().to_vec();
shape[self.dim] = self.grow_by;
let next_ad = Tensor::zeros(shape, src.dtype(), src.device())?;
*ad = Tensor::cat(&[&*ad, &next_ad], self.dim)?;
self.max_seq_len += self.grow_by;
}
ad.slice_set(src, self.dim, self.current_seq_len)?;
self.current_seq_len += seq_len;

View File

@ -83,7 +83,7 @@ fn rms_norml(device: &Device) -> Result<()> {
let (b_size, seq_len, head_dim) = (24, 70, 64);
let el_count = b_size * seq_len * head_dim;
let mut rng = StdRng::seed_from_u64(299792458);
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
let src: Vec<f32> = (0..el_count).map(|_| rng.random::<f32>()).collect();
let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?;
let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?;
let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?;
@ -130,7 +130,7 @@ fn layer_norml(device: &Device) -> Result<()> {
let (b_size, seq_len, head_dim) = (24, 70, 64);
let el_count = b_size * seq_len * head_dim;
let mut rng = StdRng::seed_from_u64(299792458);
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
let src: Vec<f32> = (0..el_count).map(|_| rng.random::<f32>()).collect();
let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?;
let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?;
let beta = Tensor::zeros(head_dim, candle::DType::F32, device)?;
@ -161,12 +161,12 @@ fn ropei(device: &Device) -> Result<()> {
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
let el_count = b_size * num_head * seq_len * head_dim;
let mut rng = StdRng::seed_from_u64(299792458);
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
let src: Vec<f32> = (0..el_count).map(|_| rng.random::<f32>()).collect();
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.gen::<f32>())
.map(|_| rng.random::<f32>())
.collect();
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.gen::<f32>())
.map(|_| rng.random::<f32>())
.collect();
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
@ -188,12 +188,12 @@ fn rope(device: &Device) -> Result<()> {
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
let el_count = b_size * num_head * seq_len * head_dim;
let mut rng = StdRng::seed_from_u64(299792458);
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
let src: Vec<f32> = (0..el_count).map(|_| rng.random::<f32>()).collect();
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.gen::<f32>())
.map(|_| rng.random::<f32>())
.collect();
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.gen::<f32>())
.map(|_| rng.random::<f32>())
.collect();
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
@ -215,12 +215,12 @@ fn rope_thd(device: &Device) -> Result<()> {
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
let el_count = b_size * num_head * seq_len * head_dim;
let mut rng = StdRng::seed_from_u64(299792458);
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
let src: Vec<f32> = (0..el_count).map(|_| rng.random::<f32>()).collect();
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.gen::<f32>())
.map(|_| rng.random::<f32>())
.collect();
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.gen::<f32>())
.map(|_| rng.random::<f32>())
.collect();
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.8.3"
version = "0.8.4"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.8.3" }
candle-nn = { path = "../candle-nn", version = "0.8.3" }
candle = { path = "../candle-core", package = "candle-core", version = "0.8.4" }
candle-nn = { path = "../candle-nn", version = "0.8.4" }
prost = "0.12.1"
[build-dependencies]

View File

@ -1,4 +1,5 @@
#![allow(clippy::redundant_closure_call)]
#![allow(clippy::useless_conversion)]
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::pyclass::CompareOp;

View File

@ -4,7 +4,7 @@
//! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p),
//! and combinations thereof.
use candle::{Context, DType, Error, Result, Tensor};
use rand::{distributions::Distribution, SeedableRng};
use rand::{distr::Distribution, SeedableRng};
#[derive(Clone, PartialEq, Debug)]
pub enum Sampling {
@ -50,7 +50,7 @@ impl LogitsProcessor {
}
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
let next_token = distr.sample(&mut self.rng) as u32;
Ok(next_token)
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,483 @@
//! Gemma LLM architecture (Google) inference implementation.
//!
//! See ["Introducing Gemma 3: The most capable model you can run on a single GPU or TPU"](https://blog.google/technology/developers/gemma-3/)
//!
//! Based on implementations from HuggingFace transformers.
use std::sync::Arc;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub attention_bias: bool,
pub head_dim: usize,
pub hidden_activation: Activation,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_attention_heads: usize,
pub num_hidden_layers: usize,
pub num_key_value_heads: usize,
pub rms_norm_eps: f64,
pub rope_theta: f64,
pub vocab_size: usize,
pub final_logit_softcapping: Option<f64>,
pub attn_logit_softcapping: Option<f64>,
pub query_pre_attn_scalar: usize,
pub sliding_window: usize,
pub sliding_window_pattern: usize,
pub max_position_embeddings: usize,
}
#[derive(Debug, Clone)]
struct RmsNorm {
weight: Tensor,
eps: f64,
}
impl RmsNorm {
fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(dim, "weight")?;
Ok(Self { weight, eps })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&(&self.weight + 1.0)?)
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: candle_nn::Activation,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let intermediate_sz = cfg.intermediate_size;
let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?;
let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?;
let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_activation,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
#[derive(Debug, Clone)]
enum KvCache {
Normal(candle_nn::kv_cache::KvCache),
Rotating(candle_nn::kv_cache::RotatingKvCache),
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
q_norm: RmsNorm,
k_norm: RmsNorm,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
attn_logit_softcapping: Option<f64>,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: KvCache,
use_flash_attn: bool,
}
impl Attention {
fn new(
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
is_sliding: bool,
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = cfg.head_dim;
let bias = cfg.attention_bias;
let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
let kv_cache = if is_sliding {
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(
2,
cfg.sliding_window,
))
} else {
KvCache::Normal(candle_nn::kv_cache::KvCache::new(2, cfg.sliding_window))
};
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
attn_logit_softcapping: cfg.attn_logit_softcapping,
rotary_emb,
kv_cache,
use_flash_attn,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let query_states = self.q_norm.forward(&query_states)?;
let key_states = self.k_norm.forward(&key_states)?;
let (query_states, key_states) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
let (key_states, value_states) = match &mut self.kv_cache {
KvCache::Normal(cache) => cache.append(&key_states, &value_states)?,
KvCache::Rotating(cache) => cache.append(&key_states, &value_states)?,
};
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
let value_states =
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_output = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
let q = query_states.transpose(1, 2)?;
let k = key_states.transpose(1, 2)?;
let v = value_states.transpose(1, 2)?;
let scale = 1f32 / (self.head_dim as f32).sqrt();
flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)?
} else {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = match self.attn_logit_softcapping {
None => attn_weights,
Some(sc) => ((attn_weights / sc)?.tanh()? * sc)?,
};
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_states)?
};
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, ()))?
.apply(&self.o_proj)
}
fn clear_kv_cache(&mut self) {
match &mut self.kv_cache {
KvCache::Normal(c) => c.reset(),
KvCache::Rotating(c) => c.reset(),
}
}
}
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
input_layernorm: RmsNorm,
pre_feedforward_layernorm: RmsNorm,
post_feedforward_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
}
impl DecoderLayer {
fn new(
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
is_sliding: bool,
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let self_attn = Attention::new(
rotary_emb,
use_flash_attn,
is_sliding,
cfg,
vb.pp("self_attn"),
)?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm =
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let pre_feedforward_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("pre_feedforward_layernorm"),
)?;
let post_feedforward_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_feedforward_layernorm"),
)?;
let post_attention_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
pre_feedforward_layernorm,
post_feedforward_layernorm,
post_attention_layernorm,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
let xs = xs.apply(&self.post_attention_layernorm)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.pre_feedforward_layernorm)?;
let xs = xs.apply(&self.mlp)?;
let xs = xs.apply(&self.post_feedforward_layernorm)?;
residual + xs
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
final_logit_softcapping: Option<f64>,
device: Device,
dtype: DType,
hidden_size: usize,
sliding_window: usize,
}
impl Model {
pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let is_sliding = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
let layer = DecoderLayer::new(
rotary_emb.clone(),
use_flash_attn,
is_sliding,
cfg,
vb_l.pp(layer_idx),
)?;
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
final_logit_softcapping: cfg.final_logit_softcapping,
device: vb.device().clone(),
dtype: vb.dtype(),
hidden_size: cfg.hidden_size,
sliding_window: cfg.sliding_window,
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
let mask: Vec<_> = match Some(self.sliding_window) {
None => (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect(),
Some(sliding_window) => (0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect(),
};
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let xs = self.embed_tokens.forward(input_ids)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
let logits = xs
.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)?;
let logits = match self.final_logit_softcapping {
None => logits,
Some(sc) => ((logits / sc)?.tanh()? * sc)?,
};
Ok(logits)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
}

View File

@ -29,6 +29,7 @@ pub mod convmixer;
pub mod convnext;
pub mod dac;
pub mod debertav2;
pub mod deepseek2;
pub mod depth_anything_v2;
pub mod dinov2;
pub mod dinov2reg4;
@ -42,6 +43,7 @@ pub mod fastvit;
pub mod flux;
pub mod gemma;
pub mod gemma2;
pub mod gemma3;
pub mod glm4;
pub mod granite;
pub mod helium;

View File

@ -6,14 +6,15 @@
//! - See modernbert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code
//!
use candle::{DType, Device, Result, Tensor, D};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{
embedding, layer_norm_no_bias, linear_no_bias, ops::softmax, Embedding, LayerNorm, Linear,
Module, VarBuilder,
embedding, layer_norm_no_bias, linear, linear_no_bias, ops::softmax, Embedding, LayerNorm,
Linear, Module, VarBuilder,
};
use serde::Deserialize;
use core::f32;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Deserialize)]
@ -30,6 +31,24 @@ pub struct Config {
pub global_rope_theta: f64,
pub local_attention: usize,
pub local_rope_theta: f64,
#[serde(default)]
#[serde(flatten)]
pub classifier_config: Option<ClassifierConfig>,
}
#[derive(Debug, Clone, Deserialize, PartialEq, Copy, Default)]
#[serde(rename_all = "lowercase")]
pub enum ClassifierPooling {
#[default]
CLS,
MEAN,
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct ClassifierConfig {
pub id2label: HashMap<String, String>,
pub label2id: HashMap<String, String>,
pub classifier_pooling: ClassifierPooling,
}
#[derive(Debug, Clone)]
@ -310,12 +329,11 @@ pub struct ModernBert {
norm: LayerNorm,
layers: Vec<ModernBertLayer>,
final_norm: LayerNorm,
head: ModernBertHead,
local_attention_size: usize,
}
impl ModernBert {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let word_embeddings = embedding(
config.vocab_size,
config.hidden_size,
@ -359,19 +377,17 @@ impl ModernBert {
config.layer_norm_eps,
vb.pp("model.final_norm"),
)?;
let head = ModernBertHead::load(vb.pp("head"), config)?;
Ok(Self {
word_embeddings,
norm,
layers,
final_norm,
head,
local_attention_size: config.local_attention,
})
}
fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
let seq_len = xs.shape().dims()[1];
let global_attention_mask =
prepare_4d_attention_mask(mask, DType::F32, None)?.to_device(xs.device())?;
@ -381,7 +397,7 @@ impl ModernBert {
for layer in self.layers.iter() {
xs = layer.forward(&xs, &global_attention_mask, &local_attention_mask)?;
}
let xs = xs.apply(&self.final_norm)?.apply(&self.head)?;
let xs = xs.apply(&self.final_norm)?;
Ok(xs)
}
}
@ -391,17 +407,98 @@ impl ModernBert {
pub struct ModernBertForMaskedLM {
model: ModernBert,
decoder: ModernBertDecoder,
head: ModernBertHead,
}
impl ModernBertForMaskedLM {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let model = ModernBert::load(vb.clone(), config)?;
let decoder = ModernBertDecoder::load(vb.clone(), config)?;
Ok(Self { model, decoder })
let head = ModernBertHead::load(vb.pp("head"), config)?;
Ok(Self {
model,
decoder,
head,
})
}
pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
let xs = self.model.forward(xs, mask)?.apply(&self.decoder)?;
let xs = self
.model
.forward(xs, mask)?
.apply(&self.head)?
.apply(&self.decoder)?;
Ok(xs)
}
}
#[derive(Clone)]
pub struct ModernBertClassifier {
classifier: Linear,
}
impl ModernBertClassifier {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
// The decoder weights are tied with the embeddings layer weights
let classifier = linear(
config.hidden_size,
config
.classifier_config
.as_ref()
.map(|cc| cc.id2label.len())
.unwrap_or_default(),
vb.pp("classifier"),
)?;
Ok(Self { classifier })
}
}
impl Module for ModernBertClassifier {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.classifier)?;
softmax(&xs, D::Minus1)
}
}
#[derive(Clone)]
pub struct ModernBertForSequenceClassification {
model: ModernBert,
head: ModernBertHead,
classifier: ModernBertClassifier,
classifier_pooling: ClassifierPooling,
}
impl ModernBertForSequenceClassification {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let model = ModernBert::load(vb.clone(), config)?;
let classifier = ModernBertClassifier::load(vb.clone(), config)?;
let head = ModernBertHead::load(vb.pp("head"), config)?;
Ok(Self {
model,
head,
classifier,
classifier_pooling: config
.classifier_config
.as_ref()
.map(|cc| cc.classifier_pooling)
.unwrap_or_default(),
})
}
pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
let output = self.model.forward(xs, mask)?;
let last_hidden_state = match self.classifier_pooling {
ClassifierPooling::CLS => output.i((.., .., 0))?,
ClassifierPooling::MEAN => {
let unsqueezed_mask = &mask.unsqueeze(D::Minus1)?.to_dtype(DType::F32)?;
let sum_output = output.broadcast_mul(unsqueezed_mask)?.sum(1)?;
sum_output.broadcast_div(&mask.sum_keepdim(1)?.to_dtype(DType::F32)?)?
}
};
let xs = self
.head
.forward(&last_hidden_state)?
.apply(&self.classifier)?;
Ok(xs)
}
}

View File

@ -10,33 +10,133 @@ use crate::models::clip::div_l2_norm;
use candle::{IndexOp, Module, Result, Tensor, D};
use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder};
fn default_text_vocab_size() -> usize {
32000
}
fn default_text_hidden_size() -> usize {
768
}
fn default_text_intermediate_size() -> usize {
3072
}
fn default_text_num_hidden_layers() -> usize {
12
}
fn default_text_num_attention_heads() -> usize {
12
}
fn default_text_max_position_embeddings() -> usize {
64
}
fn default_text_layer_norm_eps() -> f64 {
1e-6
}
fn default_text_pad_token_id() -> u32 {
1
}
fn default_text_bos_token_id() -> u32 {
49406
}
fn default_text_eos_token_id() -> u32 {
49407
}
fn default_text_hidden_act() -> candle_nn::Activation {
candle_nn::Activation::GeluPytorchTanh
}
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27
#[derive(serde::Deserialize, Clone, Debug)]
pub struct TextConfig {
#[serde(default = "default_text_vocab_size")]
pub vocab_size: usize,
#[serde(default = "default_text_hidden_size")]
pub hidden_size: usize,
#[serde(default = "default_text_intermediate_size")]
pub intermediate_size: usize,
#[serde(default = "default_text_num_hidden_layers")]
pub num_hidden_layers: usize,
#[serde(default = "default_text_num_attention_heads")]
pub num_attention_heads: usize,
#[serde(default = "default_text_max_position_embeddings")]
pub max_position_embeddings: usize,
#[serde(default = "default_text_hidden_act")]
pub hidden_act: candle_nn::Activation,
#[serde(default = "default_text_layer_norm_eps")]
pub layer_norm_eps: f64,
#[serde(default = "default_text_pad_token_id")]
pub pad_token_id: u32,
#[serde(default = "default_text_bos_token_id")]
pub bos_token_id: u32,
#[serde(default = "default_text_eos_token_id")]
pub eos_token_id: u32,
}
fn default_vision_hidden_size() -> usize {
768
}
fn default_vision_intermediate_size() -> usize {
3072
}
fn default_vision_num_hidden_layers() -> usize {
12
}
fn default_vision_num_attention_heads() -> usize {
12
}
fn default_vision_num_channels() -> usize {
3
}
fn default_vision_image_size() -> usize {
224
}
fn default_vision_batch_size() -> usize {
16
}
fn default_vision_layer_norm_eps() -> f64 {
1e-6
}
fn default_vision_hidden_act() -> candle_nn::Activation {
candle_nn::Activation::GeluPytorchTanh
}
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L132
#[derive(serde::Deserialize, Clone, Debug)]
pub struct VisionConfig {
#[serde(default = "default_vision_hidden_size")]
pub hidden_size: usize,
#[serde(default = "default_vision_intermediate_size")]
pub intermediate_size: usize,
#[serde(default = "default_vision_num_hidden_layers")]
pub num_hidden_layers: usize,
#[serde(default = "default_vision_num_attention_heads")]
pub num_attention_heads: usize,
#[serde(default = "default_vision_num_channels")]
pub num_channels: usize,
#[serde(default = "default_vision_image_size")]
pub image_size: usize,
#[serde(default = "default_vision_batch_size")]
pub patch_size: usize,
#[serde(default = "default_vision_hidden_act")]
pub hidden_act: candle_nn::Activation,
#[serde(default = "default_vision_layer_norm_eps")]
pub layer_norm_eps: f64,
}

View File

@ -3,7 +3,7 @@ use anyhow::Error as E;
use candle::{safetensors::Load, DType, Device, IndexOp, Tensor, D};
use candle_nn::{ops::softmax, VarBuilder};
pub use candle_transformers::models::whisper::{self as m, Config};
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
use rand::{distr::Distribution, rngs::StdRng, SeedableRng};
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
@ -221,7 +221,7 @@ impl Decoder {
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
let distr = rand::distr::weighted::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;