mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
97 Commits
opt-attn-m
...
phi2-gguf
Author | SHA1 | Date | |
---|---|---|---|
3754b834f4 | |||
d79041d94d | |||
af11b2d461 | |||
2817643db9 | |||
4d14777673 | |||
f135b7963d | |||
af955f260c | |||
8ad822a983 | |||
e198bb0816 | |||
f7d5bf5b97 | |||
c119600d6e | |||
c449f65b12 | |||
db7dbf3071 | |||
4ecedb1598 | |||
53e5380bf6 | |||
50e49ecc5f | |||
4c88c3ce06 | |||
8b8fb630df | |||
fb805b8ca2 | |||
79e3bec789 | |||
e6d412b156 | |||
26cbbf8d84 | |||
2bf413caa3 | |||
3ad4770eb6 | |||
a0460cd2b1 | |||
b81ecf712d | |||
a4d5a414e3 | |||
798e0335cd | |||
718671a0d5 | |||
c5fe4a7f89 | |||
7f354473cf | |||
33c9b66554 | |||
9fd52b3b71 | |||
e662431acf | |||
ab892274d1 | |||
b869a659ec | |||
88f7793598 | |||
2ac302a5d1 | |||
ace282e5c2 | |||
c87381fc96 | |||
c5626b8271 | |||
e6a5b82ba6 | |||
5aebe53dd2 | |||
f76bb7794a | |||
30b145150f | |||
f48c07e242 | |||
8967c46563 | |||
1e46cf8b19 | |||
bd8db2a771 | |||
318d143224 | |||
2be1a35710 | |||
26226068a4 | |||
cd6b9e317c | |||
08c049def3 | |||
d17b2cdad9 | |||
fb918a23c8 | |||
b23436bf90 | |||
be9c200cbb | |||
ea0d8d3753 | |||
308ea070ed | |||
b20acd622c | |||
5522bbc57c | |||
888c09a3db | |||
318cb82f16 | |||
c7557b65dc | |||
cd29c7ccd4 | |||
f9954b73ba | |||
eead1dcead | |||
92f81d2fcb | |||
3144150b8d | |||
b190fd8592 | |||
efe4a0c84b | |||
665da30487 | |||
356a170ae9 | |||
7ecbc6d50b | |||
8ad12a0e81 | |||
eb1b27abcd | |||
708e422456 | |||
c5092f2c29 | |||
cdc8b57b5c | |||
b0340d72ec | |||
b3484e7a5e | |||
ada5d7c096 | |||
13ae5a34c7 | |||
ab86cd37c8 | |||
a9abde5f93 | |||
75b6d4b0da | |||
66f0a4eeea | |||
4523ecfb2a | |||
f5dfe883d7 | |||
196765e995 | |||
60676780a9 | |||
d3a8d291d5 | |||
cd254074f3 | |||
e7f8e72588 | |||
1b98f84a2b | |||
cf7d7fcf2f |
21
Cargo.toml
21
Cargo.toml
@ -9,6 +9,7 @@ members = [
|
||||
"candle-transformers",
|
||||
"candle-wasm-examples/*",
|
||||
"candle-wasm-tests",
|
||||
"tensor-tools",
|
||||
]
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
@ -19,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.4.2"
|
||||
version = "0.5.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -32,14 +33,14 @@ ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.4.2" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.4.2" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.2" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.4.2" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.2" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.4.2" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.4.2" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.4.2" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.5.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.5.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.5.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.5.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.5.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.5.0" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||
@ -55,7 +56,7 @@ log = "0.4"
|
||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
parquet = { version = "50.0.0" }
|
||||
parquet = { version = "51.0.0" }
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
rayon = "1.7.0"
|
||||
|
12
README.md
12
README.md
@ -63,8 +63,9 @@ We also provide a some command line based examples using state of the art models
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
|
||||
the SOLAR-10.7B variant.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google
|
||||
Deepmind.
|
||||
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
|
||||
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
|
||||
Griffin based models from Google that mix attention with a RNN like state.
|
||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets. Also supports
|
||||
@ -125,10 +126,14 @@ We also provide a some command line based examples using state of the art models
|
||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
generate captions for an image.
|
||||
- [CLIP](./candle-examples/examples/clip/): multi-model vision and language
|
||||
model.
|
||||
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
|
||||
dedicated submodels for hand-writing and printed recognition.
|
||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||
model, generates the translated text from the input text.
|
||||
- [Moondream](./candle-examples/examples/moondream/): tiny computer-vision model
|
||||
that can answer real-world questions about images.
|
||||
|
||||
Run them using commands like:
|
||||
```
|
||||
@ -172,6 +177,7 @@ And then head over to
|
||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||
serving local LLMs including an OpenAI compatible API server.
|
||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||
- [`candle-coursera-ml`](https://github.com/vishpat/candle-coursera-ml): Implementation of ML algorithms from Coursera's [Machine Learning Specialization](https://www.coursera.org/specializations/machine-learning-introduction) course.
|
||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
||||
@ -206,7 +212,7 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Replit-code-v1.5-3B.
|
||||
- Bert.
|
||||
- Yi-6B and Yi-34B.
|
||||
- Qwen1.5.
|
||||
- Qwen1.5, Qwen1.5 MoE.
|
||||
- RWKV v5 and v6.
|
||||
- Quantized LLMs.
|
||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||
|
@ -7,4 +7,5 @@ criterion_main!(
|
||||
benchmarks::random::benches,
|
||||
benchmarks::where_cond::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
benchmarks::qmatmul::benches,
|
||||
);
|
||||
|
@ -1,6 +1,7 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod conv_transpose2d;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod qmatmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
|
72
candle-core/benches/benchmarks/qmatmul.rs
Normal file
72
candle-core/benches/benchmarks/qmatmul.rs
Normal file
@ -0,0 +1,72 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{
|
||||
quantized::{self, GgmlDType, QMatMul},
|
||||
Device, Module, Tensor,
|
||||
};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(matmul: &QMatMul, x: &Tensor) {
|
||||
matmul.forward(&x).unwrap();
|
||||
}
|
||||
|
||||
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
|
||||
let b = 1;
|
||||
let m = 1;
|
||||
let n = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let lhs = (0..(m * k))
|
||||
.map(|v| v as f32 / (m * k) as f32)
|
||||
.collect::<Vec<_>>();
|
||||
let rhs = (0..(k * n))
|
||||
.map(|v| v as f32 / (n * k) as f32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let lhs = Tensor::from_slice(&lhs, (m, k), device).unwrap();
|
||||
let rhs = Tensor::from_slice(&rhs, (k, n), device).unwrap();
|
||||
|
||||
let qtensor = quantized::QTensor::quantize(&rhs.t().unwrap(), dtype).unwrap();
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor).unwrap();
|
||||
|
||||
let flops = b * m * n * k;
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
|
||||
group.sample_size(200);
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&matmul), black_box(&lhs));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
for dtype in vec![
|
||||
GgmlDType::F32,
|
||||
GgmlDType::F16,
|
||||
GgmlDType::Q4_0,
|
||||
GgmlDType::Q4_1,
|
||||
GgmlDType::Q5_0,
|
||||
GgmlDType::Q5_1,
|
||||
GgmlDType::Q8_0,
|
||||
GgmlDType::Q2K,
|
||||
GgmlDType::Q3K,
|
||||
GgmlDType::Q4K,
|
||||
GgmlDType::Q5K,
|
||||
GgmlDType::Q6K,
|
||||
] {
|
||||
run_bench(c, &device, dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -142,4 +142,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()>;
|
||||
|
||||
/// Synchronize should block until all the operations on the device are completed.
|
||||
fn synchronize(&self) -> Result<()>;
|
||||
}
|
||||
|
@ -112,7 +112,8 @@ impl Tensor {
|
||||
}
|
||||
Op::Unary(_node, UnaryOp::Ceil)
|
||||
| Op::Unary(_node, UnaryOp::Floor)
|
||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
| Op::Unary(_node, UnaryOp::Round)
|
||||
| Op::Unary(_node, UnaryOp::Sign) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D { arg: node, .. }
|
||||
| Op::UpsampleNearest2D { arg: node, .. }
|
||||
@ -488,7 +489,6 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad)?;
|
||||
}
|
||||
Op::Cmp(_args, _) => {}
|
||||
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
||||
let node = broadcast_back(arg, node, reduced_dims)?;
|
||||
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||
@ -578,20 +578,18 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
|
||||
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
|
||||
Op::Unary(_, UnaryOp::Floor)
|
||||
| Op::Unary(_, UnaryOp::Round)
|
||||
| Op::Reduce(_, ReduceOp::ArgMin, _)
|
||||
| Op::Reduce(_, ReduceOp::ArgMax, _)
|
||||
| Op::Unary(_, UnaryOp::Sign)
|
||||
| Op::Cmp(_, _) => {}
|
||||
Op::Reshape(arg) => {
|
||||
let arg_grad = grad.reshape(arg.dims())?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
||||
Op::Unary(_, UnaryOp::Floor) => {
|
||||
Err(Error::BackwardNotSupported { op: "floor" })?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Round) => {
|
||||
Err(Error::BackwardNotSupported { op: "round" })?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Gelu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let cube = arg.powf(3.)?;
|
||||
|
@ -4,6 +4,11 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
||||
use half::{bf16, f16};
|
||||
use rayon::prelude::*;
|
||||
|
||||
mod utils;
|
||||
pub use utils::{
|
||||
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
|
||||
};
|
||||
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
const USE_IM2COL_CONV1D_TR: bool = true;
|
||||
const USE_IM2COL_CONV2D: bool = true;
|
||||
@ -24,102 +29,6 @@ pub enum CpuStorage {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CpuDevice;
|
||||
|
||||
pub trait Map1 {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
||||
|
||||
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
||||
match vs {
|
||||
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
|
||||
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
|
||||
CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
|
||||
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
|
||||
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
|
||||
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
|
||||
CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map1Any {
|
||||
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
|
||||
&self,
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
wrap: W,
|
||||
) -> Result<CpuStorage>;
|
||||
|
||||
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
||||
match vs {
|
||||
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
|
||||
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
|
||||
CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
|
||||
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
|
||||
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
|
||||
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
|
||||
CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type C = CpuStorage;
|
||||
pub trait Map2 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
||||
|
||||
fn map(
|
||||
&self,
|
||||
v1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
v2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
) -> Result<CpuStorage> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
||||
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2U8 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||
|
||||
fn map(
|
||||
&self,
|
||||
v1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
v2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
) -> Result<CpuStorage> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Cmp(CmpOp);
|
||||
impl Map2U8 for Cmp {
|
||||
const OP: &'static str = "cmp";
|
||||
@ -366,275 +275,6 @@ impl<'a> Map1 for ReduceSum<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
mut f: F,
|
||||
) -> Vec<U> {
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
||||
[start_offset..start_offset + len]
|
||||
.iter()
|
||||
.map(|&v| f(v))
|
||||
.collect(),
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
||||
// Specialize the case where block_len is one to avoid the second loop.
|
||||
if block_len == 1 {
|
||||
for index in block_start_index {
|
||||
let v = unsafe { vs.get_unchecked(index) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
} else {
|
||||
for index in block_start_index {
|
||||
for offset in 0..block_len {
|
||||
let v = unsafe { vs.get_unchecked(index + offset) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
mut f: F,
|
||||
mut f_vec: FV,
|
||||
) -> Vec<U> {
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(len);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(len) };
|
||||
ys
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
let el_count = layout.shape().elem_count();
|
||||
// Specialize the case where block_len is one to avoid the second loop.
|
||||
if block_len == 1 {
|
||||
let mut result = Vec::with_capacity(el_count);
|
||||
for index in block_start_index {
|
||||
let v = unsafe { vs.get_unchecked(index) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
result
|
||||
} else {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||
let mut dst_index = 0;
|
||||
for src_index in block_start_index {
|
||||
let vs = &vs[src_index..src_index + block_len];
|
||||
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
||||
f_vec(vs, ys);
|
||||
dst_index += block_len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This function maps over two strided index sequences.
|
||||
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
mut f: F,
|
||||
) -> Vec<U> {
|
||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
||||
.iter()
|
||||
.zip(rhs[o_r1..o_r2].iter())
|
||||
.map(|(&l, &r)| f(l, r))
|
||||
.collect(),
|
||||
(Some((o_l1, o_l2)), None) => {
|
||||
// TODO: Maybe we want to avoid going through the layout twice.
|
||||
match rhs_l.offsets_b() {
|
||||
Some(ob) => {
|
||||
let mut i_in_block = 0;
|
||||
let mut i_right_broadcast = 0;
|
||||
lhs[o_l1..o_l2]
|
||||
.iter()
|
||||
.map(|&l| {
|
||||
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
||||
i_right_broadcast += 1;
|
||||
if i_right_broadcast >= ob.right_broadcast {
|
||||
i_in_block += 1;
|
||||
i_right_broadcast = 0;
|
||||
}
|
||||
if i_in_block >= ob.len {
|
||||
i_in_block = 0
|
||||
}
|
||||
f(l, *r)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
(None, Some((o_r1, o_r2))) => {
|
||||
// TODO: Maybe we want to avoid going through the layout twice.
|
||||
match lhs_l.offsets_b() {
|
||||
Some(ob) => {
|
||||
let mut i_in_block = 0;
|
||||
let mut i_right_broadcast = 0;
|
||||
rhs[o_r1..o_r2]
|
||||
.iter()
|
||||
.map(|&r| {
|
||||
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
||||
i_right_broadcast += 1;
|
||||
if i_right_broadcast >= ob.right_broadcast {
|
||||
i_in_block += 1;
|
||||
i_right_broadcast = 0;
|
||||
}
|
||||
if i_in_block >= ob.len {
|
||||
i_in_block = 0
|
||||
}
|
||||
f(*l, r)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
_ => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
// Similar to binary_map but with vectorized variants.
|
||||
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
mut f: F,
|
||||
mut f_vec: FV,
|
||||
) -> Vec<T> {
|
||||
let el_count = lhs_l.shape().elem_count();
|
||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
||||
Some(ob) if ob.right_broadcast == 1 => {
|
||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
||||
f_vec(
|
||||
&lhs[src_i..src_i + ob.len],
|
||||
rhs,
|
||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||
);
|
||||
dst_i += ob.len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
Some(ob) => {
|
||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||
let mut ys = lhs[o_l1..o_l2].to_vec();
|
||||
for idx_l in 0..ob.left_broadcast {
|
||||
let start = idx_l * ob.len * ob.right_broadcast;
|
||||
for (i, &r) in rhs.iter().enumerate() {
|
||||
let start = start + i * ob.right_broadcast;
|
||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||
*v = f(*v, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
ys
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
},
|
||||
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
||||
Some(ob) if ob.right_broadcast == 1 => {
|
||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
||||
f_vec(
|
||||
lhs,
|
||||
&rhs[src_i..src_i + ob.len],
|
||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||
);
|
||||
dst_i += ob.len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
Some(ob) => {
|
||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||
let mut ys = rhs[o_r1..o_r2].to_vec();
|
||||
for idx_l in 0..ob.left_broadcast {
|
||||
let start = idx_l * ob.len * ob.right_broadcast;
|
||||
for (i, &l) in lhs.iter().enumerate() {
|
||||
let start = start + i * ob.right_broadcast;
|
||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||
*v = f(l, *v)
|
||||
}
|
||||
}
|
||||
}
|
||||
ys
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
},
|
||||
_ => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
struct Affine(f64, f64);
|
||||
|
||||
impl Map1 for Affine {
|
||||
@ -1564,6 +1204,30 @@ impl MatMul {
|
||||
}))
|
||||
.bt()
|
||||
}
|
||||
|
||||
fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
let rank = lhs_stride.len();
|
||||
let (_b, m, n, k) = self.0;
|
||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[_, stride] if lhs_l.dims()[0] == 1 => stride,
|
||||
[stride, _] if lhs_l.dims()[1] == 1 => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||
};
|
||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[_, stride] if rhs_l.dims()[0] == 1 => stride,
|
||||
[stride, _] if rhs_l.dims()[1] == 1 => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||
};
|
||||
Ok((a_skip, b_skip))
|
||||
}
|
||||
}
|
||||
|
||||
impl Map2 for MatMul {
|
||||
@ -1597,18 +1261,7 @@ impl Map2 for MatMul {
|
||||
let rhs_cs = rhs_stride[rank - 1];
|
||||
let rhs_rs = rhs_stride[rank - 2];
|
||||
|
||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||
};
|
||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||
};
|
||||
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let dst_shape: Shape = (m, n).into();
|
||||
@ -1668,20 +1321,8 @@ impl Map2 for MatMul {
|
||||
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
let rank = lhs_stride.len();
|
||||
|
||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||
};
|
||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||
};
|
||||
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
@ -1689,7 +1330,7 @@ impl Map2 for MatMul {
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
|
||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||
(n as i32, b'N')
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
(k as i32, b'T')
|
||||
@ -1697,7 +1338,7 @@ impl Map2 for MatMul {
|
||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||
};
|
||||
// The b tensor has dims batching, m, k (lhs)
|
||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||
(k as i32, b'N')
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
(m as i32, b'T')
|
||||
@ -1771,20 +1412,8 @@ impl Map2 for MatMul {
|
||||
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
let rank = lhs_stride.len();
|
||||
|
||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||
};
|
||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||
};
|
||||
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
@ -1792,7 +1421,7 @@ impl Map2 for MatMul {
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
|
||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||
(n as i32, b'N')
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
(k as i32, b'T')
|
||||
@ -1800,7 +1429,7 @@ impl Map2 for MatMul {
|
||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||
};
|
||||
// The b tensor has dims batching, m, k (lhs)
|
||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||
(k as i32, b'N')
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
(m as i32, b'T')
|
||||
@ -2999,6 +2628,10 @@ impl BackendDevice for CpuDevice {
|
||||
};
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
350
candle-core/src/cpu_backend/utils.rs
Normal file
350
candle-core/src/cpu_backend/utils.rs
Normal file
@ -0,0 +1,350 @@
|
||||
/// Helper functions to write CPU kernels.
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{Error, Layout, Result, WithDType};
|
||||
|
||||
type C = super::CpuStorage;
|
||||
pub trait Map1 {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
||||
|
||||
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
||||
match vs {
|
||||
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
|
||||
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
|
||||
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
|
||||
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
|
||||
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
|
||||
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
|
||||
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map1Any {
|
||||
fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
|
||||
|
||||
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
||||
match vs {
|
||||
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
|
||||
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
|
||||
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
|
||||
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
|
||||
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
|
||||
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
|
||||
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
||||
|
||||
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
||||
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2U8 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||
|
||||
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
mut f: F,
|
||||
) -> Vec<U> {
|
||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
||||
.iter()
|
||||
.zip(rhs[o_r1..o_r2].iter())
|
||||
.map(|(&l, &r)| f(l, r))
|
||||
.collect(),
|
||||
(Some((o_l1, o_l2)), None) => {
|
||||
// TODO: Maybe we want to avoid going through the layout twice.
|
||||
match rhs_l.offsets_b() {
|
||||
Some(ob) => {
|
||||
let mut i_in_block = 0;
|
||||
let mut i_right_broadcast = 0;
|
||||
lhs[o_l1..o_l2]
|
||||
.iter()
|
||||
.map(|&l| {
|
||||
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
||||
i_right_broadcast += 1;
|
||||
if i_right_broadcast >= ob.right_broadcast {
|
||||
i_in_block += 1;
|
||||
i_right_broadcast = 0;
|
||||
}
|
||||
if i_in_block >= ob.len {
|
||||
i_in_block = 0
|
||||
}
|
||||
f(l, *r)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
(None, Some((o_r1, o_r2))) => {
|
||||
// TODO: Maybe we want to avoid going through the layout twice.
|
||||
match lhs_l.offsets_b() {
|
||||
Some(ob) => {
|
||||
let mut i_in_block = 0;
|
||||
let mut i_right_broadcast = 0;
|
||||
rhs[o_r1..o_r2]
|
||||
.iter()
|
||||
.map(|&r| {
|
||||
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
||||
i_right_broadcast += 1;
|
||||
if i_right_broadcast >= ob.right_broadcast {
|
||||
i_in_block += 1;
|
||||
i_right_broadcast = 0;
|
||||
}
|
||||
if i_in_block >= ob.len {
|
||||
i_in_block = 0
|
||||
}
|
||||
f(*l, r)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
_ => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
// Similar to binary_map but with vectorized variants.
|
||||
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
mut f: F,
|
||||
mut f_vec: FV,
|
||||
) -> Vec<T> {
|
||||
let el_count = lhs_l.shape().elem_count();
|
||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
||||
Some(ob) if ob.right_broadcast == 1 => {
|
||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
||||
f_vec(
|
||||
&lhs[src_i..src_i + ob.len],
|
||||
rhs,
|
||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||
);
|
||||
dst_i += ob.len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
Some(ob) => {
|
||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||
let mut ys = lhs[o_l1..o_l2].to_vec();
|
||||
for idx_l in 0..ob.left_broadcast {
|
||||
let start = idx_l * ob.len * ob.right_broadcast;
|
||||
for (i, &r) in rhs.iter().enumerate() {
|
||||
let start = start + i * ob.right_broadcast;
|
||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||
*v = f(*v, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
ys
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
},
|
||||
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
||||
Some(ob) if ob.right_broadcast == 1 => {
|
||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
||||
f_vec(
|
||||
lhs,
|
||||
&rhs[src_i..src_i + ob.len],
|
||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||
);
|
||||
dst_i += ob.len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
Some(ob) => {
|
||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||
let mut ys = rhs[o_r1..o_r2].to_vec();
|
||||
for idx_l in 0..ob.left_broadcast {
|
||||
let start = idx_l * ob.len * ob.right_broadcast;
|
||||
for (i, &l) in lhs.iter().enumerate() {
|
||||
let start = start + i * ob.right_broadcast;
|
||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||
*v = f(l, *v)
|
||||
}
|
||||
}
|
||||
}
|
||||
ys
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
},
|
||||
_ => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
mut f: F,
|
||||
) -> Vec<U> {
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
||||
[start_offset..start_offset + len]
|
||||
.iter()
|
||||
.map(|&v| f(v))
|
||||
.collect(),
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
||||
// Specialize the case where block_len is one to avoid the second loop.
|
||||
if block_len == 1 {
|
||||
for index in block_start_index {
|
||||
let v = unsafe { vs.get_unchecked(index) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
} else {
|
||||
for index in block_start_index {
|
||||
for offset in 0..block_len {
|
||||
let v = unsafe { vs.get_unchecked(index + offset) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
mut f: F,
|
||||
mut f_vec: FV,
|
||||
) -> Vec<U> {
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(len);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(len) };
|
||||
ys
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
let el_count = layout.shape().elem_count();
|
||||
// Specialize the case where block_len is one to avoid the second loop.
|
||||
if block_len == 1 {
|
||||
let mut result = Vec::with_capacity(el_count);
|
||||
for index in block_start_index {
|
||||
let v = unsafe { vs.get_unchecked(index) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
result
|
||||
} else {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||
let mut dst_index = 0;
|
||||
for src_index in block_start_index {
|
||||
let vs = &vs[src_index..src_index + block_len];
|
||||
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
||||
f_vec(vs, ys);
|
||||
dst_index += block_len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
415
candle-core/src/cuda_backend/device.rs
Normal file
415
candle-core/src/cuda_backend/device.rs
Normal file
@ -0,0 +1,415 @@
|
||||
use crate::backend::BackendDevice;
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
||||
use half::{bf16, f16};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
||||
|
||||
/// Unique identifier for cuda devices.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct DeviceId(usize);
|
||||
|
||||
impl DeviceId {
|
||||
fn new() -> Self {
|
||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||
use std::sync::atomic;
|
||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
struct CudaRng(cudarc::curand::CudaRng);
|
||||
unsafe impl Send for CudaRng {}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CudaDevice {
|
||||
id: DeviceId,
|
||||
device: Arc<cudarc::driver::CudaDevice>,
|
||||
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
||||
curand: Arc<Mutex<CudaRng>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CudaDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "CudaDevice({:?})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CudaDevice {
|
||||
type Target = Arc<cudarc::driver::CudaDevice>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
||||
self.device.clone()
|
||||
}
|
||||
|
||||
pub fn id(&self) -> DeviceId {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
||||
let params = (&data, v as u8, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
||||
let params = (&data, v as u32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
||||
let params = (&data, v as i64, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
||||
let params = (&data, bf16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
||||
let params = (&data, f16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
||||
let params = (&data, v as f32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
||||
let params = (&data, v, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||
if !self.has_func(module_name, module_name) {
|
||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
||||
// done once per kernel name.
|
||||
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
||||
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
||||
.map_err(|cuda| CudaError::Load {
|
||||
cuda,
|
||||
module_name: module_name.to_string(),
|
||||
})
|
||||
.w()?;
|
||||
}
|
||||
self.get_func(module_name, module_name)
|
||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
||||
// able to only build the error value if needed.
|
||||
.ok_or(CudaError::MissingKernel {
|
||||
module_name: module_name.to_string(),
|
||||
})
|
||||
.w()
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for CudaDevice {
|
||||
type Storage = CudaStorage;
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
blas: Arc::new(blas),
|
||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||
// state will be identical and the same random numbers will be generated.
|
||||
let mut curand = self.curand.lock().unwrap();
|
||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Cuda {
|
||||
gpu_id: self.device.ordinal(),
|
||||
}
|
||||
}
|
||||
|
||||
fn same_device(&self, rhs: &Self) -> bool {
|
||||
self.id == rhs.id
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let data = self.alloc_zeros::<u8>(elem_count).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = self.alloc_zeros::<f16>(elem_count).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = self.alloc_zeros::<f32>(elem_count).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let data = self.alloc_zeros::<f64>(elem_count).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
let slice = if lo == 0. && up == 1.0 {
|
||||
slice
|
||||
} else {
|
||||
use super::utils::Map1;
|
||||
let layout = Layout::contiguous(shape);
|
||||
super::Affine(up - lo, lo).map(&slice, self, &layout)?
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
// curand can only generate an odd number of values.
|
||||
// https://github.com/huggingface/candle/issues/734
|
||||
let elem_count_round = if elem_count % 2 == 1 {
|
||||
elem_count + 1
|
||||
} else {
|
||||
elem_count
|
||||
};
|
||||
let slice = match dtype {
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_normal",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
||||
curand
|
||||
.0
|
||||
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||
.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
||||
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
self.const_impl(1., shape, dtype)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let data = self.alloc::<u8>(elem_count).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
let data = self.alloc::<u32>(elem_count).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc::<i64>(elem_count).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc::<bf16>(elem_count).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = self.alloc::<f16>(elem_count).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = self.alloc::<f32>(elem_count).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let data = self.alloc::<f64>(elem_count).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
self.device.synchronize().map_err(crate::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
62
candle-core/src/cuda_backend/error.rs
Normal file
62
candle-core/src/cuda_backend/error.rs
Normal file
@ -0,0 +1,62 @@
|
||||
use crate::{DType, Layout};
|
||||
|
||||
/// cudarc related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum CudaError {
|
||||
#[error(transparent)]
|
||||
Cuda(#[from] cudarc::driver::DriverError),
|
||||
|
||||
#[error(transparent)]
|
||||
Compiler(#[from] cudarc::nvrtc::CompileError),
|
||||
|
||||
#[error(transparent)]
|
||||
Cublas(#[from] cudarc::cublas::result::CublasError),
|
||||
|
||||
#[error(transparent)]
|
||||
Curand(#[from] cudarc::curand::result::CurandError),
|
||||
|
||||
#[error("missing kernel '{module_name}'")]
|
||||
MissingKernel { module_name: String },
|
||||
|
||||
#[error("unsupported dtype {dtype:?} for {op}")]
|
||||
UnsupportedDtype { dtype: DType, op: &'static str },
|
||||
|
||||
#[error("internal error '{0}'")]
|
||||
InternalError(&'static str),
|
||||
|
||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||
MatMulNonContiguous {
|
||||
lhs_stride: Layout,
|
||||
rhs_stride: Layout,
|
||||
mnk: (usize, usize, usize),
|
||||
},
|
||||
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedDType {
|
||||
msg: &'static str,
|
||||
expected: DType,
|
||||
got: DType,
|
||||
},
|
||||
|
||||
#[error("{cuda} when loading {module_name}")]
|
||||
Load {
|
||||
cuda: cudarc::driver::DriverError,
|
||||
module_name: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<CudaError> for crate::Error {
|
||||
fn from(val: CudaError) -> Self {
|
||||
crate::Error::Cuda(Box::new(val)).bt()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WrapErr<O> {
|
||||
fn w(self) -> std::result::Result<O, crate::Error>;
|
||||
}
|
||||
|
||||
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
||||
fn w(self) -> std::result::Result<O, crate::Error> {
|
||||
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt())
|
||||
}
|
||||
}
|
@ -5,11 +5,18 @@ pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{
|
||||
CudaFunction, CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig,
|
||||
ValidAsZeroBits,
|
||||
CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
};
|
||||
use half::{bf16, f16};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
#[cfg(feature = "cudnn")]
|
||||
pub mod cudnn;
|
||||
mod device;
|
||||
mod error;
|
||||
mod utils;
|
||||
pub use device::{CudaDevice, DeviceId};
|
||||
pub use error::{CudaError, WrapErr};
|
||||
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
|
||||
|
||||
enum SlicePtrOrNull<T> {
|
||||
Ptr(CudaSlice<T>),
|
||||
@ -36,467 +43,6 @@ impl SlicePtrOrNull<usize> {
|
||||
}
|
||||
}
|
||||
|
||||
/// cudarc related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum CudaError {
|
||||
#[error(transparent)]
|
||||
Cuda(#[from] cudarc::driver::DriverError),
|
||||
|
||||
#[error(transparent)]
|
||||
Compiler(#[from] cudarc::nvrtc::CompileError),
|
||||
|
||||
#[error(transparent)]
|
||||
Cublas(#[from] cudarc::cublas::result::CublasError),
|
||||
|
||||
#[error(transparent)]
|
||||
Curand(#[from] cudarc::curand::result::CurandError),
|
||||
|
||||
#[error("missing kernel '{module_name}'")]
|
||||
MissingKernel { module_name: String },
|
||||
|
||||
#[error("unsupported dtype {dtype:?} for {op}")]
|
||||
UnsupportedDtype { dtype: DType, op: &'static str },
|
||||
|
||||
#[error("internal error '{0}'")]
|
||||
InternalError(&'static str),
|
||||
|
||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||
MatMulNonContiguous {
|
||||
lhs_stride: Vec<usize>,
|
||||
rhs_stride: Vec<usize>,
|
||||
mnk: (usize, usize, usize),
|
||||
},
|
||||
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedDType {
|
||||
msg: &'static str,
|
||||
expected: DType,
|
||||
got: DType,
|
||||
},
|
||||
|
||||
#[error("{cuda} when loading {module_name}")]
|
||||
Load {
|
||||
cuda: cudarc::driver::DriverError,
|
||||
module_name: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<CudaError> for crate::Error {
|
||||
fn from(val: CudaError) -> Self {
|
||||
crate::Error::Cuda(Box::new(val)).bt()
|
||||
}
|
||||
}
|
||||
|
||||
/// Unique identifier for cuda devices.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct DeviceId(usize);
|
||||
|
||||
impl DeviceId {
|
||||
fn new() -> Self {
|
||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||
use std::sync::atomic;
|
||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
struct CudaRng(cudarc::curand::CudaRng);
|
||||
unsafe impl Send for CudaRng {}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CudaDevice {
|
||||
id: DeviceId,
|
||||
device: Arc<cudarc::driver::CudaDevice>,
|
||||
blas: Arc<cudarc::cublas::CudaBlas>,
|
||||
curand: Arc<Mutex<CudaRng>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CudaDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "CudaDevice({:?})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CudaDevice {
|
||||
type Target = Arc<cudarc::driver::CudaDevice>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WrapErr<O> {
|
||||
fn w(self) -> std::result::Result<O, crate::Error>;
|
||||
}
|
||||
|
||||
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
||||
fn w(self) -> std::result::Result<O, crate::Error> {
|
||||
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())))
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
||||
self.device.clone()
|
||||
}
|
||||
|
||||
pub fn id(&self) -> DeviceId {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
||||
let params = (&data, v as u8, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
||||
let params = (&data, v as u32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
||||
let params = (&data, v as i64, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
||||
let params = (&data, bf16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
||||
let params = (&data, f16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
||||
let params = (&data, v as f32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
||||
let params = (&data, v, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||
if !self.has_func(module_name, module_name) {
|
||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
||||
// done once per kernel name.
|
||||
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
||||
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
||||
.map_err(|cuda| CudaError::Load {
|
||||
cuda,
|
||||
module_name: module_name.to_string(),
|
||||
})
|
||||
.w()?;
|
||||
}
|
||||
self.get_func(module_name, module_name)
|
||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
||||
// able to only build the error value if needed.
|
||||
.ok_or(CudaError::MissingKernel {
|
||||
module_name: module_name.to_string(),
|
||||
})
|
||||
.w()
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for CudaDevice {
|
||||
type Storage = CudaStorage;
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
blas: Arc::new(blas),
|
||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||
// state will be identical and the same random numbers will be generated.
|
||||
let mut curand = self.curand.lock().unwrap();
|
||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Cuda {
|
||||
gpu_id: self.device.ordinal(),
|
||||
}
|
||||
}
|
||||
|
||||
fn same_device(&self, rhs: &Self) -> bool {
|
||||
self.id == rhs.id
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let data = self.alloc_zeros::<u8>(elem_count).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = self.alloc_zeros::<f16>(elem_count).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = self.alloc_zeros::<f32>(elem_count).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let data = self.alloc_zeros::<f64>(elem_count).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
let slice = if lo == 0. && up == 1.0 {
|
||||
slice
|
||||
} else {
|
||||
let layout = Layout::contiguous(shape);
|
||||
Affine(up - lo, lo).map(&slice, self, &layout)?
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
// curand can only generate an odd number of values.
|
||||
// https://github.com/huggingface/candle/issues/734
|
||||
let elem_count_round = if elem_count % 2 == 1 {
|
||||
elem_count + 1
|
||||
} else {
|
||||
elem_count
|
||||
};
|
||||
let slice = match dtype {
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_normal",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
||||
curand
|
||||
.0
|
||||
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||
.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
||||
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
self.const_impl(1., shape, dtype)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let data = self.alloc::<u8>(elem_count).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
let data = self.alloc::<u32>(elem_count).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc::<i64>(elem_count).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc::<bf16>(elem_count).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = self.alloc::<f16>(elem_count).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = self.alloc::<f32>(elem_count).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let data = self.alloc::<f64>(elem_count).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum CudaStorageSlice {
|
||||
U8(CudaSlice<u8>),
|
||||
@ -507,133 +53,6 @@ pub enum CudaStorageSlice {
|
||||
F32(CudaSlice<f32>),
|
||||
F64(CudaSlice<f64>),
|
||||
}
|
||||
type S = CudaStorageSlice;
|
||||
|
||||
pub trait Map1 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>>;
|
||||
|
||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||
let out = match s {
|
||||
S::U8(s) => S::U8(self.f(s, d, l)?),
|
||||
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||
S::I64(s) => S::I64(self.f(s, d, l)?),
|
||||
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||
S::F32(s) => S::F32(self.f(s, d, l)?),
|
||||
S::F64(s) => S::F64(self.f(s, d, l)?),
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>>;
|
||||
|
||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()>;
|
||||
|
||||
fn map(
|
||||
&self,
|
||||
dst: &mut S,
|
||||
dst_s: &Shape,
|
||||
src: &S,
|
||||
src_l: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map1Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
wrap: W,
|
||||
) -> Result<S>;
|
||||
|
||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||
let out = match s {
|
||||
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||
S::I64(s) => self.f(s, d, l, S::I64)?,
|
||||
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||
S::F64(s) => self.f(s, d, l, S::F64)?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<S>;
|
||||
|
||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct Clone;
|
||||
impl Map1 for Clone {
|
||||
@ -1651,26 +1070,30 @@ fn gemm_config<T>(
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
// The a tensor has dims batching, k, n (rhs)
|
||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||
// there is a single element.
|
||||
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||
(n as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
|
||||
(k as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||
} else {
|
||||
Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
lhs_stride: lhs_l.clone(),
|
||||
rhs_stride: rhs_l.clone(),
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
// The b tensor has dims batching, m, k (lhs)
|
||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||
// there is a single element.
|
||||
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||
(k as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
|
||||
(m as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||
} else {
|
||||
Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
lhs_stride: lhs_l.clone(),
|
||||
rhs_stride: rhs_l.clone(),
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
@ -1691,21 +1114,25 @@ fn gemm_config<T>(
|
||||
|
||||
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[_, stride] if lhs_l.dims()[0] == 1 => stride,
|
||||
[stride, _] if lhs_l.dims()[1] == 1 => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
lhs_stride: lhs_l.clone(),
|
||||
rhs_stride: rhs_l.clone(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
};
|
||||
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[_, stride] if rhs_l.dims()[0] == 1 => stride,
|
||||
[stride, _] if rhs_l.dims()[1] == 1 => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
lhs_stride: lhs_l.clone(),
|
||||
rhs_stride: rhs_l.clone(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
};
|
||||
@ -2274,6 +1701,11 @@ impl BackendStorage for CudaStorage {
|
||||
let dev = &self.device;
|
||||
let d1 = d1 as u32;
|
||||
let d2 = d2 as u32;
|
||||
// Nothing to copy so we exit early to avoid launching a kernel and some potential invalid
|
||||
// argument with a null pointer.
|
||||
if d1 == 0 || d2 == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let dst_s = dst_s as u32;
|
||||
let src_s = src_s as u32;
|
||||
let (src, dst, kname) = match (&self.slice, &mut dst.slice) {
|
134
candle-core/src/cuda_backend/utils.rs
Normal file
134
candle-core/src/cuda_backend/utils.rs
Normal file
@ -0,0 +1,134 @@
|
||||
/// Helper functions to plug cuda kernels in candle.
|
||||
use crate::{Layout, Result, Shape, WithDType};
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
||||
|
||||
use super::{CudaDevice, CudaError, WrapErr};
|
||||
|
||||
pub type S = super::CudaStorageSlice;
|
||||
|
||||
pub trait Map1 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>>;
|
||||
|
||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||
let out = match s {
|
||||
S::U8(s) => S::U8(self.f(s, d, l)?),
|
||||
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||
S::I64(s) => S::I64(self.f(s, d, l)?),
|
||||
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||
S::F32(s) => S::F32(self.f(s, d, l)?),
|
||||
S::F64(s) => S::F64(self.f(s, d, l)?),
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>>;
|
||||
|
||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()>;
|
||||
|
||||
fn map(
|
||||
&self,
|
||||
dst: &mut S,
|
||||
dst_s: &Shape,
|
||||
src: &S,
|
||||
src_l: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map1Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
wrap: W,
|
||||
) -> Result<S>;
|
||||
|
||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||
let out = match s {
|
||||
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||
S::I64(s) => self.f(s, d, l, S::I64)?,
|
||||
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||
S::F64(s) => self.f(s, d, l, S::F64)?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<S>;
|
||||
|
||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
@ -337,4 +337,12 @@ impl Device {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn synchronize(&self) -> Result<()> {
|
||||
match self {
|
||||
Self::Cpu => Ok(()),
|
||||
Self::Cuda(d) => d.synchronize(),
|
||||
Self::Metal(d) => d.synchronize(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -229,4 +229,8 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -241,4 +241,8 @@ impl crate::backend::BackendDevice for MetalDevice {
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -14,7 +14,7 @@
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - Simple syntax (looks and like PyTorch)
|
||||
//! - Simple syntax (looks and feels like PyTorch)
|
||||
//! - CPU and Cuda backends (and M1 support)
|
||||
//! - Enable serverless (CPU) small and fast deployments
|
||||
//! - Model training
|
||||
@ -37,14 +37,12 @@
|
||||
mod accelerate;
|
||||
pub mod backend;
|
||||
pub mod backprop;
|
||||
mod conv;
|
||||
pub mod conv;
|
||||
mod convert;
|
||||
pub mod cpu;
|
||||
pub mod cpu_backend;
|
||||
#[cfg(feature = "cuda")]
|
||||
pub mod cuda_backend;
|
||||
#[cfg(feature = "cudnn")]
|
||||
pub mod cudnn;
|
||||
mod custom_op;
|
||||
mod device;
|
||||
pub mod display;
|
||||
@ -59,7 +57,7 @@ pub mod metal_backend;
|
||||
#[cfg(feature = "mkl")]
|
||||
mod mkl;
|
||||
pub mod npy;
|
||||
mod op;
|
||||
pub mod op;
|
||||
pub mod pickle;
|
||||
pub mod quantized;
|
||||
pub mod safetensors;
|
||||
@ -73,10 +71,13 @@ pub mod test_utils;
|
||||
pub mod utils;
|
||||
mod variable;
|
||||
|
||||
#[cfg(feature = "cudnn")]
|
||||
pub use cuda_backend::cudnn;
|
||||
|
||||
pub use cpu_backend::CpuStorage;
|
||||
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||
pub use device::{Device, DeviceLocation, NdArray};
|
||||
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
||||
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use indexer::IndexOp;
|
||||
pub use layout::Layout;
|
||||
|
287
candle-core/src/metal_backend/device.rs
Normal file
287
candle-core/src/metal_backend/device.rs
Normal file
@ -0,0 +1,287 @@
|
||||
use crate::{DType, Result};
|
||||
use candle_metal_kernels::Kernels;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
|
||||
|
||||
use super::MetalError;
|
||||
|
||||
/// Unique identifier for cuda devices.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct DeviceId(usize);
|
||||
|
||||
impl DeviceId {
|
||||
pub(crate) fn new() -> Self {
|
||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||
use std::sync::atomic;
|
||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
||||
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalDevice {
|
||||
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
|
||||
/// the device itself.
|
||||
pub(crate) id: DeviceId,
|
||||
|
||||
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
||||
pub(crate) device: metal::Device,
|
||||
|
||||
/// Single command queue for the entire device.
|
||||
pub(crate) command_queue: CommandQueue,
|
||||
/// One command buffer at a time.
|
||||
/// The scheduler works by allowing multiple
|
||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
||||
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
||||
/// to start to work).
|
||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||
/// command buffer2 starts (or there are metal bugs there)
|
||||
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
|
||||
/// Keeps track of the current amount of compute command encoders on the current
|
||||
/// command buffer
|
||||
/// Arc, RwLock because of the interior mutability.
|
||||
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
|
||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||
pub(crate) compute_per_buffer: usize,
|
||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||
/// Heavily used by [`candle_metal_kernels`]
|
||||
pub(crate) kernels: Arc<Kernels>,
|
||||
/// Simple allocator struct.
|
||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
||||
/// (could be linked to FFI communication overhead).
|
||||
///
|
||||
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
|
||||
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
|
||||
/// to be reused. However, in order for this to work, we need to guarantee the order of
|
||||
/// operation, so that this buffer is not being used by another kernel at the same time.
|
||||
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
||||
///
|
||||
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
||||
/// (strong_count = 1).
|
||||
pub(crate) buffers: AllocatedBuffers,
|
||||
/// Seed for random number generation.
|
||||
pub(crate) seed: Arc<Mutex<Buffer>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "MetalDevice({:?})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for MetalDevice {
|
||||
type Target = metal::DeviceRef;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalDevice {
|
||||
pub fn id(&self) -> DeviceId {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub fn metal_device(&self) -> &metal::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn command_queue(&self) -> &CommandQueue {
|
||||
&self.command_queue
|
||||
}
|
||||
|
||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
let mut command_buffer = command_buffer_lock.to_owned();
|
||||
let mut index = self
|
||||
.command_buffer_index
|
||||
.try_write()
|
||||
.map_err(MetalError::from)?;
|
||||
if *index > self.compute_per_buffer {
|
||||
command_buffer.commit();
|
||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
*command_buffer_lock = command_buffer.clone();
|
||||
*index = 0;
|
||||
|
||||
self.drop_unused_buffers()?;
|
||||
}
|
||||
*index += 1;
|
||||
Ok(command_buffer)
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) -> Result<()> {
|
||||
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled
|
||||
| metal::MTLCommandBufferStatus::Completed => {
|
||||
panic!("Already committed");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn kernels(&self) -> &Kernels {
|
||||
&self.kernels
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &metal::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
/// Creates a new buffer (not necessarily zeroed).
|
||||
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
/// This means the buffer data cannot be read on the CPU directly.
|
||||
///
|
||||
/// [`name`] is only used to keep track of the resource origin in case of bugs
|
||||
pub fn new_buffer(
|
||||
&self,
|
||||
element_count: usize,
|
||||
dtype: DType,
|
||||
name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
||||
}
|
||||
|
||||
/// Creates a new buffer (not necessarily zeroed).
|
||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
/// This means the buffer can be read on the CPU but will require manual
|
||||
/// synchronization when the CPU memory is modified
|
||||
/// Used as a bridge to gather data back from the GPU
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
|
||||
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||
}
|
||||
|
||||
/// Creates a new buffer from data.
|
||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
///
|
||||
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
|
||||
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||
let new_buffer = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const c_void,
|
||||
size,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
let subbuffers = buffers
|
||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||
.or_insert(vec![]);
|
||||
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
Ok(new_buffer)
|
||||
}
|
||||
|
||||
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
||||
let buffer = self.allocate_buffer(
|
||||
size_in_bytes as NSUInteger,
|
||||
MTLResourceOptions::StorageModePrivate,
|
||||
"allocate_zeros",
|
||||
)?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("zeros");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.fill_buffer(
|
||||
&buffer,
|
||||
metal::NSRange {
|
||||
location: 0,
|
||||
length: buffer.length(),
|
||||
},
|
||||
0,
|
||||
);
|
||||
blit.end_encoding();
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
fn find_available_buffer(
|
||||
&self,
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
buffers: &RwLockWriteGuard<BufferMap>,
|
||||
) -> Option<Arc<Buffer>> {
|
||||
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
||||
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
||||
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
||||
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
||||
for sub in subbuffers {
|
||||
if Arc::strong_count(sub) == 1 {
|
||||
best_buffer = Some(sub);
|
||||
best_buffer_size = *buffer_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
best_buffer.cloned()
|
||||
}
|
||||
|
||||
fn drop_unused_buffers(&self) -> Result<()> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
for subbuffers in buffers.values_mut() {
|
||||
let newbuffers = subbuffers
|
||||
.iter()
|
||||
.filter(|s| Arc::strong_count(*s) > 1)
|
||||
.map(Arc::clone)
|
||||
.collect();
|
||||
*subbuffers = newbuffers;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// The critical allocator algorithm
|
||||
fn allocate_buffer(
|
||||
&self,
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
_name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
||||
// Cloning also ensures we increment the strong count
|
||||
return Ok(b.clone());
|
||||
}
|
||||
|
||||
let size = buf_size(size);
|
||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||
|
||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
|
||||
Ok(new_buffer)
|
||||
}
|
||||
|
||||
/// Create a metal GPU capture trace on [`path`].
|
||||
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
||||
let capture = metal::CaptureManager::shared();
|
||||
let descriptor = metal::CaptureDescriptor::new();
|
||||
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||
descriptor.set_capture_device(self);
|
||||
descriptor.set_output_url(path);
|
||||
|
||||
capture
|
||||
.start_capture(&descriptor)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||
size.saturating_sub(1).next_power_of_two() as NSUInteger
|
||||
}
|
@ -2,14 +2,21 @@ use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels::CallConvTranspose2dCfg;
|
||||
use candle_metal_kernels::Kernels;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
|
||||
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard, TryLockError};
|
||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
|
||||
mod device;
|
||||
pub use device::{DeviceId, MetalDevice};
|
||||
|
||||
fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
|
||||
BufferOffset {
|
||||
buffer,
|
||||
offset_in_bytes: l.start_offset() * dtype.size_in_bytes(),
|
||||
}
|
||||
}
|
||||
/// Simple way to catch lock error without
|
||||
/// depending on T
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -36,13 +43,6 @@ pub enum MetalError {
|
||||
Message(String),
|
||||
#[error(transparent)]
|
||||
KernelError(#[from] candle_metal_kernels::MetalKernelError),
|
||||
|
||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||
MatMulNonContiguous {
|
||||
lhs_stride: Vec<usize>,
|
||||
rhs_stride: Vec<usize>,
|
||||
mnk: (usize, usize, usize),
|
||||
},
|
||||
#[error("{0:?}")]
|
||||
LockError(LockError),
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
@ -59,263 +59,6 @@ impl From<String> for MetalError {
|
||||
}
|
||||
}
|
||||
|
||||
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
||||
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalDevice {
|
||||
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
||||
device: metal::Device,
|
||||
|
||||
/// Single command queue for the entire device.
|
||||
command_queue: CommandQueue,
|
||||
/// One command buffer at a time.
|
||||
/// The scheduler works by allowing multiple
|
||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
||||
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
||||
/// to start to work).
|
||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||
/// command buffer2 starts (or there are metal bugs there)
|
||||
command_buffer: Arc<RwLock<CommandBuffer>>,
|
||||
/// Keeps track of the current amount of compute command encoders on the current
|
||||
/// command buffer
|
||||
/// Arc, RwLock because of the interior mutability.
|
||||
command_buffer_index: Arc<RwLock<usize>>,
|
||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||
compute_per_buffer: usize,
|
||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||
/// Heavily used by [`candle_metal_kernels`]
|
||||
kernels: Arc<Kernels>,
|
||||
/// Simple allocator struct.
|
||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
||||
/// (could be linked to FFI communication overhead).
|
||||
///
|
||||
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
|
||||
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
|
||||
/// to be reused. However, in order for this to work, we need to guarantee the order of
|
||||
/// operation, so that this buffer is not being used by another kernel at the same time.
|
||||
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
||||
///
|
||||
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
||||
/// (strong_count = 1).
|
||||
buffers: AllocatedBuffers,
|
||||
/// Seed for random number generation.
|
||||
seed: Arc<Mutex<Buffer>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "MetalDevice({:?})", self.device.registry_id())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for MetalDevice {
|
||||
type Target = metal::DeviceRef;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalDevice {
|
||||
pub fn id(&self) -> NSUInteger {
|
||||
self.registry_id()
|
||||
}
|
||||
|
||||
pub fn metal_device(&self) -> &metal::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn command_queue(&self) -> &CommandQueue {
|
||||
&self.command_queue
|
||||
}
|
||||
|
||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
let mut command_buffer = command_buffer_lock.to_owned();
|
||||
let mut index = self
|
||||
.command_buffer_index
|
||||
.try_write()
|
||||
.map_err(MetalError::from)?;
|
||||
if *index > self.compute_per_buffer {
|
||||
command_buffer.commit();
|
||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
*command_buffer_lock = command_buffer.clone();
|
||||
*index = 0;
|
||||
|
||||
self.drop_unused_buffers()?;
|
||||
}
|
||||
*index += 1;
|
||||
Ok(command_buffer)
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) -> Result<()> {
|
||||
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled
|
||||
| metal::MTLCommandBufferStatus::Completed => {
|
||||
panic!("Already committed");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn kernels(&self) -> &Kernels {
|
||||
&self.kernels
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &metal::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
/// Creates a new buffer (not necessarily zeroed).
|
||||
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
/// This means the buffer data cannot be read on the CPU directly.
|
||||
///
|
||||
/// [`name`] is only used to keep track of the resource origin in case of bugs
|
||||
pub fn new_buffer(
|
||||
&self,
|
||||
element_count: usize,
|
||||
dtype: DType,
|
||||
name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
||||
}
|
||||
|
||||
/// Creates a new buffer (not necessarily zeroed).
|
||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
/// This means the buffer can be read on the CPU but will require manual
|
||||
/// synchronization when the CPU memory is modified
|
||||
/// Used as a bridge to gather data back from the GPU
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
|
||||
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||
}
|
||||
|
||||
/// Creates a new buffer from data.
|
||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
///
|
||||
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
|
||||
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||
let new_buffer = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const c_void,
|
||||
size,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
let subbuffers = buffers
|
||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||
.or_insert(vec![]);
|
||||
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
Ok(new_buffer)
|
||||
}
|
||||
|
||||
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
||||
let buffer = self.allocate_buffer(
|
||||
size_in_bytes as NSUInteger,
|
||||
MTLResourceOptions::StorageModePrivate,
|
||||
"allocate_zeros",
|
||||
)?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("zeros");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.fill_buffer(
|
||||
&buffer,
|
||||
metal::NSRange {
|
||||
location: 0,
|
||||
length: buffer.length(),
|
||||
},
|
||||
0,
|
||||
);
|
||||
blit.end_encoding();
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
fn find_available_buffer(
|
||||
&self,
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
buffers: &RwLockWriteGuard<BufferMap>,
|
||||
) -> Option<Arc<Buffer>> {
|
||||
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
||||
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
||||
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
||||
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
||||
for sub in subbuffers {
|
||||
if Arc::strong_count(sub) == 1 {
|
||||
best_buffer = Some(sub);
|
||||
best_buffer_size = *buffer_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
best_buffer.cloned()
|
||||
}
|
||||
|
||||
fn drop_unused_buffers(&self) -> Result<()> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
for subbuffers in buffers.values_mut() {
|
||||
let newbuffers = subbuffers
|
||||
.iter()
|
||||
.filter(|s| Arc::strong_count(*s) > 1)
|
||||
.map(Arc::clone)
|
||||
.collect();
|
||||
*subbuffers = newbuffers;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// The critical allocator algorithm
|
||||
fn allocate_buffer(
|
||||
&self,
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
_name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
||||
// Cloning also ensures we increment the strong count
|
||||
return Ok(b.clone());
|
||||
}
|
||||
|
||||
let size = buf_size(size);
|
||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||
|
||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
|
||||
Ok(new_buffer)
|
||||
}
|
||||
|
||||
/// Create a metal GPU capture trace on [`path`].
|
||||
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
||||
let capture = metal::CaptureManager::shared();
|
||||
let descriptor = metal::CaptureDescriptor::new();
|
||||
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||
descriptor.set_capture_device(self);
|
||||
descriptor.set_output_url(path);
|
||||
|
||||
capture
|
||||
.start_capture(&descriptor)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalStorage {
|
||||
/// The actual buffer containing the data.
|
||||
@ -364,7 +107,8 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype, "affine")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let src = buffer_o(&self.buffer, layout, dtype);
|
||||
if layout.is_contiguous() {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_f32",
|
||||
DType::F16 => "affine_f16",
|
||||
@ -377,7 +121,7 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
el,
|
||||
&self.buffer,
|
||||
src,
|
||||
&buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
@ -396,9 +140,8 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
@ -417,7 +160,8 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype, "powf")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let src = buffer_o(&self.buffer, layout, dtype);
|
||||
if layout.is_contiguous() {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_f32",
|
||||
DType::F16 => "powf_f16",
|
||||
@ -430,7 +174,7 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
el,
|
||||
&self.buffer,
|
||||
src,
|
||||
&buffer,
|
||||
pow as f32,
|
||||
)
|
||||
@ -448,9 +192,8 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
pow as f32,
|
||||
)
|
||||
@ -468,7 +211,8 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype, "elu")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
if layout.is_contiguous() {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_f32",
|
||||
DType::F16 => "elu_f16",
|
||||
@ -481,7 +225,7 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
el,
|
||||
&self.buffer,
|
||||
src,
|
||||
&buffer,
|
||||
alpha as f32,
|
||||
)
|
||||
@ -499,9 +243,8 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
alpha as f32,
|
||||
)
|
||||
@ -571,6 +314,7 @@ impl BackendStorage for MetalStorage {
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
candle_metal_kernels::call_reduce_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -579,8 +323,7 @@ impl BackendStorage for MetalStorage {
|
||||
&dims,
|
||||
&stride,
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -606,7 +349,8 @@ impl BackendStorage for MetalStorage {
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, "todtype")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
if layout.is_contiguous() {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
(DType::U32, DType::F16) => "cast_u32_f16",
|
||||
@ -654,8 +398,7 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -682,9 +425,8 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -701,46 +443,69 @@ impl BackendStorage for MetalStorage {
|
||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
command_buffer.set_label(B::KERNEL);
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
if layout.is_contiguous() {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
||||
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
||||
("urelu", DType::F32) => contiguous::relu::FLOAT,
|
||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||
("usin", DType::F16) => contiguous::sin::HALF,
|
||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||
("usqrt", DType::F16) => contiguous::sqrt::HALF,
|
||||
("uneg", DType::F16) => contiguous::neg::HALF,
|
||||
("uexp", DType::F16) => contiguous::exp::HALF,
|
||||
("ulog", DType::F16) => contiguous::log::HALF,
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("usilu", DType::F16) => contiguous::silu::HALF,
|
||||
("uabs", DType::F16) => contiguous::abs::HALF,
|
||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||
("uabs", DType::BF16) => contiguous::abs::BFLOAT,
|
||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("uceil", DType::BF16) => contiguous::ceil::BFLOAT,
|
||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||
("ucos", DType::BF16) => contiguous::cos::BFLOAT,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("uerf", DType::BF16) => contiguous::erf::BFLOAT,
|
||||
("uexp", DType::F16) => contiguous::exp::HALF,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
("uexp", DType::BF16) => contiguous::exp::BFLOAT,
|
||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||
("uround", DType::F16) => contiguous::round::HALF,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
("ufloor", DType::BF16) => contiguous::floor::BFLOAT,
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT,
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("ugelu", DType::BF16) => contiguous::gelu::BFLOAT,
|
||||
("ulog", DType::F16) => contiguous::log::HALF,
|
||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||
("ulog", DType::BF16) => contiguous::log::BFLOAT,
|
||||
("uneg", DType::F16) => contiguous::neg::HALF,
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uneg", DType::BF16) => contiguous::neg::BFLOAT,
|
||||
("urecip", DType::F16) => contiguous::recip::HALF,
|
||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
||||
("urecip", DType::BF16) => contiguous::recip::BFLOAT,
|
||||
("urelu", DType::F16) => contiguous::relu::HALF,
|
||||
("urelu", DType::F32) => contiguous::relu::FLOAT,
|
||||
("urelu", DType::BF16) => contiguous::relu::BFLOAT,
|
||||
("uround", DType::F16) => contiguous::round::HALF,
|
||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||
("uround", DType::BF16) => contiguous::round::BFLOAT,
|
||||
("usilu", DType::F16) => contiguous::silu::HALF,
|
||||
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
||||
("usilu", DType::BF16) => contiguous::silu::BFLOAT,
|
||||
("usin", DType::F16) => contiguous::sin::HALF,
|
||||
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||
("usin", DType::BF16) => contiguous::sin::BFLOAT,
|
||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||
("usqr", DType::BF16) => contiguous::sqr::BFLOAT,
|
||||
("usqrt", DType::F16) => contiguous::sqrt::HALF,
|
||||
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||
("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT,
|
||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
||||
("utanh", DType::BF16) => contiguous::tanh::BFLOAT,
|
||||
("usign", DType::F16) => contiguous::sign::HALF,
|
||||
("usign", DType::F32) => contiguous::sign::FLOAT,
|
||||
("usign", DType::BF16) => contiguous::sign::BFLOAT,
|
||||
("usign", DType::I64) => contiguous::sign::I64,
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
@ -751,7 +516,7 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -775,6 +540,7 @@ impl BackendStorage for MetalStorage {
|
||||
("urelu", DType::F32) => strided::relu::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
("utanh", DType::F32) => strided::tanh::FLOAT,
|
||||
|
||||
("ucos", DType::F16) => strided::cos::HALF,
|
||||
("usin", DType::F16) => strided::sin::HALF,
|
||||
("usqr", DType::F16) => strided::sqr::HALF,
|
||||
@ -792,21 +558,39 @@ impl BackendStorage for MetalStorage {
|
||||
("urelu", DType::F16) => strided::relu::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
("utanh", DType::F16) => strided::tanh::HALF,
|
||||
|
||||
("ucos", DType::BF16) => strided::cos::BFLOAT,
|
||||
("usin", DType::BF16) => strided::sin::BFLOAT,
|
||||
("usqr", DType::BF16) => strided::sqr::BFLOAT,
|
||||
("usqrt", DType::BF16) => strided::sqrt::BFLOAT,
|
||||
("uneg", DType::BF16) => strided::neg::BFLOAT,
|
||||
("uexp", DType::BF16) => strided::exp::BFLOAT,
|
||||
("ulog", DType::BF16) => strided::log::BFLOAT,
|
||||
("ugelu", DType::BF16) => strided::gelu::BFLOAT,
|
||||
("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT,
|
||||
("uerf", DType::BF16) => strided::erf::BFLOAT,
|
||||
("usilu", DType::BF16) => strided::silu::BFLOAT,
|
||||
("uabs", DType::BF16) => strided::abs::BFLOAT,
|
||||
("uceil", DType::BF16) => strided::ceil::BFLOAT,
|
||||
("ufloor", DType::BF16) => strided::floor::BFLOAT,
|
||||
("urelu", DType::BF16) => strided::relu::BFLOAT,
|
||||
("uround", DType::BF16) => strided::round::BFLOAT,
|
||||
("utanh", DType::BF16) => strided::tanh::BFLOAT,
|
||||
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
let dst = BufferOffset::zero_offset(&buffer);
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
0,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
@ -853,21 +637,21 @@ impl BackendStorage for MetalStorage {
|
||||
(DType::U8, DType::U8) => "where_u8_u8",
|
||||
(left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"),
|
||||
};
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
let t = buffer_o(&t.buffer, t_l, t.dtype);
|
||||
let f = buffer_o(&f.buffer, f_l, f.dtype);
|
||||
candle_metal_kernels::call_where_cond_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
dims,
|
||||
&self.buffer,
|
||||
(
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
),
|
||||
&t.buffer,
|
||||
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
&f.buffer,
|
||||
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
src,
|
||||
layout.stride(),
|
||||
t,
|
||||
t_l.stride(),
|
||||
f,
|
||||
f_l.stride(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -900,6 +684,7 @@ impl BackendStorage for MetalStorage {
|
||||
DType::F32 => "im2col1d_f32",
|
||||
dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"),
|
||||
};
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
candle_metal_kernels::call_im2col1d_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -908,8 +693,7 @@ impl BackendStorage for MetalStorage {
|
||||
layout.shape().dims(),
|
||||
strides,
|
||||
(k_size, stride, padding, dilation),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1021,8 +805,13 @@ impl BackendStorage for MetalStorage {
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "im2col_f32",
|
||||
DType::F16 => "im2col_f16",
|
||||
DType::BF16 => "im2col_bf16",
|
||||
DType::U8 => "im2col_u8",
|
||||
DType::U32 => "im2col_u32",
|
||||
dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"),
|
||||
};
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
candle_metal_kernels::call_im2col_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -1031,8 +820,7 @@ impl BackendStorage for MetalStorage {
|
||||
layout.shape().dims(),
|
||||
layout.stride(),
|
||||
(h_k, w_k, stride, padding, dilation),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1117,8 +905,8 @@ impl BackendStorage for MetalStorage {
|
||||
padding: params.padding,
|
||||
output_padding: params.output_padding,
|
||||
c_out: params.c_out,
|
||||
out_h: out_h,
|
||||
out_w: out_w,
|
||||
out_h,
|
||||
out_w,
|
||||
b_size: params.b_size,
|
||||
input_dims: l.dims(),
|
||||
input_stride: l.stride(),
|
||||
@ -1233,6 +1021,10 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "upsample_nearest2d_f32",
|
||||
DType::F16 => "upsample_nearest2d_f16",
|
||||
DType::BF16 => "upsample_nearest2d_bf16",
|
||||
DType::U8 => "upsample_nearest2d_u8",
|
||||
DType::U32 => "upsample_nearest2d_u32",
|
||||
dtype => crate::bail!("Metal upsample_nearest2d {dtype:?} not implemented"),
|
||||
};
|
||||
|
||||
@ -1241,6 +1033,7 @@ impl BackendStorage for MetalStorage {
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "upsample_nearest2d")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, inp_l, self.dtype);
|
||||
candle_metal_kernels::call_upsample_nearest_2d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -1250,8 +1043,7 @@ impl BackendStorage for MetalStorage {
|
||||
strides,
|
||||
out_w,
|
||||
out_h,
|
||||
&self.buffer,
|
||||
inp_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1259,9 +1051,8 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let (ids_o1, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
if !ids_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "gather" }.bt());
|
||||
};
|
||||
let ids_el = ids_l.dims()[dim];
|
||||
let dst_el = ids_l.shape().elem_count();
|
||||
@ -1271,9 +1062,12 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "gather_u32_f32",
|
||||
(DType::U32, DType::F16) => "gather_u32_f16",
|
||||
(DType::U32, DType::BF16) => "gather_u32_bf16",
|
||||
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, src_l, dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_gather(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -1282,10 +1076,8 @@ impl BackendStorage for MetalStorage {
|
||||
src_l.dims(),
|
||||
ids_el,
|
||||
dim,
|
||||
&self.buffer,
|
||||
src_l.start_offset() * dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_o1 * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1303,13 +1095,8 @@ impl BackendStorage for MetalStorage {
|
||||
) -> Result<Self> {
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
let (ids_offset, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let src_offset = match src_l.contiguous_offsets() {
|
||||
Some((o1, _)) => o1,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::F32) => "sa_u8_f32",
|
||||
@ -1328,6 +1115,8 @@ impl BackendStorage for MetalStorage {
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_scatter_add(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -1336,10 +1125,8 @@ impl BackendStorage for MetalStorage {
|
||||
src_l.dims(),
|
||||
l.dims(),
|
||||
dim,
|
||||
&src.buffer,
|
||||
src_offset * src.dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_offset * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&acc.buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1366,11 +1153,17 @@ impl BackendStorage for MetalStorage {
|
||||
(DType::U32, DType::F16) => "is_u32_f16",
|
||||
(DType::U32, DType::BF16) => "is_u32_bf16",
|
||||
|
||||
(DType::I64, DType::F32) => "is_i64_f32",
|
||||
(DType::I64, DType::F16) => "is_i64_f16",
|
||||
(DType::I64, DType::BF16) => "is_i64_bf16",
|
||||
|
||||
(left, right) => {
|
||||
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
|
||||
}
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, src_l, dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_index_select(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -1382,10 +1175,8 @@ impl BackendStorage for MetalStorage {
|
||||
src_l.is_contiguous(),
|
||||
src_l.dims(),
|
||||
src_l.stride(),
|
||||
&self.buffer,
|
||||
src_l.start_offset() * dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_l.start_offset() * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1403,13 +1194,8 @@ impl BackendStorage for MetalStorage {
|
||||
) -> Result<Self> {
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
let (ids_offset, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
let src_offset = match src_l.contiguous_offsets() {
|
||||
Some((o1, _)) => o1,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::I64, DType::BF16) => "ia_i64_bf16",
|
||||
@ -1440,6 +1226,8 @@ impl BackendStorage for MetalStorage {
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_index_add(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -1449,10 +1237,8 @@ impl BackendStorage for MetalStorage {
|
||||
l.dims(),
|
||||
ids_l.dims(),
|
||||
dim,
|
||||
&src.buffer,
|
||||
src_offset * src.dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_offset * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&acc.buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1586,17 +1372,20 @@ impl BackendStorage for MetalStorage {
|
||||
DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
|
||||
dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"),
|
||||
};
|
||||
let src = buffer_o(&self.buffer, src_l, self.dtype);
|
||||
let dst = BufferOffset {
|
||||
buffer: &dst.buffer,
|
||||
offset_in_bytes: dst_offset * dst.dtype.size_in_bytes(),
|
||||
};
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
kernel_name,
|
||||
src_l.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
src_l.stride(),
|
||||
src_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&dst.buffer,
|
||||
dst_offset * dst.dtype.size_in_bytes(),
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.set_label("copy_strided");
|
||||
@ -1630,10 +1419,9 @@ impl MetalStorage {
|
||||
let shape = lhs_l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||
&& &op[..1] != "b"
|
||||
{
|
||||
let lhs = buffer_o(&self.buffer, lhs_l, self.dtype);
|
||||
let rhs = buffer_o(&rhs.buffer, rhs_l, rhs.dtype);
|
||||
let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() && &op[..1] != "b" {
|
||||
use candle_metal_kernels::binary::contiguous;
|
||||
|
||||
let (kernel_name, dtype) = match (op, self.dtype) {
|
||||
@ -1714,8 +1502,8 @@ impl MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
lhs,
|
||||
rhs,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1813,12 +1601,10 @@ impl MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
lhs_l.dims(),
|
||||
&self.buffer,
|
||||
lhs,
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
rhs,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1867,6 +1653,7 @@ impl BackendDevice for MetalDevice {
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
)));
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
command_queue,
|
||||
command_buffer,
|
||||
@ -1885,7 +1672,7 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn same_device(&self, rhs: &Self) -> bool {
|
||||
self.device.registry_id() == rhs.device.registry_id()
|
||||
self.id == rhs.id
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
@ -2023,10 +1810,10 @@ impl BackendDevice for MetalDevice {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||
(size - 1).next_power_of_two() as NSUInteger
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
self.wait_until_completed()
|
||||
}
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
@ -66,6 +66,7 @@ pub enum UnaryOp {
|
||||
Floor,
|
||||
Ceil,
|
||||
Round,
|
||||
Sign,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -254,6 +255,7 @@ pub(crate) struct Tanh;
|
||||
pub(crate) struct Floor;
|
||||
pub(crate) struct Ceil;
|
||||
pub(crate) struct Round;
|
||||
pub(crate) struct Sign;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||
@ -457,6 +459,13 @@ unary_op!(Recip, "recip", v, v.recip());
|
||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||
|
||||
// Hardcode the value for sqrt(2/pi)
|
||||
// https://github.com/huggingface/candle/issues/1982
|
||||
#[allow(clippy::excessive_precision)]
|
||||
const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
|
||||
#[allow(clippy::excessive_precision)]
|
||||
const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
|
||||
|
||||
/// Tanh based approximation of the `gelu` operation
|
||||
/// GeluErf is the more precise one.
|
||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||
@ -469,7 +478,7 @@ impl UnaryOpT for Gelu {
|
||||
* v
|
||||
* (bf16::ONE
|
||||
+ bf16::tanh(
|
||||
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
|
||||
bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
|
||||
* v
|
||||
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
|
||||
))
|
||||
@ -480,22 +489,18 @@ impl UnaryOpT for Gelu {
|
||||
* v
|
||||
* (f16::ONE
|
||||
+ f16::tanh(
|
||||
(f16::from_f32_const(2.0) / f16::PI).sqrt()
|
||||
f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
|
||||
* v
|
||||
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
|
||||
))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
0.5 * v
|
||||
* (1.0
|
||||
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||
0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
0.5 * v
|
||||
* (1.0
|
||||
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||
0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
@ -922,3 +927,37 @@ impl std::ops::Deref for BackpropOp {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Sign {
|
||||
const NAME: &'static str = "sign";
|
||||
const KERNEL: &'static str = "usign";
|
||||
const V: Self = Sign;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
f32::from(v > 0.) - f32::from(v < 0.)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
f64::from(v > 0.) - f64::from(v < 0.)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
u8::min(1, v)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
u32::min(1, v)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
(v > 0) as i64 - (v < 0) as i64
|
||||
}
|
||||
}
|
||||
|
@ -1,22 +1,63 @@
|
||||
use super::{GgmlDType, QStorage};
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||
use crate::{CudaDevice, CudaStorage, Result};
|
||||
|
||||
use cudarc::driver::{CudaSlice, DeviceSlice};
|
||||
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct QCudaStorage {
|
||||
data: CudaSlice<u8>,
|
||||
dtype: GgmlDType,
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
|
||||
|
||||
pub fn set_force_dmmv(f: bool) {
|
||||
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub const WARP_SIZE: usize = 32;
|
||||
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
|
||||
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
|
||||
pub const NWARPS_Q4_0_AMPERE: usize = 4;
|
||||
pub const GGML_CUDA_MMV_X: usize = 32;
|
||||
pub const GGML_CUDA_MMV_Y: usize = 1;
|
||||
pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256;
|
||||
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
|
||||
pub const MATRIX_ROW_PADDING: usize = 512;
|
||||
|
||||
fn ceil_div(p: usize, q: usize) -> usize {
|
||||
(p + q - 1) / q
|
||||
}
|
||||
|
||||
fn pad(p: usize, q: usize) -> usize {
|
||||
ceil_div(p, q) * q
|
||||
}
|
||||
|
||||
fn quantize_q8_1(
|
||||
src: &CudaView<f32>,
|
||||
dst: &mut CudaSlice<u8>,
|
||||
elem_count: usize,
|
||||
ky: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let kx = elem_count;
|
||||
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
||||
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
||||
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (num_blocks as u32, ky as u32, 1),
|
||||
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let params = (src, dst, kx as i32, kx_padded as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dequantize(
|
||||
data: &CudaSlice<u8>,
|
||||
@ -30,26 +71,18 @@ fn dequantize(
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
|
||||
GgmlDType::Q5_0 => {
|
||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
(
|
||||
"dequantize_block_q5_0",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
nb,
|
||||
)
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
(
|
||||
"dequantize_block_q5_1",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
nb,
|
||||
)
|
||||
}
|
||||
GgmlDType::Q5_0 => (
|
||||
"dequantize_block_q5_0",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q5_1 => (
|
||||
"dequantize_block_q5_1",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
|
||||
@ -60,7 +93,7 @@ fn dequantize(
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
|
||||
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
@ -83,9 +116,9 @@ fn dequantize(
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn dequantize_mut_mal_vec(
|
||||
fn dequantize_mul_mat_vec(
|
||||
data: &CudaSlice<u8>,
|
||||
y: &cudarc::driver::CudaView<f32>,
|
||||
y: &CudaView<f32>,
|
||||
dtype: GgmlDType,
|
||||
ncols: usize,
|
||||
nrows: usize,
|
||||
@ -93,6 +126,13 @@ fn dequantize_mut_mal_vec(
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
||||
if data_elems < ncols * nrows {
|
||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||
}
|
||||
if y.len() != ncols {
|
||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||
}
|
||||
let kernel_name = match dtype {
|
||||
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
|
||||
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
|
||||
@ -107,8 +147,8 @@ fn dequantize_mut_mal_vec(
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = dev.alloc_zeros::<f32>(nrows).w()?;
|
||||
let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
||||
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (block_num_y as u32, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
|
||||
@ -120,9 +160,147 @@ fn dequantize_mut_mal_vec(
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn mul_mat_vec_via_q8_1(
|
||||
data: &CudaSlice<u8>,
|
||||
y: &CudaView<f32>,
|
||||
dtype: GgmlDType,
|
||||
ncols: usize,
|
||||
nrows: usize,
|
||||
b_size: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
||||
if data_elems < ncols * nrows {
|
||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||
}
|
||||
if y.len() != ncols * b_size {
|
||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||
}
|
||||
if b_size == 0 || b_size > 4 {
|
||||
crate::bail!("only bsize between 1 and 4 are supported, got {b_size}")
|
||||
}
|
||||
// Start by quantizing y
|
||||
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||
let y_size_in_bytes =
|
||||
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
|
||||
|
||||
let kernel_name = match dtype {
|
||||
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
|
||||
GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda",
|
||||
GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda",
|
||||
GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda",
|
||||
GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda",
|
||||
GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda",
|
||||
GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda",
|
||||
GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda",
|
||||
GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda",
|
||||
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let kernel_name = format!("{kernel_name}{b_size}");
|
||||
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||
let nblocks = if b_size == 1 {
|
||||
nrows as u32
|
||||
} else {
|
||||
(nrows as u32 + 1) / 2
|
||||
};
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (nblocks, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, 4, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let params = (
|
||||
data,
|
||||
&y_q8_1,
|
||||
&dst,
|
||||
/* ncols_x */ ncols as i32,
|
||||
/* nrows_x */ nrows as i32,
|
||||
/* nrows_y */ ncols_padded as i32,
|
||||
/* nrows_dst */ nrows as i32,
|
||||
);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn mul_mat_via_q8_1(
|
||||
data: &CudaSlice<u8>,
|
||||
y: &CudaView<f32>,
|
||||
dtype: GgmlDType,
|
||||
x_rows: usize,
|
||||
x_cols: usize,
|
||||
y_rows: usize,
|
||||
y_cols: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
||||
if data_elems < x_rows * x_cols {
|
||||
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
|
||||
}
|
||||
if y.len() != y_rows * y_cols {
|
||||
crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len())
|
||||
}
|
||||
if x_cols != y_rows {
|
||||
crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}")
|
||||
}
|
||||
let k = x_cols;
|
||||
// Start by quantizing y
|
||||
let k_padded = pad(k, MATRIX_ROW_PADDING);
|
||||
let y_size_in_bytes =
|
||||
k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
|
||||
|
||||
let (kernel_name, mmq_x, mmq_y) = match dtype {
|
||||
GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128),
|
||||
GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128),
|
||||
GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64),
|
||||
GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64),
|
||||
GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64),
|
||||
GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128),
|
||||
GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128),
|
||||
GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128),
|
||||
GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128),
|
||||
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (
|
||||
ceil_div(x_rows, mmq_y) as u32,
|
||||
ceil_div(y_cols, mmq_x) as u32,
|
||||
1,
|
||||
),
|
||||
block_dim: (WARP_SIZE as u32, 4, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let params = (
|
||||
/* vx */ data,
|
||||
/* vy */ &y_q8_1,
|
||||
/* dst */ &dst,
|
||||
/* ncols_x */ x_cols as i32,
|
||||
/* nrows_x */ x_rows as i32,
|
||||
/* ncols_y */ y_cols as i32,
|
||||
/* nrows_y */ k_padded as i32,
|
||||
/* nrows_dst */ x_rows as i32,
|
||||
);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
impl QCudaStorage {
|
||||
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||
let size_in_bytes = el_count * dtype.type_size() / dtype.block_size();
|
||||
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
|
||||
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
|
||||
Ok(QCudaStorage {
|
||||
data,
|
||||
@ -140,6 +318,12 @@ impl QCudaStorage {
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||
fn deq<T: GgmlType>(buffer: &[u8], n: usize, dst: &mut [f32]) -> Result<()> {
|
||||
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
|
||||
let vec = slice.to_vec();
|
||||
T::to_float(&vec, dst)
|
||||
}
|
||||
|
||||
let fast_kernel = matches!(
|
||||
self.dtype,
|
||||
GgmlDType::Q4_0
|
||||
@ -158,69 +342,25 @@ impl QCudaStorage {
|
||||
return dequantize(&self.data, self.dtype, elem_count, self.device());
|
||||
}
|
||||
// Run the dequantization on cpu.
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
|
||||
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
let block_len = elem_count / self.dtype.block_size();
|
||||
match self.dtype {
|
||||
GgmlDType::F32 => {
|
||||
let slice =
|
||||
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, block_len) };
|
||||
out.copy_from_slice(slice)
|
||||
}
|
||||
GgmlDType::F16 => {
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
|
||||
half::f16::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8K => {
|
||||
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q5_1 => deq::<crate::quantized::BlockQ5_1>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q8_0 => deq::<crate::quantized::BlockQ8_0>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q8_1 => deq::<crate::quantized::BlockQ8_1>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q2K => deq::<crate::quantized::BlockQ2K>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q3K => deq::<crate::quantized::BlockQ3K>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q4K => deq::<crate::quantized::BlockQ4K>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q5K => deq::<crate::quantized::BlockQ5K>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q6K => deq::<crate::quantized::BlockQ6K>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q8K => deq::<crate::quantized::BlockQ8K>(&buffer, block_len, &mut out)?,
|
||||
}
|
||||
|
||||
self.device
|
||||
@ -255,7 +395,17 @@ impl QCudaStorage {
|
||||
storage: &CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
|
||||
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
1
|
||||
} else {
|
||||
4
|
||||
};
|
||||
let use_vec_kernel = match layout.shape().dims() {
|
||||
[b, m, _k] => b * m <= max_bm,
|
||||
[b, _k] => *b <= max_bm,
|
||||
_ => false,
|
||||
};
|
||||
if use_vec_kernel {
|
||||
self.dequantize_matmul_vec(self_shape, storage, layout)
|
||||
} else {
|
||||
self.dequantize_matmul(self_shape, storage, layout)
|
||||
@ -276,22 +426,31 @@ impl QCudaStorage {
|
||||
Some((o1, o2)) => rhs.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
|
||||
};
|
||||
let (with_batch, k) = match rhs_l.shape().dims() {
|
||||
[1, 1, k] => (true, k),
|
||||
[1, k] => (false, k),
|
||||
let (b_size, k) = match rhs_l.shape().dims() {
|
||||
[b, m, k] => (b * m, *k),
|
||||
[b, k] => (*b, *k),
|
||||
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
|
||||
};
|
||||
if ncols != *k {
|
||||
if ncols != k {
|
||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
||||
}
|
||||
|
||||
let out =
|
||||
dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?;
|
||||
let out_shape = if with_batch {
|
||||
vec![1, 1, nrows]
|
||||
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
||||
} else {
|
||||
vec![1, nrows]
|
||||
mul_mat_vec_via_q8_1(
|
||||
&self.data,
|
||||
&rhs,
|
||||
self.dtype,
|
||||
ncols,
|
||||
nrows,
|
||||
b_size,
|
||||
self.device(),
|
||||
)?
|
||||
};
|
||||
let mut out_shape = rhs_l.shape().dims().to_vec();
|
||||
out_shape.pop();
|
||||
out_shape.push(nrows);
|
||||
Ok((out, out_shape.into()))
|
||||
}
|
||||
|
||||
@ -312,9 +471,30 @@ impl QCudaStorage {
|
||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
|
||||
}
|
||||
|
||||
let data_f32 = self.dequantize(n * k)?;
|
||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
|
||||
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
|
||||
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
let data_f32 = self.dequantize(n * k)?;
|
||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
|
||||
storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?
|
||||
} else {
|
||||
let storage = storage.as_cuda_slice::<f32>()?;
|
||||
let storage = match layout.contiguous_offsets() {
|
||||
Some((o1, o2)) => storage.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous {
|
||||
op: "quantized-matmul",
|
||||
}
|
||||
.bt())?,
|
||||
};
|
||||
mul_mat_via_q8_1(
|
||||
&self.data,
|
||||
&storage,
|
||||
self.dtype,
|
||||
/* x_rows */ n,
|
||||
/* x_cols */ k,
|
||||
/* y_rows */ k,
|
||||
/* y_cols */ b * m,
|
||||
self.device(),
|
||||
)?
|
||||
};
|
||||
let mut out_shape = layout.shape().dims().to_vec();
|
||||
out_shape.pop();
|
||||
out_shape.push(n);
|
||||
@ -322,11 +502,6 @@ impl QCudaStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &[u8], n: usize) -> Vec<T> {
|
||||
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
|
||||
slice.to_vec()
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
device: &CudaDevice,
|
||||
data: &[T],
|
||||
@ -341,3 +516,101 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
dtype: T::DTYPE,
|
||||
}))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn cuda_quantize_q8_1() -> Result<()> {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let el = 256;
|
||||
let el_padded = pad(el, MATRIX_ROW_PADDING);
|
||||
let y_size_in_bytes =
|
||||
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
||||
let y = dev.htod_sync_copy(&vs).w()?;
|
||||
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_mmv_q8_1() -> Result<()> {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let ncols = 256;
|
||||
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
||||
let y = dev.htod_sync_copy(&vs).w()?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_vec_via_q8_1(
|
||||
&xs.data,
|
||||
&y.slice(..),
|
||||
/* dtype */ GgmlDType::Q4_0,
|
||||
/* ncols */ ncols,
|
||||
/* nrows */ 1,
|
||||
/* b_size */ 1,
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||
assert_eq!(vs.len(), 1);
|
||||
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
||||
// Q8 means 1/256 precision.
|
||||
assert_eq!(vs[0], 5561664.5);
|
||||
|
||||
let cuda_storage = dequantize_mul_mat_vec(
|
||||
&xs.data,
|
||||
&y.slice(..),
|
||||
/* dtype */ GgmlDType::Q4_0,
|
||||
/* ncols */ ncols,
|
||||
/* nrows */ 1,
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||
assert_eq!(vs.len(), 1);
|
||||
assert_eq!(vs[0], 5561851.0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_mm_q8_1() -> Result<()> {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let ncols = 256;
|
||||
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
||||
let y = dev.htod_sync_copy(&vs).w()?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_via_q8_1(
|
||||
&xs.data,
|
||||
&y.slice(..),
|
||||
/* dtype */ GgmlDType::Q4_0,
|
||||
/* x_rows */ 4,
|
||||
/* x_cols */ ncols,
|
||||
/* y_rows */ ncols,
|
||||
/* y_cols */ 4,
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||
|
||||
/*
|
||||
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
||||
x @ x.t() / 16
|
||||
tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000],
|
||||
[ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000],
|
||||
[ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000],
|
||||
[ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]])
|
||||
*/
|
||||
assert_eq!(vs.len(), 16);
|
||||
assert_eq!(vs[0], 347604.0);
|
||||
assert_eq!(vs[1], 888153.06);
|
||||
assert_eq!(vs[4], 869780.7);
|
||||
assert_eq!(vs[5], 2483145.0);
|
||||
assert_eq!(vs[11], 9407368.0);
|
||||
assert_eq!(vs[14], 9470856.0);
|
||||
assert_eq!(vs[15], 13138824.0);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -149,8 +149,11 @@ impl QMetalStorage {
|
||||
let (n, k) = self_shape.dims2()?;
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
|
||||
// We always use a single batch dimension and stack all the tensors in the batch on the
|
||||
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
|
||||
// properly.
|
||||
let (b, m) = match dst_shape.len() {
|
||||
3 => (dst_shape[0], dst_shape[1]),
|
||||
3 => (1, dst_shape[0] * dst_shape[1]),
|
||||
2 => (1, dst_shape[0]),
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
|
@ -171,7 +171,7 @@ impl Shape {
|
||||
}
|
||||
let mut acc = 1;
|
||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
||||
if stride != acc {
|
||||
if dim > 1 && stride != acc {
|
||||
return false;
|
||||
}
|
||||
acc *= dim;
|
||||
@ -186,7 +186,7 @@ impl Shape {
|
||||
}
|
||||
let mut acc = 1;
|
||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
|
||||
if stride != acc {
|
||||
if dim > 1 && stride != acc {
|
||||
return false;
|
||||
}
|
||||
acc *= dim;
|
||||
|
@ -44,9 +44,19 @@ impl Storage {
|
||||
}
|
||||
|
||||
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||
let lhs = self.device().location();
|
||||
let rhs = rhs.device().location();
|
||||
if lhs != rhs {
|
||||
let lhs_device = self.device();
|
||||
let rhs_device = rhs.device();
|
||||
let lhs = lhs_device.location();
|
||||
let rhs = rhs_device.location();
|
||||
let same_device = if self.device().is_metal() {
|
||||
// On metal, we require the device to be exactly the same rather than
|
||||
// having the same location. In cuda this is not necessary as all CudaDevice on the
|
||||
// same GPU will use the same cuda stream.
|
||||
lhs_device.same_device(&rhs_device)
|
||||
} else {
|
||||
lhs == rhs
|
||||
};
|
||||
if !same_device {
|
||||
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
|
||||
} else {
|
||||
Ok(())
|
||||
|
@ -79,6 +79,9 @@ macro_rules! unary_op {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
if shape.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self
|
||||
.storage()
|
||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||
@ -92,6 +95,9 @@ macro_rules! binary_op {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
if shape.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||
&*rhs.storage(),
|
||||
self.layout(),
|
||||
@ -114,6 +120,9 @@ macro_rules! binary_op_scalar {
|
||||
.broadcast_as(self.shape())?,
|
||||
};
|
||||
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||
&*rhs.storage(),
|
||||
self.layout(),
|
||||
@ -510,6 +519,7 @@ impl Tensor {
|
||||
unary_op!(ceil, Ceil);
|
||||
unary_op!(floor, Floor);
|
||||
unary_op!(round, Round);
|
||||
unary_op!(sign, Sign);
|
||||
|
||||
/// Round element of the input tensor to the nearest integer.
|
||||
///
|
||||
@ -645,6 +655,9 @@ impl Tensor {
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().affine(self.layout(), mul, add)?;
|
||||
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
@ -652,6 +665,9 @@ impl Tensor {
|
||||
|
||||
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
||||
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().elu(self.layout(), alpha)?;
|
||||
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
@ -659,6 +675,9 @@ impl Tensor {
|
||||
|
||||
/// Raise the tensor to some float exponent `e`.
|
||||
pub fn powf(&self, e: f64) -> Result<Self> {
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().powf(self.layout(), e)?;
|
||||
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
@ -1153,6 +1172,9 @@ impl Tensor {
|
||||
let n = b_dims[dim - 1];
|
||||
|
||||
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
||||
if c_shape.elem_count() == 0 || k == 0 {
|
||||
return Tensor::zeros(c_shape, self.dtype(), self.device());
|
||||
}
|
||||
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||
let batching_b: usize = b_dims[..dim - 2].iter().product();
|
||||
if k != k2 || batching != batching_b {
|
||||
@ -2007,6 +2029,16 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a tensor that is in row major order. This always makes a copy.
|
||||
pub fn force_contiguous(&self) -> Result<Tensor> {
|
||||
let shape = self.shape();
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let op = BackpropOp::new1(self, Op::Copy);
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
|
||||
/// Create a variable based on the values currently stored in a tensor. The storage is always
|
||||
/// copied.
|
||||
pub(crate) fn make_var(&self) -> Result<Tensor> {
|
||||
|
@ -58,20 +58,18 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
}
|
||||
if dim == 0 {
|
||||
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
|
||||
if all_contiguous {
|
||||
Self::cat_contiguous(args, dim)
|
||||
} else if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
|
||||
if all_contiguous {
|
||||
Self::cat_contiguous(args, dim)
|
||||
} else {
|
||||
let args: Vec<Tensor> = args
|
||||
.iter()
|
||||
.map(|a| a.as_ref().transpose(0, dim))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cat = Self::cat0(&args)?;
|
||||
cat.transpose(0, dim)
|
||||
}
|
||||
let args: Vec<Tensor> = args
|
||||
.iter()
|
||||
.map(|a| a.as_ref().transpose(0, dim))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cat = Self::cat0(&args)?;
|
||||
cat.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
|
||||
|
106
candle-core/tests/matmul_tests.rs
Normal file
106
candle-core/tests/matmul_tests.rs
Normal file
@ -0,0 +1,106 @@
|
||||
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn matmul(device: &Device) -> Result<()> {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||
|
||||
let data = vec![1.0f32, 2.0];
|
||||
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||
let data = vec![3.0f32, 4.0];
|
||||
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||
|
||||
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
||||
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||
|
||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
||||
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
||||
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
||||
|
||||
// Also perform the matmul on contiguous transposed versions.
|
||||
let a_tt = a.t()?.contiguous()?.t()?;
|
||||
assert!(!a_tt.is_contiguous());
|
||||
assert_eq!(a.dims(), a_tt.dims());
|
||||
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||
|
||||
let b_tt = b.t()?.contiguous()?.t()?;
|
||||
assert!(!b_tt.is_contiguous());
|
||||
assert_eq!(b.dims(), b_tt.dims());
|
||||
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||
|
||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||
let out = lhs.broadcast_matmul(&rhs)?;
|
||||
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
||||
for idx1 in 0..3 {
|
||||
for idx2 in 0..6 {
|
||||
let out = out.i((idx1, idx2))?;
|
||||
let lhs = lhs.i((idx1, 0))?;
|
||||
let rhs = rhs.i(idx2)?;
|
||||
let out2 = lhs.matmul(&rhs);
|
||||
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
||||
// With cuda, we see errors of up to ~1e-12.
|
||||
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/candle/issues/1948
|
||||
fn squeeze_mm(device: &Device) -> Result<()> {
|
||||
let seq_len = 8_usize;
|
||||
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
|
||||
let x = a.i((.., seq_len - 1, ..))?;
|
||||
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
|
||||
let x = x.matmul(&w)?;
|
||||
assert_eq!(x.dims(), &[1, 32]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/candle/issues/1992
|
||||
fn mm_layout(device: &Device) -> Result<()> {
|
||||
let a = Tensor::arange(0f32, 16f32, device)?.reshape((1, 1, 4, 4))?;
|
||||
let b = Tensor::arange(0f32, 8f32, device)?.reshape((1, 1, 4, 2))?;
|
||||
let mm1 = a.matmul(&b)?;
|
||||
// Forces the layout to be:
|
||||
// shape: [1, 1, 4, 2], stride: [8, 2, 2, 1], start_offset: 0
|
||||
// This is still a contiguous matrix but matmul checks are only the two last dimensions have
|
||||
// non 1 sizes but matmul check may be reluctant to handle it.
|
||||
let b = b.transpose(1, 2)?.force_contiguous()?.transpose(1, 2)?;
|
||||
let mm2 = a.matmul(&b)?;
|
||||
let diff = (mm1 - mm2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||
test_device!(
|
||||
broadcast_matmul,
|
||||
broadcast_matmul_cpu,
|
||||
broadcast_matmul_gpu,
|
||||
broadcast_matmul_metal
|
||||
);
|
||||
test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal);
|
||||
test_device!(mm_layout, mm_layout_cpu, mm_layout_gpu, mm_layout_metal);
|
@ -3,7 +3,7 @@ use candle_core::{
|
||||
quantized::{self, GgmlDType},
|
||||
test_device,
|
||||
test_utils::to_vec2_round,
|
||||
Device, Module, Result, Tensor,
|
||||
Device, IndexOp, Module, Result, Tensor,
|
||||
};
|
||||
use quantized::{k_quants, GgmlType};
|
||||
use rand::prelude::*;
|
||||
@ -47,18 +47,14 @@ fn test_matmul(
|
||||
}
|
||||
|
||||
fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||
let lhs_s = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
|
||||
let mut dst = vec![42.; 3 * 4];
|
||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
||||
&[
|
||||
@ -67,7 +63,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
]
|
||||
);
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||
let mm = lhs.matmul(&tensor_rhs)?;
|
||||
assert_eq!(
|
||||
mm.to_vec2::<f32>()?,
|
||||
&[
|
||||
@ -79,7 +75,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
|
||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
let res = matmul.forward(&lhs)?;
|
||||
match device {
|
||||
Device::Metal(_) => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
@ -89,7 +85,15 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
[341970.0, 994574.0, 1656181.0, 2302182.0]
|
||||
]
|
||||
),
|
||||
_ => assert_eq!(
|
||||
Device::Cuda(_) => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[84866.0, 214045.0, 344676.0, 473707.0],
|
||||
[213425.0, 604313.0, 1000431.0, 1387960.0],
|
||||
[342030.0, 994630.0, 1656248.0, 2302250.0]
|
||||
]
|
||||
),
|
||||
Device::Cpu => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[85120.0, 214562.0, 345455.0, 474748.0],
|
||||
@ -98,22 +102,16 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k))
|
||||
let lhs_s = (0..(m * k))
|
||||
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
|
||||
let mut dst = vec![42.; 3 * 4];
|
||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||
let rhs = (0..k * n)
|
||||
@ -121,7 +119,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
||||
&[
|
||||
@ -129,7 +127,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
-196472.0, 63012.0, 324585.0, 587902.0
|
||||
]
|
||||
);
|
||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||
let mm = lhs.matmul(&tensor_rhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&mm, 0)?,
|
||||
&[
|
||||
@ -141,7 +139,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
|
||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
let res = matmul.forward(&lhs)?;
|
||||
match device {
|
||||
Device::Metal(_) => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
@ -151,7 +149,15 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
[-196102.0, 63022.0, 324233.0, 587191.0]
|
||||
]
|
||||
),
|
||||
_ => assert_eq!(
|
||||
Device::Cuda(_) => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243740.0, -19762.0, -285476.0, -550498.0],
|
||||
[23774.0, 21645.0, 19395.0, 18364.0],
|
||||
[-196045.0, 63030.0, 324120.0, 587079.0]
|
||||
]
|
||||
),
|
||||
Device::Cpu => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||
@ -160,22 +166,50 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
let lhs2 = Tensor::stack(&[&lhs, &lhs], 0)?;
|
||||
let res2 = matmul.forward(&lhs2)?;
|
||||
let res2 = res2.i(1)?;
|
||||
let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if device.is_cuda() {
|
||||
assert!(diff < 0.1);
|
||||
} else {
|
||||
assert_eq!(diff, 0.);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(
|
||||
quantized_matmul,
|
||||
quantized_matmul_cpu,
|
||||
quantized_matmul_cuda,
|
||||
quantized_matmul_metal
|
||||
);
|
||||
test_device!(
|
||||
quantized_matmul_neg,
|
||||
quantized_matmul_neg_cpu,
|
||||
quantized_matmul_neg_cuda,
|
||||
quantized_matmul_neg_metal
|
||||
);
|
||||
fn qmm_batch(dev: &Device) -> Result<()> {
|
||||
let (lhs, rhs, _mm) = get_random_tensors(2, 256, 6, dev)?;
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
assert_eq!(mm.shape().dims(), [2, 6]);
|
||||
let lhs2 = Tensor::cat(&[&lhs, &lhs], 0)?;
|
||||
let mm2 = rhs.forward(&lhs2)?;
|
||||
assert_eq!(mm2.shape().dims(), [4, 6]);
|
||||
let diff2 = (mm2.i(2..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff2, 0.0);
|
||||
let lhs3 = Tensor::cat(&[&lhs2, &lhs], 0)?;
|
||||
let mm3 = rhs.forward(&lhs3)?;
|
||||
assert_eq!(mm3.shape().dims(), [6, 6]);
|
||||
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if dev.is_cuda() {
|
||||
assert!(diff3 < 1e-4)
|
||||
} else {
|
||||
assert_eq!(diff3, 0.0)
|
||||
};
|
||||
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if dev.is_cuda() {
|
||||
assert!(diff3 < 1e-4)
|
||||
} else {
|
||||
assert_eq!(diff3, 0.0)
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal);
|
||||
test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal);
|
||||
test_device!(qmm_batch, qmm_b_cpu, qmm_b_cuda, qmm_b_metal);
|
||||
|
||||
fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
|
@ -106,6 +106,9 @@ fn unary_op(device: &Device) -> Result<()> {
|
||||
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
||||
]
|
||||
);
|
||||
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
|
||||
let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?;
|
||||
assert!(max_diff.to_vec0::<f32>()? < 5e-3);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||
[
|
||||
@ -148,6 +151,14 @@ fn unary_op(device: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
||||
[3000.0, 300.]
|
||||
);
|
||||
let tensor = Tensor::new(
|
||||
&[-1.01f32, -0.9, -0.1, 0.0, -0.0, 0.1, 0.9, 1.0, 1.1],
|
||||
device,
|
||||
)?;
|
||||
assert_eq!(
|
||||
tensor.sign()?.to_vec1::<f32>()?,
|
||||
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -707,6 +718,8 @@ fn embeddings(device: &Device) -> Result<()> {
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -734,44 +747,47 @@ fn index_select(device: &Device) -> Result<()> {
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 2.0, 1.0],
|
||||
[3.0, 5.0, 4.0],
|
||||
[6.0, 8.0, 7.0],
|
||||
[9.0, 11.0, 10.0]
|
||||
]
|
||||
);
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||
);
|
||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
]
|
||||
);
|
||||
for dtype in [DType::U8, DType::U32, DType::I64] {
|
||||
let ids = ids.to_dtype(dtype)?;
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 2.0, 1.0],
|
||||
[3.0, 5.0, 4.0],
|
||||
[6.0, 8.0, 7.0],
|
||||
[9.0, 11.0, 10.0]
|
||||
]
|
||||
);
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||
);
|
||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
]
|
||||
);
|
||||
|
||||
// Test when selecting dim > 0 with ids size different from elem count of
|
||||
// target dim in source/input.
|
||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||
// Test when selecting dim > 0 with ids size different from elem count of
|
||||
// target dim in source/input.
|
||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -933,74 +949,6 @@ fn gather(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn matmul(device: &Device) -> Result<()> {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||
|
||||
let data = vec![1.0f32, 2.0];
|
||||
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||
let data = vec![3.0f32, 4.0];
|
||||
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||
|
||||
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
||||
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||
|
||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
||||
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
||||
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
||||
|
||||
// Also perform the matmul on contiguous transposed versions.
|
||||
let a_tt = a.t()?.contiguous()?.t()?;
|
||||
assert!(!a_tt.is_contiguous());
|
||||
assert_eq!(a.dims(), a_tt.dims());
|
||||
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||
|
||||
let b_tt = b.t()?.contiguous()?.t()?;
|
||||
assert!(!b_tt.is_contiguous());
|
||||
assert_eq!(b.dims(), b_tt.dims());
|
||||
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||
|
||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||
let out = lhs.broadcast_matmul(&rhs)?;
|
||||
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
||||
for idx1 in 0..3 {
|
||||
for idx2 in 0..6 {
|
||||
let out = out.i((idx1, idx2))?;
|
||||
let lhs = lhs.i((idx1, 0))?;
|
||||
let rhs = rhs.i(idx2)?;
|
||||
let out2 = lhs.matmul(&rhs);
|
||||
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
||||
// With cuda, we see errors of up to ~1e-12.
|
||||
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcasting(device: &Device) -> Result<()> {
|
||||
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
|
||||
let t2 = Tensor::new(&[100f32, 200f32], device)?;
|
||||
@ -1135,6 +1083,27 @@ fn randn(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn zero_dim(device: &Device) -> Result<()> {
|
||||
let t = Tensor::zeros((4, 0, 1), DType::F32, device)?;
|
||||
assert_eq!(t.dims3()?, (4, 0, 1));
|
||||
let t2 = Tensor::zeros((4, 3, 1), DType::F32, device)?;
|
||||
let t_cat = Tensor::cat(&[&t, &t2], 1)?;
|
||||
assert_eq!(t_cat.dims3()?, (4, 3, 1));
|
||||
let t_cat = Tensor::cat(&[&t, &t], 1)?;
|
||||
assert_eq!(t_cat.dims3()?, (4, 0, 1));
|
||||
let t_unary = t.sqrt()?;
|
||||
assert_eq!(t_unary.dims3()?, (4, 0, 1));
|
||||
let t_plus = (&t + 1.)?;
|
||||
assert_eq!(t_plus.dims3()?, (4, 0, 1));
|
||||
let t_mm = t2.matmul(&t.t()?)?;
|
||||
assert_eq!(t_mm.dims3()?, (4, 3, 0));
|
||||
let t_mm = t.matmul(&t2.t()?)?;
|
||||
assert_eq!(t_mm.dims3()?, (4, 0, 3));
|
||||
let t_mm = t.t()?.matmul(&t)?;
|
||||
assert_eq!(t_mm.dims3()?, (4, 1, 1));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||
@ -1154,13 +1123,6 @@ test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||
test_device!(
|
||||
broadcast_matmul,
|
||||
broadcast_matmul_cpu,
|
||||
broadcast_matmul_gpu,
|
||||
broadcast_matmul_metal
|
||||
);
|
||||
test_device!(
|
||||
broadcasting,
|
||||
broadcasting_cpu,
|
||||
@ -1190,6 +1152,7 @@ test_device!(
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
|
||||
|
||||
// There was originally a bug on the CPU implementation for randn
|
||||
// https://github.com/huggingface/candle/issues/381
|
||||
@ -1317,8 +1280,8 @@ fn pow() -> Result<()> {
|
||||
let rhs = (&lhs - 2.)?;
|
||||
let res = lhs.pow(&rhs)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res, 4)?,
|
||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
||||
test_utils::to_vec2_round(&res, 3)?,
|
||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ hf-hub = { workspace = true, features = ["tokio"] }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
|
||||
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
|
||||
rayon = { workspace = true }
|
||||
rubato = { version = "0.15.0", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
|
46
candle-examples/examples/clip/README.md
Normal file
46
candle-examples/examples/clip/README.md
Normal file
@ -0,0 +1,46 @@
|
||||
Contrastive Language-Image Pre-Training
|
||||
|
||||
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||
pairs of images with related texts.
|
||||
|
||||
https://github.com/openai/CLIP
|
||||
|
||||
https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
||||
|
||||
## Running on an example on cpu
|
||||
|
||||
```
|
||||
$ cargo run --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
||||
|
||||
|
||||
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||
|
||||
INFO clip: Probability: 0.0000% Text: a cycling race
|
||||
INFO clip: Probability: 0.0000% Text: a photo of two cats
|
||||
INFO clip: Probability: 100.0000% Text: a robot holding a candle
|
||||
|
||||
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
INFO clip: Probability: 99.9999% Text: a cycling race
|
||||
INFO clip: Probability: 0.0001% Text: a photo of two cats
|
||||
INFO clip: Probability: 0.0000% Text: a robot holding a candle
|
||||
```
|
||||
|
||||
## Running on an example with metal feature (mac)
|
||||
|
||||
```
|
||||
$ cargo run --features metal --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
||||
|
||||
|
||||
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||
|
||||
INFO clip: Probability: 0.0000% Text: a cycling race
|
||||
INFO clip: Probability: 0.0000% Text: a photo of two cats
|
||||
INFO clip: Probability: 100.0000% Text: a robot holding a candle
|
||||
|
||||
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
INFO clip: Probability: 99.9999% Text: a cycling race
|
||||
INFO clip: Probability: 0.0001% Text: a photo of two cats
|
||||
INFO clip: Probability: 0.0000% Text: a robot holding a candle
|
||||
```
|
202
candle-examples/examples/clip/main.rs
Normal file
202
candle-examples/examples/clip/main.rs
Normal file
@ -0,0 +1,202 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use candle_transformers::models::clip;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long, use_value_delimiter = true)]
|
||||
images: Option<Vec<String>>,
|
||||
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long, use_value_delimiter = true)]
|
||||
sequences: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
|
||||
let img = image::io::Reader::open(path)?.decode()?;
|
||||
let (height, width) = (image_size, image_size);
|
||||
let img = img.resize_to_fill(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::Triangle,
|
||||
);
|
||||
|
||||
let img = img.to_rgb8();
|
||||
|
||||
let img = img.into_raw();
|
||||
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
|
||||
.permute((2, 0, 1))?
|
||||
.to_dtype(DType::F32)?
|
||||
.affine(2. / 255., -1.)?;
|
||||
// .unsqueeze(0)?;
|
||||
Ok(img)
|
||||
}
|
||||
|
||||
fn load_images<T: AsRef<std::path::Path>>(
|
||||
paths: &Vec<T>,
|
||||
image_size: usize,
|
||||
) -> anyhow::Result<Tensor> {
|
||||
let mut images = vec![];
|
||||
|
||||
for path in paths {
|
||||
let tensor = load_image(path, image_size)?;
|
||||
images.push(tensor);
|
||||
}
|
||||
|
||||
let images = Tensor::stack(&images, 0)?;
|
||||
|
||||
Ok(images)
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
// std::env::set_var("RUST_BACKTRACE", "full");
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
|
||||
let api = api.repo(hf_hub::Repo::with_revision(
|
||||
"openai/clip-vit-base-patch32".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/15".to_string(),
|
||||
));
|
||||
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let tokenizer = get_tokenizer(args.tokenizer)?;
|
||||
|
||||
let config = clip::ClipConfig::vit_base_patch32();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let vec_imgs = match args.images {
|
||||
Some(imgs) => imgs,
|
||||
None => vec![
|
||||
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
|
||||
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
||||
],
|
||||
};
|
||||
|
||||
// let image = load_image(args.image, config.image_size)?.to_device(&device)?;
|
||||
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
|
||||
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
|
||||
|
||||
let model = clip::ClipModel::new(vb, &config)?;
|
||||
|
||||
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
|
||||
|
||||
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
|
||||
|
||||
let softmax_image = softmax(&logits_per_image, 1)?;
|
||||
|
||||
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
|
||||
|
||||
info!("softmax_image_vec: {:?}", softmax_image_vec);
|
||||
|
||||
let probability_vec = softmax_image_vec
|
||||
.iter()
|
||||
.map(|v| v * 100.0)
|
||||
.collect::<Vec<f32>>();
|
||||
|
||||
let probability_per_image = probability_vec.len() / vec_imgs.len();
|
||||
|
||||
for (i, img) in vec_imgs.iter().enumerate() {
|
||||
let start = i * probability_per_image;
|
||||
let end = start + probability_per_image;
|
||||
let prob = &probability_vec[start..end];
|
||||
info!("\n\nResults for image: {}\n", img);
|
||||
|
||||
for (i, p) in prob.iter().enumerate() {
|
||||
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer = match tokenizer {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.repo(hf_hub::Repo::with_revision(
|
||||
"openai/clip-vit-base-patch32".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/15".to_string(),
|
||||
));
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
Some(file) => file.into(),
|
||||
};
|
||||
|
||||
Tokenizer::from_file(tokenizer).map_err(E::msg)
|
||||
}
|
||||
|
||||
pub fn tokenize_sequences(
|
||||
sequences: Option<Vec<String>>,
|
||||
tokenizer: &Tokenizer,
|
||||
device: &Device,
|
||||
) -> anyhow::Result<(Tensor, Vec<String>)> {
|
||||
let pad_id = *tokenizer
|
||||
.get_vocab(true)
|
||||
.get("<|endoftext|>")
|
||||
.ok_or(E::msg("No pad token"))?;
|
||||
|
||||
let vec_seq = match sequences {
|
||||
Some(seq) => seq,
|
||||
None => vec![
|
||||
"a cycling race".to_string(),
|
||||
"a photo of two cats".to_string(),
|
||||
"a robot holding a candle".to_string(),
|
||||
],
|
||||
};
|
||||
|
||||
let mut tokens = vec![];
|
||||
|
||||
for seq in vec_seq.clone() {
|
||||
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
|
||||
tokens.push(encoding.get_ids().to_vec());
|
||||
}
|
||||
|
||||
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
|
||||
|
||||
// Pad the sequences to have the same length
|
||||
for token_vec in tokens.iter_mut() {
|
||||
let len_diff = max_len - token_vec.len();
|
||||
if len_diff > 0 {
|
||||
token_vec.extend(vec![pad_id; len_diff]);
|
||||
}
|
||||
}
|
||||
|
||||
let input_ids = Tensor::new(tokens, device)?;
|
||||
|
||||
Ok((input_ids, vec_seq))
|
||||
}
|
@ -16,6 +16,30 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "2b")]
|
||||
Base2B,
|
||||
#[value(name = "7b")]
|
||||
Base7B,
|
||||
#[value(name = "2b-it")]
|
||||
Instruct2B,
|
||||
#[value(name = "7b-it")]
|
||||
Instruct7B,
|
||||
#[value(name = "1.1-2b-it")]
|
||||
InstructV1_1_2B,
|
||||
#[value(name = "1.1-7b-it")]
|
||||
InstructV1_1_7B,
|
||||
#[value(name = "code-2b")]
|
||||
CodeBase2B,
|
||||
#[value(name = "code-7b")]
|
||||
CodeBase7B,
|
||||
#[value(name = "code-2b-it")]
|
||||
CodeInstruct2B,
|
||||
#[value(name = "code-7b-it")]
|
||||
CodeInstruct7B,
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
@ -165,6 +189,10 @@ struct Args {
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "2b")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -196,14 +224,19 @@ fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match &args.model_id {
|
||||
Some(model_id) => match model_id.as_str() {
|
||||
"7b-it" => "google/gemma-7b-it".to_string(),
|
||||
"7b" => "google/gemma-7b".to_string(),
|
||||
"2b-it" => "google/gemma-2b-it".to_string(),
|
||||
"2b" => "google/gemma-2b".to_string(),
|
||||
_ => model_id.to_string(),
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => match args.which {
|
||||
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
||||
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
||||
Which::Base2B => "google/gemma-2b".to_string(),
|
||||
Which::Base7B => "google/gemma-7b".to_string(),
|
||||
Which::Instruct2B => "google/gemma-2b-it".to_string(),
|
||||
Which::Instruct7B => "google/gemma-7b-it".to_string(),
|
||||
Which::CodeBase2B => "google/codegemma-2b".to_string(),
|
||||
Which::CodeBase7B => "google/codegemma-7b".to_string(),
|
||||
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
||||
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
||||
},
|
||||
None => "google/gemma-2b".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
|
@ -54,6 +54,7 @@ impl TextGeneration {
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let dtype = self.model.dtype();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
@ -66,7 +67,7 @@ impl TextGeneration {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let mut state = State::new(1, &self.config, &self.device)?;
|
||||
let mut state = State::new(1, &self.config, dtype, &self.device)?;
|
||||
let mut next_logits = None;
|
||||
for &t in tokens.iter() {
|
||||
let input = Tensor::new(&[t], &self.device)?;
|
||||
@ -84,7 +85,7 @@ impl TextGeneration {
|
||||
Some(logits) => logits,
|
||||
None => anyhow::bail!("cannot work on an empty prompt"),
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(dtype)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
@ -210,6 +211,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "f32")]
|
||||
dtype: String,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
@ -220,6 +224,7 @@ struct Args {
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use std::str::FromStr;
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
@ -279,7 +284,8 @@ fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let dtype = DType::from_str(&args.dtype)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb.pp("backbone"))?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
@ -13,7 +13,7 @@ use candle_transformers::models::quantized_mistral::Model as QMistral;
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
@ -39,11 +39,26 @@ impl TextGeneration {
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
top_k: Option<usize>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
let logits_processor = {
|
||||
let temperature = temp.unwrap_or(0.);
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (top_k, top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(seed, sampling)
|
||||
};
|
||||
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
@ -159,6 +174,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
@ -196,6 +215,10 @@ struct Args {
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// Use the slower dmmv cuda kernel.
|
||||
#[arg(long)]
|
||||
force_dmmv: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -203,6 +226,9 @@ fn main() -> Result<()> {
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
#[cfg(feature = "cuda")]
|
||||
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
@ -307,6 +333,7 @@ fn main() -> Result<()> {
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.top_k,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
|
26
candle-examples/examples/moondream/README.md
Normal file
26
candle-examples/examples/moondream/README.md
Normal file
@ -0,0 +1,26 @@
|
||||
# candle-moondream
|
||||
|
||||
[Moondream](https://github.com/vikhyat/moondream) is a computer-vision model can answer real-world questions about images. It's tiny by today's models, with only 1.6B parameters. That enables it to run on a variety of devices, including mobile phones and edge devices.
|
||||
|
||||
## Running some examples
|
||||
First download an example image
|
||||
```bash
|
||||
$ wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg
|
||||
```
|
||||
|
||||
<img src="https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg" width="200">
|
||||
|
||||
Now you can run Moondream from the `candle-examples` crate:
|
||||
```bash
|
||||
$ cargo run --example moondream --release -- --prompt "What is the girl eating?" --image "./demo-1.jpg"
|
||||
|
||||
avavx: false, neon: true, simd128: false, f16c: false
|
||||
temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64
|
||||
retrieved the files in 3.395583ms
|
||||
Running on CPU, to run on GPU(metal), build this example with `--features metal`
|
||||
loaded the model in 5.485493792s
|
||||
loaded and encoded the image Tensor[dims 3, 378, 378; f32] in 4.801396417s
|
||||
starting the inference loop
|
||||
The girl is eating a hamburger.<
|
||||
9 tokens generated (0.68 token/s)
|
||||
```
|
343
candle-examples/examples/moondream/main.rs
Normal file
343
candle-examples/examples/moondream/main.rs
Normal file
@ -0,0 +1,343 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::{
|
||||
generation::LogitsProcessor,
|
||||
models::{moondream, quantized_moondream},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
Moondream(moondream::Model),
|
||||
Quantized(quantized_moondream::Model),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
verbose_prompt,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, image_embeds: &Tensor, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||
if tokens.is_empty() {
|
||||
anyhow::bail!("Empty prompts are not supported in the Moondream model.")
|
||||
}
|
||||
if self.verbose_prompt {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
}
|
||||
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
|
||||
// Moondream tokenizer bos_token and eos_token is "<|endoftext|>"
|
||||
// https://huggingface.co/vikhyatk/moondream2/blob/main/special_tokens_map.json
|
||||
let special_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
None => anyhow::bail!("cannot find the special token"),
|
||||
};
|
||||
let (bos_token, eos_token) = (special_token, special_token);
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut load_t = std::time::Duration::from_secs_f64(0f64);
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = if index > 0 {
|
||||
match self.model {
|
||||
Model::Moondream(ref mut model) => model.text_model.forward(&input)?,
|
||||
Model::Quantized(ref mut model) => model.text_model.forward(&input)?,
|
||||
}
|
||||
} else {
|
||||
let bos_token = Tensor::new(&[bos_token], &self.device)?.unsqueeze(0)?;
|
||||
let logits = match self.model {
|
||||
Model::Moondream(ref mut model) => {
|
||||
model
|
||||
.text_model
|
||||
.forward_with_img(&bos_token, &input, image_embeds)?
|
||||
}
|
||||
Model::Quantized(ref mut model) => {
|
||||
model
|
||||
.text_model
|
||||
.forward_with_img(&bos_token, &input, image_embeds)?
|
||||
}
|
||||
};
|
||||
load_t = start_gen.elapsed();
|
||||
println!("load_t: {:?}", load_t);
|
||||
logits
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token || tokens.ends_with(&[27, 10619, 29] /* <END> */) {
|
||||
break;
|
||||
}
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let dt = start_gen.elapsed() - load_t;
|
||||
println!(
|
||||
"\ngenerated in {} seconds\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
dt.as_secs_f64(),
|
||||
(generated_tokens - 1) as f64 / dt.as_secs_f64()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 0)]
|
||||
seed: u64,
|
||||
|
||||
#[arg(long, default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Use f16 precision for all the computations rather than f32.
|
||||
#[arg(long)]
|
||||
f16: bool,
|
||||
|
||||
#[arg(long)]
|
||||
model_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
}
|
||||
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 378, 378).
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> candle::Result<Tensor> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
.resize_to_fill(378, 378, image::imageops::FilterType::Triangle); // Adjusted to 378x378
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (378, 378, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
let mean = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
let std = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||
.broadcast_sub(&mean)?
|
||||
.broadcast_div(&std)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = hf_hub::api::tokio::Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
"santiagomed/candle-moondream".to_string()
|
||||
} else {
|
||||
"vikhyatk/moondream2".to_string()
|
||||
}
|
||||
}
|
||||
};
|
||||
let repo = api.repo(hf_hub::Repo::with_revision(
|
||||
model_id,
|
||||
hf_hub::RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let model_file = match args.model_file {
|
||||
Some(m) => m.into(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
repo.get("model-q4_0.gguf").await?
|
||||
} else {
|
||||
repo.get("model.safetensors").await?
|
||||
}
|
||||
}
|
||||
};
|
||||
let tokenizer = match args.tokenizer_file {
|
||||
Some(m) => m.into(),
|
||||
None => repo.get("tokenizer.json").await?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let config = moondream::Config::v2();
|
||||
let dtype = if args.quantized {
|
||||
if args.f16 {
|
||||
anyhow::bail!("Quantized model does not support f16");
|
||||
}
|
||||
DType::F32
|
||||
} else if device.is_cuda() || args.f16 {
|
||||
DType::F16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let model = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&model_file,
|
||||
&device,
|
||||
)?;
|
||||
let model = quantized_moondream::Model::new(&config, vb)?;
|
||||
Model::Quantized(model)
|
||||
} else {
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
let model = moondream::Model::new(&config, vb)?;
|
||||
Model::Moondream(model)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let image = load_image(args.image)?
|
||||
.to_device(&device)?
|
||||
.to_dtype(dtype)?;
|
||||
let image_embeds = image.unsqueeze(0)?;
|
||||
let image_embeds = match model {
|
||||
Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?,
|
||||
Model::Quantized(ref m) => image_embeds.apply(m.vision_encoder())?,
|
||||
};
|
||||
println!(
|
||||
"loaded and encoded the image {image:?} in {:?}",
|
||||
start.elapsed()
|
||||
);
|
||||
|
||||
let prompt = format!("\n\nQuestion: {0}\n\nAnswer:", args.prompt);
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&prompt, &image_embeds, args.sample_len)?;
|
||||
|
||||
Ok(())
|
||||
}
|
@ -17,7 +17,7 @@ generate quantized weight files from the original safetensors file by using the
|
||||
`tensor-tools` command line utility via:
|
||||
|
||||
```bash
|
||||
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||
$ cargo run --bin tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||
```
|
||||
|
||||
## Using custom models
|
||||
|
@ -10,7 +10,7 @@ use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::Tensor;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::quantized_llama as model;
|
||||
@ -67,6 +67,8 @@ enum Which {
|
||||
Mixtral,
|
||||
#[value(name = "mixtral-instruct")]
|
||||
MixtralInstruct,
|
||||
#[value(name = "phi-2")]
|
||||
Phi2,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
@ -82,7 +84,8 @@ impl Which {
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode
|
||||
| Self::Leo7b
|
||||
| Self::Leo13b => false,
|
||||
| Self::Leo13b
|
||||
| Self::Phi2 => false,
|
||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||
// same way. Starling is a fine tuned version of OpenChat.
|
||||
Self::OpenChat35
|
||||
@ -116,6 +119,7 @@ impl Which {
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::Mistral7bInstructV02
|
||||
| Self::OpenChat35
|
||||
| Self::Phi2
|
||||
| Self::Starling7bAlpha => false,
|
||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||
}
|
||||
@ -139,6 +143,7 @@ impl Which {
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::Mistral7bInstructV02
|
||||
| Self::Phi2
|
||||
| Self::Zephyr7bAlpha
|
||||
| Self::Zephyr7bBeta => false,
|
||||
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
||||
@ -147,26 +152,27 @@ impl Which {
|
||||
|
||||
fn tokenizer_repo(&self) -> &'static str {
|
||||
match self {
|
||||
Which::L7b
|
||||
| Which::L13b
|
||||
| Which::L70b
|
||||
| Which::L7bChat
|
||||
| Which::L13bChat
|
||||
| Which::L70bChat
|
||||
| Which::L7bCode
|
||||
| Which::L13bCode
|
||||
| Which::L34bCode => "hf-internal-testing/llama-tokenizer",
|
||||
Which::Leo7b => "LeoLM/leo-hessianai-7b",
|
||||
Which::Leo13b => "LeoLM/leo-hessianai-13b",
|
||||
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
|
||||
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
Which::Mistral7b
|
||||
| Which::Mistral7bInstruct
|
||||
| Which::Mistral7bInstructV02
|
||||
| Which::Zephyr7bAlpha
|
||||
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
||||
Which::OpenChat35 => "openchat/openchat_3.5",
|
||||
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
||||
Self::L7b
|
||||
| Self::L13b
|
||||
| Self::L70b
|
||||
| Self::L7bChat
|
||||
| Self::L13bChat
|
||||
| Self::L70bChat
|
||||
| Self::L7bCode
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode => "hf-internal-testing/llama-tokenizer",
|
||||
Self::Leo7b => "LeoLM/leo-hessianai-7b",
|
||||
Self::Leo13b => "LeoLM/leo-hessianai-13b",
|
||||
Self::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
|
||||
Self::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
Self::Mistral7b
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::Mistral7bInstructV02
|
||||
| Self::Zephyr7bAlpha
|
||||
| Self::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
||||
Self::OpenChat35 => "openchat/openchat_3.5",
|
||||
Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
||||
Self::Phi2 => "microsoft/phi-2",
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -200,6 +206,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
@ -235,6 +245,10 @@ struct Args {
|
||||
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
|
||||
#[arg(long)]
|
||||
gqa: Option<usize>,
|
||||
|
||||
/// Use the slower dmmv cuda kernel.
|
||||
#[arg(long)]
|
||||
force_dmmv: bool,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -314,6 +328,7 @@ impl Args {
|
||||
"TheBloke/Starling-LM-7B-alpha-GGUF",
|
||||
"starling-lm-7b-alpha.Q4_K_M.gguf",
|
||||
),
|
||||
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf"),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(repo.to_string());
|
||||
@ -341,11 +356,10 @@ fn main() -> anyhow::Result<()> {
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let temperature = if args.temperature == 0. {
|
||||
None
|
||||
} else {
|
||||
Some(args.temperature)
|
||||
};
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
@ -413,7 +427,8 @@ fn main() -> anyhow::Result<()> {
|
||||
| Which::L13bCode
|
||||
| Which::L34bCode
|
||||
| Which::Leo7b
|
||||
| Which::Leo13b => 1,
|
||||
| Which::Leo13b
|
||||
| Which::Phi2 => 1,
|
||||
Which::Mixtral
|
||||
| Which::MixtralInstruct
|
||||
| Which::Mistral7b
|
||||
@ -492,7 +507,20 @@ fn main() -> anyhow::Result<()> {
|
||||
prompt_tokens
|
||||
};
|
||||
let mut all_tokens = vec![];
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = if !args.split_prompt {
|
||||
|
27
candle-examples/examples/qwen/README.md
Normal file
27
candle-examples/examples/qwen/README.md
Normal file
@ -0,0 +1,27 @@
|
||||
# candle-qwen: large language model series from Alibaba Cloud
|
||||
|
||||
Qwen 1.5 is a series of large language models that provide strong performances
|
||||
on English and Chinese.
|
||||
|
||||
- [Blog post](https://qwenlm.github.io/blog/qwen1.5/) introducing Qwen1.5.
|
||||
- [Model card](https://huggingface.co/Qwen/Qwen1.5-0.5B) on the HuggingFace Hub.
|
||||
- [Blog post](https://qwenlm.github.io/blog/qwen-moe/) for the
|
||||
mixture-of-experts (MoE) variant.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example qwen --release -- --prompt "Hello there "
|
||||
```
|
||||
|
||||
Various model sizes are available via the `--model` argument, including the MoE
|
||||
variant.
|
||||
|
||||
```bash
|
||||
$ cargo run --example qwen --release -- --model moe-a2.7b --prompt 'def print_prime(n: int): '
|
||||
def print_prime(n: int): # n is the number of primes to be printed
|
||||
for i in range(2, n + 1):
|
||||
if all(i % j != 0 for j in range(2, i)):
|
||||
print(i)
|
||||
```
|
||||
|
@ -7,7 +7,8 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::qwen2::{Config, Model};
|
||||
use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase};
|
||||
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -16,6 +17,20 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
Base(ModelBase),
|
||||
Moe(ModelMoe),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, xs: &Tensor, s: usize) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Moe(ref mut m) => m.forward(xs, s),
|
||||
Self::Base(ref mut m) => m.forward(xs, s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
@ -127,6 +142,8 @@ enum WhichModel {
|
||||
W14b,
|
||||
#[value(name = "72b")]
|
||||
W72b,
|
||||
#[value(name = "moe-a2.7b")]
|
||||
MoeA27b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -224,6 +241,7 @@ fn main() -> Result<()> {
|
||||
WhichModel::W7b => "7B",
|
||||
WhichModel::W14b => "14B",
|
||||
WhichModel::W72b => "72B",
|
||||
WhichModel::MoeA27b => "MoE-A2.7B",
|
||||
};
|
||||
format!("Qwen/Qwen1.5-{size}")
|
||||
}
|
||||
@ -244,7 +262,11 @@ fn main() -> Result<()> {
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.model {
|
||||
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => {
|
||||
WhichModel::W4b
|
||||
| WhichModel::W7b
|
||||
| WhichModel::W14b
|
||||
| WhichModel::W72b
|
||||
| WhichModel::MoeA27b => {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
},
|
||||
@ -254,7 +276,6 @@ fn main() -> Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config_file = repo.get("config.json")?;
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
@ -262,7 +283,16 @@ fn main() -> Result<()> {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
let model = match args.model {
|
||||
WhichModel::MoeA27b => {
|
||||
let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Moe(ModelMoe::new(&config, vb)?)
|
||||
}
|
||||
_ => {
|
||||
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Base(ModelBase::new(&config, vb)?)
|
||||
}
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
9
candle-examples/examples/recurrent-gemma/README.md
Normal file
9
candle-examples/examples/recurrent-gemma/README.md
Normal file
@ -0,0 +1,9 @@
|
||||
# candle-recurrent-gemma
|
||||
|
||||
This model card corresponds to the 2B base version of the RecurrentGemma model
|
||||
[huggingface model card](https://huggingface.co/google/recurrentgemma-2b).
|
||||
|
||||
```bash
|
||||
cargo run --features cuda -r --example recurrent-gemma -- \
|
||||
--prompt "Write me a poem about Machine Learning."
|
||||
```
|
321
candle-examples/examples/recurrent-gemma/main.rs
Normal file
321
candle-examples/examples/recurrent-gemma/main.rs
Normal file
@ -0,0 +1,321 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::quantized_recurrent_gemma::Model as QModel;
|
||||
use candle_transformers::models::recurrent_gemma::{Config, Model as BModel};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
B(BModel),
|
||||
Q(QModel),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::B(m) => m.forward(xs, pos),
|
||||
Self::Q(m) => m.forward(xs, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "2b")]
|
||||
Base2B,
|
||||
#[value(name = "2b-it")]
|
||||
Instruct2B,
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
top_k: usize,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let sampling = match temp {
|
||||
None => candle_transformers::generation::Sampling::ArgMax,
|
||||
Some(temperature) => match top_p {
|
||||
None => candle_transformers::generation::Sampling::TopK {
|
||||
temperature,
|
||||
k: top_k,
|
||||
},
|
||||
Some(top_p) => candle_transformers::generation::Sampling::TopKThenTopP {
|
||||
temperature,
|
||||
k: top_k,
|
||||
p: top_p,
|
||||
},
|
||||
},
|
||||
};
|
||||
let logits_processor = LogitsProcessor::from_sampling(seed, sampling);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
#[arg(long, default_value_t = 250)]
|
||||
top_k: usize,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 8000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "2b")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match &args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => match args.which {
|
||||
Which::Base2B => "google/recurrentgemma-2b".to_string(),
|
||||
Which::Instruct2B => "google/recurrentgemma-2b-it".to_string(),
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
let filename = match args.which {
|
||||
Which::Base2B => "recurrent-gemma-2b-q4k.gguf",
|
||||
Which::Instruct2B => "recurrent-gemma-7b-q4k.gguf",
|
||||
};
|
||||
let filename = api.model("lmz/candle-gemma".to_string()).get(filename)?;
|
||||
vec![filename]
|
||||
} else {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let model = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&filenames[0],
|
||||
&device,
|
||||
)?;
|
||||
Model::Q(QModel::new(&config, vb.pp("model"))?)
|
||||
} else {
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
Model::B(BModel::new(&config, vb.pp("model"))?)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.top_k,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -42,7 +42,7 @@ impl GymEnv {
|
||||
/// Creates a new session of the specified OpenAI Gym environment.
|
||||
pub fn new(name: &str) -> Result<GymEnv> {
|
||||
Python::with_gil(|py| {
|
||||
let gym = py.import("gymnasium")?;
|
||||
let gym = py.import_bound("gymnasium")?;
|
||||
let make = gym.getattr("make")?;
|
||||
let env = make.call1((name,))?;
|
||||
let action_space = env.getattr("action_space")?;
|
||||
@ -66,10 +66,10 @@ impl GymEnv {
|
||||
/// Resets the environment, returning the observation tensor.
|
||||
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
||||
let state: Vec<f32> = Python::with_gil(|py| {
|
||||
let kwargs = PyDict::new(py);
|
||||
let kwargs = PyDict::new_bound(py);
|
||||
kwargs.set_item("seed", seed)?;
|
||||
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
||||
state.as_ref(py).get_item(0)?.extract()
|
||||
let state = self.env.call_method_bound(py, "reset", (), Some(&kwargs))?;
|
||||
state.bind(py).get_item(0)?.extract()
|
||||
})
|
||||
.map_err(w)?;
|
||||
Tensor::new(state, &Device::Cpu)
|
||||
@ -81,8 +81,10 @@ impl GymEnv {
|
||||
action: A,
|
||||
) -> Result<Step<A>> {
|
||||
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
|
||||
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
||||
let step = step.as_ref(py);
|
||||
let step = self
|
||||
.env
|
||||
.call_method_bound(py, "step", (action.clone(),), None)?;
|
||||
let step = step.bind(py);
|
||||
let state: Vec<f32> = step.get_item(0)?.extract()?;
|
||||
let reward: f64 = step.get_item(1)?.extract()?;
|
||||
let terminated: bool = step.get_item(2)?.extract()?;
|
||||
|
@ -24,13 +24,13 @@ fn w(res: PyErr) -> candle::Error {
|
||||
impl VecGymEnv {
|
||||
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
|
||||
Python::with_gil(|py| {
|
||||
let sys = py.import("sys")?;
|
||||
let sys = py.import_bound("sys")?;
|
||||
let path = sys.getattr("path")?;
|
||||
let _ = path.call_method1(
|
||||
"append",
|
||||
("candle-examples/examples/reinforcement-learning",),
|
||||
)?;
|
||||
let gym = py.import("atari_wrappers")?;
|
||||
let gym = py.import_bound("atari_wrappers")?;
|
||||
let make = gym.getattr("make")?;
|
||||
let env = make.call1((name, img_dir, nprocesses))?;
|
||||
let action_space = env.getattr("action_space")?;
|
||||
@ -60,10 +60,10 @@ impl VecGymEnv {
|
||||
|
||||
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
|
||||
let (obs, reward, is_done) = Python::with_gil(|py| {
|
||||
let step = self.env.call_method(py, "step", (action,), None)?;
|
||||
let step = step.as_ref(py);
|
||||
let step = self.env.call_method_bound(py, "step", (action,), None)?;
|
||||
let step = step.bind(py);
|
||||
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
|
||||
let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
|
||||
let obs_buffer = pyo3::buffer::PyBuffer::get_bound(&obs)?;
|
||||
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
|
||||
let reward: Vec<f32> = step.get_item(1)?.extract()?;
|
||||
let is_done: Vec<f32> = step.get_item(2)?.extract()?;
|
||||
|
@ -46,7 +46,8 @@ The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
|
||||
- `--cpu`: use the cpu rather than the gpu (much slower).
|
||||
- `--height`, `--width`: set the height and width for the generated image.
|
||||
- `--n-steps`: the number of steps to be used in the diffusion process.
|
||||
- `--num-samples`: the number of samples to generate.
|
||||
- `--num-samples`: the number of samples to generate iteratively.
|
||||
- `--bsize`: the numbers of samples to generate simultaneously.
|
||||
- `--final-image`: the filename for the generated image(s).
|
||||
|
||||
### Using flash-attention
|
||||
|
@ -9,6 +9,7 @@ use candle_transformers::models::stable_diffusion;
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
||||
use clap::Parser;
|
||||
use stable_diffusion::vae::AutoEncoderKL;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Parser)]
|
||||
@ -64,9 +65,13 @@ struct Args {
|
||||
#[arg(long)]
|
||||
n_steps: Option<usize>,
|
||||
|
||||
/// The number of samples to generate.
|
||||
/// The number of samples to generate iteratively.
|
||||
#[arg(long, default_value_t = 1)]
|
||||
num_samples: i64,
|
||||
num_samples: usize,
|
||||
|
||||
/// The numbers of samples to generate simultaneously.
|
||||
#[arg[long, default_value_t = 1]]
|
||||
bsize: usize,
|
||||
|
||||
/// The name of the final image to generate.
|
||||
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
|
||||
@ -236,8 +241,8 @@ impl ModelFile {
|
||||
|
||||
fn output_filename(
|
||||
basename: &str,
|
||||
sample_idx: i64,
|
||||
num_samples: i64,
|
||||
sample_idx: usize,
|
||||
num_samples: usize,
|
||||
timestep_idx: Option<usize>,
|
||||
) -> String {
|
||||
let filename = if num_samples > 1 {
|
||||
@ -261,6 +266,33 @@ fn output_filename(
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn save_image(
|
||||
vae: &AutoEncoderKL,
|
||||
latents: &Tensor,
|
||||
vae_scale: f64,
|
||||
bsize: usize,
|
||||
idx: usize,
|
||||
final_image: &str,
|
||||
num_samples: usize,
|
||||
timestep_ids: Option<usize>,
|
||||
) -> Result<()> {
|
||||
let images = vae.decode(&(latents / vae_scale)?)?;
|
||||
let images = ((images / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let images = (images.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?;
|
||||
for batch in 0..bsize {
|
||||
let image = images.i(batch)?;
|
||||
let image_filename = output_filename(
|
||||
final_image,
|
||||
(bsize * idx) + batch + 1,
|
||||
batch + num_samples,
|
||||
timestep_ids,
|
||||
);
|
||||
candle_examples::save_image(&image, image_filename)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn text_embeddings(
|
||||
prompt: &str,
|
||||
@ -382,6 +414,7 @@ fn run(args: Args) -> Result<()> {
|
||||
final_image,
|
||||
sliced_attention_size,
|
||||
num_samples,
|
||||
bsize,
|
||||
sd_version,
|
||||
clip_weights,
|
||||
vae_weights,
|
||||
@ -475,6 +508,7 @@ fn run(args: Args) -> Result<()> {
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
|
||||
let text_embeddings = text_embeddings.repeat((bsize, 1, 1))?;
|
||||
println!("{text_embeddings:?}");
|
||||
|
||||
println!("Building the autoencoder.");
|
||||
@ -496,7 +530,6 @@ fn run(args: Args) -> Result<()> {
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let bsize = 1;
|
||||
|
||||
let vae_scale = match sd_version {
|
||||
StableDiffusionVersion::V1_5
|
||||
@ -560,12 +593,16 @@ fn run(args: Args) -> Result<()> {
|
||||
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
|
||||
|
||||
if args.intermediary_images {
|
||||
let image = vae.decode(&(&latents / vae_scale)?)?;
|
||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image_filename =
|
||||
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
|
||||
candle_examples::save_image(&image, image_filename)?
|
||||
save_image(
|
||||
&vae,
|
||||
&latents,
|
||||
vae_scale,
|
||||
bsize,
|
||||
idx,
|
||||
&final_image,
|
||||
num_samples,
|
||||
Some(timestep_index + 1),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
@ -574,11 +611,16 @@ fn run(args: Args) -> Result<()> {
|
||||
idx + 1,
|
||||
num_samples
|
||||
);
|
||||
let image = vae.decode(&(&latents / vae_scale)?)?;
|
||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
||||
candle_examples::save_image(&image, image_filename)?
|
||||
save_image(
|
||||
&vae,
|
||||
&latents,
|
||||
vae_scale,
|
||||
bsize,
|
||||
idx,
|
||||
&final_image,
|
||||
num_samples,
|
||||
None,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -12,12 +12,23 @@ use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
T5Base,
|
||||
T5Small,
|
||||
T5Large,
|
||||
T5_3B,
|
||||
Mt5Base,
|
||||
Mt5Small,
|
||||
Mt5Large,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -36,6 +47,15 @@ struct Args {
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
model_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
/// Enable decoding.
|
||||
#[arg(long)]
|
||||
decode: bool,
|
||||
@ -71,6 +91,10 @@ struct Args {
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model to be used.
|
||||
#[arg(long, default_value = "t5-small")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
struct T5ModelBuilder {
|
||||
@ -82,8 +106,17 @@ struct T5ModelBuilder {
|
||||
impl T5ModelBuilder {
|
||||
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let default_model = "t5-small".to_string();
|
||||
let default_revision = "refs/pr/15".to_string();
|
||||
let (default_model, default_revision) = match args.which {
|
||||
Which::T5Base => ("t5-base", "main"),
|
||||
Which::T5Small => ("t5-small", "refs/pr/15"),
|
||||
Which::T5Large => ("t5-large", "main"),
|
||||
Which::T5_3B => ("t5-3b", "main"),
|
||||
Which::Mt5Base => ("google/mt5-base", "refs/pr/5"),
|
||||
Which::Mt5Small => ("google/mt5-small", "refs/pr/6"),
|
||||
Which::Mt5Large => ("google/mt5-large", "refs/pr/2"),
|
||||
};
|
||||
let default_model = default_model.to_string();
|
||||
let default_revision = default_revision.to_string();
|
||||
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
@ -93,14 +126,35 @@ impl T5ModelBuilder {
|
||||
|
||||
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config_filename = api.get("config.json")?;
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
|
||||
{
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
} else {
|
||||
vec![api.get("model.safetensors")?]
|
||||
let repo = api.repo(repo);
|
||||
let config_filename = match &args.config_file {
|
||||
None => repo.get("config.json")?,
|
||||
Some(f) => f.into(),
|
||||
};
|
||||
let tokenizer_filename = match &args.tokenizer_file {
|
||||
None => match args.which {
|
||||
Which::Mt5Base => api
|
||||
.model("lmz/mt5-tokenizers".into())
|
||||
.get("mt5-base.tokenizer.json")?,
|
||||
Which::Mt5Small => api
|
||||
.model("lmz/mt5-tokenizers".into())
|
||||
.get("mt5-small.tokenizer.json")?,
|
||||
Which::Mt5Large => api
|
||||
.model("lmz/mt5-tokenizers".into())
|
||||
.get("mt5-large.tokenizer.json")?,
|
||||
_ => repo.get("tokenizer.json")?,
|
||||
},
|
||||
Some(f) => f.into(),
|
||||
};
|
||||
let weights_filename = match &args.model_file {
|
||||
Some(f) => f.split(',').map(|v| v.into()).collect::<Vec<_>>(),
|
||||
None => {
|
||||
if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
} else {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
}
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||
|
@ -115,7 +115,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let processor = image_processor::ViTImageProcessor::new(&processor_config);
|
||||
|
||||
let image = vec![args.image.as_str()];
|
||||
let image = processor.preprocess(image)?;
|
||||
let image = processor.preprocess(image)?.to_device(&device)?;
|
||||
|
||||
let encoder_xs = model.encoder().forward(&image)?;
|
||||
|
||||
|
BIN
candle-examples/examples/yolo-v8/assets/bike.pp.jpg
Normal file
BIN
candle-examples/examples/yolo-v8/assets/bike.pp.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 175 KiB |
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.4.2"
|
||||
version = "0.5.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.2" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.5.0" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.4.2"
|
||||
version = "0.5.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -1,5 +1,8 @@
|
||||
fn main() {
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
println!("cargo:rerun-if-changed=src/compatibility.cuh");
|
||||
println!("cargo:rerun-if-changed=src/cuda_utils.cuh");
|
||||
println!("cargo:rerun-if-changed=src/binary_op_macros.cuh");
|
||||
|
||||
let builder = bindgen_cuda::Builder::default();
|
||||
println!("cargo:info={builder:?}");
|
||||
|
@ -14,7 +14,7 @@ __device__ bool is_contiguous(
|
||||
size_t acc = 1;
|
||||
for (unsigned int d = 0; d < num_dims; d++) {
|
||||
unsigned int dim_idx = num_dims - 1 - d;
|
||||
if (acc != strides[dim_idx]) {
|
||||
if (dims[dim_idx] > 1 && acc != strides[dim_idx]) {
|
||||
return false;
|
||||
}
|
||||
acc *= dims[dim_idx];
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -147,6 +147,65 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= bh * td) return;
|
||||
|
||||
uint32_t rope_idx = idx % (td / 2);
|
||||
T c = cos[rope_idx];
|
||||
T s = sin[rope_idx];
|
||||
|
||||
dst[2 * idx] = src[2 * idx] * c - src[2 * idx + 1] * s;
|
||||
dst[2 * idx + 1] = src[2 * idx] * s + src[2 * idx + 1] * c;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= bh * td) return;
|
||||
|
||||
uint32_t i_bh = idx / (td / 2);
|
||||
uint32_t i_td = idx - (td / 2) * i_bh;
|
||||
uint32_t i_t = i_td / (d / 2);
|
||||
uint32_t i_d = i_td - (d / 2) * i_t;
|
||||
uint32_t i1 = i_bh * td + i_t * d + i_d;
|
||||
uint32_t i2 = i1 + d / 2;
|
||||
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
|
||||
dst[i1] = src[i1] * c - src[i2] * s;
|
||||
dst[i2] = src[i1] * s + src[i2] * c;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void rope_thd(
|
||||
const T * src,
|
||||
const T * cos,
|
||||
const T * sin,
|
||||
T * dst,
|
||||
const uint32_t b,
|
||||
const uint32_t t,
|
||||
const uint32_t h,
|
||||
const uint32_t d
|
||||
) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= b * t * h * d) return;
|
||||
|
||||
uint32_t i_bth = idx / (d / 2);
|
||||
uint32_t i_d = idx - (d / 2) * i_bth;
|
||||
uint32_t i_t = (i_bth / h) % t;
|
||||
uint32_t i1 = i_bth * d + i_d;
|
||||
uint32_t i2 = i1 + d / 2;
|
||||
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
|
||||
dst[i1] = src[i1] * c - src[i2] * s;
|
||||
dst[i2] = src[i1] * s + src[i2] * c;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void
|
||||
fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
@ -402,9 +461,42 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
|
||||
} \
|
||||
|
||||
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
|
||||
extern "C" __global__ void FN_NAME_I( \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *cos, \
|
||||
const TYPENAME *sin, \
|
||||
TYPENAME *dst, \
|
||||
const uint32_t bh, \
|
||||
const uint32_t td) { \
|
||||
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *cos, \
|
||||
const TYPENAME *sin, \
|
||||
TYPENAME *dst, \
|
||||
const uint32_t bh, \
|
||||
const uint32_t td, \
|
||||
const uint32_t d) { \
|
||||
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME_THD( \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *cos, \
|
||||
const TYPENAME *sin, \
|
||||
TYPENAME *dst, \
|
||||
const uint32_t b, \
|
||||
const uint32_t t, \
|
||||
const uint32_t h, \
|
||||
const uint32_t d) { \
|
||||
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
|
||||
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
|
||||
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
|
||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
||||
#endif
|
||||
@ -412,6 +504,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
SOFTMAX_OP(__half, float, softmax_f16)
|
||||
RMSNORM_OP(__half, rmsnorm_f16)
|
||||
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
|
||||
SUM_OP(__half, sum_f16)
|
||||
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
||||
#endif
|
||||
@ -423,6 +516,8 @@ SOFTMAX_OP(float, float, softmax_f32)
|
||||
SOFTMAX_OP(double, double, softmax_f64)
|
||||
RMSNORM_OP(float, rmsnorm_f32)
|
||||
RMSNORM_OP(double, rmsnorm_f64)
|
||||
ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32)
|
||||
ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)
|
||||
|
||||
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
|
||||
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
|
||||
|
@ -86,6 +86,11 @@ extern "C" __global__ void FN_NAME( \
|
||||
} \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
__device__ T sign_(T t) {
|
||||
return static_cast<T>(t > static_cast<T>(0)) - static_cast<T>(t < static_cast<T>(0));
|
||||
}
|
||||
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
|
||||
@ -110,6 +115,7 @@ UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
||||
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
||||
UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -135,6 +141,7 @@ UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
||||
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
||||
UNARY_OP(__half, usilu_f16, silu_fwd(x))
|
||||
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||
UNARY_OP(__half, usign_f16, sign_(x))
|
||||
#endif
|
||||
|
||||
UNARY_OP(uint8_t, ucopy_u8, x)
|
||||
@ -184,3 +191,5 @@ UNARY_OP(float, usilu_f32, silu_fwd(x))
|
||||
UNARY_OP(double, usilu_f64, silu_fwd(x))
|
||||
UNARY_OP1(float, upowf_f32, powg(x, param))
|
||||
UNARY_OP1(double, upowf_f64, powg(x, param))
|
||||
UNARY_OP(float, usign_f32, sign_(x))
|
||||
UNARY_OP(double, usign_f64, sign_(x))
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.4.2"
|
||||
version = "0.5.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
@ -60,21 +60,24 @@ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \
|
||||
BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \
|
||||
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
||||
|
||||
#define INT64_BINARY_OP(NAME, FN) \
|
||||
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
|
||||
|
||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
#define BINARY_OP_OUT(NAME, FN) \
|
||||
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
|
||||
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \
|
||||
BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
|
||||
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
||||
|
||||
#define INT64_BINARY_OP(NAME, FN) \
|
||||
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
|
||||
|
||||
#define INT64_BINARY_OP_OUT(NAME, FN) \
|
||||
BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided);
|
||||
|
||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
#define BFLOAT_BINARY_OP_OUT(NAME, FN) \
|
||||
BINARY(FN, bfloat, uint8_t, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
BINARY_OP(x + y, add)
|
||||
BINARY_OP(x - y, sub)
|
||||
BINARY_OP(x * y, mul)
|
||||
@ -112,4 +115,11 @@ BFLOAT_BINARY_OP(x * y, mul)
|
||||
BFLOAT_BINARY_OP(x / y, div)
|
||||
BFLOAT_BINARY_OP(MIN(x, y), min)
|
||||
BFLOAT_BINARY_OP(MAX(x, y), max)
|
||||
|
||||
BFLOAT_BINARY_OP_OUT(eq, x == y)
|
||||
BFLOAT_BINARY_OP_OUT(ne, x != y)
|
||||
BFLOAT_BINARY_OP_OUT(le, x <= y)
|
||||
BFLOAT_BINARY_OP_OUT(lt, x < y)
|
||||
BFLOAT_BINARY_OP_OUT(ge, x >= y)
|
||||
BFLOAT_BINARY_OP_OUT(gt, x > y)
|
||||
#endif
|
||||
|
@ -486,16 +486,24 @@ kernel void FN_NAME( \
|
||||
} \
|
||||
|
||||
IM2COL_OP(float, im2col_f32)
|
||||
IM2COL_OP(half, im2col_f16)
|
||||
IM2COL_OP(uint8_t, im2col_u8)
|
||||
IM2COL_OP(uint32_t, im2col_u32)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
IM2COL_OP(bfloat, im2col_bf16)
|
||||
#endif
|
||||
|
||||
IM2COL1D_OP(float, im2col1d_f32)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||
|
||||
UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
|
||||
UPSAMPLE_NEAREST2D_OP(half, upsample_nearest2d_f16)
|
||||
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
|
||||
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
UPSAMPLE_NEAREST2D_OP(bfloat, upsample_nearest2d_bf16)
|
||||
#endif
|
||||
|
||||
MAXPOOL2D_OP(float, max_pool2d_f32)
|
||||
MAXPOOL2D_OP(half, max_pool2d_f16)
|
||||
|
@ -187,6 +187,12 @@ kernel void NAME( \
|
||||
}
|
||||
|
||||
|
||||
INDEX_OP(is_i64_f32, int64_t, float)
|
||||
INDEX_OP(is_i64_f16, int64_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_OP(is_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
INDEX_OP(is_u32_f32, uint32_t, float)
|
||||
INDEX_OP(is_u32_f16, uint32_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
@ -201,6 +207,9 @@ INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
||||
|
||||
GATHER_OP(gather_u32_f32, uint, float)
|
||||
GATHER_OP(gather_u32_f16, uint, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
GATHER_OP(gather_u32_bf16, uint, bfloat)
|
||||
#endif
|
||||
|
||||
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
|
||||
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
|
||||
@ -242,4 +251,4 @@ INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
|
||||
INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
||||
#endif
|
||||
#endif
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -21,6 +21,52 @@ METAL_FUNC uint get_strided_index(
|
||||
|
||||
constant int THREADGROUP_SIZE = 2048;
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void argmin(
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const T *src,
|
||||
device uint *dst,
|
||||
uint id,
|
||||
uint tid,
|
||||
uint dst_id,
|
||||
uint block_dim,
|
||||
threadgroup T *shared_memory,
|
||||
threadgroup uint *shared_indices
|
||||
) {
|
||||
bool notset = true;
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||
size_t idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||
if (notset || src[strided_i] < shared_memory[tid]) {
|
||||
shared_memory[tid] = src[strided_i];
|
||||
/* Assume that the reduction takes place over the last dimension which is contiguous. */
|
||||
shared_indices[tid] = idx % dims[num_dims - 1];
|
||||
notset = false;
|
||||
}
|
||||
idx += block_dim;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
// reduction in shared memory
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) {
|
||||
shared_indices[tid] = shared_indices[tid + s];
|
||||
shared_memory[tid] = shared_memory[tid + s];
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
if (tid == 0) {
|
||||
dst[dst_id] = shared_indices[0];
|
||||
}
|
||||
}
|
||||
|
||||
#define ARGMIN(NAME, T, MAXVALUE) \
|
||||
kernel void NAME( \
|
||||
@ -35,53 +81,63 @@ kernel void NAME( \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MAXVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
bool notset = true; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || src[strided_i] < shared_memory[tid]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = MAXVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
argmin<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \
|
||||
} \
|
||||
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void argmax(
|
||||
constant size_t & num_dims,
|
||||
constant size_t * dims,
|
||||
constant size_t * strides,
|
||||
constant size_t & el_to_sum_per_block,
|
||||
device const T * src,
|
||||
device uint * dst,
|
||||
uint id,
|
||||
uint tid,
|
||||
uint dst_id,
|
||||
uint block_dim,
|
||||
threadgroup T * shared_memory,
|
||||
threadgroup uint * shared_indices
|
||||
) {
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||
size_t idx = start_idx + tid;
|
||||
bool notset = true;
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||
if (notset || shared_memory[tid] < src[strided_i]) {
|
||||
shared_memory[tid] = src[strided_i];
|
||||
shared_indices[tid] = idx % dims[num_dims - 1];
|
||||
notset = false;
|
||||
}
|
||||
idx += block_dim;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// reduction in shared memory
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) {
|
||||
shared_indices[tid] = shared_indices[tid + s];
|
||||
shared_memory[tid] = shared_memory[tid + s];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
// Thread 0 writes the result of the reduction
|
||||
if (tid == 0) {
|
||||
dst[dst_id] = shared_indices[0];
|
||||
}
|
||||
}
|
||||
|
||||
#define ARGMAX(NAME, T, MINVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
@ -95,223 +151,337 @@ kernel void NAME( \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MINVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
bool notset = true; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || shared_memory[tid] < src[strided_i]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
argmax<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void reduce(
|
||||
constant size_t & num_dims,
|
||||
constant size_t * dims,
|
||||
constant size_t * strides,
|
||||
constant size_t & el_to_sum_per_block,
|
||||
device const T * src,
|
||||
device T * dst,
|
||||
uint id,
|
||||
uint tid,
|
||||
uint dst_id,
|
||||
uint block_dim,
|
||||
threadgroup T * shared_memory,
|
||||
T (*fn)(T, T)
|
||||
) {
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||
size_t idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
// TODO: Fast version for the contiguous case.
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||
T x = shared_memory[tid];
|
||||
T y = src[strided_i];
|
||||
shared_memory[tid] = fn(x, y);
|
||||
idx += block_dim;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// reduction in shared memory
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
T x = shared_memory[tid];
|
||||
T y = shared_memory[tid + s];
|
||||
shared_memory[tid] = fn(x, y);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
dst[dst_id] = shared_memory[0];
|
||||
}
|
||||
}
|
||||
|
||||
#define REDUCE(FN, NAME, T, START) \
|
||||
METAL_FUNC T NAME##_##op(T x, T y) { return FN; } \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = START; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = src[strided_i]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = START; \
|
||||
reduce<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, NAME##_##op); \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void softmax(
|
||||
constant size_t & src_numel,
|
||||
constant size_t & el_to_sum_per_block,
|
||||
device const T * src,
|
||||
device T * dst,
|
||||
uint id,
|
||||
uint tid,
|
||||
uint dst_id,
|
||||
uint block_dim,
|
||||
threadgroup float * shared_memory
|
||||
) {
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = -INFINITY; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
\
|
||||
float tmp = -INFINITY; \
|
||||
while (idx < stop_idx) { \
|
||||
tmp = MAX(tmp, float(src[idx])); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
shared_memory[tid] = tmp; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
/* wait for shared_memory[0] to be filled */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float _max = shared_memory[0]; \
|
||||
\
|
||||
/* prevent tid=0 from overwriting _max before other threads have written it */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
shared_memory[tid] = 0; \
|
||||
\
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
const float val = exp(float(src[idx]) - _max); \
|
||||
dst[idx] = T(val); \
|
||||
shared_memory[tid] += val; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] += shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
const T inv_acc = T(1.0/shared_memory[0]); \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
dst[idx] *= inv_acc; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
float tmp = -INFINITY;
|
||||
while (idx < stop_idx) {
|
||||
tmp = MAX(tmp, float(src[idx]));
|
||||
idx += block_dim;
|
||||
}
|
||||
shared_memory[tid] = tmp;
|
||||
|
||||
#define RMSNORM(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
device const T *alpha, \
|
||||
constant float &eps, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = 0; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
\
|
||||
float tmp = 0; \
|
||||
while (idx < stop_idx) { \
|
||||
tmp = tmp + float(src[idx]) * float(src[idx]); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
shared_memory[tid] = tmp; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
/* wait for shared_memory[0] to be filled */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps); \
|
||||
float inv_norm = 1.0f / norm; \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
float val = float(src[idx]) * inv_norm; \
|
||||
if (alpha != nullptr) { \
|
||||
val *= float(alpha[idx - start_idx]); \
|
||||
} \
|
||||
dst[idx] = T(val); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]);\
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
/* wait for shared_memory[0] to be filled */
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float _max = shared_memory[0];
|
||||
|
||||
/* prevent tid=0 from overwriting _max before other threads have written it */
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
shared_memory[tid] = 0;
|
||||
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
const float val = exp(float(src[idx]) - _max);
|
||||
dst[idx] = T(val);
|
||||
shared_memory[tid] += val;
|
||||
idx += block_dim;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] += shared_memory[tid + s];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
const T inv_acc = T(1.0 / shared_memory[0]);
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
dst[idx] *= inv_acc;
|
||||
idx += block_dim;
|
||||
}
|
||||
}
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = -INFINITY; \
|
||||
softmax<T>(src_numel, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory); \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void rmsnorm(
|
||||
constant size_t & src_numel,
|
||||
constant size_t & el_to_sum_per_block,
|
||||
device const T * src,
|
||||
device T * dst,
|
||||
device const T * alpha,
|
||||
constant float & eps,
|
||||
uint id,
|
||||
uint tid,
|
||||
uint dst_id,
|
||||
uint block_dim,
|
||||
threadgroup float * shared_memory
|
||||
) {
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
float tmp = 0;
|
||||
while (idx < stop_idx) {
|
||||
tmp = tmp + float(src[idx]) * float(src[idx]);
|
||||
idx += block_dim;
|
||||
}
|
||||
shared_memory[tid] = tmp;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
/* wait for shared_memory[0] to be filled */
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps);
|
||||
float inv_norm = 1.0f / norm;
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
float val = float(src[idx]) * inv_norm;
|
||||
if (alpha != nullptr) {
|
||||
val *= float(alpha[idx - start_idx]);
|
||||
}
|
||||
dst[idx] = T(val);
|
||||
idx += block_dim;
|
||||
}
|
||||
}
|
||||
|
||||
#define RMSNORM(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
device const T *alpha, \
|
||||
constant float &eps, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = 0; \
|
||||
rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void ropei(
|
||||
constant size_t &bh,
|
||||
constant size_t &td,
|
||||
device const T *src,
|
||||
device const T *cos,
|
||||
device const T *sin,
|
||||
device T *dst,
|
||||
uint tid
|
||||
) {
|
||||
if (2 * tid >= bh * td) {
|
||||
return;
|
||||
}
|
||||
size_t rope_idx = tid % (td / 2);
|
||||
T c = cos[rope_idx];
|
||||
T s = sin[rope_idx];
|
||||
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s;
|
||||
dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void rope(
|
||||
constant size_t &bh,
|
||||
constant size_t &td,
|
||||
constant size_t &d,
|
||||
device const T *src,
|
||||
device const T *cos,
|
||||
device const T *sin,
|
||||
device T *dst,
|
||||
uint idx
|
||||
) {
|
||||
if (2 * idx >= bh * td) {
|
||||
return;
|
||||
}
|
||||
size_t i_bh = idx / (td / 2);
|
||||
size_t i_td = idx - (td / 2) * i_bh;
|
||||
size_t i_t = i_td / (d / 2);
|
||||
size_t i_d = i_td - (d / 2) * i_t;
|
||||
size_t i1 = i_bh * td + i_t * d + i_d;
|
||||
size_t i2 = i1 + d / 2;
|
||||
size_t i_cs = i_t * (d / 2) + i_d;
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
dst[i1] = src[i1] * c - src[i2] * s;
|
||||
dst[i2] = src[i1] * s + src[i2] * c;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void rope_thd(
|
||||
constant size_t &b,
|
||||
constant size_t &t,
|
||||
constant size_t &h,
|
||||
constant size_t &d,
|
||||
device const T *src,
|
||||
device const T *cos,
|
||||
device const T *sin,
|
||||
device T *dst,
|
||||
uint idx
|
||||
) {
|
||||
if (2 * idx >= b * t * h * d) {
|
||||
return;
|
||||
}
|
||||
const size_t i_bth = idx / (d / 2);
|
||||
const size_t i_d = idx - (d / 2) * i_bth;
|
||||
const size_t i_t = (i_bth / h) % t;
|
||||
const size_t i1 = i_bth * d + i_d;
|
||||
const size_t i2 = i1 + d / 2;
|
||||
const size_t i_cs = i_t * (d / 2) + i_d;
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
dst[i1] = src[i1] * c - src[i2] * s;
|
||||
dst[i2] = src[i1] * s + src[i2] * c;
|
||||
}
|
||||
|
||||
#define ROPE(FN_NAME, FN_NAME_I, FN_NAME_THD, TYPENAME) \
|
||||
kernel void FN_NAME_I( \
|
||||
constant size_t &bh, \
|
||||
constant size_t &td, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *cos, \
|
||||
device const TYPENAME *sin, \
|
||||
device TYPENAME *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
ropei<TYPENAME>(bh, td, src, cos, sin, dst, tid); \
|
||||
}\
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &bh, \
|
||||
constant size_t &td, \
|
||||
constant size_t &d, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *cos, \
|
||||
device const TYPENAME *sin, \
|
||||
device TYPENAME *dst, \
|
||||
uint idx [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
rope<TYPENAME>(bh, td, d, src, cos, sin, dst, idx); \
|
||||
}\
|
||||
kernel void FN_NAME_THD( \
|
||||
constant size_t &b, \
|
||||
constant size_t &t, \
|
||||
constant size_t &h, \
|
||||
constant size_t &d, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *cos, \
|
||||
device const TYPENAME *sin, \
|
||||
device TYPENAME *dst, \
|
||||
uint idx [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
rope_thd<TYPENAME>(b, t, h, d, src, cos, sin, dst, idx); \
|
||||
}\
|
||||
|
||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||
@ -341,6 +511,8 @@ SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
RMSNORM(rmsnorm_f32, float)
|
||||
RMSNORM(rmsnorm_f16, half)
|
||||
ROPE(rope_f32, rope_i_f32, rope_thd_f32, float)
|
||||
ROPE(rope_f16, rope_i_f16, rope_thd_f16, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
|
||||
@ -352,11 +524,16 @@ ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||
REDUCE(x + y, fast_sum_bf16_strided, half, 0)
|
||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||
REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
||||
REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF)
|
||||
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
||||
REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF)
|
||||
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
RMSNORM(rmsnorm_bf16, bfloat)
|
||||
ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat)
|
||||
#endif
|
||||
|
@ -12,7 +12,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let ptr = data.as_ptr() as *const c_void;
|
||||
let size = (data.len() * std::mem::size_of::<T>()) as u64;
|
||||
let size = std::mem::size_of_val(data) as u64;
|
||||
device.new_buffer_with_data(ptr, size, options)
|
||||
}
|
||||
|
||||
@ -41,6 +41,10 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let input = BufferOffset {
|
||||
buffer: &input,
|
||||
offset_in_bytes: 0,
|
||||
};
|
||||
let output = new_buffer(&device, v);
|
||||
call_unary_contiguous(
|
||||
&device,
|
||||
@ -48,7 +52,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
input,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -72,8 +76,8 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
|
||||
&kernels,
|
||||
name,
|
||||
x.len(),
|
||||
&left,
|
||||
&right,
|
||||
BufferOffset::zero_offset(&left),
|
||||
BufferOffset::zero_offset(&right),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -93,7 +97,15 @@ fn run_strided<T: Clone>(
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
let input = BufferOffset {
|
||||
buffer: &input,
|
||||
offset_in_bytes: offset,
|
||||
};
|
||||
let output_b = new_buffer(&device, v);
|
||||
let output = BufferOffset {
|
||||
buffer: &output_b,
|
||||
offset_in_bytes: 0,
|
||||
};
|
||||
let kernels = Kernels::new();
|
||||
call_unary_strided(
|
||||
&device,
|
||||
@ -101,16 +113,14 @@ fn run_strided<T: Clone>(
|
||||
&kernels,
|
||||
kernel,
|
||||
shape,
|
||||
&input,
|
||||
input,
|
||||
strides,
|
||||
offset,
|
||||
&output,
|
||||
0,
|
||||
output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
read_to_vec(&output, v.len())
|
||||
read_to_vec(&output_b, v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -308,8 +318,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
0,
|
||||
BufferOffset::zero_offset(&input),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -521,7 +530,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
&kernels,
|
||||
"affine_f32",
|
||||
size,
|
||||
&input,
|
||||
BufferOffset::zero_offset(&input),
|
||||
&output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
@ -554,9 +563,8 @@ fn run_affine_strided<T: Clone>(
|
||||
&kernels,
|
||||
"affine_f32_strided",
|
||||
shape,
|
||||
&input,
|
||||
BufferOffset::zero_offset(&input),
|
||||
strides,
|
||||
0,
|
||||
&output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
@ -633,7 +641,7 @@ fn index_select_strided() {
|
||||
fn index_select_f16() {
|
||||
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||
.into_iter()
|
||||
.map(|x| f16::from_f32(x))
|
||||
.map(f16::from_f32)
|
||||
.collect();
|
||||
let shape = [5, 2];
|
||||
let stride = [2, 1];
|
||||
@ -700,8 +708,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let embeddings_buffer = new_buffer(&device, &embeddings);
|
||||
let ids_buffer = new_buffer(&device, &ids);
|
||||
let embeddings_buffer = new_buffer(&device, embeddings);
|
||||
let ids_buffer = new_buffer(&device, ids);
|
||||
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
@ -711,7 +719,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
let kernels = Kernels::new();
|
||||
call_index_select(
|
||||
&device,
|
||||
&command_buffer,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
shape,
|
||||
@ -720,10 +728,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
true,
|
||||
shape,
|
||||
stride,
|
||||
&embeddings_buffer,
|
||||
0,
|
||||
&ids_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&embeddings_buffer),
|
||||
BufferOffset::zero_offset(&ids_buffer),
|
||||
&dst_buffer,
|
||||
)
|
||||
.unwrap();
|
||||
@ -746,8 +752,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let embeddings_buffer = new_buffer(&device, &embeddings);
|
||||
let ids_buffer = new_buffer(&device, &ids);
|
||||
let embeddings_buffer = new_buffer(&device, embeddings);
|
||||
let ids_buffer = new_buffer(&device, ids);
|
||||
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
@ -757,7 +763,7 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
let kernels = Kernels::new();
|
||||
call_index_select(
|
||||
&device,
|
||||
&command_buffer,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
shape,
|
||||
@ -766,10 +772,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
false,
|
||||
shape,
|
||||
stride,
|
||||
&embeddings_buffer,
|
||||
0,
|
||||
&ids_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&embeddings_buffer),
|
||||
BufferOffset::zero_offset(&ids_buffer),
|
||||
&dst_buffer,
|
||||
)
|
||||
.unwrap();
|
||||
@ -811,8 +815,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
||||
&dims,
|
||||
&strides,
|
||||
out_length,
|
||||
&input,
|
||||
0,
|
||||
BufferOffset::zero_offset(&input),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -931,6 +934,7 @@ fn softmax() {
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_where_cond<I: Clone, T: Clone>(
|
||||
shape: &[usize],
|
||||
cond: &[I],
|
||||
@ -965,18 +969,30 @@ fn run_where_cond<I: Clone, T: Clone>(
|
||||
);
|
||||
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
let cond = BufferOffset {
|
||||
buffer: &cond,
|
||||
offset_in_bytes: cond_offset,
|
||||
};
|
||||
let left = BufferOffset {
|
||||
buffer: &left,
|
||||
offset_in_bytes: left_offset,
|
||||
};
|
||||
let right = BufferOffset {
|
||||
buffer: &right,
|
||||
offset_in_bytes: cond_offset,
|
||||
};
|
||||
call_where_cond_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
shape,
|
||||
&cond,
|
||||
(&cond_stride, cond_offset),
|
||||
&left,
|
||||
(&left_stride, left_offset),
|
||||
&right,
|
||||
(&cond_stride, cond_offset),
|
||||
cond,
|
||||
&cond_stride,
|
||||
left,
|
||||
&left_stride,
|
||||
right,
|
||||
&cond_stride,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -1148,7 +1164,7 @@ fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b:
|
||||
#[test]
|
||||
fn random() {
|
||||
fn calc_mean(data: &[f32]) -> f32 {
|
||||
let sum = data.iter().sum::<f32>() as f32;
|
||||
let sum = data.iter().sum::<f32>();
|
||||
let count = data.len();
|
||||
assert!(count > 0);
|
||||
sum / count as f32
|
||||
@ -1162,7 +1178,7 @@ fn random() {
|
||||
let variance = data
|
||||
.iter()
|
||||
.map(|value| {
|
||||
let diff = mean - (*value as f32);
|
||||
let diff = mean - *value;
|
||||
diff * diff
|
||||
})
|
||||
.sum::<f32>()
|
||||
@ -1241,10 +1257,8 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
shape,
|
||||
shape,
|
||||
dim,
|
||||
&input_buffer,
|
||||
0,
|
||||
&ids_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&input_buffer),
|
||||
BufferOffset::zero_offset(&ids_buffer),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -1346,10 +1360,8 @@ fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
shape,
|
||||
shape,
|
||||
dim,
|
||||
&input_buffer,
|
||||
0,
|
||||
&indices_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&input_buffer),
|
||||
BufferOffset::zero_offset(&indices_buffer),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -1787,6 +1799,7 @@ fn avg_pool2d_u32() {
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_conv_transpose1d<T: Clone>(
|
||||
input: &[T],
|
||||
input_shape: &[usize],
|
||||
|
@ -104,21 +104,17 @@ UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
#define COPY2D(FN_NAME, TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &d1, \
|
||||
constant size_t &d2, \
|
||||
constant size_t &src_s, \
|
||||
constant size_t &dst_s, \
|
||||
constant int64_t &d1, \
|
||||
constant int64_t &d2, \
|
||||
constant int64_t &src_s, \
|
||||
constant int64_t &dst_s, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
uint2 idx [[thread_position_in_grid]] \
|
||||
) { \
|
||||
if (tid >= d1 * d2) { \
|
||||
return; \
|
||||
} \
|
||||
size_t idx1 = tid / d2; \
|
||||
size_t idx2 = tid - idx1 * d2; \
|
||||
size_t src_idx = idx1 * src_s + idx2; \
|
||||
size_t dst_idx = idx1 * dst_s + idx2; \
|
||||
if (idx.x >= d1 || idx.y >= d2) return; \
|
||||
int64_t src_idx = idx.x * src_s + idx.y; \
|
||||
int64_t dst_idx = idx.x * dst_s + idx.y; \
|
||||
output[dst_idx] = input[src_idx]; \
|
||||
}
|
||||
|
||||
@ -145,6 +141,7 @@ UNARY_OP(erf)
|
||||
UNARY_OP(tanh)
|
||||
UNARY_OP(recip)
|
||||
UNARY_OP(relu)
|
||||
UNARY_OP(sign)
|
||||
UNARY(id, float, copy_f32, copy_f32_strided)
|
||||
UNARY(id, half, copy_f16, copy_f16_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
@ -174,8 +171,9 @@ BFLOAT_UNARY_OP(erf)
|
||||
BFLOAT_UNARY_OP(tanh)
|
||||
BFLOAT_UNARY_OP(recip)
|
||||
BFLOAT_UNARY_OP(relu)
|
||||
BFLOAT_UNARY_OP(sign)
|
||||
|
||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
|
||||
COPY2D(copy2d_bf64, bfloat)
|
||||
COPY2D(copy2d_bf16, bfloat)
|
||||
#endif
|
||||
|
162
candle-metal-kernels/src/utils.rs
Normal file
162
candle-metal-kernels/src/utils.rs
Normal file
@ -0,0 +1,162 @@
|
||||
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, MTLSize};
|
||||
use std::ffi::c_void;
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||
/// actual total buffer length).
|
||||
/// Then kernels can just do their op on their single point in the buffer.
|
||||
pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
||||
let size = length as u64;
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
|
||||
let count = (size + width - 1) / width;
|
||||
let thread_group_count = MTLSize {
|
||||
width: count,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
(thread_group_count, thread_group_size)
|
||||
}
|
||||
|
||||
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
|
||||
pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
|
||||
let mut pows0 = 0u64;
|
||||
let mut pows1 = 0u64;
|
||||
let mut pows2 = 0u64;
|
||||
let mut sum = 0u64;
|
||||
loop {
|
||||
let presum = sum;
|
||||
// Check all the pows
|
||||
if dim0 >= (1 << (pows0 + 1)) {
|
||||
pows0 += 1;
|
||||
sum += 1;
|
||||
}
|
||||
if sum == 10 {
|
||||
break;
|
||||
}
|
||||
if dim1 >= (1 << (pows1 + 1)) {
|
||||
pows1 += 1;
|
||||
sum += 1;
|
||||
}
|
||||
if sum == 10 {
|
||||
break;
|
||||
}
|
||||
if dim2 >= (1 << (pows2 + 1)) {
|
||||
pows2 += 1;
|
||||
sum += 1;
|
||||
}
|
||||
if sum == presum || sum == 10 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
MTLSize {
|
||||
width: 1 << pows0,
|
||||
height: 1 << pows1,
|
||||
depth: 1 << pows2,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_param<P: EncoderParam>(
|
||||
encoder: &ComputeCommandEncoderRef,
|
||||
position: u64,
|
||||
data: P,
|
||||
) {
|
||||
<P as EncoderParam>::set_param(encoder, position, data)
|
||||
}
|
||||
|
||||
/// Helper functions to create the various objects on the compute command encoder
|
||||
/// on a single line.
|
||||
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
||||
pub(crate) trait EncoderParam {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
||||
}
|
||||
macro_rules! primitive {
|
||||
($type:ty) => {
|
||||
impl EncoderParam for $type {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_bytes(
|
||||
position,
|
||||
core::mem::size_of::<$type>() as u64,
|
||||
&data as *const $type as *const c_void,
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
primitive!(bool);
|
||||
primitive!(usize);
|
||||
primitive!(i32);
|
||||
primitive!(i64);
|
||||
primitive!(u32);
|
||||
primitive!(u64);
|
||||
primitive!(f32);
|
||||
|
||||
pub struct BufferOffset<'a> {
|
||||
pub buffer: &'a Buffer,
|
||||
pub offset_in_bytes: usize,
|
||||
}
|
||||
|
||||
impl<'a> BufferOffset<'a> {
|
||||
pub fn zero_offset(buffer: &'a Buffer) -> Self {
|
||||
Self {
|
||||
buffer,
|
||||
offset_in_bytes: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> EncoderParam for &[T] {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_bytes(
|
||||
position,
|
||||
core::mem::size_of_val(data) as u64,
|
||||
data.as_ptr() as *const c_void,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderParam for &Buffer {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data), 0);
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderParam for (&Buffer, usize) {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> EncoderParam for &BufferOffset<'a> {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64);
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderParam for &mut Buffer {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data), 0);
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderParam for (&mut Buffer, usize) {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! set_params {
|
||||
($encoder:ident, ($($param:expr),+)) => (
|
||||
let mut _index = 0;
|
||||
$(
|
||||
$crate::utils::set_param($encoder, _index, $param);
|
||||
_index += 1;
|
||||
)*
|
||||
);
|
||||
}
|
@ -25,6 +25,8 @@ candle-metal-kernels = { workspace = true, optional = true }
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
criterion = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
@ -32,3 +34,7 @@ accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||
cuda = ["candle/cuda"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||
metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]
|
||||
|
||||
[[bench]]
|
||||
name = "bench_main"
|
||||
harness = false
|
4
candle-nn/benches/bench_main.rs
Normal file
4
candle-nn/benches/bench_main.rs
Normal file
@ -0,0 +1,4 @@
|
||||
mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
criterion_main!(benchmarks::layer_norm::benches, benchmarks::conv::benches);
|
54
candle-nn/benches/benchmarks/conv.rs
Normal file
54
candle-nn/benches/benchmarks/conv.rs
Normal file
@ -0,0 +1,54 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle::{DType, Device, Module, Tensor};
|
||||
use candle_nn::{Conv2d, Conv2dConfig};
|
||||
use criterion::{black_box, criterion_group, Criterion};
|
||||
use std::time::Instant;
|
||||
|
||||
const B: usize = 1;
|
||||
const C: usize = 1;
|
||||
const M: usize = 128;
|
||||
const K: usize = 128;
|
||||
const K_SIZE: usize = 3;
|
||||
|
||||
fn run(input: Tensor, weight: Tensor, bias: Tensor, config: Conv2dConfig) {
|
||||
Conv2d::new(weight, Some(bias), config)
|
||||
.forward(&input)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn run_conv2d_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let weight = Tensor::ones((1, 1, K_SIZE, K_SIZE), dtype, device)
|
||||
.unwrap()
|
||||
.to_dtype(dtype)
|
||||
.unwrap();
|
||||
let bias = Tensor::zeros(K, dtype, device).unwrap();
|
||||
let input = Tensor::ones((B, C, M, K), dtype, device).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(
|
||||
black_box(input.clone()),
|
||||
black_box(weight.clone()),
|
||||
black_box(bias.clone()),
|
||||
Default::default(),
|
||||
);
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let device = BenchDeviceHandler::new().unwrap();
|
||||
for d in device.devices {
|
||||
run_conv2d_benchmark(c, &d, DType::F32, "conv2d_f32");
|
||||
run_conv2d_benchmark(c, &d, DType::F16, "conv2d_f16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
48
candle-nn/benches/benchmarks/layer_norm.rs
Normal file
48
candle-nn/benches/benchmarks/layer_norm.rs
Normal file
@ -0,0 +1,48 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle::{DType, Device, Module, Tensor};
|
||||
use candle_nn::LayerNorm;
|
||||
use criterion::{black_box, criterion_group, Criterion};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(input: &Tensor, weight: &Tensor, bias: &Tensor) {
|
||||
let _ = LayerNorm::new(weight.clone(), bias.clone(), 1e-5).forward(&input);
|
||||
}
|
||||
|
||||
const B: usize = 1;
|
||||
const M: usize = 1024;
|
||||
const K: usize = 1024;
|
||||
|
||||
fn run_layer_norm_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let elements = B * M * K;
|
||||
|
||||
let weight = Tensor::arange(0.0, elements as f32, device)
|
||||
.unwrap()
|
||||
.to_dtype(dtype)
|
||||
.unwrap();
|
||||
let bias = weight.ones_like().unwrap();
|
||||
let input = weight.ones_like().unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&input), black_box(&weight), black_box(&bias));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let device = BenchDeviceHandler::new().unwrap();
|
||||
for d in device.devices {
|
||||
run_layer_norm_benchmark(c, &d, DType::F32, "layer_norm_f32");
|
||||
run_layer_norm_benchmark(c, &d, DType::BF16, "layer_norm_bf16");
|
||||
run_layer_norm_benchmark(c, &d, DType::F16, "layer_norm_f16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
64
candle-nn/benches/benchmarks/mod.rs
Normal file
64
candle-nn/benches/benchmarks/mod.rs
Normal file
@ -0,0 +1,64 @@
|
||||
pub(crate) mod conv;
|
||||
pub(crate) mod layer_norm;
|
||||
|
||||
use candle::{Device, Result};
|
||||
|
||||
pub(crate) trait BenchDevice {
|
||||
fn sync(&self) -> Result<()>;
|
||||
|
||||
fn bench_name<S: Into<String>>(&self, name: S) -> String;
|
||||
}
|
||||
|
||||
impl BenchDevice for Device {
|
||||
fn sync(&self) -> Result<()> {
|
||||
match self {
|
||||
Device::Cpu => Ok(()),
|
||||
Device::Cuda(device) => {
|
||||
#[cfg(feature = "cuda")]
|
||||
return Ok(device.synchronize()?);
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
#[cfg(feature = "metal")]
|
||||
return Ok(device.wait_until_completed()?);
|
||||
#[cfg(not(feature = "metal"))]
|
||||
panic!("Metal device without metal feature enabled: {:?}", device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn bench_name<S: Into<String>>(&self, name: S) -> String {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let cpu_type = if cfg!(feature = "accelerate") {
|
||||
"accelerate"
|
||||
} else if cfg!(feature = "mkl") {
|
||||
"mkl"
|
||||
} else {
|
||||
"cpu"
|
||||
};
|
||||
format!("{}_{}", cpu_type, name.into())
|
||||
}
|
||||
Device::Cuda(_) => format!("cuda_{}", name.into()),
|
||||
Device::Metal(_) => format!("metal_{}", name.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct BenchDeviceHandler {
|
||||
devices: Vec<Device>,
|
||||
}
|
||||
|
||||
impl BenchDeviceHandler {
|
||||
pub fn new() -> Result<Self> {
|
||||
let mut devices = Vec::new();
|
||||
if cfg!(feature = "metal") {
|
||||
devices.push(Device::new_metal(0)?);
|
||||
} else if cfg!(feature = "cuda") {
|
||||
devices.push(Device::new_cuda(0)?);
|
||||
}
|
||||
devices.push(Device::Cpu);
|
||||
Ok(Self { devices })
|
||||
}
|
||||
}
|
@ -12,6 +12,7 @@ pub mod loss;
|
||||
pub mod ops;
|
||||
pub mod optim;
|
||||
pub mod rnn;
|
||||
pub mod rotary_emb;
|
||||
pub mod sequential;
|
||||
pub mod var_builder;
|
||||
pub mod var_map;
|
||||
|
@ -31,7 +31,7 @@ pub trait RNN {
|
||||
let (_b_size, seq_len, _features) = input.dims3()?;
|
||||
let mut output = Vec::with_capacity(seq_len);
|
||||
for seq_index in 0..seq_len {
|
||||
let input = input.i((.., seq_index, ..))?;
|
||||
let input = input.i((.., seq_index, ..))?.contiguous()?;
|
||||
let state = if seq_index == 0 {
|
||||
self.step(&input, init_state)?
|
||||
} else {
|
||||
|
730
candle-nn/src/rotary_emb.rs
Normal file
730
candle-nn/src/rotary_emb.rs
Normal file
@ -0,0 +1,730 @@
|
||||
use candle::{CpuStorage, Layout, Result, Shape, Tensor, D};
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// Interleaved variant of rotary embeddings.
|
||||
/// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
|
||||
/// The resulting y0 and y1 are also interleaved with:
|
||||
/// y0 = x0*cos - x1*sin
|
||||
/// y1 = x0*sin + x1*cos
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbI;
|
||||
|
||||
impl candle::CustomOp3 for RotaryEmbI {
|
||||
fn name(&self) -> &'static str {
|
||||
"rotary-emb-int"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
s3: &CpuStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
fn inner<T: candle::WithDType + num_traits::Float>(
|
||||
src: &[T],
|
||||
l_src: &Layout,
|
||||
cos: &[T],
|
||||
l_cos: &Layout,
|
||||
sin: &[T],
|
||||
l_sin: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
let src = match l_src.contiguous_offsets() {
|
||||
None => candle::bail!("input src has to be contiguous"),
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
let cos = match l_cos.contiguous_offsets() {
|
||||
None => candle::bail!("input cos has to be contiguous"),
|
||||
Some((o1, o2)) => &cos[o1..o2],
|
||||
};
|
||||
let sin = match l_sin.contiguous_offsets() {
|
||||
None => candle::bail!("input sin has to be contiguous"),
|
||||
Some((o1, o2)) => &sin[o1..o2],
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let el_count = b * h * t * d;
|
||||
let mut dst = vec![T::zero(); el_count];
|
||||
src.par_chunks(t * d)
|
||||
.zip(dst.par_chunks_mut(t * d))
|
||||
.for_each(|(src, dst)| {
|
||||
for i_over_2 in 0..t * d / 2 {
|
||||
let i = 2 * i_over_2;
|
||||
dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2];
|
||||
dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2];
|
||||
}
|
||||
});
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, (b, h, t, d).into()))
|
||||
}
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use CpuStorage::{BF16, F16, F32, F64};
|
||||
match (s1, s2, s3) {
|
||||
(BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
(F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
(F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
(F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
_ => candle::bail!(
|
||||
"unsupported dtype for rope {:?} {:?} {:?}",
|
||||
s1.dtype(),
|
||||
s2.dtype(),
|
||||
s3.dtype()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
s1: &candle::CudaStorage,
|
||||
l1: &Layout,
|
||||
s2: &candle::CudaStorage,
|
||||
l2: &Layout,
|
||||
s3: &candle::CudaStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
||||
fn inner<T: DeviceRepr + WithDType>(
|
||||
src: &CudaSlice<T>,
|
||||
l_src: &Layout,
|
||||
cos: &CudaSlice<T>,
|
||||
l_cos: &Layout,
|
||||
sin: &CudaSlice<T>,
|
||||
l_sin: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let src = match l_src.contiguous_offsets() {
|
||||
None => candle::bail!("src input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
};
|
||||
let cos = match l_cos.contiguous_offsets() {
|
||||
None => candle::bail!("cos input has to be contiguous"),
|
||||
Some((o1, o2)) => cos.slice(o1..o2),
|
||||
};
|
||||
let sin = match l_sin.contiguous_offsets() {
|
||||
None => candle::bail!("sin input has to be contiguous"),
|
||||
Some((o1, o2)) => sin.slice(o1..o2),
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
|
||||
let dev = s1.device();
|
||||
let slice = match (&s1.slice, &s2.slice, &s3.slice) {
|
||||
(BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
(F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
(F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
(F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
_ => candle::bail!(
|
||||
"unsupported dtype for rope {:?} {:?} {:?}",
|
||||
s1.dtype(),
|
||||
s2.dtype(),
|
||||
s3.dtype()
|
||||
),
|
||||
};
|
||||
let dst = candle::cuda_backend::CudaStorage {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Ok((dst, l1.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
src: &candle::MetalStorage,
|
||||
l_src: &Layout,
|
||||
cos: &candle::MetalStorage,
|
||||
l_cos: &Layout,
|
||||
sin: &candle::MetalStorage,
|
||||
l_sin: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
let device = src.device();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let kernels = device.kernels();
|
||||
if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
|
||||
candle::bail!(
|
||||
"dtype mismatch in rope-i {:?} {:?} {:?}",
|
||||
src.dtype(),
|
||||
cos.dtype(),
|
||||
sin.dtype()
|
||||
)
|
||||
}
|
||||
let name = match src.dtype() {
|
||||
candle::DType::F32 => "rope_i_f32",
|
||||
candle::DType::F16 => "rope_i_f16",
|
||||
candle::DType::BF16 => "rope_i_bf16",
|
||||
dtype => candle::bail!("rope-i is not implemented for {dtype:?}"),
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
|
||||
candle_metal_kernels::call_rope_i(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
b * h,
|
||||
t * d,
|
||||
src.buffer(),
|
||||
l_src.start_offset() * src.dtype().size_in_bytes(),
|
||||
cos.buffer(),
|
||||
l_cos.start_offset() * cos.dtype().size_in_bytes(),
|
||||
sin.buffer(),
|
||||
l_sin.start_offset() * sin.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
|
||||
Ok((out, l_src.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
||||
let (sin_seq_len, sin_n_embd) = cos.dims2()?;
|
||||
if cos_n_embd * 2 != n_embd
|
||||
|| sin_n_embd * 2 != n_embd
|
||||
|| seq_len > cos_seq_len
|
||||
|| seq_len > sin_seq_len
|
||||
{
|
||||
candle::bail!(
|
||||
"inconsistent last dim size in rope {:?} {:?} {:?}",
|
||||
xs.shape(),
|
||||
cos.shape(),
|
||||
sin.shape()
|
||||
)
|
||||
}
|
||||
if !xs.is_contiguous() {
|
||||
candle::bail!("xs has to be contiguous in rope")
|
||||
}
|
||||
if !cos.is_contiguous() {
|
||||
candle::bail!("cos has to be contiguous in rope")
|
||||
}
|
||||
if !sin.is_contiguous() {
|
||||
candle::bail!("sin has to be contiguous in rope")
|
||||
}
|
||||
xs.apply_op3_no_bwd(cos, sin, &RotaryEmbI)
|
||||
}
|
||||
|
||||
pub fn rope_i_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
|
||||
let cos = cos
|
||||
.narrow(0, 0, seq_len)?
|
||||
.reshape((seq_len, n_embd / 2, 1))?;
|
||||
let sin = sin
|
||||
.narrow(0, 0, seq_len)?
|
||||
.reshape((seq_len, n_embd / 2, 1))?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
|
||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
|
||||
let rope = rope.flatten_from(D::Minus2)?;
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
/// Contiguous variant of rope embeddings.
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmb;
|
||||
|
||||
impl candle::CustomOp3 for RotaryEmb {
|
||||
fn name(&self) -> &'static str {
|
||||
"rotary-emb"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
s3: &CpuStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
fn inner<T: candle::WithDType + num_traits::Float>(
|
||||
src: &[T],
|
||||
l_src: &Layout,
|
||||
cos: &[T],
|
||||
l_cos: &Layout,
|
||||
sin: &[T],
|
||||
l_sin: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
let src = match l_src.contiguous_offsets() {
|
||||
None => candle::bail!("input src has to be contiguous"),
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
let cos = match l_cos.contiguous_offsets() {
|
||||
None => candle::bail!("input cos has to be contiguous"),
|
||||
Some((o1, o2)) => &cos[o1..o2],
|
||||
};
|
||||
let sin = match l_sin.contiguous_offsets() {
|
||||
None => candle::bail!("input sin has to be contiguous"),
|
||||
Some((o1, o2)) => &sin[o1..o2],
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let el_count = b * h * t * d;
|
||||
let mut dst = vec![T::zero(); el_count];
|
||||
src.par_chunks(t * d)
|
||||
.zip(dst.par_chunks_mut(t * d))
|
||||
.for_each(|(src, dst)| {
|
||||
for i_t in 0..t {
|
||||
for i_d in 0..d / 2 {
|
||||
let i1 = i_t * d + i_d;
|
||||
let i2 = i1 + d / 2;
|
||||
let i_cs = i_t * (d / 2) + i_d;
|
||||
dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
|
||||
dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
|
||||
}
|
||||
}
|
||||
});
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, (b, h, t, d).into()))
|
||||
}
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use CpuStorage::{BF16, F16, F32, F64};
|
||||
match (s1, s2, s3) {
|
||||
(BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
(F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
(F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
(F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
_ => candle::bail!(
|
||||
"unsupported dtype for rope {:?} {:?} {:?}",
|
||||
s1.dtype(),
|
||||
s2.dtype(),
|
||||
s3.dtype()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
s1: &candle::CudaStorage,
|
||||
l1: &Layout,
|
||||
s2: &candle::CudaStorage,
|
||||
l2: &Layout,
|
||||
s3: &candle::CudaStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
||||
fn inner<T: DeviceRepr + WithDType>(
|
||||
src: &CudaSlice<T>,
|
||||
l_src: &Layout,
|
||||
cos: &CudaSlice<T>,
|
||||
l_cos: &Layout,
|
||||
sin: &CudaSlice<T>,
|
||||
l_sin: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let src = match l_src.contiguous_offsets() {
|
||||
None => candle::bail!("src input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
};
|
||||
let cos = match l_cos.contiguous_offsets() {
|
||||
None => candle::bail!("cos input has to be contiguous"),
|
||||
Some((o1, o2)) => cos.slice(o1..o2),
|
||||
};
|
||||
let sin = match l_sin.contiguous_offsets() {
|
||||
None => candle::bail!("sin input has to be contiguous"),
|
||||
Some((o1, o2)) => sin.slice(o1..o2),
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
&src,
|
||||
&cos,
|
||||
&sin,
|
||||
&dst,
|
||||
(b * h) as u32,
|
||||
(t * d) as u32,
|
||||
d as u32,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
|
||||
let dev = s1.device();
|
||||
let slice = match (&s1.slice, &s2.slice, &s3.slice) {
|
||||
(BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
(F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
(F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
(F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
_ => candle::bail!(
|
||||
"unsupported dtype for rope {:?} {:?} {:?}",
|
||||
s1.dtype(),
|
||||
s2.dtype(),
|
||||
s3.dtype()
|
||||
),
|
||||
};
|
||||
let dst = candle::cuda_backend::CudaStorage {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Ok((dst, l1.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
src: &candle::MetalStorage,
|
||||
l_src: &Layout,
|
||||
cos: &candle::MetalStorage,
|
||||
l_cos: &Layout,
|
||||
sin: &candle::MetalStorage,
|
||||
l_sin: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
let device = src.device();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let kernels = device.kernels();
|
||||
if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
|
||||
candle::bail!(
|
||||
"dtype mismatch in rope {:?} {:?} {:?}",
|
||||
src.dtype(),
|
||||
cos.dtype(),
|
||||
sin.dtype()
|
||||
)
|
||||
}
|
||||
let name = match src.dtype() {
|
||||
candle::DType::F32 => "rope_f32",
|
||||
candle::DType::F16 => "rope_f16",
|
||||
candle::DType::BF16 => "rope_bf16",
|
||||
dtype => candle::bail!("rope is not implemented for {dtype:?}"),
|
||||
};
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
|
||||
candle_metal_kernels::call_rope(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
b * h,
|
||||
t * d,
|
||||
d,
|
||||
src.buffer(),
|
||||
l_src.start_offset() * src.dtype().size_in_bytes(),
|
||||
cos.buffer(),
|
||||
l_cos.start_offset() * cos.dtype().size_in_bytes(),
|
||||
sin.buffer(),
|
||||
l_sin.start_offset() * sin.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
|
||||
Ok((out, l_src.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
||||
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
|
||||
if cos_n_embd * 2 != n_embd
|
||||
|| sin_n_embd * 2 != n_embd
|
||||
|| seq_len > cos_seq_len
|
||||
|| seq_len > sin_seq_len
|
||||
{
|
||||
candle::bail!(
|
||||
"inconsistent last dim size in rope {:?} {:?} {:?}",
|
||||
xs.shape(),
|
||||
cos.shape(),
|
||||
sin.shape()
|
||||
)
|
||||
}
|
||||
if !xs.is_contiguous() {
|
||||
candle::bail!("xs has to be contiguous in rope")
|
||||
}
|
||||
if !cos.is_contiguous() {
|
||||
candle::bail!("cos has to be contiguous in rope")
|
||||
}
|
||||
if !sin.is_contiguous() {
|
||||
candle::bail!("sin has to be contiguous in rope")
|
||||
}
|
||||
xs.apply_op3_no_bwd(cos, sin, &RotaryEmb)
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
pub fn rope_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = x.dims4()?;
|
||||
let cos = Tensor::cat(&[cos, cos], D::Minus1)?;
|
||||
let sin = Tensor::cat(&[sin, sin], D::Minus1)?;
|
||||
let cos = cos.narrow(0, 0, seq_len)?;
|
||||
let sin = sin.narrow(0, 0, seq_len)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?;
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?;
|
||||
x.broadcast_mul(&cos)? + rotate_half(x)?.broadcast_mul(&sin)?
|
||||
}
|
||||
|
||||
/// T (seqlen)/H (num-heads)/D (head-dim) contiguous variant of rope embeddings.
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbThd;
|
||||
|
||||
impl candle::CustomOp3 for RotaryEmbThd {
|
||||
fn name(&self) -> &'static str {
|
||||
"rotary-emb"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
s3: &CpuStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
fn inner<T: candle::WithDType + num_traits::Float>(
|
||||
src: &[T],
|
||||
l_src: &Layout,
|
||||
cos: &[T],
|
||||
l_cos: &Layout,
|
||||
sin: &[T],
|
||||
l_sin: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
let src = match l_src.contiguous_offsets() {
|
||||
None => candle::bail!("input src has to be contiguous"),
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
let cos = match l_cos.contiguous_offsets() {
|
||||
None => candle::bail!("input cos has to be contiguous"),
|
||||
Some((o1, o2)) => &cos[o1..o2],
|
||||
};
|
||||
let sin = match l_sin.contiguous_offsets() {
|
||||
None => candle::bail!("input sin has to be contiguous"),
|
||||
Some((o1, o2)) => &sin[o1..o2],
|
||||
};
|
||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||
let el_count = b * h * t * d;
|
||||
let mut dst = vec![T::zero(); el_count];
|
||||
src.par_chunks(t * h * d)
|
||||
.zip(dst.par_chunks_mut(t * h * d))
|
||||
.for_each(|(src, dst)| {
|
||||
for i_t in 0..t {
|
||||
for i_d in 0..d / 2 {
|
||||
let i_cs = i_t * (d / 2) + i_d;
|
||||
for i_h in 0..h {
|
||||
let i1 = i_t * h * d + i_h * d + i_d;
|
||||
let i2 = i1 + d / 2;
|
||||
dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
|
||||
dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, (b, t, h, d).into()))
|
||||
}
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use CpuStorage::{BF16, F16, F32, F64};
|
||||
match (s1, s2, s3) {
|
||||
(BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
(F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
(F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
(F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
|
||||
_ => candle::bail!(
|
||||
"unsupported dtype for rope {:?} {:?} {:?}",
|
||||
s1.dtype(),
|
||||
s2.dtype(),
|
||||
s3.dtype()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
s1: &candle::CudaStorage,
|
||||
l1: &Layout,
|
||||
s2: &candle::CudaStorage,
|
||||
l2: &Layout,
|
||||
s3: &candle::CudaStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
||||
fn inner<T: DeviceRepr + WithDType>(
|
||||
src: &CudaSlice<T>,
|
||||
l_src: &Layout,
|
||||
cos: &CudaSlice<T>,
|
||||
l_cos: &Layout,
|
||||
sin: &CudaSlice<T>,
|
||||
l_sin: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let src = match l_src.contiguous_offsets() {
|
||||
None => candle::bail!("src input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
};
|
||||
let cos = match l_cos.contiguous_offsets() {
|
||||
None => candle::bail!("cos input has to be contiguous"),
|
||||
Some((o1, o2)) => cos.slice(o1..o2),
|
||||
};
|
||||
let sin = match l_sin.contiguous_offsets() {
|
||||
None => candle::bail!("sin input has to be contiguous"),
|
||||
Some((o1, o2)) => sin.slice(o1..o2),
|
||||
};
|
||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
&src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32,
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
|
||||
let dev = s1.device();
|
||||
let slice = match (&s1.slice, &s2.slice, &s3.slice) {
|
||||
(BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
(F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
(F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
(F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
|
||||
_ => candle::bail!(
|
||||
"unsupported dtype for rope {:?} {:?} {:?}",
|
||||
s1.dtype(),
|
||||
s2.dtype(),
|
||||
s3.dtype()
|
||||
),
|
||||
};
|
||||
let dst = candle::cuda_backend::CudaStorage {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Ok((dst, l1.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
src: &candle::MetalStorage,
|
||||
l_src: &Layout,
|
||||
cos: &candle::MetalStorage,
|
||||
l_cos: &Layout,
|
||||
sin: &candle::MetalStorage,
|
||||
l_sin: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
let device = src.device();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let kernels = device.kernels();
|
||||
if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
|
||||
candle::bail!(
|
||||
"dtype mismatch in rope {:?} {:?} {:?}",
|
||||
src.dtype(),
|
||||
cos.dtype(),
|
||||
sin.dtype()
|
||||
)
|
||||
}
|
||||
let name = match src.dtype() {
|
||||
candle::DType::F32 => "rope_thd_f32",
|
||||
candle::DType::F16 => "rope_thd_f16",
|
||||
candle::DType::BF16 => "rope_thd_bf16",
|
||||
dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"),
|
||||
};
|
||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let output = device.new_buffer(el, src.dtype(), "rope-thd")?;
|
||||
candle_metal_kernels::call_rope_thd(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
b,
|
||||
t,
|
||||
h,
|
||||
d,
|
||||
src.buffer(),
|
||||
l_src.start_offset() * src.dtype().size_in_bytes(),
|
||||
cos.buffer(),
|
||||
l_cos.start_offset() * cos.dtype().size_in_bytes(),
|
||||
sin.buffer(),
|
||||
l_sin.start_offset() * sin.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
|
||||
Ok((out, l_src.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
|
||||
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
|
||||
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
|
||||
if cos_n_embd * 2 != n_embd
|
||||
|| sin_n_embd * 2 != n_embd
|
||||
|| seq_len > cos_seq_len
|
||||
|| seq_len > sin_seq_len
|
||||
{
|
||||
candle::bail!(
|
||||
"inconsistent last dim size in rope {:?} {:?} {:?}",
|
||||
xs.shape(),
|
||||
cos.shape(),
|
||||
sin.shape()
|
||||
)
|
||||
}
|
||||
if !xs.is_contiguous() {
|
||||
candle::bail!("xs has to be contiguous in rope")
|
||||
}
|
||||
if !cos.is_contiguous() {
|
||||
candle::bail!("cos has to be contiguous in rope")
|
||||
}
|
||||
if !sin.is_contiguous() {
|
||||
candle::bail!("sin has to be contiguous in rope")
|
||||
}
|
||||
xs.apply_op3_no_bwd(cos, sin, &RotaryEmbThd)
|
||||
}
|
@ -178,16 +178,27 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
||||
name: &str,
|
||||
hints: B::Hints,
|
||||
) -> Result<Tensor> {
|
||||
let path = self.path(name);
|
||||
self.data
|
||||
.backend
|
||||
.get(s.into(), &path, hints, self.data.dtype, &self.data.device)
|
||||
self.get_with_hints_dtype(s, name, hints, self.data.dtype)
|
||||
}
|
||||
|
||||
/// Retrieve the tensor associated with the given name at the current path.
|
||||
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor> {
|
||||
self.get_with_hints(s, name, Default::default())
|
||||
}
|
||||
|
||||
/// Retrieve the tensor associated with the given name & dtype at the current path.
|
||||
pub fn get_with_hints_dtype<S: Into<Shape>>(
|
||||
&self,
|
||||
s: S,
|
||||
name: &str,
|
||||
hints: B::Hints,
|
||||
dtype: DType,
|
||||
) -> Result<Tensor> {
|
||||
let path = self.path(name);
|
||||
self.data
|
||||
.backend
|
||||
.get(s.into(), &path, hints, dtype, &self.data.device)
|
||||
}
|
||||
}
|
||||
|
||||
struct Zeros;
|
||||
@ -487,6 +498,53 @@ impl<'a> VarBuilder<'a> {
|
||||
let pth = candle::pickle::PthTensors::new(p, None)?;
|
||||
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before
|
||||
/// passing the new names to the inner VarBuilder.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, DType, Device};
|
||||
///
|
||||
/// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
|
||||
/// let tensors: std::collections::HashMap<_, _> = [
|
||||
/// ("foo".to_string(), a),
|
||||
/// ]
|
||||
/// .into_iter()
|
||||
/// .collect();
|
||||
/// let vb = candle_nn::VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
|
||||
/// assert!(vb.contains_tensor("foo"));
|
||||
/// assert!(vb.get((2, 3), "foo").is_ok());
|
||||
/// assert!(!vb.contains_tensor("bar"));
|
||||
/// let vb = vb.rename_f(|f: &str| if f == "bar" { "foo".to_string() } else { f.to_string() });
|
||||
/// assert!(vb.contains_tensor("bar"));
|
||||
/// assert!(vb.contains_tensor("foo"));
|
||||
/// assert!(vb.get((2, 3), "bar").is_ok());
|
||||
/// assert!(vb.get((2, 3), "foo").is_ok());
|
||||
/// assert!(!vb.contains_tensor("baz"));
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>(self, f: F) -> Self {
|
||||
let f: Box<dyn Fn(&str) -> String + Sync + Send + 'static> = Box::new(f);
|
||||
self.rename(f)
|
||||
}
|
||||
|
||||
pub fn rename<R: Renamer + Send + Sync + 'a>(self, renamer: R) -> Self {
|
||||
let dtype = self.dtype();
|
||||
let device = self.device().clone();
|
||||
let path = self.path.clone();
|
||||
let backend = Rename::new(self, renamer);
|
||||
let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
|
||||
let data = TensorData {
|
||||
backend,
|
||||
dtype,
|
||||
device,
|
||||
};
|
||||
Self {
|
||||
data: Arc::new(data),
|
||||
path,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
|
||||
@ -607,3 +665,49 @@ impl Backend for ShardedSafeTensors {
|
||||
self.0.get(name).is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// This traits specifies a way to rename the queried names into names that are stored in an inner
|
||||
/// VarBuilder.
|
||||
pub trait Renamer {
|
||||
/// This is applied to the name obtained by a name call and the resulting name is passed to the
|
||||
/// inner VarBuilder.
|
||||
fn rename(&self, v: &str) -> std::borrow::Cow<'_, str>;
|
||||
}
|
||||
|
||||
pub struct Rename<'a, R: Renamer> {
|
||||
inner: VarBuilder<'a>,
|
||||
renamer: R,
|
||||
}
|
||||
|
||||
impl<'a, R: Renamer + Sync + Send> SimpleBackend for Rename<'a, R> {
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
name: &str,
|
||||
h: crate::Init,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let name = self.renamer.rename(name);
|
||||
self.inner
|
||||
.get_with_hints_dtype(s, &name, h, dtype)?
|
||||
.to_device(dev)
|
||||
}
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool {
|
||||
let name = self.renamer.rename(name);
|
||||
self.inner.contains_tensor(&name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, R: Renamer> Rename<'a, R> {
|
||||
pub fn new(inner: VarBuilder<'a>, renamer: R) -> Self {
|
||||
Self { inner, renamer }
|
||||
}
|
||||
}
|
||||
|
||||
impl Renamer for Box<dyn Fn(&str) -> String + Sync + Send> {
|
||||
fn rename(&self, v: &str) -> std::borrow::Cow<'_, str> {
|
||||
std::borrow::Cow::Owned(self(v))
|
||||
}
|
||||
}
|
||||
|
@ -86,5 +86,92 @@ fn softmax_numerical_stability() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ropei(device: &Device) -> Result<()> {
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
||||
let el_count = b_size * num_head * seq_len * head_dim;
|
||||
let mut rng = StdRng::seed_from_u64(299792458);
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
||||
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.collect();
|
||||
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.collect();
|
||||
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
|
||||
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
|
||||
let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?;
|
||||
let rope1 = candle_nn::rotary_emb::rope_i(&src, &cos, &sin)?;
|
||||
let rope2 = candle_nn::rotary_emb::rope_i_slow(&src, &cos, &sin)?;
|
||||
let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if device.is_cpu() {
|
||||
assert_eq!(sum_diff, 0.);
|
||||
} else {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rope(device: &Device) -> Result<()> {
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
||||
let el_count = b_size * num_head * seq_len * head_dim;
|
||||
let mut rng = StdRng::seed_from_u64(299792458);
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
||||
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.collect();
|
||||
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.collect();
|
||||
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
|
||||
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
|
||||
let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?;
|
||||
let rope1 = candle_nn::rotary_emb::rope(&src, &cos, &sin)?;
|
||||
let rope2 = candle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?;
|
||||
let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if device.is_cpu() {
|
||||
assert_eq!(sum_diff, 0.);
|
||||
} else {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rope_thd(device: &Device) -> Result<()> {
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
||||
let el_count = b_size * num_head * seq_len * head_dim;
|
||||
let mut rng = StdRng::seed_from_u64(299792458);
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
||||
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.collect();
|
||||
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.collect();
|
||||
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
|
||||
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
|
||||
let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?;
|
||||
let rope1 = {
|
||||
let src = src.transpose(1, 2)?.contiguous()?;
|
||||
candle_nn::rotary_emb::rope_thd(&src, &cos, &sin)?.transpose(1, 2)?
|
||||
};
|
||||
let rope2 = candle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?;
|
||||
let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if device.is_cpu() {
|
||||
assert_eq!(sum_diff, 0.);
|
||||
} else {
|
||||
assert!(sum_diff < 1e-4);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
|
||||
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
|
||||
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
|
||||
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
|
||||
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.4.2"
|
||||
version = "0.5.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.4.2" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.4.2" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.5.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.5.0" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -2,7 +2,7 @@ use crate::onnx;
|
||||
use crate::onnx::attribute_proto::AttributeType;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, usize};
|
||||
|
||||
pub type Value = Tensor;
|
||||
|
||||
@ -508,17 +508,33 @@ pub fn simple_eval(
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Gather" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather
|
||||
let xs = get(&node.input[0])?;
|
||||
let indices = get(&node.input[1])?;
|
||||
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
||||
let axis = xs.normalize_axis(axis)?;
|
||||
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
|
||||
// differentiable way.
|
||||
let xs = if indices.rank() == 0 {
|
||||
let index = indices.to_vec0::<i64>()? as usize;
|
||||
xs.narrow(axis, index, 1)?.squeeze(axis)?
|
||||
} else {
|
||||
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
|
||||
|
||||
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
|
||||
// tensor directly, but candle does not support tensor indexing at the moment, so
|
||||
// some workarounds must be done.
|
||||
let xs = match indices.dims() {
|
||||
[] => {
|
||||
let index = indices.to_vec0::<i64>()? as usize;
|
||||
xs.narrow(axis, index, 1)?.squeeze(axis)?
|
||||
}
|
||||
[_] => xs.index_select(indices, axis)?,
|
||||
[first, _] => {
|
||||
let mut v = Vec::with_capacity(*first);
|
||||
for i in 0..*first {
|
||||
v.push(xs.index_select(&indices.get(i)?, axis)?)
|
||||
}
|
||||
Tensor::stack(&v, axis)?
|
||||
}
|
||||
_ => {
|
||||
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
|
||||
// differentiable way.
|
||||
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
|
||||
}
|
||||
};
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
@ -776,6 +792,34 @@ pub fn simple_eval(
|
||||
let output = input.reshape(new_shape)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#identity
|
||||
"Identity" => {
|
||||
let input = get(&node.input[0])?;
|
||||
values.insert(node.output[0].clone(), input.clone());
|
||||
}
|
||||
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
|
||||
// TODO: This version is only compatible with ReduceMean V13 and below.
|
||||
"ReduceMean" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let axes = get_attr_opt::<[i64]>(node, "axes")?;
|
||||
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
|
||||
|
||||
let n_dims = input.dims().len();
|
||||
|
||||
let axes: Vec<usize> = if let Some(axes) = axes {
|
||||
axes.iter()
|
||||
.map(|e| (if e < &0 { (n_dims as i64) + *e } else { *e }) as usize)
|
||||
.collect()
|
||||
} else {
|
||||
(0..n_dims).collect()
|
||||
};
|
||||
let output = if keepdims == 1 {
|
||||
input.mean_keepdim(axes)?
|
||||
} else {
|
||||
input.mean(axes)?
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{Device, Result, Tensor};
|
||||
use candle::{Device, NdArray, Result, Tensor};
|
||||
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -829,7 +829,134 @@ fn test_flatten_operation() -> Result<()> {
|
||||
// #[test]
|
||||
|
||||
// "Gather"
|
||||
// #[test]
|
||||
#[test]
|
||||
fn test_gather_operation() -> Result<()> {
|
||||
// test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary.
|
||||
test(
|
||||
&[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]],
|
||||
&[[0i64, 1], [1, 2]],
|
||||
0,
|
||||
&[[[1.0, 1.2], [2.3, 3.4]], [[2.3, 3.4], [4.5, 5.7]]],
|
||||
)?;
|
||||
|
||||
// test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary.
|
||||
test(
|
||||
&[[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9]],
|
||||
&[[0i64, 2]],
|
||||
1,
|
||||
&[[[1.0, 1.9]], [[2.3, 3.9]], [[4.5, 5.9]]],
|
||||
)?;
|
||||
|
||||
// all the tests below are generated from numpy.take, which works like
|
||||
// onnx's Gather operation.
|
||||
test(&[1.0, 2.0, 3.0, 4.0], 3i64, 0, 4.0)?;
|
||||
|
||||
test(&[[1.0, 2.0, 3.0, 4.0]], 3i64, 1, &[4.0])?;
|
||||
|
||||
test(
|
||||
&[[1.0], [2.0], [3.0], [4.0]],
|
||||
&[3i64, 2],
|
||||
0,
|
||||
&[[4.0], [3.0]],
|
||||
)?;
|
||||
|
||||
test(
|
||||
&[
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
[[5.0, 6.0], [7.0, 8.0]],
|
||||
[[9.0, 10.0], [11.0, 12.0]],
|
||||
[[13.0, 14.0], [15.0, 16.0]],
|
||||
],
|
||||
1i64,
|
||||
0,
|
||||
&[[5.0, 6.0], [7.0, 8.0]],
|
||||
)?;
|
||||
|
||||
test(
|
||||
&[
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
[[5.0, 6.0], [7.0, 8.0]],
|
||||
[[9.0, 10.0], [11.0, 12.0]],
|
||||
[[13.0, 14.0], [15.0, 16.0]],
|
||||
],
|
||||
&[1i64, 0],
|
||||
0,
|
||||
&[[[5.0, 6.0], [7.0, 8.0]], [[1.0, 2.0], [3.0, 4.0]]],
|
||||
)?;
|
||||
|
||||
fn test(
|
||||
data: impl NdArray,
|
||||
indices: impl NdArray,
|
||||
axis: i64,
|
||||
expected: impl NdArray,
|
||||
) -> Result<()> {
|
||||
let att_axis = AttributeProto {
|
||||
name: "axis".to_string(),
|
||||
ref_attr_name: "axis".to_string(),
|
||||
i: axis,
|
||||
doc_string: "axis".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Gather".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![att_axis],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||
match expected.dims().len() {
|
||||
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Shape"
|
||||
#[test]
|
||||
@ -1335,3 +1462,180 @@ fn test_relu_operation() -> Result<()> {
|
||||
|
||||
// "Cast"
|
||||
// #[test]
|
||||
|
||||
// "ReduceMean"
|
||||
#[test]
|
||||
fn test_reduce_mean() -> Result<()> {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 default_axes_keepdims
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
None,
|
||||
1,
|
||||
&[[[18.25]]],
|
||||
)?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 do_no_keepdims
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
Some(vec![1]),
|
||||
0,
|
||||
&[[12.5, 1.5], [35.0, 1.5], [57.5, 1.5]],
|
||||
)?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 keepdims
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
Some(vec![1]),
|
||||
1,
|
||||
&[[[12.5, 1.5]], [[35.0, 1.5]], [[57.5, 1.5]]],
|
||||
)?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 negative_axes_keepdims
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
Some(vec![-2]),
|
||||
1,
|
||||
&[[[12.5, 1.5]], [[35.0, 1.5]], [[57.5, 1.5]]],
|
||||
)?;
|
||||
|
||||
// All the test data below was generated based on numpy's np.mean
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
Some(vec![1, 2]),
|
||||
0,
|
||||
&[7.0, 18.25, 29.5],
|
||||
)?;
|
||||
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
Some(vec![1, 2]),
|
||||
1,
|
||||
&[[[7.0]], [[18.25]], [[29.5]]],
|
||||
)?;
|
||||
|
||||
test(&[1., 2., 3.], None, 1, &[2.0])?;
|
||||
|
||||
fn test(
|
||||
data: impl NdArray,
|
||||
axes: Option<Vec<i64>>,
|
||||
keepdims: i64,
|
||||
expected: impl NdArray,
|
||||
) -> Result<()> {
|
||||
let has_axes = axes.is_some();
|
||||
|
||||
let att_axes = AttributeProto {
|
||||
name: "axes".to_string(),
|
||||
ref_attr_name: "axes".to_string(),
|
||||
i: 0,
|
||||
doc_string: "axes".to_string(),
|
||||
r#type: 7,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: axes.unwrap_or_default(),
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let att_keepdims = AttributeProto {
|
||||
name: "keepdims".to_string(),
|
||||
ref_attr_name: "keepdims".to_string(),
|
||||
i: keepdims,
|
||||
doc_string: "keepdims".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "ReduceMean".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: if has_axes {
|
||||
vec![att_axes, att_keepdims]
|
||||
} else {
|
||||
vec![att_keepdims]
|
||||
},
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||
match expected.dims().len() {
|
||||
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -20,10 +20,10 @@ candle-nn = { workspace = true }
|
||||
candle-onnx = { workspace = true, optional = true }
|
||||
half = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
||||
pyo3 = { version = "0.21.0", features = ["extension-module", "abi3-py38"] }
|
||||
|
||||
[build-dependencies]
|
||||
pyo3-build-config = "0.20"
|
||||
pyo3-build-config = "0.21"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
19
candle-pyo3/py_src/candle/nn/__init__.pyi
Normal file
19
candle-pyo3/py_src/candle/nn/__init__.pyi
Normal file
@ -0,0 +1,19 @@
|
||||
# Generated content DO NOT EDIT
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
|
||||
from candle import Tensor, DType, QTensor
|
||||
|
||||
@staticmethod
|
||||
def silu(tensor: Tensor) -> Tensor:
|
||||
"""
|
||||
Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def softmax(tensor: Tensor, dim: int) -> Tensor:
|
||||
"""
|
||||
Applies the Softmax function to a given tensor.#
|
||||
"""
|
||||
pass
|
@ -60,8 +60,8 @@ impl PyDType {
|
||||
impl PyDType {
|
||||
fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult<Self> {
|
||||
use std::str::FromStr;
|
||||
if let Ok(dtype) = ob.extract::<&str>(py) {
|
||||
let dtype = DType::from_str(dtype)
|
||||
if let Ok(dtype) = ob.extract::<String>(py) {
|
||||
let dtype = DType::from_str(&dtype)
|
||||
.map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
|
||||
Ok(Self(dtype))
|
||||
} else {
|
||||
@ -116,8 +116,8 @@ impl PyDevice {
|
||||
|
||||
impl<'source> FromPyObject<'source> for PyDevice {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let device: &str = ob.extract()?;
|
||||
let device = match device {
|
||||
let device: String = ob.extract()?;
|
||||
let device = match device.as_str() {
|
||||
"cpu" => PyDevice::Cpu,
|
||||
"cuda" => PyDevice::Cuda,
|
||||
_ => Err(PyTypeError::new_err(format!("invalid device '{device}'")))?,
|
||||
@ -265,7 +265,7 @@ impl PyTensor {
|
||||
} else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
|
||||
return PyTensor::new(py, numpy);
|
||||
} else {
|
||||
let ty = data.as_ref(py).get_type();
|
||||
let ty = data.bind(py).get_type();
|
||||
Err(PyTypeError::new_err(format!(
|
||||
"incorrect type {ty} for tensor"
|
||||
)))?
|
||||
@ -322,7 +322,7 @@ impl PyTensor {
|
||||
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||
let candle_values = self.values(py)?;
|
||||
let torch_tensor: PyObject = py
|
||||
.import("torch")?
|
||||
.import_bound("torch")?
|
||||
.getattr("tensor")?
|
||||
.call1((candle_values,))?
|
||||
.extract()?;
|
||||
@ -333,7 +333,7 @@ impl PyTensor {
|
||||
/// Gets the tensor's shape.
|
||||
/// &RETURNS&: Tuple[int]
|
||||
fn shape(&self, py: Python<'_>) -> PyObject {
|
||||
PyTuple::new(py, self.0.dims()).to_object(py)
|
||||
PyTuple::new_bound(py, self.0.dims()).to_object(py)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
@ -347,7 +347,7 @@ impl PyTensor {
|
||||
/// Gets the tensor's strides.
|
||||
/// &RETURNS&: Tuple[int]
|
||||
fn stride(&self, py: Python<'_>) -> PyObject {
|
||||
PyTuple::new(py, self.0.stride()).to_object(py)
|
||||
PyTuple::new_bound(py, self.0.stride()).to_object(py)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
@ -527,7 +527,7 @@ impl PyTensor {
|
||||
}
|
||||
|
||||
fn extract_indexer(
|
||||
py_indexer: &PyAny,
|
||||
py_indexer: &Bound<PyAny>,
|
||||
current_dim: usize,
|
||||
dims: &[usize],
|
||||
index_argument_count: usize,
|
||||
@ -567,7 +567,7 @@ impl PyTensor {
|
||||
),
|
||||
current_dim + 1,
|
||||
))
|
||||
} else if py_indexer.is_ellipsis() {
|
||||
} else if py_indexer.is(&py_indexer.py().Ellipsis()) {
|
||||
// Handle '...' e.g. tensor[..., 0]
|
||||
if current_dim > 0 {
|
||||
return Err(PyTypeError::new_err(
|
||||
@ -586,7 +586,7 @@ impl PyTensor {
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(tuple) = idx.downcast::<pyo3::types::PyTuple>(py) {
|
||||
if let Ok(tuple) = idx.downcast_bound::<pyo3::types::PyTuple>(py) {
|
||||
let not_none_count: usize = tuple.iter().filter(|x| !x.is_none()).count();
|
||||
|
||||
if not_none_count > dims.len() {
|
||||
@ -596,12 +596,12 @@ impl PyTensor {
|
||||
let mut current_dim = 0;
|
||||
for item in tuple.iter() {
|
||||
let (indexer, new_current_dim) =
|
||||
extract_indexer(item, current_dim, dims, not_none_count)?;
|
||||
extract_indexer(&item, current_dim, dims, not_none_count)?;
|
||||
current_dim = new_current_dim;
|
||||
indexers.push(indexer);
|
||||
}
|
||||
} else {
|
||||
let (indexer, _) = extract_indexer(idx.downcast::<PyAny>(py)?, 0, dims, 1)?;
|
||||
let (indexer, _) = extract_indexer(idx.downcast_bound::<PyAny>(py)?, 0, dims, 1)?;
|
||||
indexers.push(indexer);
|
||||
}
|
||||
|
||||
@ -652,7 +652,7 @@ impl PyTensor {
|
||||
|
||||
/// Add two tensors.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
fn __add__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
self.0.broadcast_add(&rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
@ -663,13 +663,13 @@ impl PyTensor {
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
|
||||
fn __radd__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
fn __radd__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
|
||||
self.__add__(rhs)
|
||||
}
|
||||
|
||||
/// Multiply two tensors.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
fn __mul__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
self.0.broadcast_mul(&rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
@ -680,13 +680,13 @@ impl PyTensor {
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
|
||||
fn __rmul__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
fn __rmul__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
|
||||
self.__mul__(rhs)
|
||||
}
|
||||
|
||||
/// Subtract two tensors.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __sub__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
fn __sub__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
self.0.broadcast_sub(&rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
@ -699,7 +699,7 @@ impl PyTensor {
|
||||
|
||||
/// Divide two tensors.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
fn __truediv__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
self.0.broadcast_div(&rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
@ -711,7 +711,7 @@ impl PyTensor {
|
||||
}
|
||||
/// Rich-compare two tensors.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult<Self> {
|
||||
fn __richcmp__(&self, rhs: &Bound<PyAny>, op: CompareOp) -> PyResult<Self> {
|
||||
let compare = |lhs: &Tensor, rhs: &Tensor| {
|
||||
let t = match op {
|
||||
CompareOp::Eq => lhs.eq(rhs),
|
||||
@ -957,7 +957,7 @@ impl PyTensor {
|
||||
#[pyo3(signature = (*args, **kwargs), text_signature = "(self, *args, **kwargs)")]
|
||||
/// Performs Tensor dtype and/or device conversion.
|
||||
/// &RETURNS&: Tensor
|
||||
fn to(&self, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
||||
fn to(&self, args: &Bound<PyTuple>, kwargs: Option<&Bound<PyDict>>) -> PyResult<Self> {
|
||||
let mut device: Option<PyDevice> = None;
|
||||
let mut dtype: Option<PyDType> = None;
|
||||
let mut other: Option<PyTensor> = None;
|
||||
@ -1227,7 +1227,7 @@ impl PyQTensor {
|
||||
///Gets the shape of the tensor.
|
||||
/// &RETURNS&: Tuple[int]
|
||||
fn shape(&self, py: Python<'_>) -> PyObject {
|
||||
PyTuple::new(py, self.0.shape().dims()).to_object(py)
|
||||
PyTuple::new_bound(py, self.0.shape().dims()).to_object(py)
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
@ -1265,7 +1265,7 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
|
||||
.into_iter()
|
||||
.map(|(key, value)| (key, PyTensor(value).into_py(py)))
|
||||
.collect::<Vec<_>>();
|
||||
Ok(res.into_py_dict(py).to_object(py))
|
||||
Ok(res.into_py_dict_bound(py).to_object(py))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
@ -1303,7 +1303,7 @@ fn load_ggml(
|
||||
.map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))))
|
||||
.collect::<::candle::Result<Vec<_>>>()
|
||||
.map_err(wrap_err)?;
|
||||
let tensors = tensors.into_py_dict(py).to_object(py);
|
||||
let tensors = tensors.into_py_dict_bound(py).to_object(py);
|
||||
let hparams = [
|
||||
("n_vocab", ggml.hparams.n_vocab),
|
||||
("n_embd", ggml.hparams.n_embd),
|
||||
@ -1313,7 +1313,7 @@ fn load_ggml(
|
||||
("n_rot", ggml.hparams.n_rot),
|
||||
("ftype", ggml.hparams.ftype),
|
||||
];
|
||||
let hparams = hparams.into_py_dict(py).to_object(py);
|
||||
let hparams = hparams.into_py_dict_bound(py).to_object(py);
|
||||
let vocab = ggml
|
||||
.vocab
|
||||
.token_score_pairs
|
||||
@ -1351,7 +1351,7 @@ fn load_gguf(
|
||||
gguf_file::Value::Bool(x) => x.into_py(py),
|
||||
gguf_file::Value::String(x) => x.into_py(py),
|
||||
gguf_file::Value::Array(x) => {
|
||||
let list = pyo3::types::PyList::empty(py);
|
||||
let list = pyo3::types::PyList::empty_bound(py);
|
||||
for elem in x.iter() {
|
||||
list.append(gguf_value_to_pyobject(elem, py)?)?;
|
||||
}
|
||||
@ -1371,13 +1371,13 @@ fn load_gguf(
|
||||
})
|
||||
.collect::<::candle::Result<Vec<_>>>()
|
||||
.map_err(wrap_err)?;
|
||||
let tensors = tensors.into_py_dict(py).to_object(py);
|
||||
let tensors = tensors.into_py_dict_bound(py).to_object(py);
|
||||
let metadata = gguf
|
||||
.metadata
|
||||
.iter()
|
||||
.map(|(key, value)| Ok((key, gguf_value_to_pyobject(value, py)?)))
|
||||
.collect::<PyResult<Vec<_>>>()?
|
||||
.into_py_dict(py)
|
||||
.into_py_dict_bound(py)
|
||||
.to_object(py);
|
||||
Ok((tensors, metadata))
|
||||
}
|
||||
@ -1390,7 +1390,7 @@ fn load_gguf(
|
||||
fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> {
|
||||
use ::candle::quantized::gguf_file;
|
||||
|
||||
fn pyobject_to_gguf_value(v: &PyAny, py: Python<'_>) -> PyResult<gguf_file::Value> {
|
||||
fn pyobject_to_gguf_value(v: &Bound<PyAny>, py: Python<'_>) -> PyResult<gguf_file::Value> {
|
||||
let v: gguf_file::Value = if let Ok(x) = v.extract::<u8>() {
|
||||
gguf_file::Value::U8(x)
|
||||
} else if let Ok(x) = v.extract::<i8>() {
|
||||
@ -1418,7 +1418,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
|
||||
} else if let Ok(x) = v.extract::<Vec<PyObject>>() {
|
||||
let x = x
|
||||
.into_iter()
|
||||
.map(|f| pyobject_to_gguf_value(f.as_ref(py), py))
|
||||
.map(|f| pyobject_to_gguf_value(f.bind(py), py))
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
gguf_file::Value::Array(x)
|
||||
} else {
|
||||
@ -1450,7 +1450,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
|
||||
Ok((
|
||||
key.extract::<String>()
|
||||
.map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?,
|
||||
pyobject_to_gguf_value(value, py)?,
|
||||
pyobject_to_gguf_value(&value.as_borrowed(), py)?,
|
||||
))
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
@ -1498,7 +1498,7 @@ fn get_num_threads() -> usize {
|
||||
::candle::utils::get_num_threads()
|
||||
}
|
||||
|
||||
fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
fn candle_utils(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
|
||||
@ -1579,7 +1579,7 @@ fn tanh(tensor: PyTensor) -> PyResult<PyTensor> {
|
||||
Ok(PyTensor(s))
|
||||
}
|
||||
|
||||
fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
fn candle_functional_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(silu, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(softmax, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(max_pool2d, m)?)?;
|
||||
@ -1591,7 +1591,7 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
}
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
fn candle_onnx_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
use onnx::{PyONNXModel, PyONNXTensorDescriptor};
|
||||
m.add_class::<PyONNXModel>()?;
|
||||
m.add_class::<PyONNXTensorDescriptor>()?;
|
||||
@ -1599,18 +1599,18 @@ fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
let utils = PyModule::new(py, "utils")?;
|
||||
candle_utils(py, utils)?;
|
||||
m.add_submodule(utils)?;
|
||||
let nn = PyModule::new(py, "functional")?;
|
||||
candle_functional_m(py, nn)?;
|
||||
m.add_submodule(nn)?;
|
||||
fn candle(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
let utils = PyModule::new_bound(py, "utils")?;
|
||||
candle_utils(py, &utils)?;
|
||||
m.add_submodule(&utils)?;
|
||||
let nn = PyModule::new_bound(py, "functional")?;
|
||||
candle_functional_m(py, &nn)?;
|
||||
m.add_submodule(&nn)?;
|
||||
#[cfg(feature = "onnx")]
|
||||
{
|
||||
let onnx = PyModule::new(py, "onnx")?;
|
||||
candle_onnx_m(py, onnx)?;
|
||||
m.add_submodule(onnx)?;
|
||||
let onnx = PyModule::new_bound(py, "onnx")?;
|
||||
candle_onnx_m(py, &onnx)?;
|
||||
m.add_submodule(&onnx)?;
|
||||
}
|
||||
m.add_class::<PyTensor>()?;
|
||||
m.add_class::<PyQTensor>()?;
|
||||
|
@ -39,7 +39,7 @@ impl PyONNXTensorDescriptor {
|
||||
/// The shape of the tensor.
|
||||
/// &RETURNS&: Tuple[Union[int,str,Any]]
|
||||
fn shape(&self, py: Python) -> PyResult<Py<PyTuple>> {
|
||||
let shape = PyList::empty(py);
|
||||
let shape = PyList::empty_bound(py);
|
||||
if let Some(d) = &self.0.shape {
|
||||
for dim in d.dim.iter() {
|
||||
if let Some(value) = &dim.value {
|
||||
|
@ -206,6 +206,8 @@ def write(module, directory, origin, check=False):
|
||||
if check:
|
||||
with open(filename, "r") as f:
|
||||
data = f.read()
|
||||
print("generated content")
|
||||
print(pyi_content)
|
||||
assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
|
||||
else:
|
||||
with open(filename, "w") as f:
|
||||
@ -229,6 +231,8 @@ def write(module, directory, origin, check=False):
|
||||
if check:
|
||||
with open(filename, "r") as f:
|
||||
data = f.read()
|
||||
print("generated content")
|
||||
print(py_content)
|
||||
assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
|
||||
else:
|
||||
with open(filename, "w") as f:
|
||||
|
@ -7,6 +7,7 @@ pub enum Sampling {
|
||||
All { temperature: f64 },
|
||||
TopK { k: usize, temperature: f64 },
|
||||
TopP { p: f64, temperature: f64 },
|
||||
TopKThenTopP { k: usize, p: f64, temperature: f64 },
|
||||
}
|
||||
|
||||
pub struct LogitsProcessor {
|
||||
@ -77,7 +78,6 @@ impl LogitsProcessor {
|
||||
self.sample_multinomial(prs)
|
||||
} else {
|
||||
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
|
||||
// Sort by descending probability.
|
||||
let (indices, _, _) =
|
||||
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
|
||||
let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
|
||||
@ -86,6 +86,26 @@ impl LogitsProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
// top-k sampling samples from the k tokens with the largest probabilities.
|
||||
// then top-p sampling.
|
||||
fn sample_topk_topp(&mut self, prs: &mut Vec<f32>, top_k: usize, top_p: f32) -> Result<u32> {
|
||||
if top_k >= prs.len() {
|
||||
self.sample_topp(prs, top_p)
|
||||
} else {
|
||||
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
|
||||
let (indices, _, _) =
|
||||
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
|
||||
let mut prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
|
||||
let sum_p = prs.iter().sum::<f32>();
|
||||
let index = if top_p <= 0.0 || top_p >= sum_p {
|
||||
self.sample_multinomial(&prs)?
|
||||
} else {
|
||||
self.sample_topp(&mut prs, top_p)?
|
||||
};
|
||||
Ok(indices[index as usize] as u32)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
||||
self.sample_f(logits, |_| {})
|
||||
}
|
||||
@ -120,6 +140,10 @@ impl LogitsProcessor {
|
||||
let mut prs = prs(*temperature)?;
|
||||
self.sample_topk(&mut prs, *k)?
|
||||
}
|
||||
Sampling::TopKThenTopP { k, p, temperature } => {
|
||||
let mut prs = prs(*temperature)?;
|
||||
self.sample_topk_topp(&mut prs, *k, *p as f32)?
|
||||
}
|
||||
};
|
||||
Ok(next_token)
|
||||
}
|
||||
|
154
candle-transformers/src/models/clip/mod.rs
Normal file
154
candle-transformers/src/models/clip/mod.rs
Normal file
@ -0,0 +1,154 @@
|
||||
//! Contrastive Language-Image Pre-Training
|
||||
//!
|
||||
//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||
//! pairs of images with related texts.
|
||||
//!
|
||||
//! https://github.com/openai/CLIP
|
||||
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
||||
use self::{
|
||||
text_model::{Activation, ClipTextTransformer},
|
||||
vision_model::ClipVisionTransformer,
|
||||
};
|
||||
use candle::{Result, Tensor, D};
|
||||
|
||||
pub mod text_model;
|
||||
pub mod vision_model;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ClipModel {
|
||||
text_model: ClipTextTransformer,
|
||||
vision_model: ClipVisionTransformer,
|
||||
visual_projection: candle_nn::Linear,
|
||||
text_projection: candle_nn::Linear,
|
||||
logit_scale: Tensor,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum EncoderConfig {
|
||||
Text(text_model::ClipTextConfig),
|
||||
Vision(vision_model::ClipVisionConfig),
|
||||
}
|
||||
|
||||
impl EncoderConfig {
|
||||
pub fn embed_dim(&self) -> usize {
|
||||
match self {
|
||||
Self::Text(c) => c.embed_dim,
|
||||
Self::Vision(c) => c.embed_dim,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_attention_heads(&self) -> usize {
|
||||
match self {
|
||||
Self::Text(c) => c.num_attention_heads,
|
||||
Self::Vision(c) => c.num_attention_heads,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn intermediate_size(&self) -> usize {
|
||||
match self {
|
||||
Self::Text(c) => c.intermediate_size,
|
||||
Self::Vision(c) => c.intermediate_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_hidden_layers(&self) -> usize {
|
||||
match self {
|
||||
Self::Text(c) => c.num_hidden_layers,
|
||||
Self::Vision(c) => c.num_hidden_layers,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn activation(&self) -> Activation {
|
||||
match self {
|
||||
Self::Text(_c) => Activation::QuickGelu,
|
||||
Self::Vision(c) => c.activation,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ClipConfig {
|
||||
pub text_config: text_model::ClipTextConfig,
|
||||
pub vision_config: vision_model::ClipVisionConfig,
|
||||
pub logit_scale_init_value: f32,
|
||||
pub image_size: usize,
|
||||
}
|
||||
|
||||
impl ClipConfig {
|
||||
// base image size is 224, model size is 600Mb
|
||||
pub fn vit_base_patch32() -> Self {
|
||||
let text_config = text_model::ClipTextConfig::vit_base_patch32();
|
||||
let vision_config = vision_model::ClipVisionConfig::vit_base_patch32();
|
||||
|
||||
Self {
|
||||
text_config,
|
||||
vision_config,
|
||||
logit_scale_init_value: 2.6592,
|
||||
image_size: 224,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ClipModel {
|
||||
pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result<Self> {
|
||||
let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?;
|
||||
|
||||
let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?;
|
||||
|
||||
let visual_projection = candle_nn::linear_no_bias(
|
||||
c.vision_config.embed_dim,
|
||||
c.vision_config.projection_dim,
|
||||
vs.pp("visual_projection"),
|
||||
)?;
|
||||
|
||||
let text_projection = candle_nn::linear_no_bias(
|
||||
c.text_config.embed_dim,
|
||||
c.text_config.projection_dim,
|
||||
vs.pp("text_projection"),
|
||||
)?;
|
||||
|
||||
// originally nn.Parameter
|
||||
let logit_scale = if vs.contains_tensor("logit_scale") {
|
||||
vs.get(&[], "logit_scale")?
|
||||
} else {
|
||||
Tensor::new(&[c.logit_scale_init_value], vs.device())?
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
text_model,
|
||||
vision_model,
|
||||
visual_projection,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
input_ids
|
||||
.apply(&self.text_model)?
|
||||
.apply(&self.text_projection)
|
||||
}
|
||||
|
||||
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||
pixel_values
|
||||
.apply(&self.vision_model)?
|
||||
.apply(&self.visual_projection)
|
||||
}
|
||||
|
||||
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let image_features = self.get_image_features(pixel_values)?;
|
||||
let text_features = self.get_text_features(input_ids)?;
|
||||
let image_features_normalized = div_l2_norm(&image_features)?;
|
||||
let text_features_normalized = div_l2_norm(&text_features)?;
|
||||
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
|
||||
let logit_scale = self.logit_scale.exp()?;
|
||||
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
|
||||
let logits_per_image = logits_per_text.t()?;
|
||||
Ok((logits_per_text, logits_per_image))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn div_l2_norm(v: &Tensor) -> Result<Tensor> {
|
||||
let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
|
||||
v.broadcast_div(&l2_norm)
|
||||
}
|
333
candle-transformers/src/models/clip/text_model.rs
Normal file
333
candle-transformers/src/models/clip/text_model.rs
Normal file
@ -0,0 +1,333 @@
|
||||
//! Contrastive Language-Image Pre-Training
|
||||
//!
|
||||
//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||
//! pairs of images with related texts.
|
||||
//!
|
||||
//! https://github.com/openai/CLIP
|
||||
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
||||
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
use candle_nn::Module;
|
||||
|
||||
use super::EncoderConfig;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Activation {
|
||||
QuickGelu,
|
||||
}
|
||||
|
||||
impl Module for Activation {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClipTextConfig {
|
||||
pub vocab_size: usize,
|
||||
pub embed_dim: usize,
|
||||
pub activation: Activation,
|
||||
pub intermediate_size: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub pad_with: Option<String>,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
#[allow(dead_code)]
|
||||
pub projection_dim: usize,
|
||||
}
|
||||
|
||||
impl ClipTextConfig {
|
||||
// The config details can be found in the "text_config" section of this json file:
|
||||
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
||||
pub fn vit_base_patch32() -> Self {
|
||||
Self {
|
||||
vocab_size: 49408,
|
||||
embed_dim: 512,
|
||||
intermediate_size: 2048,
|
||||
max_position_embeddings: 77,
|
||||
pad_with: None,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 8,
|
||||
projection_dim: 512,
|
||||
activation: Activation::QuickGelu,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ClipTextEmbeddings mostly based on the existing implementation in the stable diffision model.
|
||||
// TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142
|
||||
#[derive(Clone, Debug)]
|
||||
struct ClipTextEmbeddings {
|
||||
token_embedding: candle_nn::Embedding,
|
||||
position_embedding: candle_nn::Embedding,
|
||||
position_ids: Tensor,
|
||||
}
|
||||
|
||||
impl ClipTextEmbeddings {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
|
||||
let token_embedding =
|
||||
candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
|
||||
let position_embedding: nn::Embedding = candle_nn::embedding(
|
||||
c.max_position_embeddings,
|
||||
c.embed_dim,
|
||||
vs.pp("position_embedding"),
|
||||
)?;
|
||||
let position_ids =
|
||||
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
|
||||
Ok(ClipTextEmbeddings {
|
||||
token_embedding,
|
||||
position_embedding,
|
||||
position_ids,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ClipTextEmbeddings {
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let seq_length = input_ids.dim(D::Minus1)?;
|
||||
let inputs_embeds = self.token_embedding.forward(input_ids)?;
|
||||
let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
|
||||
let position_embedding = self.position_embedding.forward(&position_ids)?;
|
||||
inputs_embeds.broadcast_add(&position_embedding)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ClipAttention {
|
||||
k_proj: candle_nn::Linear,
|
||||
v_proj: candle_nn::Linear,
|
||||
q_proj: candle_nn::Linear,
|
||||
out_proj: candle_nn::Linear,
|
||||
head_dim: usize,
|
||||
scale: f64,
|
||||
num_attention_heads: usize,
|
||||
}
|
||||
|
||||
impl ClipAttention {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
|
||||
let embed_dim = c.embed_dim();
|
||||
let num_attention_heads = c.num_attention_heads();
|
||||
let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
|
||||
let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
|
||||
let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
|
||||
let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
|
||||
let head_dim = embed_dim / num_attention_heads;
|
||||
let scale = (head_dim as f64).powf(-0.5);
|
||||
|
||||
Ok(ClipAttention {
|
||||
k_proj,
|
||||
v_proj,
|
||||
q_proj,
|
||||
out_proj,
|
||||
head_dim,
|
||||
scale,
|
||||
num_attention_heads,
|
||||
})
|
||||
}
|
||||
|
||||
fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
|
||||
xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let in_dtype = xs.dtype();
|
||||
let (bsz, seq_len, embed_dim) = xs.dims3()?;
|
||||
|
||||
let query_states = (self.q_proj.forward(xs)? * self.scale)?;
|
||||
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
|
||||
let query_states = self
|
||||
.shape(&query_states, seq_len, bsz)?
|
||||
.reshape(proj_shape)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let key_states = self
|
||||
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
|
||||
.reshape(proj_shape)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let value_states = self
|
||||
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
|
||||
.reshape(proj_shape)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||
|
||||
let src_len = key_states.dim(1)?;
|
||||
|
||||
let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
|
||||
attn_weights
|
||||
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
|
||||
.broadcast_add(causal_attention_mask)?
|
||||
.reshape((bsz * self.num_attention_heads, seq_len, src_len))?
|
||||
} else {
|
||||
attn_weights
|
||||
};
|
||||
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
|
||||
let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
|
||||
let attn_output = attn_output
|
||||
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((bsz, seq_len, embed_dim))?;
|
||||
self.out_proj.forward(&attn_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ClipMlp {
|
||||
fc1: candle_nn::Linear,
|
||||
fc2: candle_nn::Linear,
|
||||
activation: Activation,
|
||||
}
|
||||
|
||||
impl ClipMlp {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
|
||||
let fc1 = candle_nn::linear(c.embed_dim(), c.intermediate_size(), vs.pp("fc1"))?;
|
||||
let fc2 = candle_nn::linear(c.intermediate_size(), c.embed_dim(), vs.pp("fc2"))?;
|
||||
|
||||
Ok(ClipMlp {
|
||||
fc1,
|
||||
fc2,
|
||||
activation: c.activation(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ClipMlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.fc1.forward(xs)?;
|
||||
self.fc2.forward(&self.activation.forward(&xs)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ClipEncoderLayer {
|
||||
self_attn: ClipAttention,
|
||||
layer_norm1: candle_nn::LayerNorm,
|
||||
mlp: ClipMlp,
|
||||
layer_norm2: candle_nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl ClipEncoderLayer {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
|
||||
let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
|
||||
let layer_norm1 = candle_nn::layer_norm(c.embed_dim(), 1e-5, vs.pp("layer_norm1"))?;
|
||||
let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
|
||||
let layer_norm2 = candle_nn::layer_norm(c.embed_dim(), 1e-5, vs.pp("layer_norm2"))?;
|
||||
|
||||
Ok(ClipEncoderLayer {
|
||||
self_attn,
|
||||
layer_norm1,
|
||||
mlp,
|
||||
layer_norm2,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.layer_norm1.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
|
||||
let xs = (xs + residual)?;
|
||||
|
||||
let residual = &xs;
|
||||
let xs = self.layer_norm2.forward(&xs)?;
|
||||
let xs = self.mlp.forward(&xs)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ClipEncoder {
|
||||
layers: Vec<ClipEncoderLayer>,
|
||||
}
|
||||
|
||||
impl ClipEncoder {
|
||||
pub fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
|
||||
let vs = vs.pp("layers");
|
||||
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
|
||||
for index in 0..c.num_hidden_layers() {
|
||||
let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(ClipEncoder { layers })
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, causal_attention_mask)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
/// A CLIP transformer based model.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ClipTextTransformer {
|
||||
embeddings: ClipTextEmbeddings,
|
||||
encoder: ClipEncoder,
|
||||
final_layer_norm: candle_nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl ClipTextTransformer {
|
||||
pub fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
|
||||
let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
|
||||
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?;
|
||||
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
|
||||
Ok(ClipTextTransformer {
|
||||
embeddings,
|
||||
encoder,
|
||||
final_layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
// TODO: rewrrite to newer version
|
||||
fn build_causal_attention_mask(
|
||||
bsz: usize,
|
||||
seq_len: usize,
|
||||
mask_after: usize,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..seq_len)
|
||||
.flat_map(|i| {
|
||||
(0..seq_len).map(move |j| {
|
||||
if j > i || j > mask_after {
|
||||
f32::MIN
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
||||
mask.broadcast_as((bsz, 1, seq_len, seq_len))
|
||||
}
|
||||
|
||||
pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
|
||||
let (bsz, seq_len) = input_ids.dims2()?;
|
||||
let input_ids = self.embeddings.forward(input_ids)?;
|
||||
let causal_attention_mask =
|
||||
Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
|
||||
let input_ids = self
|
||||
.encoder
|
||||
.forward(&input_ids, Some(&causal_attention_mask))?;
|
||||
self.final_layer_norm.forward(&input_ids)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ClipTextTransformer {
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let output = self.forward_with_mask(input_ids, usize::MAX)?;
|
||||
let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
|
||||
|
||||
let mut indices = Vec::new();
|
||||
for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
|
||||
let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
|
||||
indices.push(index);
|
||||
}
|
||||
Tensor::cat(&indices, 0)
|
||||
}
|
||||
}
|
147
candle-transformers/src/models/clip/vision_model.rs
Normal file
147
candle-transformers/src/models/clip/vision_model.rs
Normal file
@ -0,0 +1,147 @@
|
||||
//! Contrastive Language-Image Pre-Training
|
||||
//!
|
||||
//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||
//! pairs of images with related texts.
|
||||
//!
|
||||
//! https://github.com/openai/CLIP
|
||||
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
||||
|
||||
use candle::{IndexOp, Result, Shape, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
use candle_nn::Module;
|
||||
use nn::Conv2dConfig;
|
||||
|
||||
use super::{
|
||||
text_model::{Activation, ClipEncoder},
|
||||
EncoderConfig,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClipVisionConfig {
|
||||
pub embed_dim: usize,
|
||||
pub activation: Activation,
|
||||
pub intermediate_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
#[allow(dead_code)]
|
||||
pub projection_dim: usize,
|
||||
pub num_channels: usize,
|
||||
pub image_size: usize,
|
||||
pub patch_size: usize,
|
||||
}
|
||||
|
||||
impl ClipVisionConfig {
|
||||
// The config details can be found in the "vision_config" section of this json file:
|
||||
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
||||
pub fn vit_base_patch32() -> Self {
|
||||
Self {
|
||||
embed_dim: 768,
|
||||
activation: Activation::QuickGelu,
|
||||
intermediate_size: 3072,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
projection_dim: 512,
|
||||
num_channels: 3,
|
||||
image_size: 224,
|
||||
patch_size: 32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
|
||||
#[derive(Clone, Debug)]
|
||||
struct ClipVisionEmbeddings {
|
||||
patch_embedding: candle_nn::Conv2d,
|
||||
position_ids: Tensor,
|
||||
class_embedding: Tensor,
|
||||
position_embedding: candle_nn::Embedding,
|
||||
}
|
||||
|
||||
impl ClipVisionEmbeddings {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result<Self> {
|
||||
// originally nn.Parameter
|
||||
let class_embedding = if vs.contains_tensor("class_embedding") {
|
||||
vs.get(c.embed_dim, "class_embedding")?
|
||||
} else {
|
||||
Tensor::randn(0f32, 1f32, c.embed_dim, vs.device())?
|
||||
};
|
||||
|
||||
let num_patches = (c.image_size / c.patch_size).pow(2);
|
||||
let num_positions = num_patches + 1;
|
||||
let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?;
|
||||
|
||||
let conv2dconfig = Conv2dConfig {
|
||||
stride: c.patch_size,
|
||||
..Default::default()
|
||||
};
|
||||
let position_embedding =
|
||||
candle_nn::embedding(num_positions, c.embed_dim, vs.pp("position_embedding"))?;
|
||||
let patch_embedding = candle_nn::conv2d_no_bias(
|
||||
c.num_channels,
|
||||
c.embed_dim,
|
||||
c.patch_size,
|
||||
conv2dconfig,
|
||||
vs.pp("patch_embedding"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
patch_embedding,
|
||||
position_ids,
|
||||
class_embedding,
|
||||
position_embedding,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ClipVisionEmbeddings {
|
||||
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||
let batch_size = pixel_values.shape().dims();
|
||||
let patch_embeds = self
|
||||
.patch_embedding
|
||||
.forward(pixel_values)?
|
||||
.flatten_from(2)?
|
||||
.transpose(1, 2)?;
|
||||
let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?));
|
||||
let class_embeds = self.class_embedding.expand(shape)?;
|
||||
let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
|
||||
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
||||
embeddings.broadcast_add(&position_embedding)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L743
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ClipVisionTransformer {
|
||||
embeddings: ClipVisionEmbeddings,
|
||||
encoder: ClipEncoder,
|
||||
pre_layer_norm: candle_nn::LayerNorm,
|
||||
final_layer_norm: candle_nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl ClipVisionTransformer {
|
||||
pub fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result<Self> {
|
||||
let embeddings = ClipVisionEmbeddings::new(vs.pp("embeddings"), c)?;
|
||||
let pre_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("pre_layrnorm"))?;
|
||||
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Vision(c.clone()))?;
|
||||
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("post_layernorm"))?;
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
final_layer_norm,
|
||||
pre_layer_norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ClipVisionTransformer {
|
||||
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = pixel_values
|
||||
.apply(&self.embeddings)?
|
||||
.apply(&self.pre_layer_norm)?;
|
||||
|
||||
let encoder_outputs = self.encoder.forward(&hidden_states, None)?;
|
||||
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787
|
||||
// pooled_output = encoder_outputs[:, 0, :]
|
||||
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
||||
self.final_layer_norm.forward(&pooled_output)
|
||||
}
|
||||
}
|
@ -52,8 +52,8 @@ impl Module for Attention {
|
||||
.transpose(0, 1)? // 20134
|
||||
.transpose(2, 3)?; // 20314
|
||||
let q = (qkv.i(0)? * self.scale)?;
|
||||
let k = qkv.i(1)?;
|
||||
let v = qkv.i(2)?;
|
||||
let k = qkv.i(1)?.contiguous()?;
|
||||
let v = qkv.i(2)?.contiguous()?;
|
||||
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
|
||||
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
|
||||
self.proj.forward(&attn)
|
||||
|
@ -1,5 +1,6 @@
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
|
||||
@ -18,7 +19,7 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
}
|
||||
|
||||
// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
@ -119,7 +120,7 @@ fn rotate_half(x: &Tensor) -> Result<Tensor> {
|
||||
Ok(x21)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct FalconRotaryEmbedding {
|
||||
inv_freq: Tensor,
|
||||
cache: Option<(usize, Tensor, Tensor)>,
|
||||
@ -178,12 +179,14 @@ impl FalconRotaryEmbedding {
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let on_true = Tensor::new(on_true, on_false.device())?
|
||||
.to_dtype(on_false.dtype())?
|
||||
.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct FalconAttention {
|
||||
query_key_value: Linear,
|
||||
dense: Linear,
|
||||
@ -247,7 +250,7 @@ impl FalconAttention {
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> {
|
||||
let fused_qkv = self.query_key_value.forward(x)?;
|
||||
let head_dim = self.head_dim;
|
||||
let (query, key, value) = self.split_heads(&fused_qkv)?;
|
||||
@ -267,7 +270,6 @@ impl FalconAttention {
|
||||
(query, key)
|
||||
};
|
||||
let (mut key, mut value) = (key, value);
|
||||
let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?;
|
||||
if self.use_cache {
|
||||
if let Some((cache_k, cache_v)) = &self.kv_cache {
|
||||
// TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for
|
||||
@ -293,13 +295,18 @@ impl FalconAttention {
|
||||
|
||||
// Only handle the case where alibi is None here, and non-flash attention.
|
||||
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
|
||||
let attention_scores = candle_nn::ops::softmax(
|
||||
&attention_scores
|
||||
.broadcast_add(&mask.squeeze(1)?)?
|
||||
.to_dtype(DType::F32)?,
|
||||
D::Minus1,
|
||||
)?
|
||||
.to_dtype(x.dtype())?;
|
||||
let attention_scores = match mask {
|
||||
None => attention_scores,
|
||||
Some(mask) => {
|
||||
let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?
|
||||
.to_dtype(query.dtype())?;
|
||||
attention_scores.broadcast_add(&mask.squeeze(1)?)?
|
||||
}
|
||||
};
|
||||
|
||||
let attention_scores =
|
||||
candle_nn::ops::softmax(&attention_scores.to_dtype(DType::F32)?, D::Minus1)?
|
||||
.to_dtype(x.dtype())?;
|
||||
let attn_output = attention_scores
|
||||
.matmul(&value)?
|
||||
.reshape((b_sz, self.num_heads, seq_len, head_dim))?
|
||||
@ -308,9 +315,13 @@ impl FalconAttention {
|
||||
let attn_output = self.dense.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct FalconMlp {
|
||||
dense_h_to_4h: Linear,
|
||||
dense_4h_to_h: Linear,
|
||||
@ -335,7 +346,7 @@ impl FalconMlp {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct FalconDecoderLayer {
|
||||
inp_layernorm: LayerNorm,
|
||||
self_attention: FalconAttention,
|
||||
@ -372,7 +383,7 @@ impl FalconDecoderLayer {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> {
|
||||
let residual = x.clone();
|
||||
let ln_attn = self.inp_layernorm.forward(x)?;
|
||||
let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?;
|
||||
@ -395,9 +406,13 @@ impl FalconDecoderLayer {
|
||||
let output = (mlp_output + residual)?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.self_attention.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Falcon {
|
||||
word_embeddings: Embedding,
|
||||
blocks: Vec<FalconDecoderLayer>,
|
||||
@ -457,13 +472,23 @@ impl Falcon {
|
||||
Some((k, _)) => k.dim(1)?,
|
||||
None => 0,
|
||||
};
|
||||
let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?;
|
||||
let causal_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?)
|
||||
};
|
||||
for block in self.blocks.iter_mut() {
|
||||
hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?;
|
||||
hidden_state = block.forward(&hidden_state, causal_mask.as_ref(), past_kv_len)?;
|
||||
}
|
||||
let hidden_state = self.ln_f.forward(&hidden_state)?;
|
||||
let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?;
|
||||
let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?;
|
||||
Ok(logits)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for block in self.blocks.iter_mut() {
|
||||
block.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{linear_b as linear, Linear, VarBuilder};
|
||||
use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
|
||||
|
||||
fn default_max_position_embeddings() -> usize {
|
||||
4096
|
||||
@ -11,7 +11,9 @@ fn default_max_position_embeddings() -> usize {
|
||||
pub struct Config {
|
||||
pub attention_bias: bool,
|
||||
pub head_dim: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
// The code gemma configs include both hidden_act and hidden_activation.
|
||||
pub hidden_act: Option<Activation>,
|
||||
pub hidden_activation: Option<Activation>,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_attention_heads: usize,
|
||||
@ -25,6 +27,16 @@ pub struct Config {
|
||||
pub max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn hidden_act(&self) -> Result<Activation> {
|
||||
match (self.hidden_act, self.hidden_activation) {
|
||||
(None, Some(act)) | (Some(act), None) => Ok(act),
|
||||
(Some(_), Some(_)) => candle::bail!("both hidden_act and hidden_activation are set"),
|
||||
(None, None) => candle::bail!("none of hidden_act and hidden_activation are set"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RmsNorm {
|
||||
weight: Tensor,
|
||||
@ -126,7 +138,7 @@ impl MLP {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_act,
|
||||
act_fn: cfg.hidden_act()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -179,18 +191,6 @@ impl Attention {
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -227,8 +227,8 @@ impl Attention {
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
|
@ -240,8 +240,12 @@ impl CausalSelfAttention {
|
||||
let k = k.to_dtype(DType::F32)?;
|
||||
let v = v.to_dtype(DType::F32)?;
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = if seq_len == 1 {
|
||||
att
|
||||
} else {
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
masked_fill(&att, &mask, f32::NEG_INFINITY)?
|
||||
};
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
|
||||
@ -252,17 +256,7 @@ impl CausalSelfAttention {
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_attention_heads / self.num_key_value_heads;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
crate::utils::repeat_kv(x, self.num_attention_heads / self.num_key_value_heads)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
|
@ -194,8 +194,12 @@ impl CausalSelfAttention {
|
||||
let v = v.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = if seq_len <= 1 {
|
||||
att
|
||||
} else {
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
masked_fill(&att, &mask, f32::NEG_INFINITY)?
|
||||
};
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
|
@ -1,4 +1,3 @@
|
||||
#![allow(unused)]
|
||||
/// A fast implementation of mamba for inference only.
|
||||
/// This is based on: https://github.com/LaurentMazare/mamba.rs
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||
@ -10,10 +9,10 @@ const D_STATE: usize = 16;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
d_model: usize,
|
||||
n_layer: usize,
|
||||
vocab_size: usize,
|
||||
pad_vocab_size_multiple: usize,
|
||||
pub d_model: usize,
|
||||
pub n_layer: usize,
|
||||
pub vocab_size: usize,
|
||||
pub pad_vocab_size_multiple: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -38,12 +37,12 @@ pub struct State {
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn new(batch_size: usize, cfg: &Config, device: &Device) -> Result<Self> {
|
||||
pub fn new(batch_size: usize, cfg: &Config, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let mut hs = Vec::with_capacity(cfg.n_layer);
|
||||
let mut prev_xs = Vec::with_capacity(cfg.n_layer);
|
||||
for _i in 0..cfg.n_layer {
|
||||
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), DType::F32, device)?;
|
||||
let x = Tensor::zeros((batch_size, cfg.d_inner()), DType::F32, device)?;
|
||||
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), dtype, device)?;
|
||||
let x = Tensor::zeros((batch_size, cfg.d_inner()), dtype, device)?;
|
||||
hs.push(h);
|
||||
prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]);
|
||||
}
|
||||
@ -128,8 +127,8 @@ impl MambaBlock {
|
||||
let delta = delta.apply(&self.dt_proj)?;
|
||||
// softplus
|
||||
let delta = (delta.exp()? + 1.)?.log()?;
|
||||
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
|
||||
let d = self.d.to_dtype(candle::DType::F32)?;
|
||||
let a = self.a_log.to_dtype(delta.dtype())?.exp()?.neg()?;
|
||||
let d = self.d.to_dtype(delta.dtype())?;
|
||||
|
||||
// Selective scan part
|
||||
// Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t
|
||||
@ -178,6 +177,7 @@ pub struct Model {
|
||||
layers: Vec<ResidualBlock>,
|
||||
norm_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -196,6 +196,7 @@ impl Model {
|
||||
layers,
|
||||
norm_f,
|
||||
lm_head,
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -208,4 +209,8 @@ impl Model {
|
||||
state.pos += 1;
|
||||
xs.apply(&self.norm_f)?.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
}
|
||||
|
@ -88,13 +88,6 @@ struct RotaryEmbedding {
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let rope_theta = cfg.rope_theta as f32;
|
||||
@ -110,7 +103,6 @@ impl RotaryEmbedding {
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
@ -126,10 +118,8 @@ impl RotaryEmbedding {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
@ -226,18 +216,6 @@ impl Attention {
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -252,10 +230,12 @@ impl Attention {
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
@ -274,8 +254,8 @@ impl Attention {
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?;
|
||||
let value_states = self.repeat_kv(value_states)?;
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||
|
||||
let attn_output = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
|
@ -126,18 +126,11 @@ impl Module for Embedding {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
fn get_mask(size: usize, dtype: DType, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.flat_map(|i| (0..size).map(move |j| if j > i { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device)
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
Tensor::from_slice(&mask, (size, size), device)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -147,7 +140,7 @@ struct RotaryEmbedding {
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dim: usize, max_seq_len: usize, dev: &Device) -> Result<Self> {
|
||||
fn new(dim: usize, max_seq_len: usize, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||
@ -159,8 +152,8 @@ impl RotaryEmbedding {
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||
})
|
||||
}
|
||||
|
||||
@ -175,30 +168,14 @@ impl RotaryEmbedding {
|
||||
}
|
||||
let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?;
|
||||
let rotary_dim = rotary_dim * 2;
|
||||
let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?;
|
||||
let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?.contiguous()?;
|
||||
let q_pass = qkv.i((.., .., 0, .., rotary_dim..))?;
|
||||
let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?;
|
||||
let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?.contiguous()?;
|
||||
let k_pass = qkv.i((.., .., 1, .., rotary_dim..))?;
|
||||
let q12 = q_rot.chunk(2, D::Minus1)?;
|
||||
let k12 = k_rot.chunk(2, D::Minus1)?;
|
||||
let (q1, q2) = (&q12[0], &q12[1]);
|
||||
let (k1, k2) = (&k12[0], &k12[1]);
|
||||
let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
|
||||
let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
|
||||
let q_rot = Tensor::cat(
|
||||
&[
|
||||
(q1.broadcast_mul(&c)? - q2.broadcast_mul(&s)?)?,
|
||||
(q1.broadcast_mul(&s)? + q2.broadcast_mul(&c)?)?,
|
||||
],
|
||||
D::Minus1,
|
||||
)?;
|
||||
let k_rot = Tensor::cat(
|
||||
&[
|
||||
(k1.broadcast_mul(&c)? - k2.broadcast_mul(&s)?)?,
|
||||
(k1.broadcast_mul(&s)? + k2.broadcast_mul(&c)?)?,
|
||||
],
|
||||
D::Minus1,
|
||||
)?;
|
||||
let c = self.cos.narrow(0, seqlen_offset, seqlen)?;
|
||||
let s = self.sin.narrow(0, seqlen_offset, seqlen)?;
|
||||
let q_rot = candle_nn::rotary_emb::rope_thd(&q_rot, &c, &s)?;
|
||||
let k_rot = candle_nn::rotary_emb::rope_thd(&k_rot, &c, &s)?;
|
||||
let q = Tensor::cat(&[&q_rot, &q_pass], D::Minus1)?;
|
||||
let k = Tensor::cat(&[&k_rot, &k_pass], D::Minus1)?;
|
||||
let v = qkv.i((.., .., 2))?;
|
||||
@ -212,6 +189,7 @@ struct MLP {
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
act: Activation,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
@ -223,12 +201,14 @@ impl MLP {
|
||||
fc1,
|
||||
fc2,
|
||||
act: cfg.activation_function,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mlp"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
@ -263,9 +243,11 @@ struct MHA {
|
||||
rotary_emb: RotaryEmbedding,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
head_dim: usize,
|
||||
n_head: usize,
|
||||
softmax_scale: f64,
|
||||
span: tracing::Span,
|
||||
span_rope: tracing::Span,
|
||||
span_mask: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
}
|
||||
|
||||
impl MHA {
|
||||
@ -274,17 +256,20 @@ impl MHA {
|
||||
let op_size = cfg.n_embd;
|
||||
let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
|
||||
let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
|
||||
let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?;
|
||||
let rotary_emb =
|
||||
RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.dtype(), vb.device())?;
|
||||
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
||||
Ok(Self {
|
||||
wqkv,
|
||||
out_proj,
|
||||
head_dim,
|
||||
n_head: cfg.n_head,
|
||||
kv_cache: None,
|
||||
rotary_emb,
|
||||
softmax_scale,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mha"),
|
||||
span_rope: tracing::span!(tracing::Level::TRACE, "rope"),
|
||||
span_mask: tracing::span!(tracing::Level::TRACE, "mask"),
|
||||
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -300,7 +285,10 @@ impl MHA {
|
||||
Some((prev_k, _)) => prev_k.dim(1)?,
|
||||
};
|
||||
// In the python implementation, a single tensor is returned with the third axis of size 3.
|
||||
let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?;
|
||||
let (q, k, v) = {
|
||||
let _enter = self.span_rope.enter();
|
||||
self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?
|
||||
};
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((prev_k, prev_v)) => {
|
||||
@ -320,13 +308,15 @@ impl MHA {
|
||||
// scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||
let attn_weights = match mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => masked_fill(
|
||||
&attn_weights,
|
||||
&mask.broadcast_left(b_size * self.n_head)?,
|
||||
f32::NEG_INFINITY,
|
||||
)?,
|
||||
Some(mask) => {
|
||||
let _enter = self.span_mask.enter();
|
||||
attn_weights.broadcast_add(mask)?
|
||||
}
|
||||
};
|
||||
let attn_weights = {
|
||||
let _enter = self.span_softmax.enter();
|
||||
candle_nn::ops::softmax_last_dim(&attn_weights)?
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
|
||||
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
||||
// attn_weights: b*h,t,s, v: b*h,s,d
|
||||
@ -430,7 +420,7 @@ impl MixFormerSequentialForCausalLM {
|
||||
let mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(get_mask(seq_len, xs.device())?)
|
||||
Some(get_mask(seq_len, xs.dtype(), xs.device())?)
|
||||
};
|
||||
for block in self.blocks.iter_mut() {
|
||||
xs = block.forward(&xs, mask.as_ref())?
|
||||
@ -438,6 +428,30 @@ impl MixFormerSequentialForCausalLM {
|
||||
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
|
||||
}
|
||||
|
||||
pub fn forward_with_img(
|
||||
&mut self,
|
||||
bos_token: &Tensor,
|
||||
xs: &Tensor,
|
||||
img_embeds: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = xs.apply(&self.embedding)?;
|
||||
let bos_token = bos_token.apply(&self.embedding)?;
|
||||
// Python implementation sequence order is <bos token embedding><img embedding><rest of text embedding>
|
||||
// https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L43-L56
|
||||
let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?;
|
||||
let (_b_size, seq_len, _embds) = xs.dims3()?;
|
||||
let mask = Some(get_mask(seq_len, xs.dtype(), xs.device())?);
|
||||
for block in self.blocks.iter_mut() {
|
||||
xs = block.forward(&xs, mask.as_ref())?
|
||||
}
|
||||
let xs = xs
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.head)?
|
||||
.squeeze(1)?;
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
|
||||
}
|
||||
|
@ -158,18 +158,6 @@ impl Attention {
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -206,8 +194,8 @@ impl Attention {
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?;
|
||||
let value_states = self.repeat_kv(value_states)?;
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||
|
||||
let attn_output = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
|
@ -3,6 +3,7 @@ pub mod bigcode;
|
||||
pub mod blip;
|
||||
pub mod blip_text;
|
||||
pub mod chatglm;
|
||||
pub mod clip;
|
||||
pub mod convmixer;
|
||||
pub mod convnext;
|
||||
pub mod dinov2;
|
||||
@ -23,6 +24,7 @@ pub mod mistral;
|
||||
pub mod mixformer;
|
||||
pub mod mixtral;
|
||||
pub mod mobileone;
|
||||
pub mod moondream;
|
||||
pub mod mpt;
|
||||
pub mod persimmon;
|
||||
pub mod phi;
|
||||
@ -33,12 +35,16 @@ pub mod quantized_llama2_c;
|
||||
pub mod quantized_metavoice;
|
||||
pub mod quantized_mistral;
|
||||
pub mod quantized_mixformer;
|
||||
pub mod quantized_moondream;
|
||||
pub mod quantized_mpt;
|
||||
pub mod quantized_recurrent_gemma;
|
||||
pub mod quantized_rwkv_v5;
|
||||
pub mod quantized_rwkv_v6;
|
||||
pub mod quantized_stable_lm;
|
||||
pub mod quantized_t5;
|
||||
pub mod qwen2;
|
||||
pub mod qwen2_moe;
|
||||
pub mod recurrent_gemma;
|
||||
pub mod repvgg;
|
||||
pub mod resnet;
|
||||
pub mod rwkv_v5;
|
||||
|
327
candle-transformers/src/models/moondream.rs
Normal file
327
candle-transformers/src/models/moondream.rs
Normal file
@ -0,0 +1,327 @@
|
||||
use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel};
|
||||
use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear};
|
||||
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
pub struct Config {
|
||||
pub phi_config: PhiConfig,
|
||||
pub vision_config: VisionConfig,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn v2() -> Self {
|
||||
Self {
|
||||
phi_config: PhiConfig::v1_5(),
|
||||
vision_config: VisionConfig::v2(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
let dim = q.dim(D::Minus1)?;
|
||||
let scale_factor = 1.0 / (dim as f64).sqrt();
|
||||
let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
|
||||
candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||
pub struct VisionConfig {
|
||||
pub(crate) image_embedding_dim: usize,
|
||||
pub(crate) model_dim: usize,
|
||||
pub(crate) hidden_dim: usize,
|
||||
pub(crate) hidden_features: usize,
|
||||
pub(crate) embed_len: usize,
|
||||
pub(crate) embed_dim: usize,
|
||||
pub(crate) num_blocks: usize,
|
||||
pub(crate) num_heads: usize,
|
||||
pub(crate) act: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl VisionConfig {
|
||||
pub fn v2() -> Self {
|
||||
Self {
|
||||
image_embedding_dim: 1152,
|
||||
model_dim: 2048,
|
||||
hidden_dim: 2048 * 4,
|
||||
hidden_features: 4304,
|
||||
embed_len: 729,
|
||||
embed_dim: 1152,
|
||||
num_blocks: 27,
|
||||
num_heads: 16,
|
||||
act: candle_nn::Activation::GeluPytorchTanh,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LinearPatchEmbedding {
|
||||
linear: Linear,
|
||||
}
|
||||
|
||||
impl LinearPatchEmbedding {
|
||||
fn new(vb: VarBuilder) -> Result<Self> {
|
||||
let linear = linear_b(588, 1152, true, vb.pp("linear"))?;
|
||||
Ok(Self { linear })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LinearPatchEmbedding {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.linear)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
qkv: Linear,
|
||||
proj: Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
pub fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
|
||||
let qkv = linear_b(dim, dim * 3, true, vb.pp("qkv"))?;
|
||||
let proj = linear_b(dim, dim, true, vb.pp("proj"))?;
|
||||
Ok(Self {
|
||||
num_heads,
|
||||
head_dim: dim / num_heads,
|
||||
qkv,
|
||||
proj,
|
||||
span: tracing::span!(tracing::Level::TRACE, "vit-attn"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b, n, c) = xs.dims3()?;
|
||||
let qkv = xs
|
||||
.apply(&self.qkv)?
|
||||
.reshape((b, n, 3, self.num_heads, self.head_dim))?
|
||||
.permute((2, 0, 3, 1, 4))?;
|
||||
let (q, k, v) = (
|
||||
qkv.i(0)?.contiguous()?,
|
||||
qkv.i(1)?.contiguous()?,
|
||||
qkv.i(2)?.contiguous()?,
|
||||
);
|
||||
scaled_dot_product_attention(&q, &k, &v)?
|
||||
.transpose(1, 2)?
|
||||
.reshape((b, n, c))?
|
||||
.apply(&self.proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct VitBlock {
|
||||
attn: Attention,
|
||||
mlp: Mlp,
|
||||
norm1: LayerNorm,
|
||||
norm2: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl VitBlock {
|
||||
fn new(vb: VarBuilder, dim: usize, num_heads: usize, cfg: &VisionConfig) -> Result<Self> {
|
||||
let attn = Attention::new(vb.pp("attn"), dim, num_heads)?;
|
||||
let mlp = Mlp::new(vb.pp("mlp"), dim, cfg.hidden_features, dim, cfg.act)?;
|
||||
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
|
||||
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
|
||||
Ok(Self {
|
||||
attn,
|
||||
mlp,
|
||||
norm1,
|
||||
norm2,
|
||||
span: tracing::span!(tracing::Level::TRACE, "vit-block"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VitBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let ys = xs.apply(&self.norm1)?.apply(&self.attn)?;
|
||||
let xs = (xs + &ys)?;
|
||||
let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?;
|
||||
let xs = (&xs + &ys)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct VisionTransformer {
|
||||
patch_embed: LinearPatchEmbedding,
|
||||
pos_embed: Tensor,
|
||||
blocks: Vec<VitBlock>,
|
||||
norm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl VisionTransformer {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let patch_embed = LinearPatchEmbedding::new(vb.pp("patch_embed"))?;
|
||||
let pos_embed = vb.get((1, cfg.embed_len, cfg.embed_dim), "pos_embed")?;
|
||||
let blocks = (0..cfg.num_blocks)
|
||||
.map(|i| {
|
||||
VitBlock::new(
|
||||
vb.pp(&format!("blocks.{}", i)),
|
||||
cfg.embed_dim,
|
||||
cfg.num_heads,
|
||||
cfg,
|
||||
)
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
let norm = layer_norm(cfg.embed_dim, 1e-5, vb.pp("norm"))?;
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
pos_embed,
|
||||
blocks,
|
||||
norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "vit"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VisionTransformer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?;
|
||||
for block in self.blocks.iter() {
|
||||
xs = xs.apply(block)?;
|
||||
}
|
||||
xs.apply(&self.norm)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Encoder {
|
||||
model: VisionTransformer,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let model = VisionTransformer::new(cfg, vb.pp("model.visual"))?;
|
||||
Ok(Self { model })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Encoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.model)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
fc1: Linear,
|
||||
act: candle_nn::Activation,
|
||||
fc2: Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
in_features: usize,
|
||||
hidden_features: usize,
|
||||
out_features: usize,
|
||||
act: candle_nn::Activation,
|
||||
) -> Result<Self> {
|
||||
let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?;
|
||||
let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?;
|
||||
Ok(Self {
|
||||
fc1,
|
||||
act,
|
||||
fc2,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mlp"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct VisionProjection {
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl VisionProjection {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let mlp = Mlp::new(
|
||||
vb.pp("mlp"),
|
||||
cfg.image_embedding_dim,
|
||||
cfg.hidden_dim,
|
||||
cfg.model_dim,
|
||||
cfg.act,
|
||||
)?;
|
||||
Ok(Self { mlp })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VisionProjection {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.mlp)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VisionEncoder {
|
||||
encoder: Encoder,
|
||||
projection: VisionProjection,
|
||||
}
|
||||
|
||||
impl VisionEncoder {
|
||||
pub fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
|
||||
let projection = VisionProjection::new(cfg, vb.pp("projection"))?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
projection,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VisionEncoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b, c, hp1, wp2) = xs.dims4()?;
|
||||
let (p1, p2) = (14, 14);
|
||||
let h = hp1 / p1;
|
||||
let w = wp2 / p2;
|
||||
xs.reshape((b, c, h, p1, h, p2))?
|
||||
.permute((0, 2, 4, 1, 3, 5))?
|
||||
.reshape((b, h * w, c * p1 * p2))?
|
||||
.apply(&self.encoder)?
|
||||
.apply(&self.projection)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Model {
|
||||
pub text_model: PhiModel,
|
||||
pub vision_encoder: VisionEncoder,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(config: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let text_model = PhiModel::new_v2(&config.phi_config, vb.pp("text_model"))?;
|
||||
let vision_encoder = VisionEncoder::new(&config.vision_config, vb.pp("vision_encoder"))?;
|
||||
Ok(Self {
|
||||
text_model,
|
||||
vision_encoder,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn vision_encoder(&self) -> &VisionEncoder {
|
||||
&self.vision_encoder
|
||||
}
|
||||
|
||||
pub fn text_model(&mut self) -> &mut PhiModel {
|
||||
&mut self.text_model
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user