mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Merge branch 'main' into metal-mfa-bfloat
This commit is contained in:
27
Cargo.toml
27
Cargo.toml
@ -9,6 +9,7 @@ members = [
|
||||
"candle-transformers",
|
||||
"candle-wasm-examples/*",
|
||||
"candle-wasm-tests",
|
||||
"tensor-tools",
|
||||
]
|
||||
exclude = [
|
||||
"candle-flash-attn",
|
||||
@ -19,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.4.1"
|
||||
version = "0.5.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -28,17 +29,18 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.4.1" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.4.1" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.1" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.4.1" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.1" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.4.1" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.4.1" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.4.1" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.5.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.5.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.5.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.5.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.5.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.5.0" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||
@ -46,19 +48,18 @@ fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
|
||||
imageproc = { version = "0.23.0", default-features = false }
|
||||
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
|
||||
imageproc = { version = "0.24.0", default-features = false }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
parquet = { version = "50.0.0" }
|
||||
parquet = { version = "51.0.0" }
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
rayon = "1.7.0"
|
||||
rusttype = { version = "0.9", default-features = false }
|
||||
safetensors = "0.4.1"
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_plain = "1.0.2"
|
||||
|
@ -125,10 +125,14 @@ We also provide a some command line based examples using state of the art models
|
||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
generate captions for an image.
|
||||
- [CLIP](./candle-examples/examples/clip/): multi-model vision and language
|
||||
model.
|
||||
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
|
||||
dedicated submodels for hand-writing and printed recognition.
|
||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||
model, generates the translated text from the input text.
|
||||
- [Moondream](./candle-examples/examples/moondream/): tiny computer-vision model
|
||||
that can answer real-world questions about images.
|
||||
|
||||
Run them using commands like:
|
||||
```
|
||||
@ -172,9 +176,11 @@ And then head over to
|
||||
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
|
||||
serving local LLMs including an OpenAI compatible API server.
|
||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||
- [`candle-coursera-ml`](https://github.com/vishpat/candle-coursera-ml): Implementation of ML algorithms from Coursera's [Machine Learning Specialization](https://www.coursera.org/specializations/machine-learning-introduction) course.
|
||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
||||
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
|
||||
|
||||
If you have an addition to this list, please submit a pull request.
|
||||
|
||||
@ -205,7 +211,7 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Replit-code-v1.5-3B.
|
||||
- Bert.
|
||||
- Yi-6B and Yi-34B.
|
||||
- Qwen1.5.
|
||||
- Qwen1.5, Qwen1.5 MoE.
|
||||
- RWKV v5 and v6.
|
||||
- Quantized LLMs.
|
||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||
|
@ -2,8 +2,9 @@ mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
criterion_main!(
|
||||
//benchmarks::affine::benches,
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::matmul::benches,
|
||||
//benchmarks::random::benches,
|
||||
//benchmarks::where_cond::benches
|
||||
benchmarks::random::benches,
|
||||
benchmarks::where_cond::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
);
|
||||
|
59
candle-core/benches/benchmarks/conv_transpose2d.rs
Normal file
59
candle-core/benches/benchmarks/conv_transpose2d.rs
Normal file
@ -0,0 +1,59 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(
|
||||
x: &Tensor,
|
||||
k: &Tensor,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
) {
|
||||
x.conv_transpose2d(k, padding, output_padding, stride, dilation)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let t = Tensor::arange(0.0f32, 10000.0, device)
|
||||
.unwrap()
|
||||
.reshape((1, 4, 50, 50))
|
||||
.unwrap()
|
||||
.to_dtype(dtype)
|
||||
.unwrap();
|
||||
|
||||
let kernel = Tensor::arange(0.0f32, 100.0, device)
|
||||
.unwrap()
|
||||
.reshape((4, 1, 5, 5))
|
||||
.unwrap()
|
||||
.to_dtype(dtype)
|
||||
.unwrap();
|
||||
|
||||
let flops = t.dims().iter().product::<usize>() * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&t), black_box(&kernel), 1, 0, 1, 2);
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32");
|
||||
run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16");
|
||||
run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16");
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -1,4 +1,5 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod conv_transpose2d;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod where_cond;
|
||||
|
@ -98,6 +98,19 @@ pub trait BackendStorage: Sized {
|
||||
) -> Result<Self>;
|
||||
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
// Similar to cudaMemcpy2D, though values are in elements and not in bytes.
|
||||
fn copy2d(
|
||||
&self,
|
||||
_: &mut Self,
|
||||
_d1: usize,
|
||||
_d2: usize,
|
||||
_src_stride1: usize,
|
||||
_dst_stride1: usize,
|
||||
_src_offset: usize,
|
||||
_dst_offset: usize,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
@ -114,8 +127,16 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
/// # Safety
|
||||
/// This function is unsafe as it doesn't initialize the underlying data store.
|
||||
/// The caller should ensure that the data is properly initialized as early as possible
|
||||
/// after this call.
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
||||
|
||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
|
||||
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
@ -1,3 +1,4 @@
|
||||
/// Methods for backpropagation of gradients.
|
||||
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::{Error, Result, Tensor, TensorId};
|
||||
use std::collections::HashMap;
|
||||
@ -111,7 +112,8 @@ impl Tensor {
|
||||
}
|
||||
Op::Unary(_node, UnaryOp::Ceil)
|
||||
| Op::Unary(_node, UnaryOp::Floor)
|
||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
| Op::Unary(_node, UnaryOp::Round)
|
||||
| Op::Unary(_node, UnaryOp::Sign) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D { arg: node, .. }
|
||||
| Op::UpsampleNearest2D { arg: node, .. }
|
||||
@ -310,9 +312,32 @@ impl Tensor {
|
||||
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "conv-transpose1d",
|
||||
})?,
|
||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "conv-transpose2d",
|
||||
})?,
|
||||
Op::ConvTranspose2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
output_padding: _output_padding,
|
||||
} => {
|
||||
let grad_arg = grad.conv2d(kernel, *padding, *dilation, *stride, 1)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
|
||||
let grad_kernel = grad
|
||||
.transpose(0, 1)?
|
||||
.conv2d(&arg.transpose(0, 1)?, *padding, *stride, *dilation, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
let (_, _, k0, k1) = kernel.dims4()?;
|
||||
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
|
||||
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
|
||||
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
|
||||
} else {
|
||||
grad_kernel
|
||||
};
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
}
|
||||
Op::AvgPool2D {
|
||||
arg,
|
||||
kernel_size,
|
||||
@ -464,7 +489,6 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad)?;
|
||||
}
|
||||
Op::Cmp(_args, _) => {}
|
||||
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
|
||||
let node = broadcast_back(arg, node, reduced_dims)?;
|
||||
let grad = broadcast_back(arg, &grad, reduced_dims)?;
|
||||
@ -554,20 +578,18 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
|
||||
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
|
||||
Op::Unary(_, UnaryOp::Floor)
|
||||
| Op::Unary(_, UnaryOp::Round)
|
||||
| Op::Reduce(_, ReduceOp::ArgMin, _)
|
||||
| Op::Reduce(_, ReduceOp::ArgMax, _)
|
||||
| Op::Unary(_, UnaryOp::Sign)
|
||||
| Op::Cmp(_, _) => {}
|
||||
Op::Reshape(arg) => {
|
||||
let arg_grad = grad.reshape(arg.dims())?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
||||
Op::Unary(_, UnaryOp::Floor) => {
|
||||
Err(Error::BackwardNotSupported { op: "floor" })?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Round) => {
|
||||
Err(Error::BackwardNotSupported { op: "round" })?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Gelu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let cube = arg.powf(3.)?;
|
||||
@ -690,30 +712,38 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation.
|
||||
#[derive(Debug)]
|
||||
pub struct GradStore(HashMap<TensorId, Tensor>);
|
||||
|
||||
impl GradStore {
|
||||
/// Create a new gradient store
|
||||
fn new() -> Self {
|
||||
GradStore(HashMap::new())
|
||||
}
|
||||
|
||||
/// Get the gradient tensor corresponding to the given tensor id
|
||||
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
|
||||
self.0.get(&id)
|
||||
}
|
||||
|
||||
/// Get the gradient tensor associated with the given tensor
|
||||
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
|
||||
self.0.get(&tensor.id())
|
||||
}
|
||||
|
||||
/// Remove the gradient tensor associated with the given tensor, returning it if it exists
|
||||
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
|
||||
self.0.remove(&tensor.id())
|
||||
}
|
||||
|
||||
/// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed
|
||||
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
|
||||
self.0.insert(tensor.id(), grad)
|
||||
}
|
||||
|
||||
/// Get the gradient tensor associated with the given tensor, or, if it does not exist,
|
||||
/// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
|
||||
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
|
||||
use std::collections::hash_map::Entry;
|
||||
let grad = match self.0.entry(tensor.id()) {
|
||||
|
@ -4,7 +4,13 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
||||
use half::{bf16, f16};
|
||||
use rayon::prelude::*;
|
||||
|
||||
mod utils;
|
||||
pub use utils::{
|
||||
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
|
||||
};
|
||||
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
const USE_IM2COL_CONV1D_TR: bool = true;
|
||||
const USE_IM2COL_CONV2D: bool = true;
|
||||
|
||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||
@ -23,102 +29,6 @@ pub enum CpuStorage {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CpuDevice;
|
||||
|
||||
pub trait Map1 {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
||||
|
||||
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
||||
match vs {
|
||||
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
|
||||
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
|
||||
CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
|
||||
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
|
||||
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
|
||||
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
|
||||
CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map1Any {
|
||||
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
|
||||
&self,
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
wrap: W,
|
||||
) -> Result<CpuStorage>;
|
||||
|
||||
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
||||
match vs {
|
||||
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
|
||||
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
|
||||
CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
|
||||
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
|
||||
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
|
||||
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
|
||||
CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type C = CpuStorage;
|
||||
pub trait Map2 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
||||
|
||||
fn map(
|
||||
&self,
|
||||
v1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
v2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
) -> Result<CpuStorage> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
||||
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2U8 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||
|
||||
fn map(
|
||||
&self,
|
||||
v1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
v2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
) -> Result<CpuStorage> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Cmp(CmpOp);
|
||||
impl Map2U8 for Cmp {
|
||||
const OP: &'static str = "cmp";
|
||||
@ -365,275 +275,6 @@ impl<'a> Map1 for ReduceSum<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
mut f: F,
|
||||
) -> Vec<U> {
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
||||
[start_offset..start_offset + len]
|
||||
.iter()
|
||||
.map(|&v| f(v))
|
||||
.collect(),
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
||||
// Specialize the case where block_len is one to avoid the second loop.
|
||||
if block_len == 1 {
|
||||
for index in block_start_index {
|
||||
let v = unsafe { vs.get_unchecked(index) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
} else {
|
||||
for index in block_start_index {
|
||||
for offset in 0..block_len {
|
||||
let v = unsafe { vs.get_unchecked(index + offset) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
mut f: F,
|
||||
mut f_vec: FV,
|
||||
) -> Vec<U> {
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(len);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(len) };
|
||||
ys
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
let el_count = layout.shape().elem_count();
|
||||
// Specialize the case where block_len is one to avoid the second loop.
|
||||
if block_len == 1 {
|
||||
let mut result = Vec::with_capacity(el_count);
|
||||
for index in block_start_index {
|
||||
let v = unsafe { vs.get_unchecked(index) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
result
|
||||
} else {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||
let mut dst_index = 0;
|
||||
for src_index in block_start_index {
|
||||
let vs = &vs[src_index..src_index + block_len];
|
||||
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
||||
f_vec(vs, ys);
|
||||
dst_index += block_len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This function maps over two strided index sequences.
|
||||
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
mut f: F,
|
||||
) -> Vec<U> {
|
||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
||||
.iter()
|
||||
.zip(rhs[o_r1..o_r2].iter())
|
||||
.map(|(&l, &r)| f(l, r))
|
||||
.collect(),
|
||||
(Some((o_l1, o_l2)), None) => {
|
||||
// TODO: Maybe we want to avoid going through the layout twice.
|
||||
match rhs_l.offsets_b() {
|
||||
Some(ob) => {
|
||||
let mut i_in_block = 0;
|
||||
let mut i_right_broadcast = 0;
|
||||
lhs[o_l1..o_l2]
|
||||
.iter()
|
||||
.map(|&l| {
|
||||
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
||||
i_right_broadcast += 1;
|
||||
if i_right_broadcast >= ob.right_broadcast {
|
||||
i_in_block += 1;
|
||||
i_right_broadcast = 0;
|
||||
}
|
||||
if i_in_block >= ob.len {
|
||||
i_in_block = 0
|
||||
}
|
||||
f(l, *r)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
(None, Some((o_r1, o_r2))) => {
|
||||
// TODO: Maybe we want to avoid going through the layout twice.
|
||||
match lhs_l.offsets_b() {
|
||||
Some(ob) => {
|
||||
let mut i_in_block = 0;
|
||||
let mut i_right_broadcast = 0;
|
||||
rhs[o_r1..o_r2]
|
||||
.iter()
|
||||
.map(|&r| {
|
||||
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
||||
i_right_broadcast += 1;
|
||||
if i_right_broadcast >= ob.right_broadcast {
|
||||
i_in_block += 1;
|
||||
i_right_broadcast = 0;
|
||||
}
|
||||
if i_in_block >= ob.len {
|
||||
i_in_block = 0
|
||||
}
|
||||
f(*l, r)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
_ => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
// Similar to binary_map but with vectorized variants.
|
||||
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
mut f: F,
|
||||
mut f_vec: FV,
|
||||
) -> Vec<T> {
|
||||
let el_count = lhs_l.shape().elem_count();
|
||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
||||
Some(ob) if ob.right_broadcast == 1 => {
|
||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
||||
f_vec(
|
||||
&lhs[src_i..src_i + ob.len],
|
||||
rhs,
|
||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||
);
|
||||
dst_i += ob.len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
Some(ob) => {
|
||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||
let mut ys = lhs[o_l1..o_l2].to_vec();
|
||||
for idx_l in 0..ob.left_broadcast {
|
||||
let start = idx_l * ob.len * ob.right_broadcast;
|
||||
for (i, &r) in rhs.iter().enumerate() {
|
||||
let start = start + i * ob.right_broadcast;
|
||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||
*v = f(*v, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
ys
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
},
|
||||
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
||||
Some(ob) if ob.right_broadcast == 1 => {
|
||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
||||
f_vec(
|
||||
lhs,
|
||||
&rhs[src_i..src_i + ob.len],
|
||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||
);
|
||||
dst_i += ob.len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
Some(ob) => {
|
||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||
let mut ys = rhs[o_r1..o_r2].to_vec();
|
||||
for idx_l in 0..ob.left_broadcast {
|
||||
let start = idx_l * ob.len * ob.right_broadcast;
|
||||
for (i, &l) in lhs.iter().enumerate() {
|
||||
let start = start + i * ob.right_broadcast;
|
||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||
*v = f(l, *v)
|
||||
}
|
||||
}
|
||||
}
|
||||
ys
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
},
|
||||
_ => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
struct Affine(f64, f64);
|
||||
|
||||
impl Map1 for Affine {
|
||||
@ -1022,6 +663,26 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn copy2d_<T: Copy>(
|
||||
src: &[T],
|
||||
dst: &mut [T],
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_stride1: usize,
|
||||
dst_stride1: usize,
|
||||
src_offset: usize,
|
||||
dst_offset: usize,
|
||||
) {
|
||||
for i1 in 0..d1 {
|
||||
let dst_idx = i1 * dst_stride1 + dst_offset;
|
||||
let src_idx = i1 * src_stride1 + src_offset;
|
||||
let dst = &mut dst[dst_idx..dst_idx + d2];
|
||||
let src = &src[src_idx..src_idx + d2];
|
||||
dst.copy_from_slice(src)
|
||||
}
|
||||
}
|
||||
|
||||
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
|
||||
match src_l.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
@ -1256,6 +917,34 @@ impl Map1 for Im2Col {
|
||||
}
|
||||
}
|
||||
|
||||
struct Col2Im1D {
|
||||
stride: usize,
|
||||
}
|
||||
|
||||
impl Map1 for Col2Im1D {
|
||||
fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
|
||||
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
|
||||
let stride = self.stride;
|
||||
let l_out = (l_in - 1) * stride + k_size;
|
||||
let mut im = vec![T::zero(); b_size * c_out * l_out];
|
||||
let (dst_s0, dst_s1) = (c_out * l_out, l_out);
|
||||
let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
|
||||
for l_in_i in 0..l_in {
|
||||
for k_i in 0..k_size {
|
||||
let l_out_i = l_in_i * stride + k_i;
|
||||
for b_i in 0..b_size {
|
||||
for c_i in 0..c_out {
|
||||
let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
|
||||
let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
|
||||
im[dst_idx] += col[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(im)
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||
|
||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
@ -1515,6 +1204,30 @@ impl MatMul {
|
||||
}))
|
||||
.bt()
|
||||
}
|
||||
|
||||
fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
let rank = lhs_stride.len();
|
||||
let (_b, m, n, k) = self.0;
|
||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[_, stride] if lhs_l.dims()[0] == 1 => stride,
|
||||
[stride, _] if lhs_l.dims()[1] == 1 => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||
};
|
||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[_, stride] if rhs_l.dims()[0] == 1 => stride,
|
||||
[stride, _] if rhs_l.dims()[1] == 1 => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||
};
|
||||
Ok((a_skip, b_skip))
|
||||
}
|
||||
}
|
||||
|
||||
impl Map2 for MatMul {
|
||||
@ -1548,18 +1261,7 @@ impl Map2 for MatMul {
|
||||
let rhs_cs = rhs_stride[rank - 1];
|
||||
let rhs_rs = rhs_stride[rank - 2];
|
||||
|
||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||
};
|
||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||
};
|
||||
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let dst_shape: Shape = (m, n).into();
|
||||
@ -1619,20 +1321,8 @@ impl Map2 for MatMul {
|
||||
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
let rank = lhs_stride.len();
|
||||
|
||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||
};
|
||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||
};
|
||||
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
@ -1640,7 +1330,7 @@ impl Map2 for MatMul {
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
|
||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||
(n as i32, b'N')
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
(k as i32, b'T')
|
||||
@ -1648,7 +1338,7 @@ impl Map2 for MatMul {
|
||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||
};
|
||||
// The b tensor has dims batching, m, k (lhs)
|
||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||
(k as i32, b'N')
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
(m as i32, b'T')
|
||||
@ -1722,20 +1412,8 @@ impl Map2 for MatMul {
|
||||
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
let rank = lhs_stride.len();
|
||||
|
||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
|
||||
};
|
||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
|
||||
};
|
||||
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
@ -1743,7 +1421,7 @@ impl Map2 for MatMul {
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
|
||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||
(n as i32, b'N')
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
(k as i32, b'T')
|
||||
@ -1751,7 +1429,7 @@ impl Map2 for MatMul {
|
||||
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
|
||||
};
|
||||
// The b tensor has dims batching, m, k (lhs)
|
||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||
(k as i32, b'N')
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
(m as i32, b'T')
|
||||
@ -2423,6 +2101,48 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o: usize,
|
||||
dst_o: usize,
|
||||
) -> Result<()> {
|
||||
match (self, dst) {
|
||||
(Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
|
||||
(Self::U32(src), Self::U32(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(Self::I64(src), Self::I64(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(Self::BF16(src), Self::BF16(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(Self::F16(src), Self::F16(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(Self::F32(src), Self::F32(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(Self::F64(src), Self::F64(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(_, dst) => {
|
||||
return Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: dst.dtype(),
|
||||
op: "copy2d",
|
||||
}
|
||||
.bt());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
match (self, dst) {
|
||||
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
@ -2491,7 +2211,10 @@ impl BackendStorage for CpuStorage {
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
let mut kernel_c = unsafe {
|
||||
self.device()
|
||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||
};
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
@ -2499,7 +2222,7 @@ impl BackendStorage for CpuStorage {
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
@ -2511,7 +2234,52 @@ impl BackendStorage for CpuStorage {
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
||||
let can_use_col2im = kernel_l.is_contiguous()
|
||||
&& params.dilation == 1
|
||||
&& params.padding == 0
|
||||
&& params.output_padding == 0;
|
||||
if USE_IM2COL_CONV1D_TR && can_use_col2im {
|
||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||
if !kernel_l.is_contiguous() {
|
||||
crate::bail!(
|
||||
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
|
||||
)
|
||||
}
|
||||
if c_in != c_in2 {
|
||||
crate::bail!(
|
||||
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||
l.shape(),
|
||||
kernel_l.shape()
|
||||
)
|
||||
}
|
||||
let col = {
|
||||
// This merges the last two dimensions of the kernel together.
|
||||
let kernel_l_mm = Layout::new(
|
||||
(b_size, c_in, k_size * c_out).into(),
|
||||
vec![0, k_size * c_out, 1],
|
||||
kernel_l.start_offset(),
|
||||
);
|
||||
self.matmul(
|
||||
kernel,
|
||||
(
|
||||
b_size,
|
||||
/* m */ l_in,
|
||||
/* n */ c_out * k_size,
|
||||
/* k */ c_in,
|
||||
),
|
||||
&l.transpose(1, 2)?,
|
||||
&kernel_l_mm,
|
||||
)?
|
||||
};
|
||||
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
||||
Col2Im1D {
|
||||
stride: params.stride,
|
||||
}
|
||||
.map(&col, &col_l)
|
||||
} else {
|
||||
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
||||
}
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
@ -2545,7 +2313,10 @@ impl BackendStorage for CpuStorage {
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
let mut kernel_c = unsafe {
|
||||
self.device()
|
||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||
};
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
@ -2555,7 +2326,7 @@ impl BackendStorage for CpuStorage {
|
||||
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
|
||||
.transpose(1, 2)?
|
||||
.transpose(1, 3)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
@ -2678,6 +2449,10 @@ impl BackendDevice for CpuDevice {
|
||||
Ok(s.clone())
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> {
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
fn new(_: usize) -> Result<Self> {
|
||||
Ok(Self)
|
||||
}
|
||||
@ -2779,6 +2554,53 @@ impl BackendDevice for CpuDevice {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::uninit_vec)]
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
// The code below is highly unsafe but hopefully not directly unsound as we only consider
|
||||
// types that are Copy, not Drop, and for which all bit patterns are proper values.
|
||||
// It's still pretty risky, see the following for more details:
|
||||
// https://github.com/rust-lang/rust-clippy/issues/4483
|
||||
let storage = match dtype {
|
||||
DType::U8 => {
|
||||
let mut v = Vec::with_capacity(elem_count);
|
||||
v.set_len(elem_count);
|
||||
CpuStorage::U8(v)
|
||||
}
|
||||
DType::U32 => {
|
||||
let mut v = Vec::with_capacity(elem_count);
|
||||
v.set_len(elem_count);
|
||||
CpuStorage::U32(v)
|
||||
}
|
||||
DType::I64 => {
|
||||
let mut v = Vec::with_capacity(elem_count);
|
||||
v.set_len(elem_count);
|
||||
CpuStorage::I64(v)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let mut v = Vec::with_capacity(elem_count);
|
||||
v.set_len(elem_count);
|
||||
CpuStorage::BF16(v)
|
||||
}
|
||||
DType::F16 => {
|
||||
let mut v = Vec::with_capacity(elem_count);
|
||||
v.set_len(elem_count);
|
||||
CpuStorage::F16(v)
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut v = Vec::with_capacity(elem_count);
|
||||
v.set_len(elem_count);
|
||||
CpuStorage::F32(v)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut v = Vec::with_capacity(elem_count);
|
||||
v.set_len(elem_count);
|
||||
CpuStorage::F64(v)
|
||||
}
|
||||
};
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let storage = match dtype {
|
350
candle-core/src/cpu_backend/utils.rs
Normal file
350
candle-core/src/cpu_backend/utils.rs
Normal file
@ -0,0 +1,350 @@
|
||||
/// Helper functions to write CPU kernels.
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{Error, Layout, Result, WithDType};
|
||||
|
||||
type C = super::CpuStorage;
|
||||
pub trait Map1 {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
||||
|
||||
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
||||
match vs {
|
||||
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
|
||||
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
|
||||
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
|
||||
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
|
||||
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
|
||||
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
|
||||
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map1Any {
|
||||
fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
|
||||
|
||||
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
||||
match vs {
|
||||
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
|
||||
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
|
||||
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
|
||||
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
|
||||
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
|
||||
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
|
||||
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
||||
|
||||
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
||||
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2U8 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||
|
||||
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
mut f: F,
|
||||
) -> Vec<U> {
|
||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
||||
.iter()
|
||||
.zip(rhs[o_r1..o_r2].iter())
|
||||
.map(|(&l, &r)| f(l, r))
|
||||
.collect(),
|
||||
(Some((o_l1, o_l2)), None) => {
|
||||
// TODO: Maybe we want to avoid going through the layout twice.
|
||||
match rhs_l.offsets_b() {
|
||||
Some(ob) => {
|
||||
let mut i_in_block = 0;
|
||||
let mut i_right_broadcast = 0;
|
||||
lhs[o_l1..o_l2]
|
||||
.iter()
|
||||
.map(|&l| {
|
||||
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
||||
i_right_broadcast += 1;
|
||||
if i_right_broadcast >= ob.right_broadcast {
|
||||
i_in_block += 1;
|
||||
i_right_broadcast = 0;
|
||||
}
|
||||
if i_in_block >= ob.len {
|
||||
i_in_block = 0
|
||||
}
|
||||
f(l, *r)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
(None, Some((o_r1, o_r2))) => {
|
||||
// TODO: Maybe we want to avoid going through the layout twice.
|
||||
match lhs_l.offsets_b() {
|
||||
Some(ob) => {
|
||||
let mut i_in_block = 0;
|
||||
let mut i_right_broadcast = 0;
|
||||
rhs[o_r1..o_r2]
|
||||
.iter()
|
||||
.map(|&r| {
|
||||
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
||||
i_right_broadcast += 1;
|
||||
if i_right_broadcast >= ob.right_broadcast {
|
||||
i_in_block += 1;
|
||||
i_right_broadcast = 0;
|
||||
}
|
||||
if i_in_block >= ob.len {
|
||||
i_in_block = 0
|
||||
}
|
||||
f(*l, r)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
_ => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
// Similar to binary_map but with vectorized variants.
|
||||
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
mut f: F,
|
||||
mut f_vec: FV,
|
||||
) -> Vec<T> {
|
||||
let el_count = lhs_l.shape().elem_count();
|
||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
||||
Some(ob) if ob.right_broadcast == 1 => {
|
||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
||||
f_vec(
|
||||
&lhs[src_i..src_i + ob.len],
|
||||
rhs,
|
||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||
);
|
||||
dst_i += ob.len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
Some(ob) => {
|
||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||
let mut ys = lhs[o_l1..o_l2].to_vec();
|
||||
for idx_l in 0..ob.left_broadcast {
|
||||
let start = idx_l * ob.len * ob.right_broadcast;
|
||||
for (i, &r) in rhs.iter().enumerate() {
|
||||
let start = start + i * ob.right_broadcast;
|
||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||
*v = f(*v, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
ys
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
},
|
||||
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
||||
Some(ob) if ob.right_broadcast == 1 => {
|
||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
||||
f_vec(
|
||||
lhs,
|
||||
&rhs[src_i..src_i + ob.len],
|
||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||
);
|
||||
dst_i += ob.len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
Some(ob) => {
|
||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||
let mut ys = rhs[o_r1..o_r2].to_vec();
|
||||
for idx_l in 0..ob.left_broadcast {
|
||||
let start = idx_l * ob.len * ob.right_broadcast;
|
||||
for (i, &l) in lhs.iter().enumerate() {
|
||||
let start = start + i * ob.right_broadcast;
|
||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||
*v = f(l, *v)
|
||||
}
|
||||
}
|
||||
}
|
||||
ys
|
||||
}
|
||||
None => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
},
|
||||
_ => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
mut f: F,
|
||||
) -> Vec<U> {
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
||||
[start_offset..start_offset + len]
|
||||
.iter()
|
||||
.map(|&v| f(v))
|
||||
.collect(),
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
||||
// Specialize the case where block_len is one to avoid the second loop.
|
||||
if block_len == 1 {
|
||||
for index in block_start_index {
|
||||
let v = unsafe { vs.get_unchecked(index) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
} else {
|
||||
for index in block_start_index {
|
||||
for offset in 0..block_len {
|
||||
let v = unsafe { vs.get_unchecked(index + offset) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
mut f: F,
|
||||
mut f_vec: FV,
|
||||
) -> Vec<U> {
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(len);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(len) };
|
||||
ys
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
let el_count = layout.shape().elem_count();
|
||||
// Specialize the case where block_len is one to avoid the second loop.
|
||||
if block_len == 1 {
|
||||
let mut result = Vec::with_capacity(el_count);
|
||||
for index in block_start_index {
|
||||
let v = unsafe { vs.get_unchecked(index) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
result
|
||||
} else {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||
let mut dst_index = 0;
|
||||
for src_index in block_start_index {
|
||||
let vs = &vs[src_index..src_index + block_len];
|
||||
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
||||
f_vec(vs, ys);
|
||||
dst_index += block_len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
410
candle-core/src/cuda_backend/device.rs
Normal file
410
candle-core/src/cuda_backend/device.rs
Normal file
@ -0,0 +1,410 @@
|
||||
use crate::backend::BackendDevice;
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
||||
use half::{bf16, f16};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
||||
|
||||
/// Unique identifier for cuda devices.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct DeviceId(usize);
|
||||
|
||||
impl DeviceId {
|
||||
fn new() -> Self {
|
||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||
use std::sync::atomic;
|
||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
struct CudaRng(cudarc::curand::CudaRng);
|
||||
unsafe impl Send for CudaRng {}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CudaDevice {
|
||||
id: DeviceId,
|
||||
device: Arc<cudarc::driver::CudaDevice>,
|
||||
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
||||
curand: Arc<Mutex<CudaRng>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CudaDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "CudaDevice({:?})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CudaDevice {
|
||||
type Target = Arc<cudarc::driver::CudaDevice>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
||||
self.device.clone()
|
||||
}
|
||||
|
||||
pub fn id(&self) -> DeviceId {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
||||
let params = (&data, v as u8, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
||||
let params = (&data, v as u32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
||||
let params = (&data, v as i64, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
||||
let params = (&data, bf16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
||||
let params = (&data, f16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
||||
let params = (&data, v as f32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
||||
let params = (&data, v, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||
if !self.has_func(module_name, module_name) {
|
||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
||||
// done once per kernel name.
|
||||
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
||||
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
||||
.map_err(|cuda| CudaError::Load {
|
||||
cuda,
|
||||
module_name: module_name.to_string(),
|
||||
})
|
||||
.w()?;
|
||||
}
|
||||
self.get_func(module_name, module_name)
|
||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
||||
// able to only build the error value if needed.
|
||||
.ok_or(CudaError::MissingKernel {
|
||||
module_name: module_name.to_string(),
|
||||
})
|
||||
.w()
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for CudaDevice {
|
||||
type Storage = CudaStorage;
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
blas: Arc::new(blas),
|
||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||
// state will be identical and the same random numbers will be generated.
|
||||
let mut curand = self.curand.lock().unwrap();
|
||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Cuda {
|
||||
gpu_id: self.device.ordinal(),
|
||||
}
|
||||
}
|
||||
|
||||
fn same_device(&self, rhs: &Self) -> bool {
|
||||
self.id == rhs.id
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let data = self.alloc_zeros::<u8>(elem_count).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = self.alloc_zeros::<f16>(elem_count).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = self.alloc_zeros::<f32>(elem_count).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let data = self.alloc_zeros::<f64>(elem_count).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
let slice = if lo == 0. && up == 1.0 {
|
||||
slice
|
||||
} else {
|
||||
use super::utils::Map1;
|
||||
let layout = Layout::contiguous(shape);
|
||||
super::Affine(up - lo, lo).map(&slice, self, &layout)?
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
// curand can only generate an odd number of values.
|
||||
// https://github.com/huggingface/candle/issues/734
|
||||
let elem_count_round = if elem_count % 2 == 1 {
|
||||
elem_count + 1
|
||||
} else {
|
||||
elem_count
|
||||
};
|
||||
let slice = match dtype {
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_normal",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
||||
curand
|
||||
.0
|
||||
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||
.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
||||
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
self.const_impl(1., shape, dtype)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let data = self.alloc::<u8>(elem_count).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
let data = self.alloc::<u32>(elem_count).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc::<i64>(elem_count).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc::<bf16>(elem_count).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = self.alloc::<f16>(elem_count).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = self.alloc::<f32>(elem_count).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let data = self.alloc::<f64>(elem_count).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
62
candle-core/src/cuda_backend/error.rs
Normal file
62
candle-core/src/cuda_backend/error.rs
Normal file
@ -0,0 +1,62 @@
|
||||
use crate::{DType, Layout};
|
||||
|
||||
/// cudarc related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum CudaError {
|
||||
#[error(transparent)]
|
||||
Cuda(#[from] cudarc::driver::DriverError),
|
||||
|
||||
#[error(transparent)]
|
||||
Compiler(#[from] cudarc::nvrtc::CompileError),
|
||||
|
||||
#[error(transparent)]
|
||||
Cublas(#[from] cudarc::cublas::result::CublasError),
|
||||
|
||||
#[error(transparent)]
|
||||
Curand(#[from] cudarc::curand::result::CurandError),
|
||||
|
||||
#[error("missing kernel '{module_name}'")]
|
||||
MissingKernel { module_name: String },
|
||||
|
||||
#[error("unsupported dtype {dtype:?} for {op}")]
|
||||
UnsupportedDtype { dtype: DType, op: &'static str },
|
||||
|
||||
#[error("internal error '{0}'")]
|
||||
InternalError(&'static str),
|
||||
|
||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||
MatMulNonContiguous {
|
||||
lhs_stride: Layout,
|
||||
rhs_stride: Layout,
|
||||
mnk: (usize, usize, usize),
|
||||
},
|
||||
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedDType {
|
||||
msg: &'static str,
|
||||
expected: DType,
|
||||
got: DType,
|
||||
},
|
||||
|
||||
#[error("{cuda} when loading {module_name}")]
|
||||
Load {
|
||||
cuda: cudarc::driver::DriverError,
|
||||
module_name: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<CudaError> for crate::Error {
|
||||
fn from(val: CudaError) -> Self {
|
||||
crate::Error::Cuda(Box::new(val)).bt()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WrapErr<O> {
|
||||
fn w(self) -> std::result::Result<O, crate::Error>;
|
||||
}
|
||||
|
||||
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
||||
fn w(self) -> std::result::Result<O, crate::Error> {
|
||||
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt())
|
||||
}
|
||||
}
|
@ -5,395 +5,41 @@ pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{
|
||||
CudaFunction, CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig,
|
||||
ValidAsZeroBits,
|
||||
CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
};
|
||||
use half::{bf16, f16};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
/// cudarc related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum CudaError {
|
||||
#[error(transparent)]
|
||||
Cuda(#[from] cudarc::driver::DriverError),
|
||||
#[cfg(feature = "cudnn")]
|
||||
pub mod cudnn;
|
||||
mod device;
|
||||
mod error;
|
||||
mod utils;
|
||||
pub use device::{CudaDevice, DeviceId};
|
||||
pub use error::{CudaError, WrapErr};
|
||||
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
|
||||
|
||||
#[error(transparent)]
|
||||
Compiler(#[from] cudarc::nvrtc::CompileError),
|
||||
|
||||
#[error(transparent)]
|
||||
Cublas(#[from] cudarc::cublas::result::CublasError),
|
||||
|
||||
#[error(transparent)]
|
||||
Curand(#[from] cudarc::curand::result::CurandError),
|
||||
|
||||
#[error("missing kernel '{module_name}'")]
|
||||
MissingKernel { module_name: String },
|
||||
|
||||
#[error("unsupported dtype {dtype:?} for {op}")]
|
||||
UnsupportedDtype { dtype: DType, op: &'static str },
|
||||
|
||||
#[error("internal error '{0}'")]
|
||||
InternalError(&'static str),
|
||||
|
||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||
MatMulNonContiguous {
|
||||
lhs_stride: Vec<usize>,
|
||||
rhs_stride: Vec<usize>,
|
||||
mnk: (usize, usize, usize),
|
||||
},
|
||||
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedDType {
|
||||
msg: &'static str,
|
||||
expected: DType,
|
||||
got: DType,
|
||||
},
|
||||
|
||||
#[error("{cuda} when loading {module_name}")]
|
||||
Load {
|
||||
cuda: cudarc::driver::DriverError,
|
||||
module_name: String,
|
||||
},
|
||||
enum SlicePtrOrNull<T> {
|
||||
Ptr(CudaSlice<T>),
|
||||
Null,
|
||||
}
|
||||
|
||||
impl From<CudaError> for crate::Error {
|
||||
fn from(val: CudaError) -> Self {
|
||||
crate::Error::Cuda(Box::new(val)).bt()
|
||||
}
|
||||
}
|
||||
|
||||
/// Unique identifier for cuda devices.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct DeviceId(usize);
|
||||
|
||||
impl DeviceId {
|
||||
fn new() -> Self {
|
||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||
use std::sync::atomic;
|
||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
struct CudaRng(cudarc::curand::CudaRng);
|
||||
unsafe impl Send for CudaRng {}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CudaDevice {
|
||||
id: DeviceId,
|
||||
device: Arc<cudarc::driver::CudaDevice>,
|
||||
blas: Arc<cudarc::cublas::CudaBlas>,
|
||||
curand: Arc<Mutex<CudaRng>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for CudaDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "CudaDevice({:?})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CudaDevice {
|
||||
type Target = Arc<cudarc::driver::CudaDevice>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WrapErr<O> {
|
||||
fn w(self) -> std::result::Result<O, crate::Error>;
|
||||
}
|
||||
|
||||
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
||||
fn w(self) -> std::result::Result<O, crate::Error> {
|
||||
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())))
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
||||
self.device.clone()
|
||||
}
|
||||
|
||||
pub fn id(&self) -> DeviceId {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
||||
let params = (&data, v as u8, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
||||
let params = (&data, v as u32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
||||
let params = (&data, v as i64, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
||||
let params = (&data, bf16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
||||
let params = (&data, f16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
||||
let params = (&data, v as f32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
||||
let params = (&data, v, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||
if !self.has_func(module_name, module_name) {
|
||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
||||
// done once per kernel name.
|
||||
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
||||
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
||||
.map_err(|cuda| CudaError::Load {
|
||||
cuda,
|
||||
module_name: module_name.to_string(),
|
||||
})
|
||||
.w()?;
|
||||
}
|
||||
self.get_func(module_name, module_name)
|
||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
||||
// able to only build the error value if needed.
|
||||
.ok_or(CudaError::MissingKernel {
|
||||
module_name: module_name.to_string(),
|
||||
})
|
||||
.w()
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for CudaDevice {
|
||||
type Storage = CudaStorage;
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
blas: Arc::new(blas),
|
||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||
// state will be identical and the same random numbers will be generated.
|
||||
let mut curand = self.curand.lock().unwrap();
|
||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Cuda {
|
||||
gpu_id: self.device.ordinal(),
|
||||
unsafe impl<T: DeviceRepr> DeviceRepr for &SlicePtrOrNull<T> {
|
||||
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
|
||||
match self {
|
||||
SlicePtrOrNull::Ptr(slice) => slice.as_kernel_param(),
|
||||
SlicePtrOrNull::Null => 0usize.as_kernel_param(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn same_device(&self, rhs: &Self) -> bool {
|
||||
self.id == rhs.id
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let data = self.alloc_zeros::<u8>(elem_count).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = self.alloc_zeros::<f16>(elem_count).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = self.alloc_zeros::<f32>(elem_count).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let data = self.alloc_zeros::<f64>(elem_count).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
let slice = match dtype {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_uniform",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||
curand.0.fill_with_uniform(&mut data).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
let slice = if lo == 0. && up == 1.0 {
|
||||
slice
|
||||
impl SlicePtrOrNull<usize> {
|
||||
fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||
let ds = if l.is_contiguous() {
|
||||
SlicePtrOrNull::Null
|
||||
} else {
|
||||
let layout = Layout::contiguous(shape);
|
||||
Affine(up - lo, lo).map(&slice, self, &layout)?
|
||||
SlicePtrOrNull::Ptr(dev.htod_copy([l.dims(), l.stride()].concat()).w()?)
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
|
||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||
// cudarc changes.
|
||||
let elem_count = shape.elem_count();
|
||||
let curand = self.curand.lock().unwrap();
|
||||
// curand can only generate an odd number of values.
|
||||
// https://github.com/huggingface/candle/issues/734
|
||||
let elem_count_round = if elem_count % 2 == 1 {
|
||||
elem_count + 1
|
||||
} else {
|
||||
elem_count
|
||||
};
|
||||
let slice = match dtype {
|
||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||
Err(CudaError::UnsupportedDtype {
|
||||
dtype,
|
||||
op: "rand_normal",
|
||||
})
|
||||
.w()?
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
||||
curand
|
||||
.0
|
||||
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||
.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
||||
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
self.const_impl(1., shape, dtype)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
Ok(ds)
|
||||
}
|
||||
}
|
||||
|
||||
@ -407,133 +53,6 @@ pub enum CudaStorageSlice {
|
||||
F32(CudaSlice<f32>),
|
||||
F64(CudaSlice<f64>),
|
||||
}
|
||||
type S = CudaStorageSlice;
|
||||
|
||||
pub trait Map1 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>>;
|
||||
|
||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||
let out = match s {
|
||||
S::U8(s) => S::U8(self.f(s, d, l)?),
|
||||
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||
S::I64(s) => S::I64(self.f(s, d, l)?),
|
||||
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||
S::F32(s) => S::F32(self.f(s, d, l)?),
|
||||
S::F64(s) => S::F64(self.f(s, d, l)?),
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>>;
|
||||
|
||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()>;
|
||||
|
||||
fn map(
|
||||
&self,
|
||||
dst: &mut S,
|
||||
dst_s: &Shape,
|
||||
src: &S,
|
||||
src_l: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map1Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
wrap: W,
|
||||
) -> Result<S>;
|
||||
|
||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||
let out = match s {
|
||||
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||
S::I64(s) => self.f(s, d, l, S::I64)?,
|
||||
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||
S::F64(s) => self.f(s, d, l, S::F64)?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<S>;
|
||||
|
||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct Clone;
|
||||
impl Map1 for Clone {
|
||||
@ -564,7 +83,7 @@ impl Map1 for Affine {
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -596,7 +115,7 @@ impl Map1 for Elu {
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -719,7 +238,7 @@ impl Map1 for Powf {
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -852,7 +371,7 @@ impl<U: UnaryOpT> Map1 for U {
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -1402,9 +921,14 @@ impl<U: crate::op::BinaryOpT> Map2 for U {
|
||||
let dims = shape.dims();
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dims_and_strides = dev
|
||||
.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||
.w()?;
|
||||
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
||||
SlicePtrOrNull::Null
|
||||
} else {
|
||||
SlicePtrOrNull::Ptr(
|
||||
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||
.w()?,
|
||||
)
|
||||
};
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?;
|
||||
@ -1431,9 +955,14 @@ impl Map2Any for Cmp {
|
||||
let dims = shape.dims();
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dims_and_strides = dev
|
||||
.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||
.w()?;
|
||||
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
||||
SlicePtrOrNull::Null
|
||||
} else {
|
||||
SlicePtrOrNull::Ptr(
|
||||
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||
.w()?,
|
||||
)
|
||||
};
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let name = match self.0 {
|
||||
@ -1541,26 +1070,30 @@ fn gemm_config<T>(
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
// The a tensor has dims batching, k, n (rhs)
|
||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||
// there is a single element.
|
||||
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||
(n as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
|
||||
(k as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||
} else {
|
||||
Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
lhs_stride: lhs_l.clone(),
|
||||
rhs_stride: rhs_l.clone(),
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
// The b tensor has dims batching, m, k (lhs)
|
||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||
// there is a single element.
|
||||
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||
(k as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
|
||||
(m as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||
} else {
|
||||
Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
lhs_stride: lhs_l.clone(),
|
||||
rhs_stride: rhs_l.clone(),
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
@ -1581,21 +1114,25 @@ fn gemm_config<T>(
|
||||
|
||||
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[_, stride] if lhs_l.dims()[0] == 1 => stride,
|
||||
[stride, _] if lhs_l.dims()[1] == 1 => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
lhs_stride: lhs_l.clone(),
|
||||
rhs_stride: rhs_l.clone(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
};
|
||||
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[_, stride] if rhs_l.dims()[0] == 1 => stride,
|
||||
[stride, _] if rhs_l.dims()[1] == 1 => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
lhs_stride: lhs_l.clone(),
|
||||
rhs_stride: rhs_l.clone(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
};
|
||||
@ -1640,7 +1177,7 @@ impl BackendStorage for CudaStorage {
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let start_o = layout.start_offset();
|
||||
// This returns an i64 rather than a &i64, this is useful to get around some temporary
|
||||
// lifetime issue and is safe as long as self.slice does not go out of scope before inp
|
||||
@ -1844,7 +1381,10 @@ impl BackendStorage for CudaStorage {
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
let mut kernel_c = unsafe {
|
||||
self.device()
|
||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||
};
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
@ -1852,7 +1392,7 @@ impl BackendStorage for CudaStorage {
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
};
|
||||
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
@ -1909,7 +1449,10 @@ impl BackendStorage for CudaStorage {
|
||||
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||
} else {
|
||||
// Make the kernel contiguous if not already the case.
|
||||
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
|
||||
let mut kernel_c = unsafe {
|
||||
self.device()
|
||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||
};
|
||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||
.transpose(1, 2)?
|
||||
@ -1919,7 +1462,7 @@ impl BackendStorage for CudaStorage {
|
||||
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||
.transpose(1, 2)?
|
||||
.transpose(1, 3)?;
|
||||
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
|
||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
||||
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
||||
Ok(res_t)
|
||||
}
|
||||
@ -2056,7 +1599,7 @@ impl BackendStorage for CudaStorage {
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
@ -2071,7 +1614,7 @@ impl BackendStorage for CudaStorage {
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
@ -2145,6 +1688,72 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o: usize,
|
||||
dst_o: usize,
|
||||
) -> Result<()> {
|
||||
let dev = &self.device;
|
||||
let d1 = d1 as u32;
|
||||
let d2 = d2 as u32;
|
||||
// Nothing to copy so we exit early to avoid launching a kernel and some potential invalid
|
||||
// argument with a null pointer.
|
||||
if d1 == 0 || d2 == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let dst_s = dst_s as u32;
|
||||
let src_s = src_s as u32;
|
||||
let (src, dst, kname) = match (&self.slice, &mut dst.slice) {
|
||||
(S::U8(s), S::U8(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_u8",
|
||||
),
|
||||
(S::U32(s), S::U32(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_u32",
|
||||
),
|
||||
(S::I64(s), S::I64(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_i64",
|
||||
),
|
||||
(S::BF16(s), S::BF16(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_bf16",
|
||||
),
|
||||
(S::F16(s), S::F16(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_f16",
|
||||
),
|
||||
(S::F32(s), S::F32(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_f32",
|
||||
),
|
||||
(S::F64(s), S::F64(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_f64",
|
||||
),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?,
|
||||
};
|
||||
let func = dev.get_or_load_func(kname, kernels::FILL)?;
|
||||
let cfg = LaunchConfig::for_num_elems(d1 * d2);
|
||||
let params = (src, dst, d1, d2, src_s, dst_s);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let src_shape = src_l.shape();
|
||||
let dims = src_shape.dims();
|
||||
@ -2154,7 +1763,7 @@ impl BackendStorage for CudaStorage {
|
||||
}
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = &self.device;
|
||||
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, src_l)?;
|
||||
match (&self.slice, &mut dst.slice) {
|
||||
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
134
candle-core/src/cuda_backend/utils.rs
Normal file
134
candle-core/src/cuda_backend/utils.rs
Normal file
@ -0,0 +1,134 @@
|
||||
/// Helper functions to plug cuda kernels in candle.
|
||||
use crate::{Layout, Result, Shape, WithDType};
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
||||
|
||||
use super::{CudaDevice, CudaError, WrapErr};
|
||||
|
||||
pub type S = super::CudaStorageSlice;
|
||||
|
||||
pub trait Map1 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>>;
|
||||
|
||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||
let out = match s {
|
||||
S::U8(s) => S::U8(self.f(s, d, l)?),
|
||||
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||
S::I64(s) => S::I64(self.f(s, d, l)?),
|
||||
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||
S::F32(s) => S::F32(self.f(s, d, l)?),
|
||||
S::F64(s) => S::F64(self.f(s, d, l)?),
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>>;
|
||||
|
||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()>;
|
||||
|
||||
fn map(
|
||||
&self,
|
||||
dst: &mut S,
|
||||
dst_s: &Shape,
|
||||
src: &S,
|
||||
src_l: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map1Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
wrap: W,
|
||||
) -> Result<S>;
|
||||
|
||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||
let out = match s {
|
||||
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||
S::I64(s) => self.f(s, d, l, S::I64)?,
|
||||
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||
S::F64(s) => self.f(s, d, l, S::F64)?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<S>;
|
||||
|
||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
377
candle-core/src/custom_op.rs
Normal file
377
candle-core/src/custom_op.rs
Normal file
@ -0,0 +1,377 @@
|
||||
use crate::op::{BackpropOp, Op};
|
||||
use crate::tensor::from_storage;
|
||||
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Unary ops that can be defined in user-land.
|
||||
pub trait CustomOp1 {
|
||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_storage: &MetalStorage,
|
||||
_layout: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||
/// The function should return the gradient of the argument.
|
||||
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp2 {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
_arg2: &Tensor,
|
||||
_res: &Tensor,
|
||||
_grad_res: &Tensor,
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp3 {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
s3: &CpuStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
_arg2: &Tensor,
|
||||
_arg3: &Tensor,
|
||||
_res: &Tensor,
|
||||
_grad_res: &Tensor,
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Applies a unary custom op without backward support
|
||||
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a binary custom op without backward support
|
||||
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) =
|
||||
self.storage()
|
||||
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op without backward support
|
||||
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
&t3.storage(),
|
||||
t3.layout(),
|
||||
c,
|
||||
)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a unary custom op.
|
||||
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
|
||||
let (storage, shape) = self
|
||||
.storage()
|
||||
.apply_op1(self.layout(), c.as_ref().as_ref())?;
|
||||
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
|
||||
self.apply_op1_arc(Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a binary custom op.
|
||||
pub fn apply_op2_arc(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||
) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op2(
|
||||
self.layout(),
|
||||
&rhs.storage(),
|
||||
rhs.layout(),
|
||||
c.as_ref().as_ref(),
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
|
||||
self.apply_op2_arc(r, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op.
|
||||
pub fn apply_op3_arc(
|
||||
&self,
|
||||
t2: &Self,
|
||||
t3: &Self,
|
||||
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||
) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
&t3.storage(),
|
||||
t3.layout(),
|
||||
c.as_ref().as_ref(),
|
||||
)?;
|
||||
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
|
||||
Op::CustomOp3(t1, t2, t3, c.clone())
|
||||
});
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
|
||||
&self,
|
||||
t2: &Self,
|
||||
t3: &Self,
|
||||
c: C,
|
||||
) -> Result<Self> {
|
||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||
}
|
||||
}
|
||||
|
||||
// In place ops.
|
||||
|
||||
/// Unary ops that can be defined in user-land.
|
||||
/// These ops work in place and as such back-prop is unsupported.
|
||||
pub trait InplaceOp1 {
|
||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait InplaceOp2 {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
|
||||
-> Result<()>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &mut MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<()> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait InplaceOp3 {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &mut CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
s3: &CpuStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<()>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_: &mut CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
) -> Result<()> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &mut MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<()> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Applies a unary custom op in place.
|
||||
pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
|
||||
self.storage_mut().inplace_op1(self.layout(), c)
|
||||
}
|
||||
|
||||
/// Applies a unary custom op in place (for the first tensor).
|
||||
pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
|
||||
self.storage_mut()
|
||||
.inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op in place (for the first tensor).
|
||||
pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
|
||||
self.storage_mut().inplace_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
&t3.storage(),
|
||||
t3.layout(),
|
||||
c,
|
||||
)
|
||||
}
|
||||
}
|
@ -289,17 +289,34 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuDevice.alloc_uninit(shape, dtype)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.alloc_uninit(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.alloc_uninit(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||
Device::Cuda(device) => {
|
||||
let storage = array.to_cpu_storage();
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = array.to_cpu_storage();
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
@ -310,12 +327,12 @@ impl Device {
|
||||
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
|
||||
Device::Cuda(device) => {
|
||||
let storage = S::to_cpu_storage_owned(data);
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = S::to_cpu_storage_owned(data);
|
||||
let storage = device.storage_from_cpu_storage(&storage)?;
|
||||
let storage = device.storage_from_cpu_storage_owned(storage)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
|
@ -65,12 +65,13 @@ impl std::fmt::Debug for Tensor {
|
||||
}
|
||||
|
||||
/// Options for Tensor pretty printing
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PrinterOptions {
|
||||
precision: usize,
|
||||
threshold: usize,
|
||||
edge_items: usize,
|
||||
line_width: usize,
|
||||
sci_mode: Option<bool>,
|
||||
pub precision: usize,
|
||||
pub threshold: usize,
|
||||
pub edge_items: usize,
|
||||
pub line_width: usize,
|
||||
pub sci_mode: Option<bool>,
|
||||
}
|
||||
|
||||
static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
|
||||
@ -89,6 +90,10 @@ impl PrinterOptions {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
|
||||
&PRINT_OPTS
|
||||
}
|
||||
|
||||
pub fn set_print_options(options: PrinterOptions) {
|
||||
*PRINT_OPTS.lock().unwrap() = options
|
||||
}
|
||||
@ -117,6 +122,26 @@ pub fn set_print_options_full() {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_line_width(line_width: usize) {
|
||||
PRINT_OPTS.lock().unwrap().line_width = line_width
|
||||
}
|
||||
|
||||
pub fn set_precision(precision: usize) {
|
||||
PRINT_OPTS.lock().unwrap().precision = precision
|
||||
}
|
||||
|
||||
pub fn set_edge_items(edge_items: usize) {
|
||||
PRINT_OPTS.lock().unwrap().edge_items = edge_items
|
||||
}
|
||||
|
||||
pub fn set_threshold(threshold: usize) {
|
||||
PRINT_OPTS.lock().unwrap().threshold = threshold
|
||||
}
|
||||
|
||||
pub fn set_sci_mode(sci_mode: Option<bool>) {
|
||||
PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
|
||||
}
|
||||
|
||||
struct FmtSize {
|
||||
current_size: usize,
|
||||
}
|
||||
|
@ -23,7 +23,15 @@ pub enum DType {
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct DTypeParseError;
|
||||
pub struct DTypeParseError(String);
|
||||
|
||||
impl std::fmt::Display for DTypeParseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "cannot parse '{}' as a dtype", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for DTypeParseError {}
|
||||
|
||||
impl std::str::FromStr for DType {
|
||||
type Err = DTypeParseError;
|
||||
@ -36,7 +44,7 @@ impl std::str::FromStr for DType {
|
||||
"f16" => Ok(Self::F16),
|
||||
"f32" => Ok(Self::F32),
|
||||
"f64" => Ok(Self::F64),
|
||||
_ => Err(DTypeParseError),
|
||||
_ => Err(DTypeParseError(s.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -154,6 +154,19 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
_: &mut Self,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -197,10 +210,18 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -166,6 +166,19 @@ impl crate::backend::BackendStorage for MetalStorage {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
_: &mut Self,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
@ -209,10 +222,18 @@ impl crate::backend::BackendDevice for MetalDevice {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ impl Layout {
|
||||
self.shape.is_fortran_contiguous(&self.stride)
|
||||
}
|
||||
|
||||
pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||
let dims = self.shape().dims();
|
||||
if dim >= dims.len() {
|
||||
Err(Error::DimOutOfRange {
|
||||
@ -99,7 +99,7 @@ impl Layout {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||
let rank = self.shape.rank();
|
||||
if rank <= dim1 || rank <= dim2 {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
@ -120,7 +120,7 @@ impl Layout {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||
pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||
let is_permutation =
|
||||
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
||||
if !is_permutation {
|
||||
|
@ -14,7 +14,7 @@
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - Simple syntax (looks and like PyTorch)
|
||||
//! - Simple syntax (looks and feels like PyTorch)
|
||||
//! - CPU and Cuda backends (and M1 support)
|
||||
//! - Enable serverless (CPU) small and fast deployments
|
||||
//! - Model training
|
||||
@ -37,14 +37,13 @@
|
||||
mod accelerate;
|
||||
pub mod backend;
|
||||
pub mod backprop;
|
||||
mod conv;
|
||||
pub mod conv;
|
||||
mod convert;
|
||||
pub mod cpu;
|
||||
pub mod cpu_backend;
|
||||
#[cfg(feature = "cuda")]
|
||||
pub mod cuda_backend;
|
||||
#[cfg(feature = "cudnn")]
|
||||
pub mod cudnn;
|
||||
mod custom_op;
|
||||
mod device;
|
||||
pub mod display;
|
||||
mod dtype;
|
||||
@ -58,7 +57,7 @@ pub mod metal_backend;
|
||||
#[cfg(feature = "mkl")]
|
||||
mod mkl;
|
||||
pub mod npy;
|
||||
mod op;
|
||||
pub mod op;
|
||||
pub mod pickle;
|
||||
pub mod quantized;
|
||||
pub mod safetensors;
|
||||
@ -67,17 +66,21 @@ pub mod shape;
|
||||
mod storage;
|
||||
mod strided_index;
|
||||
mod tensor;
|
||||
mod tensor_cat;
|
||||
pub mod test_utils;
|
||||
pub mod utils;
|
||||
mod variable;
|
||||
|
||||
#[cfg(feature = "cudnn")]
|
||||
pub use cuda_backend::cudnn;
|
||||
|
||||
pub use cpu_backend::CpuStorage;
|
||||
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||
pub use device::{Device, DeviceLocation, NdArray};
|
||||
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
||||
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use indexer::IndexOp;
|
||||
pub use layout::Layout;
|
||||
pub use op::{CustomOp1, CustomOp2, CustomOp3};
|
||||
pub use shape::{Shape, D};
|
||||
pub use storage::Storage;
|
||||
pub use strided_index::{StridedBlocks, StridedIndex};
|
||||
|
287
candle-core/src/metal_backend/device.rs
Normal file
287
candle-core/src/metal_backend/device.rs
Normal file
@ -0,0 +1,287 @@
|
||||
use crate::{DType, Result};
|
||||
use candle_metal_kernels::Kernels;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
|
||||
|
||||
use super::MetalError;
|
||||
|
||||
/// Unique identifier for cuda devices.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct DeviceId(usize);
|
||||
|
||||
impl DeviceId {
|
||||
pub(crate) fn new() -> Self {
|
||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||
use std::sync::atomic;
|
||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
||||
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalDevice {
|
||||
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
|
||||
/// the device itself.
|
||||
pub(crate) id: DeviceId,
|
||||
|
||||
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
||||
pub(crate) device: metal::Device,
|
||||
|
||||
/// Single command queue for the entire device.
|
||||
pub(crate) command_queue: CommandQueue,
|
||||
/// One command buffer at a time.
|
||||
/// The scheduler works by allowing multiple
|
||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
||||
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
||||
/// to start to work).
|
||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||
/// command buffer2 starts (or there are metal bugs there)
|
||||
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
|
||||
/// Keeps track of the current amount of compute command encoders on the current
|
||||
/// command buffer
|
||||
/// Arc, RwLock because of the interior mutability.
|
||||
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
|
||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||
pub(crate) compute_per_buffer: usize,
|
||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||
/// Heavily used by [`candle_metal_kernels`]
|
||||
pub(crate) kernels: Arc<Kernels>,
|
||||
/// Simple allocator struct.
|
||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
||||
/// (could be linked to FFI communication overhead).
|
||||
///
|
||||
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
|
||||
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
|
||||
/// to be reused. However, in order for this to work, we need to guarantee the order of
|
||||
/// operation, so that this buffer is not being used by another kernel at the same time.
|
||||
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
||||
///
|
||||
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
||||
/// (strong_count = 1).
|
||||
pub(crate) buffers: AllocatedBuffers,
|
||||
/// Seed for random number generation.
|
||||
pub(crate) seed: Arc<Mutex<Buffer>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "MetalDevice({:?})", self.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for MetalDevice {
|
||||
type Target = metal::DeviceRef;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalDevice {
|
||||
pub fn id(&self) -> DeviceId {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub fn metal_device(&self) -> &metal::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn command_queue(&self) -> &CommandQueue {
|
||||
&self.command_queue
|
||||
}
|
||||
|
||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
let mut command_buffer = command_buffer_lock.to_owned();
|
||||
let mut index = self
|
||||
.command_buffer_index
|
||||
.try_write()
|
||||
.map_err(MetalError::from)?;
|
||||
if *index > self.compute_per_buffer {
|
||||
command_buffer.commit();
|
||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
*command_buffer_lock = command_buffer.clone();
|
||||
*index = 0;
|
||||
|
||||
self.drop_unused_buffers()?;
|
||||
}
|
||||
*index += 1;
|
||||
Ok(command_buffer)
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) -> Result<()> {
|
||||
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled
|
||||
| metal::MTLCommandBufferStatus::Completed => {
|
||||
panic!("Already committed");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn kernels(&self) -> &Kernels {
|
||||
&self.kernels
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &metal::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
/// Creates a new buffer (not necessarily zeroed).
|
||||
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
/// This means the buffer data cannot be read on the CPU directly.
|
||||
///
|
||||
/// [`name`] is only used to keep track of the resource origin in case of bugs
|
||||
pub fn new_buffer(
|
||||
&self,
|
||||
element_count: usize,
|
||||
dtype: DType,
|
||||
name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
||||
}
|
||||
|
||||
/// Creates a new buffer (not necessarily zeroed).
|
||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
/// This means the buffer can be read on the CPU but will require manual
|
||||
/// synchronization when the CPU memory is modified
|
||||
/// Used as a bridge to gather data back from the GPU
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
|
||||
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||
}
|
||||
|
||||
/// Creates a new buffer from data.
|
||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
///
|
||||
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
|
||||
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||
let new_buffer = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const c_void,
|
||||
size,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
let subbuffers = buffers
|
||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||
.or_insert(vec![]);
|
||||
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
Ok(new_buffer)
|
||||
}
|
||||
|
||||
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
||||
let buffer = self.allocate_buffer(
|
||||
size_in_bytes as NSUInteger,
|
||||
MTLResourceOptions::StorageModePrivate,
|
||||
"allocate_zeros",
|
||||
)?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("zeros");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.fill_buffer(
|
||||
&buffer,
|
||||
metal::NSRange {
|
||||
location: 0,
|
||||
length: buffer.length(),
|
||||
},
|
||||
0,
|
||||
);
|
||||
blit.end_encoding();
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
fn find_available_buffer(
|
||||
&self,
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
buffers: &RwLockWriteGuard<BufferMap>,
|
||||
) -> Option<Arc<Buffer>> {
|
||||
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
||||
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
||||
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
||||
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
||||
for sub in subbuffers {
|
||||
if Arc::strong_count(sub) == 1 {
|
||||
best_buffer = Some(sub);
|
||||
best_buffer_size = *buffer_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
best_buffer.cloned()
|
||||
}
|
||||
|
||||
fn drop_unused_buffers(&self) -> Result<()> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
for subbuffers in buffers.values_mut() {
|
||||
let newbuffers = subbuffers
|
||||
.iter()
|
||||
.filter(|s| Arc::strong_count(*s) > 1)
|
||||
.map(Arc::clone)
|
||||
.collect();
|
||||
*subbuffers = newbuffers;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// The critical allocator algorithm
|
||||
fn allocate_buffer(
|
||||
&self,
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
_name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
||||
// Cloning also ensures we increment the strong count
|
||||
return Ok(b.clone());
|
||||
}
|
||||
|
||||
let size = buf_size(size);
|
||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||
|
||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
|
||||
Ok(new_buffer)
|
||||
}
|
||||
|
||||
/// Create a metal GPU capture trace on [`path`].
|
||||
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
||||
let capture = metal::CaptureManager::shared();
|
||||
let descriptor = metal::CaptureDescriptor::new();
|
||||
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||
descriptor.set_capture_device(self);
|
||||
descriptor.set_output_url(path);
|
||||
|
||||
capture
|
||||
.start_capture(&descriptor)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||
(size - 1).next_power_of_two() as NSUInteger
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||
use crate::Tensor;
|
||||
use half::{bf16, f16};
|
||||
use num_traits::float::Float;
|
||||
|
||||
@ -66,6 +66,7 @@ pub enum UnaryOp {
|
||||
Floor,
|
||||
Ceil,
|
||||
Round,
|
||||
Sign,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -161,168 +162,23 @@ pub enum Op {
|
||||
Permute(Tensor, Vec<usize>),
|
||||
Elu(Tensor, f64),
|
||||
Powf(Tensor, f64),
|
||||
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
|
||||
CustomOp1(
|
||||
Tensor,
|
||||
std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
|
||||
),
|
||||
CustomOp2(
|
||||
Tensor,
|
||||
Tensor,
|
||||
std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||
std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
|
||||
),
|
||||
CustomOp3(
|
||||
Tensor,
|
||||
Tensor,
|
||||
Tensor,
|
||||
std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||
std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
|
||||
),
|
||||
}
|
||||
|
||||
/// Unary ops that can be defined in user-land.
|
||||
pub trait CustomOp1 {
|
||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_storage: &MetalStorage,
|
||||
_layout: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||
/// The function should return the gradient of the argument.
|
||||
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp2 {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
_arg2: &Tensor,
|
||||
_res: &Tensor,
|
||||
_grad_res: &Tensor,
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp3 {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
s3: &CpuStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
_: &MetalStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
_arg2: &Tensor,
|
||||
_arg3: &Tensor,
|
||||
_res: &Tensor,
|
||||
_grad_res: &Tensor,
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait UnaryOpT {
|
||||
const NAME: &'static str;
|
||||
const KERNEL: &'static str;
|
||||
@ -399,6 +255,7 @@ pub(crate) struct Tanh;
|
||||
pub(crate) struct Floor;
|
||||
pub(crate) struct Ceil;
|
||||
pub(crate) struct Round;
|
||||
pub(crate) struct Sign;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||
@ -602,6 +459,13 @@ unary_op!(Recip, "recip", v, v.recip());
|
||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||
|
||||
// Hardcode the value for sqrt(2/pi)
|
||||
// https://github.com/huggingface/candle/issues/1982
|
||||
#[allow(clippy::excessive_precision)]
|
||||
const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
|
||||
#[allow(clippy::excessive_precision)]
|
||||
const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
|
||||
|
||||
/// Tanh based approximation of the `gelu` operation
|
||||
/// GeluErf is the more precise one.
|
||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||
@ -614,7 +478,7 @@ impl UnaryOpT for Gelu {
|
||||
* v
|
||||
* (bf16::ONE
|
||||
+ bf16::tanh(
|
||||
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
|
||||
bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
|
||||
* v
|
||||
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
|
||||
))
|
||||
@ -625,22 +489,18 @@ impl UnaryOpT for Gelu {
|
||||
* v
|
||||
* (f16::ONE
|
||||
+ f16::tanh(
|
||||
(f16::from_f32_const(2.0) / f16::PI).sqrt()
|
||||
f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
|
||||
* v
|
||||
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
|
||||
))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
0.5 * v
|
||||
* (1.0
|
||||
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||
0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
0.5 * v
|
||||
* (1.0
|
||||
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||
0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
@ -1067,3 +927,37 @@ impl std::ops::Deref for BackpropOp {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Sign {
|
||||
const NAME: &'static str = "sign";
|
||||
const KERNEL: &'static str = "usign";
|
||||
const V: Self = Sign;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
f32::from(v > 0.) - f32::from(v < 0.)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
f64::from(v > 0.) - f64::from(v < 0.)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
u8::min(1, v)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
u32::min(1, v)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
(v > 0) as i64 - (v < 0) as i64
|
||||
}
|
||||
}
|
||||
|
@ -1,22 +1,62 @@
|
||||
use super::{GgmlDType, QStorage};
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||
use crate::{CudaDevice, CudaStorage, Result};
|
||||
|
||||
use cudarc::driver::{CudaSlice, DeviceSlice};
|
||||
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct QCudaStorage {
|
||||
data: CudaSlice<u8>,
|
||||
dtype: GgmlDType,
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
|
||||
|
||||
pub fn set_force_dmmv(f: bool) {
|
||||
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub const WARP_SIZE: usize = 32;
|
||||
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
|
||||
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
|
||||
pub const NWARPS_Q4_0_AMPERE: usize = 4;
|
||||
pub const GGML_CUDA_MMV_X: usize = 32;
|
||||
pub const GGML_CUDA_MMV_Y: usize = 1;
|
||||
pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256;
|
||||
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
|
||||
pub const MATRIX_ROW_PADDING: usize = 512;
|
||||
|
||||
fn ceil_div(p: usize, q: usize) -> usize {
|
||||
(p + q - 1) / q
|
||||
}
|
||||
|
||||
fn pad(p: usize, q: usize) -> usize {
|
||||
ceil_div(p, q) * q
|
||||
}
|
||||
|
||||
fn quantize_q8_1(
|
||||
src: &CudaView<f32>,
|
||||
dst: &mut CudaSlice<u8>,
|
||||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let kx = elem_count;
|
||||
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
||||
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
||||
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (num_blocks as u32, 1, 1),
|
||||
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let params = (src, dst, kx as i32, kx_padded as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dequantize(
|
||||
data: &CudaSlice<u8>,
|
||||
@ -30,26 +70,18 @@ fn dequantize(
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
|
||||
GgmlDType::Q5_0 => {
|
||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
(
|
||||
"dequantize_block_q5_0",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
nb,
|
||||
)
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
(
|
||||
"dequantize_block_q5_1",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
nb,
|
||||
)
|
||||
}
|
||||
GgmlDType::Q5_0 => (
|
||||
"dequantize_block_q5_0",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q5_1 => (
|
||||
"dequantize_block_q5_1",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
|
||||
@ -60,7 +92,7 @@ fn dequantize(
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
|
||||
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
@ -83,9 +115,9 @@ fn dequantize(
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn dequantize_mut_mal_vec(
|
||||
fn dequantize_mul_mat_vec(
|
||||
data: &CudaSlice<u8>,
|
||||
y: &cudarc::driver::CudaView<f32>,
|
||||
y: &CudaView<f32>,
|
||||
dtype: GgmlDType,
|
||||
ncols: usize,
|
||||
nrows: usize,
|
||||
@ -93,6 +125,13 @@ fn dequantize_mut_mal_vec(
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
||||
if data_elems < ncols * nrows {
|
||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||
}
|
||||
if y.len() != ncols {
|
||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||
}
|
||||
let kernel_name = match dtype {
|
||||
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
|
||||
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
|
||||
@ -107,8 +146,8 @@ fn dequantize_mut_mal_vec(
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = dev.alloc_zeros::<f32>(nrows).w()?;
|
||||
let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
||||
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (block_num_y as u32, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
|
||||
@ -120,9 +159,66 @@ fn dequantize_mut_mal_vec(
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn mul_mat_vec_via_q8_1(
|
||||
data: &CudaSlice<u8>,
|
||||
y: &CudaView<f32>,
|
||||
dtype: GgmlDType,
|
||||
ncols: usize,
|
||||
nrows: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
||||
if data_elems < ncols * nrows {
|
||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||
}
|
||||
if y.len() != ncols {
|
||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||
}
|
||||
// Start by quantizing y
|
||||
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||
let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||
quantize_q8_1(y, &mut y_q8_1, ncols, dev)?;
|
||||
|
||||
let kernel_name = match dtype {
|
||||
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
|
||||
GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda",
|
||||
GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda",
|
||||
GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda",
|
||||
GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda",
|
||||
GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda",
|
||||
GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda",
|
||||
GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda",
|
||||
GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda",
|
||||
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (nrows as u32, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, 4, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let params = (
|
||||
data,
|
||||
&y_q8_1,
|
||||
&dst,
|
||||
/* ncols_x */ ncols as i32,
|
||||
/* nrows_x */ nrows as i32,
|
||||
/* nrows_y */ ncols as i32,
|
||||
/* nrows_dst */ nrows as i32,
|
||||
);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
impl QCudaStorage {
|
||||
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||
let size_in_bytes = el_count * dtype.type_size() / dtype.block_size();
|
||||
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
|
||||
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
|
||||
Ok(QCudaStorage {
|
||||
data,
|
||||
@ -140,6 +236,12 @@ impl QCudaStorage {
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||
fn deq<T: GgmlType>(buffer: &[u8], n: usize, dst: &mut [f32]) -> Result<()> {
|
||||
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
|
||||
let vec = slice.to_vec();
|
||||
T::to_float(&vec, dst)
|
||||
}
|
||||
|
||||
let fast_kernel = matches!(
|
||||
self.dtype,
|
||||
GgmlDType::Q4_0
|
||||
@ -158,69 +260,25 @@ impl QCudaStorage {
|
||||
return dequantize(&self.data, self.dtype, elem_count, self.device());
|
||||
}
|
||||
// Run the dequantization on cpu.
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
|
||||
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
let block_len = elem_count / self.dtype.block_size();
|
||||
match self.dtype {
|
||||
GgmlDType::F32 => {
|
||||
let slice =
|
||||
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, block_len) };
|
||||
out.copy_from_slice(slice)
|
||||
}
|
||||
GgmlDType::F16 => {
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
|
||||
half::f16::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8K => {
|
||||
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q5_1 => deq::<crate::quantized::BlockQ5_1>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q8_0 => deq::<crate::quantized::BlockQ8_0>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q8_1 => deq::<crate::quantized::BlockQ8_1>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q2K => deq::<crate::quantized::BlockQ2K>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q3K => deq::<crate::quantized::BlockQ3K>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q4K => deq::<crate::quantized::BlockQ4K>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q5K => deq::<crate::quantized::BlockQ5K>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q6K => deq::<crate::quantized::BlockQ6K>(&buffer, block_len, &mut out)?,
|
||||
GgmlDType::Q8K => deq::<crate::quantized::BlockQ8K>(&buffer, block_len, &mut out)?,
|
||||
}
|
||||
|
||||
self.device
|
||||
@ -285,8 +343,11 @@ impl QCudaStorage {
|
||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
||||
}
|
||||
|
||||
let out =
|
||||
dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?;
|
||||
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
||||
} else {
|
||||
mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
||||
};
|
||||
let out_shape = if with_batch {
|
||||
vec![1, 1, nrows]
|
||||
} else {
|
||||
@ -313,7 +374,7 @@ impl QCudaStorage {
|
||||
}
|
||||
|
||||
let data_f32 = self.dequantize(n * k)?;
|
||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0);
|
||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
|
||||
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
|
||||
let mut out_shape = layout.shape().dims().to_vec();
|
||||
out_shape.pop();
|
||||
@ -322,11 +383,6 @@ impl QCudaStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &[u8], n: usize) -> Vec<T> {
|
||||
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
|
||||
slice.to_vec()
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
device: &CudaDevice,
|
||||
data: &[T],
|
||||
@ -341,3 +397,60 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
dtype: T::DTYPE,
|
||||
}))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn cuda_quantize_q8_1() -> Result<()> {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let el = 256;
|
||||
let el_padded = pad(el, MATRIX_ROW_PADDING);
|
||||
let y_size_in_bytes =
|
||||
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
||||
let y = dev.htod_sync_copy(&vs).w()?;
|
||||
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, &dev)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_mmv_q8_1() -> Result<()> {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let ncols = 256;
|
||||
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
||||
let y = dev.htod_sync_copy(&vs).w()?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_vec_via_q8_1(
|
||||
&xs.data,
|
||||
&y.slice(..),
|
||||
/* dtype */ GgmlDType::Q4_0,
|
||||
/* ncols */ ncols,
|
||||
/* nrows */ 1,
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||
assert_eq!(vs.len(), 1);
|
||||
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
||||
// Q8 means 1/256 precision.
|
||||
assert_eq!(vs[0], 5561664.5);
|
||||
|
||||
let cuda_storage = dequantize_mul_mat_vec(
|
||||
&xs.data,
|
||||
&y.slice(..),
|
||||
/* dtype */ GgmlDType::Q4_0,
|
||||
/* ncols */ ncols,
|
||||
/* nrows */ 1,
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||
assert_eq!(vs.len(), 1);
|
||||
assert_eq!(vs[0], 5561851.0);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -149,8 +149,11 @@ impl QMetalStorage {
|
||||
let (n, k) = self_shape.dims2()?;
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
|
||||
// We always use a single batch dimension and stack all the tensors in the batch on the
|
||||
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
|
||||
// properly.
|
||||
let (b, m) = match dst_shape.len() {
|
||||
3 => (dst_shape[0], dst_shape[1]),
|
||||
3 => (1, dst_shape[0] * dst_shape[1]),
|
||||
2 => (1, dst_shape[0]),
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
|
@ -398,7 +398,7 @@ impl QMatMul {
|
||||
_ => DEQUANTIZE_ALL.with(|b| *b),
|
||||
};
|
||||
let t = if dequantize {
|
||||
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||
let tensor = qtensor.dequantize(&qtensor.device())?;
|
||||
Self::Tensor(tensor)
|
||||
} else {
|
||||
Self::QTensor(qtensor)
|
||||
|
@ -171,7 +171,7 @@ impl Shape {
|
||||
}
|
||||
let mut acc = 1;
|
||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
||||
if stride != acc {
|
||||
if dim > 1 && stride != acc {
|
||||
return false;
|
||||
}
|
||||
acc *= dim;
|
||||
@ -186,7 +186,7 @@ impl Shape {
|
||||
}
|
||||
let mut acc = 1;
|
||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
|
||||
if stride != acc {
|
||||
if dim > 1 && stride != acc {
|
||||
return false;
|
||||
}
|
||||
acc *= dim;
|
||||
|
@ -1,6 +1,7 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||
use crate::op::{self, CmpOp, ReduceOp};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
// out of memory. Instead try_clone should be used.
|
||||
@ -43,9 +44,19 @@ impl Storage {
|
||||
}
|
||||
|
||||
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||
let lhs = self.device().location();
|
||||
let rhs = rhs.device().location();
|
||||
if lhs != rhs {
|
||||
let lhs_device = self.device();
|
||||
let rhs_device = rhs.device();
|
||||
let lhs = lhs_device.location();
|
||||
let rhs = rhs_device.location();
|
||||
let same_device = if self.device().is_metal() {
|
||||
// On metal, we require the device to be exactly the same rather than
|
||||
// having the same location. In cuda this is not necessary as all CudaDevice on the
|
||||
// same GPU will use the same cuda stream.
|
||||
lhs_device.same_device(&rhs_device)
|
||||
} else {
|
||||
lhs == rhs
|
||||
};
|
||||
if !same_device {
|
||||
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
|
||||
} else {
|
||||
Ok(())
|
||||
@ -252,6 +263,51 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
|
||||
match self {
|
||||
Self::Cpu(storage) => c.cpu_fwd(storage, l),
|
||||
Self::Cuda(storage) => c.cuda_fwd(storage, l),
|
||||
Self::Metal(storage) => c.metal_fwd(storage, l),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn inplace_op2(
|
||||
&mut self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
l2: &Layout,
|
||||
c: &dyn InplaceOp2,
|
||||
) -> Result<()> {
|
||||
self.same_device(t2, c.name())?;
|
||||
match (self, t2) {
|
||||
(Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
|
||||
(Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
|
||||
(Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn inplace_op3(
|
||||
&mut self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
l2: &Layout,
|
||||
t3: &Self,
|
||||
l3: &Layout,
|
||||
c: &dyn InplaceOp3,
|
||||
) -> Result<()> {
|
||||
self.same_device(t2, c.name())?;
|
||||
self.same_device(t3, c.name())?;
|
||||
match (self, t2, t3) {
|
||||
(Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
|
||||
(Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
|
||||
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
|
||||
c.metal_fwd(s1, l1, s2, l2, s3, l3)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
@ -701,4 +757,32 @@ impl Storage {
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn copy2d(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o: usize,
|
||||
dst_o: usize,
|
||||
) -> Result<()> {
|
||||
match (self, dst) {
|
||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
|
||||
(Self::Cuda(src), Self::Cuda(dst)) => {
|
||||
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
||||
}
|
||||
(Self::Metal(src), Self::Metal(dst)) => {
|
||||
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "copy2d",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,7 @@
|
||||
//! Tensors are N-dimensional matrixes of elements using a single data type.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{
|
||||
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
|
||||
};
|
||||
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
@ -512,6 +510,7 @@ impl Tensor {
|
||||
unary_op!(ceil, Ceil);
|
||||
unary_op!(floor, Floor);
|
||||
unary_op!(round, Round);
|
||||
unary_op!(sign, Sign);
|
||||
|
||||
/// Round element of the input tensor to the nearest integer.
|
||||
///
|
||||
@ -666,7 +665,7 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||
pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||
if dim >= self.dims().len() {
|
||||
Err(Error::DimOutOfRange {
|
||||
shape: self.shape().clone(),
|
||||
@ -1351,7 +1350,7 @@ impl Tensor {
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
||||
let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let offset = start * src.dims()[1..].iter().product::<usize>();
|
||||
@ -2001,7 +2000,7 @@ impl Tensor {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let op = BackpropOp::new1(self, Op::Copy);
|
||||
@ -2009,11 +2008,21 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a tensor that is in row major order. This always makes a copy.
|
||||
pub fn force_contiguous(&self) -> Result<Tensor> {
|
||||
let shape = self.shape();
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let op = BackpropOp::new1(self, Op::Copy);
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
|
||||
/// Create a variable based on the values currently stored in a tensor. The storage is always
|
||||
/// copied.
|
||||
pub(crate) fn make_var(&self) -> Result<Tensor> {
|
||||
let shape = self.shape().clone();
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), true))
|
||||
@ -2066,7 +2075,7 @@ impl Tensor {
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
} else {
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
@ -2093,8 +2102,19 @@ impl Tensor {
|
||||
let dim = dim.to_index(self.shape(), "squeeze")?;
|
||||
if dims[dim] == 1 {
|
||||
let mut dims = dims.to_vec();
|
||||
let mut strides = self.stride().to_vec();
|
||||
dims.remove(dim);
|
||||
self.reshape(dims)
|
||||
strides.remove(dim);
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
|
||||
op: BackpropOp::new1(self, Op::Reshape),
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
} else {
|
||||
Ok(self.clone())
|
||||
}
|
||||
@ -2115,10 +2135,24 @@ impl Tensor {
|
||||
/// ```
|
||||
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let mut dims = self.dims().to_vec();
|
||||
let mut strides = self.stride().to_vec();
|
||||
let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
|
||||
// Cannot panic because to_index_plus_one already checks dimensions
|
||||
dims.insert(dim, 1);
|
||||
self.reshape(dims)
|
||||
// Any stride would work here, but we pick one so as to maximize the probability to remain
|
||||
// C contiguous.
|
||||
let stride = if dim < strides.len() { strides[dim] } else { 1 };
|
||||
strides.insert(dim, stride);
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
|
||||
op: BackpropOp::new1(self, Op::Reshape),
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// Stacks two or more tensors along a particular dimension.
|
||||
@ -2149,152 +2183,6 @@ impl Tensor {
|
||||
Self::cat(&args, dim)
|
||||
}
|
||||
|
||||
/// Concatenates two or more tensors along a particular dimension.
|
||||
///
|
||||
/// All tensors must of the same rank, and the output will have
|
||||
/// the same rank
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, DType, Device};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 0)?;
|
||||
/// assert_eq!(c.shape().dims(), &[4, 3]);
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 1)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 6]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let dim = dim.to_index(arg0.shape(), "cat")?;
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg0.rank() != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: arg0.rank(),
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
// TODO: Avoid these transpositions and have an implementation that works
|
||||
// for dim != 0...
|
||||
let args: Vec<Tensor> = args
|
||||
.iter()
|
||||
.map(|a| a.as_ref().transpose(0, dim))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cat = Self::cat0(&args)?;
|
||||
cat.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
|
||||
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let rank = arg0.rank();
|
||||
let device = arg0.device();
|
||||
let dtype = arg0.dtype();
|
||||
let first_dims = arg0.shape().dims();
|
||||
let mut cat_dims = first_dims.to_vec();
|
||||
cat_dims[0] = 0;
|
||||
let mut offsets = vec![0usize];
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg.dtype() != dtype {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if rank != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: rank,
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx == 0 {
|
||||
cat_dims[0] += v2;
|
||||
}
|
||||
if dim_idx != 0 && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
||||
offsets.push(next_offset);
|
||||
}
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = BackpropOp::new(args, |args| Op::Cat(args, 0));
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
let arg = arg.as_ref();
|
||||
arg.storage()
|
||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
||||
}
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
|
||||
/// input tensor values and `right` elements after.
|
||||
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||
@ -2377,6 +2265,10 @@ impl Tensor {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
|
||||
pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
|
||||
self.storage.write().unwrap()
|
||||
}
|
||||
|
||||
// If we extend the visibility of this function to be usable outside of this crate, we should
|
||||
// make it unsafe.
|
||||
pub(crate) fn storage_mut_and_layout(
|
||||
@ -2398,96 +2290,6 @@ impl Tensor {
|
||||
std::ptr::eq(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Applies a unary custom op without backward support
|
||||
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a binary custom op without backward support
|
||||
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) =
|
||||
self.storage()
|
||||
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op without backward support
|
||||
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
&t3.storage(),
|
||||
t3.layout(),
|
||||
c,
|
||||
)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a unary custom op.
|
||||
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
|
||||
let (storage, shape) = self
|
||||
.storage()
|
||||
.apply_op1(self.layout(), c.as_ref().as_ref())?;
|
||||
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
|
||||
self.apply_op1_arc(Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a binary custom op.
|
||||
pub fn apply_op2_arc(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||
) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op2(
|
||||
self.layout(),
|
||||
&rhs.storage(),
|
||||
rhs.layout(),
|
||||
c.as_ref().as_ref(),
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
|
||||
self.apply_op2_arc(r, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op.
|
||||
pub fn apply_op3_arc(
|
||||
&self,
|
||||
t2: &Self,
|
||||
t3: &Self,
|
||||
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||
) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
&t3.storage(),
|
||||
t3.layout(),
|
||||
c.as_ref().as_ref(),
|
||||
)?;
|
||||
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
|
||||
Op::CustomOp3(t1, t2, t3, c.clone())
|
||||
});
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
|
||||
&self,
|
||||
t2: &Self,
|
||||
t3: &Self,
|
||||
c: C,
|
||||
) -> Result<Self> {
|
||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Normalize a 'relative' axis value: positive values are kept, negative
|
||||
/// values means counting the dimensions from the back.
|
||||
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||
|
238
candle-core/src/tensor_cat.rs
Normal file
238
candle-core/src/tensor_cat.rs
Normal file
@ -0,0 +1,238 @@
|
||||
use crate::{shape::Dim, Error, Result, Shape, Tensor};
|
||||
|
||||
impl Tensor {
|
||||
/// Concatenates two or more tensors along a particular dimension.
|
||||
///
|
||||
/// All tensors must of the same rank, and the output will have
|
||||
/// the same rank
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, DType, Device};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 0)?;
|
||||
/// assert_eq!(c.shape().dims(), &[4, 3]);
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 1)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 6]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let dim = dim.to_index(arg0.shape(), "cat")?;
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg0.rank() != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: arg0.rank(),
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
|
||||
if all_contiguous {
|
||||
Self::cat_contiguous(args, dim)
|
||||
} else if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
let args: Vec<Tensor> = args
|
||||
.iter()
|
||||
.map(|a| a.as_ref().transpose(0, dim))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cat = Self::cat0(&args)?;
|
||||
cat.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
|
||||
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let rank = arg0.rank();
|
||||
let device = arg0.device();
|
||||
let dtype = arg0.dtype();
|
||||
let first_dims = arg0.shape().dims();
|
||||
let mut cat_dims = first_dims.to_vec();
|
||||
cat_dims[0] = 0;
|
||||
let mut offsets = vec![0usize];
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg.dtype() != dtype {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if rank != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: rank,
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx == 0 {
|
||||
cat_dims[0] += v2;
|
||||
}
|
||||
if dim_idx != 0 && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
||||
offsets.push(next_offset);
|
||||
}
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
|
||||
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
let arg = arg.as_ref();
|
||||
arg.storage()
|
||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
||||
}
|
||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let rank = arg0.rank();
|
||||
let device = arg0.device();
|
||||
let dtype = arg0.dtype();
|
||||
let first_dims = arg0.shape().dims();
|
||||
let mut cat_dims = first_dims.to_vec();
|
||||
cat_dims[dim] = 0;
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg.dtype() != dtype {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if rank != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: rank,
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx == dim {
|
||||
cat_dims[dim] += v2;
|
||||
}
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
let cat_target_dim_len = cat_dims[dim];
|
||||
let block_size: usize = cat_dims.iter().skip(1 + dim).product();
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
|
||||
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
||||
let mut dst_o = 0;
|
||||
for arg in args.iter() {
|
||||
let arg = arg.as_ref();
|
||||
let arg_dims = arg.shape().dims();
|
||||
let d1: usize = arg_dims.iter().take(dim).product();
|
||||
let d2 = block_size * arg_dims[dim];
|
||||
let dst_s = block_size * cat_target_dim_len;
|
||||
let src_o = arg.layout().start_offset();
|
||||
arg.storage().copy2d(
|
||||
&mut storage,
|
||||
d1,
|
||||
d2,
|
||||
/* src_s */ d2,
|
||||
dst_s,
|
||||
src_o,
|
||||
dst_o,
|
||||
)?;
|
||||
dst_o += d2;
|
||||
}
|
||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||
}
|
||||
}
|
@ -53,26 +53,31 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 4, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
|
||||
[
|
||||
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
|
||||
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
|
||||
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
|
||||
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
|
||||
]
|
||||
);
|
||||
|
||||
let w = w.transpose(0, 1)?;
|
||||
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
||||
for w in [w.clone(), w.contiguous()?] {
|
||||
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 4, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
|
||||
[
|
||||
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
|
||||
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
|
||||
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
|
||||
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
|
||||
]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -130,7 +135,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||
],
|
||||
dev,
|
||||
)?;
|
||||
@ -158,7 +163,9 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
|
||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
@ -183,6 +190,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
// Dilations.
|
||||
let res = t.conv2d(&w, 0, 1, 2, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 1, 1]);
|
||||
@ -221,6 +229,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -267,13 +276,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000,
|
||||
0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855,
|
||||
-1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1640,
|
||||
-0.0111, -0.1742, 0.0, 0.0, 0.0, 0.0, 2.6437, -2.0268, 1.1823, 0.0, 0.0, 0.0, 0.0,
|
||||
3.2855, -1.0324, 0.2539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0
|
||||
]
|
||||
);
|
||||
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
@ -375,6 +384,7 @@ print(w.grad.shape)
|
||||
print(w.grad[0])
|
||||
*/
|
||||
fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
// conv-transposes are not implemented for metal
|
||||
use candle_core::Var;
|
||||
let t = Var::from_slice(
|
||||
&[
|
||||
@ -387,7 +397,7 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||
],
|
||||
(1, 4, 5, 5),
|
||||
dev,
|
||||
@ -572,6 +582,154 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
// Conv Transpose 2d Test
|
||||
//tested against following python
|
||||
|
||||
// import torch
|
||||
// torch.manual_seed(4242)
|
||||
// padding = 4
|
||||
// outpadding = 2
|
||||
// dilation = 3
|
||||
// stride = 3
|
||||
// input = torch.randn((1, 4, 7, 5), requires_grad=True)
|
||||
// kernel = torch.randn((4, 2, 3, 5), requires_grad=True)
|
||||
// print("input", input.flatten())
|
||||
// print("kernel", kernel.flatten())
|
||||
// res = torch.nn.functional.conv_transpose2d(
|
||||
// input,
|
||||
// kernel,
|
||||
// stride=stride,
|
||||
// padding=padding,
|
||||
// dilation=dilation,
|
||||
// output_padding=outpadding,
|
||||
// )
|
||||
// res.retain_grad()
|
||||
// print(res.shape)
|
||||
// loss = (res**2).sum()
|
||||
// print(loss)
|
||||
// loss.backward()
|
||||
// print(input.grad.shape)
|
||||
// print("input grad", torch.round(input.grad, decimals=1))
|
||||
// print(kernel.grad.shape)
|
||||
// print("kernel grad", torch.round(kernel.grad.flatten(), decimals=1))
|
||||
|
||||
let padding = 4;
|
||||
let outpadding = 2;
|
||||
let dilation = 3;
|
||||
let stride = 3;
|
||||
|
||||
let t = Var::from_slice(
|
||||
&[
|
||||
0.4056_f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997,
|
||||
3.0616, 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843,
|
||||
0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013,
|
||||
-0.6836, 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130,
|
||||
1.3123, 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071,
|
||||
1.1586, 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090,
|
||||
0.2049, 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323,
|
||||
-1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742,
|
||||
0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.5411, -1.1715, -2.4997, 2.3249, -0.8912,
|
||||
-0.4733, -0.5701, -2.8888, -1.4112, -0.5471, -0.9234, -1.1660, 0.4189, -0.7465,
|
||||
-0.6473, 0.1402, 0.7875, 0.5377, -0.6779, -0.8088, -0.4864, -0.2312, 0.9279, 0.1264,
|
||||
1.5480, 0.8265, -0.1025, 0.5138, -0.2512, 0.1576, 1.2705, 0.3641, -0.9325, 0.6451,
|
||||
-0.8537, 0.2378, 0.1794, 0.2752, -0.3687, -1.1149, -0.1410, -0.5829, -0.0892, 1.4258,
|
||||
-2.2789, 0.5270, 0.1825, 1.7007, -0.5263, -0.2954, 0.4440, 0.5537, 0.3492, 0.6186,
|
||||
1.6475, 0.2219,
|
||||
],
|
||||
(1, 4, 7, 5),
|
||||
dev,
|
||||
)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let w = Var::from_slice(
|
||||
&[
|
||||
-1.1744_f32, 0.3266, 2.5893, 1.0142, 0.1763, 0.7752, 0.6604, 0.2029, -0.2145, 0.7234,
|
||||
-0.3441, -1.5400, -0.6333, 0.6613, 0.2083, 0.6230, -1.7002, 0.3393, 0.4049, 1.0762,
|
||||
0.2723, 1.4181, 0.0029, -0.2122, 1.7668, 1.4168, 0.3320, -0.2719, 0.7932, -0.7204,
|
||||
0.4447, 0.1211, 0.5908, 1.0089, -0.1646, 1.8033, -0.6286, 0.2016, -0.3370, 1.2555,
|
||||
0.8009, -0.6488, -0.4652, -1.5685, 1.5860, 0.5583, 0.4623, 0.6026, 0.8828, 2.4990,
|
||||
0.6811, -0.3369, 1.3320, 1.7669, -1.1067, 1.2958, -0.9415, -0.9655, -0.4462, 0.7181,
|
||||
0.5181, -1.1658, -1.8467, -0.7763, 1.2769, 0.8651, 0.9890, 1.5092, 0.7207, -0.8481,
|
||||
0.7417, 0.3375, -1.2685, 1.4572, 1.0915, 0.1093, -0.8550, -0.5831, -0.6309, -0.2509,
|
||||
0.5220, -0.0914, 0.7900, 0.1096, 0.3258, 0.2723, -1.0942, -0.3393, -0.1653, 0.5732,
|
||||
-0.8014, 1.8194, -1.9023, 0.2127, 1.8636, -0.8979, 0.1927, -0.2778, 0.3105, 0.0071,
|
||||
-1.1823, 0.2476, -0.7178, -1.3821, 1.0769, -0.4376, -0.9967, -0.1227, 1.6197, -1.0604,
|
||||
0.1372, 0.8141, -0.6163, 0.7304, -0.8285, 2.0636, -0.7176, 0.2495, -0.2581, -0.4478,
|
||||
],
|
||||
(4, 2, 3, 5),
|
||||
dev,
|
||||
)?;
|
||||
let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 2904.0);
|
||||
let grads = loss.backward()?;
|
||||
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 7, 5]);
|
||||
assert_eq!(grad_w.dims(), [4, 2, 3, 5]);
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?,
|
||||
[
|
||||
// torch gets 89.1
|
||||
-89.0, -135.3, 136.7, 102.0, -53.4, 117.9, 118.6, -43.9, -218.0, -58.5, -114.3, -150.0,
|
||||
-15.6, 172.1, 66.3, -64.3, -27.9, -19.8, 31.7, 62.1, 5.5, 92.6, 28.2, -29.6, 55.9,
|
||||
52.7, -72.7, -119.8, 53.8, -25.5, 128.8, 19.3, 68.0, 190.9, -64.1, -86.2, -111.2,
|
||||
106.6, -67.7, 37.8, 115.9, 50.4, -77.7, -54.9, 22.3, -4.6, 89.8, 61.7, 122.4, 192.6,
|
||||
-27.8, -104.6, 57.0, 166.4, 27.1, 6.1, 18.7, -93.2, 31.5, 168.2, -3.7, -99.5, -55.5,
|
||||
-10.8, 17.5, 20.8, 16.9, 43.8, 42.0, -89.2, 18.8, -9.6, -84.1, 212.6, 19.7, -50.0,
|
||||
-52.0, -40.0, -166.6, -73.2, -10.8, -73.3, 31.5, -23.4, -79.3, -27.0, -84.4, -42.9,
|
||||
-20.3, 51.8, -16.7, 76.3, -120.5, -65.8, 96.5, -10.7, -45.9, -88.1, 65.4, -7.0, -1.5,
|
||||
92.8, -25.1, -114.2, -5.8, -14.8, -51.2, -20.7, 54.2, -79.8, 47.7, -29.2, -8.8, 53.5,
|
||||
-28.4, 85.0, -18.3, 107.0, 28.3, -71.8
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_t.i(0)?, 1)?,
|
||||
[
|
||||
[
|
||||
[32.3, -41.6, -24.0, 14.1, 17.6],
|
||||
[-11.8, 72.5, 87.6, 46.4, 61.5],
|
||||
[115.0, 108.5, -48.6, -63.4, -50.0],
|
||||
[51.3, 5.4, 31.3, 91.1, -30.9],
|
||||
[52.7, 92.8, -68.0, -47.0, 83.0],
|
||||
// pytorch gets -107.1
|
||||
[-10.2, -107.0, -5.4, 213.1, -31.4],
|
||||
[-2.4, 65.1, 9.2, -146.2, -24.2]
|
||||
],
|
||||
[
|
||||
[-72.6, -63.9, -61.9, 45.3, 33.0],
|
||||
[79.3, -0.5, -26.2, 78.2, 42.7],
|
||||
[90.9, 141.6, 40.1, -62.7, 37.0],
|
||||
[32.8, 198.2, -0.8, -31.1, 27.3],
|
||||
// torch gets 48.0
|
||||
[34.5, 34.9, -47.9, 127.6, -12.3],
|
||||
[-61.4, -3.2, -2.9, -10.9, -16.6],
|
||||
[74.6, 60.1, -68.9, 34.5, -50.4]
|
||||
],
|
||||
[
|
||||
[37.5, -56.9, -43.6, -13.5, -9.9],
|
||||
[40.0, 97.3, 28.6, 14.2, -30.1],
|
||||
[-22.3, -126.3, -68.8, -8.2, 26.1],
|
||||
[-32.9, 37.3, 108.5, -54.8, 29.6],
|
||||
[34.9, -176.9, -125.0, -28.3, -13.9],
|
||||
[-54.9, 142.6, 62.1, -80.4, -65.6],
|
||||
[7.4, -91.1, -67.6, 35.0, 39.7]
|
||||
],
|
||||
[
|
||||
[-57.2, -40.9, -10.1, 32.6, 29.4],
|
||||
[18.7, -18.0, 29.5, -1.2, 59.2],
|
||||
[-14.0, -74.4, 19.8, -117.0, 58.2],
|
||||
[-21.8, 163.5, -71.1, -99.0, 80.9],
|
||||
[-58.9, -10.9, 93.8, -139.6, 98.0],
|
||||
// torch gets 54.5
|
||||
[-54.4, 135.3, 6.0, -79.1, 134.6],
|
||||
[27.5, -76.0, 43.4, -2.8, -7.8]
|
||||
]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -112,3 +112,34 @@ fn custom_op1_with_backward() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl candle_core::InplaceOp1 for Elu {
|
||||
fn name(&self) -> &'static str {
|
||||
"elu"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, s: &mut CpuStorage, _l: &Layout) -> Result<()> {
|
||||
let alpha = self.alpha;
|
||||
match s {
|
||||
CpuStorage::BF16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
||||
CpuStorage::F16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
||||
CpuStorage::F32(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
||||
CpuStorage::F64(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
|
||||
_ => candle_core::bail!("unsupported dtype for inplace elu"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inplace_op1() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
|
||||
let t = (t - 5.)?;
|
||||
t.inplace_op1(&Elu { alpha: 1. })?;
|
||||
assert_eq!(
|
||||
to_vec1_round(&t, 4)?,
|
||||
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
#![allow(clippy::approx_constant)]
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
||||
|
||||
@ -96,24 +97,24 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
y.to_vec1::<f32>()?,
|
||||
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[20.0855, 2.7183, 54.5982, 1.1618]
|
||||
);
|
||||
assert_eq!(
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[20.0855, 2.7183, 54.5982, 1.1618]
|
||||
);
|
||||
let y = x.exp()?.sqr()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
y.to_vec1::<f32>()?,
|
||||
[403.4288, 7.3890557, 2980.9578, 1.3498588]
|
||||
test_utils::to_vec1_round(&y, 3)?,
|
||||
[403.429, 7.389, 2980.958, 1.35]
|
||||
);
|
||||
// exp(x)^2 = exp(2*x)
|
||||
assert_eq!(
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[806.8576, 14.778111, 5961.9155, 2.6997175]
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[806.86, 14.78, 5961.92, 2.7]
|
||||
);
|
||||
let y = x.sin()?;
|
||||
let grads = y.backward()?;
|
||||
@ -261,6 +262,7 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let y = elu_x.elu(2.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||
|
@ -88,7 +88,7 @@ fn strided_blocks() -> Result<()> {
|
||||
}
|
||||
};
|
||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||
let tensor = tensor.i((.., 1))?;
|
||||
let tensor = tensor.i((.., 1))?.contiguous()?;
|
||||
match tensor.strided_blocks() {
|
||||
candle::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
assert_eq!(start_offset, 0);
|
||||
@ -100,6 +100,20 @@ fn strided_blocks() -> Result<()> {
|
||||
}
|
||||
};
|
||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||
let tensor = tensor.i((.., 1))?;
|
||||
match tensor.strided_blocks() {
|
||||
candle::StridedBlocks::SingleBlock { .. } => {
|
||||
panic!("unexpected block structure")
|
||||
}
|
||||
candle::StridedBlocks::MultipleBlocks {
|
||||
block_len,
|
||||
block_start_index,
|
||||
} => {
|
||||
assert_eq!(block_len, 4);
|
||||
assert_eq!(block_start_index.collect::<Vec<_>>(), &[4, 16])
|
||||
}
|
||||
};
|
||||
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
|
||||
match tensor.t()?.strided_blocks() {
|
||||
candle::StridedBlocks::SingleBlock { .. } => {
|
||||
panic!("unexpected block structure")
|
||||
|
106
candle-core/tests/matmul_tests.rs
Normal file
106
candle-core/tests/matmul_tests.rs
Normal file
@ -0,0 +1,106 @@
|
||||
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
|
||||
|
||||
fn matmul(device: &Device) -> Result<()> {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||
|
||||
let data = vec![1.0f32, 2.0];
|
||||
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||
let data = vec![3.0f32, 4.0];
|
||||
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||
|
||||
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
||||
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||
|
||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
||||
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
||||
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
||||
|
||||
// Also perform the matmul on contiguous transposed versions.
|
||||
let a_tt = a.t()?.contiguous()?.t()?;
|
||||
assert!(!a_tt.is_contiguous());
|
||||
assert_eq!(a.dims(), a_tt.dims());
|
||||
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||
|
||||
let b_tt = b.t()?.contiguous()?.t()?;
|
||||
assert!(!b_tt.is_contiguous());
|
||||
assert_eq!(b.dims(), b_tt.dims());
|
||||
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||
|
||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||
let out = lhs.broadcast_matmul(&rhs)?;
|
||||
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
||||
for idx1 in 0..3 {
|
||||
for idx2 in 0..6 {
|
||||
let out = out.i((idx1, idx2))?;
|
||||
let lhs = lhs.i((idx1, 0))?;
|
||||
let rhs = rhs.i(idx2)?;
|
||||
let out2 = lhs.matmul(&rhs);
|
||||
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
||||
// With cuda, we see errors of up to ~1e-12.
|
||||
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/candle/issues/1948
|
||||
fn squeeze_mm(device: &Device) -> Result<()> {
|
||||
let seq_len = 8_usize;
|
||||
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
|
||||
let x = a.i((.., seq_len - 1, ..))?;
|
||||
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
|
||||
let x = x.matmul(&w)?;
|
||||
assert_eq!(x.dims(), &[1, 32]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/candle/issues/1992
|
||||
fn mm_layout(device: &Device) -> Result<()> {
|
||||
let a = Tensor::arange(0f32, 16f32, device)?.reshape((1, 1, 4, 4))?;
|
||||
let b = Tensor::arange(0f32, 8f32, device)?.reshape((1, 1, 4, 2))?;
|
||||
let mm1 = a.matmul(&b)?;
|
||||
// Forces the layout to be:
|
||||
// shape: [1, 1, 4, 2], stride: [8, 2, 2, 1], start_offset: 0
|
||||
// This is still a contiguous matrix but matmul checks are only the two last dimensions have
|
||||
// non 1 sizes but matmul check may be reluctant to handle it.
|
||||
let b = b.transpose(1, 2)?.force_contiguous()?.transpose(1, 2)?;
|
||||
let mm2 = a.matmul(&b)?;
|
||||
let diff = (mm1 - mm2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||
test_device!(
|
||||
broadcast_matmul,
|
||||
broadcast_matmul_cpu,
|
||||
broadcast_matmul_gpu,
|
||||
broadcast_matmul_metal
|
||||
);
|
||||
test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal);
|
||||
test_device!(mm_layout, mm_layout_cpu, mm_layout_gpu, mm_layout_metal);
|
@ -43,6 +43,9 @@ res = torch.nn.functional.avg_pool2d(t, 2)
|
||||
print(res)
|
||||
*/
|
||||
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
|
@ -106,6 +106,9 @@ fn unary_op(device: &Device) -> Result<()> {
|
||||
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
|
||||
]
|
||||
);
|
||||
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
|
||||
let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?;
|
||||
assert!(max_diff.to_vec0::<f32>()? < 5e-3);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
|
||||
[
|
||||
@ -148,6 +151,14 @@ fn unary_op(device: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
|
||||
[3000.0, 300.]
|
||||
);
|
||||
let tensor = Tensor::new(
|
||||
&[-1.01f32, -0.9, -0.1, 0.0, -0.0, 0.1, 0.9, 1.0, 1.1],
|
||||
device,
|
||||
)?;
|
||||
assert_eq!(
|
||||
tensor.sign()?.to_vec1::<f32>()?,
|
||||
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -672,6 +683,31 @@ fn cat(device: &Device) -> Result<()> {
|
||||
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
|
||||
]
|
||||
);
|
||||
|
||||
// 3D
|
||||
let t1 = Tensor::arange(0, 48i64, device)?.reshape((2, 6, 4))?;
|
||||
let t2 = Tensor::arange(100, 124i64, device)?.reshape((2, 3, 4))?;
|
||||
let t3 = Tensor::arange(10000, 10032i64, device)?.reshape((2, 4, 4))?;
|
||||
|
||||
let t_cat = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
||||
|
||||
let t1 = t1.t()?.contiguous()?.t()?;
|
||||
let t2 = t2.t()?.contiguous()?.t()?;
|
||||
let t3 = t3.t()?.contiguous()?.t()?;
|
||||
let t_cat2 = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
||||
|
||||
let diff = t_cat.eq(&t_cat2)?.to_dtype(DType::F32)?.sum_all()?;
|
||||
assert_eq!(diff.to_vec0::<f32>()?, 104.0);
|
||||
assert_eq!(t_cat.i((0, 0, 0))?.to_vec0::<i64>()?, 0);
|
||||
assert_eq!(t_cat.i((0, 4, 0))?.to_vec0::<i64>()?, 16);
|
||||
assert_eq!(t_cat.i((0, 5, 0))?.to_vec0::<i64>()?, 20);
|
||||
assert_eq!(t_cat.i((1, 5, 0))?.to_vec0::<i64>()?, 44);
|
||||
assert_eq!(t_cat.i((0, 6, 0))?.to_vec0::<i64>()?, 100);
|
||||
assert_eq!(t_cat.i((1, 6, 0))?.to_vec0::<i64>()?, 112);
|
||||
assert_eq!(t_cat.i((0, 6, 1))?.to_vec0::<i64>()?, 101);
|
||||
assert_eq!(t_cat.i((0, 7, 1))?.to_vec0::<i64>()?, 105);
|
||||
assert_eq!(t_cat.i((0, 12, 1))?.to_vec0::<i64>()?, 10013);
|
||||
assert_eq!(t_cat.i((1, 12, 3))?.to_vec0::<i64>()?, 10031);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -682,6 +718,8 @@ fn embeddings(device: &Device) -> Result<()> {
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -709,44 +747,47 @@ fn index_select(device: &Device) -> Result<()> {
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 2.0, 1.0],
|
||||
[3.0, 5.0, 4.0],
|
||||
[6.0, 8.0, 7.0],
|
||||
[9.0, 11.0, 10.0]
|
||||
]
|
||||
);
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||
);
|
||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
]
|
||||
);
|
||||
for dtype in [DType::U8, DType::U32, DType::I64] {
|
||||
let ids = ids.to_dtype(dtype)?;
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 2.0, 1.0],
|
||||
[3.0, 5.0, 4.0],
|
||||
[6.0, 8.0, 7.0],
|
||||
[9.0, 11.0, 10.0]
|
||||
]
|
||||
);
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||
);
|
||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
]
|
||||
);
|
||||
|
||||
// Test when selecting dim > 0 with ids size different from elem count of
|
||||
// target dim in source/input.
|
||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||
// Test when selecting dim > 0 with ids size different from elem count of
|
||||
// target dim in source/input.
|
||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||
let hs = t.index_select(&ids, 1)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -908,74 +949,6 @@ fn gather(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn matmul(device: &Device) -> Result<()> {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||
|
||||
let data = vec![1.0f32, 2.0];
|
||||
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||
let data = vec![3.0f32, 4.0];
|
||||
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||
|
||||
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
||||
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||
|
||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
||||
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
||||
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
||||
|
||||
// Also perform the matmul on contiguous transposed versions.
|
||||
let a_tt = a.t()?.contiguous()?.t()?;
|
||||
assert!(!a_tt.is_contiguous());
|
||||
assert_eq!(a.dims(), a_tt.dims());
|
||||
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||
|
||||
let b_tt = b.t()?.contiguous()?.t()?;
|
||||
assert!(!b_tt.is_contiguous());
|
||||
assert_eq!(b.dims(), b_tt.dims());
|
||||
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||
|
||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||
let out = lhs.broadcast_matmul(&rhs)?;
|
||||
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
||||
for idx1 in 0..3 {
|
||||
for idx2 in 0..6 {
|
||||
let out = out.i((idx1, idx2))?;
|
||||
let lhs = lhs.i((idx1, 0))?;
|
||||
let rhs = rhs.i(idx2)?;
|
||||
let out2 = lhs.matmul(&rhs);
|
||||
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
||||
// With cuda, we see errors of up to ~1e-12.
|
||||
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcasting(device: &Device) -> Result<()> {
|
||||
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
|
||||
let t2 = Tensor::new(&[100f32, 200f32], device)?;
|
||||
@ -1080,8 +1053,33 @@ fn broadcasting(device: &Device) -> Result<()> {
|
||||
fn randn(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||
assert_eq!(tensor.dims(), [5, 3]);
|
||||
// Check that the seed gets updated by checking that
|
||||
// a new series of numbers is generated each time
|
||||
let tensor2 = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
|
||||
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||
assert_eq!(tensor.dims(), [5, 3]);
|
||||
// Check that the seed gets updated by checking that
|
||||
// a new series of numbers is generated each time
|
||||
let tensor2 = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
|
||||
// We do not expect deterministic elements at any index.
|
||||
// There once was a bug that had a deterministic zero element in evenly sized tensors.
|
||||
const N: usize = 2;
|
||||
let v = (0..100)
|
||||
.map(|_| Tensor::randn(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
assert!(
|
||||
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
|
||||
"There are deterministic values in the randn tensors"
|
||||
);
|
||||
let v = (0..100)
|
||||
.map(|_| Tensor::rand(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
assert!(
|
||||
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
|
||||
"There are deterministic values in the rand tensors"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1104,13 +1102,6 @@ test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||
test_device!(
|
||||
broadcast_matmul,
|
||||
broadcast_matmul_cpu,
|
||||
broadcast_matmul_gpu,
|
||||
broadcast_matmul_metal
|
||||
);
|
||||
test_device!(
|
||||
broadcasting,
|
||||
broadcasting_cpu,
|
||||
@ -1267,8 +1258,8 @@ fn pow() -> Result<()> {
|
||||
let rhs = (&lhs - 2.)?;
|
||||
let res = lhs.pow(&rhs)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res, 4)?,
|
||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
||||
test_utils::to_vec2_round(&res, 3)?,
|
||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -25,8 +25,9 @@ hf-hub = { workspace = true, features = ["tokio"] }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
|
||||
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
|
||||
rayon = { workspace = true }
|
||||
rubato = { version = "0.15.0", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
@ -41,7 +42,7 @@ clap = { workspace = true }
|
||||
imageproc = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rusttype = { workspace = true }
|
||||
ab_glyph = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
@ -63,6 +64,7 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
microphone = ["cpal"]
|
||||
encodec = ["cpal", "symphonia", "rubato"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
@ -98,6 +100,4 @@ required-features = ["candle-datasets"]
|
||||
|
||||
[[example]]
|
||||
name = "encodec"
|
||||
required-features = ["symphonia"]
|
||||
|
||||
|
||||
required-features = ["encodec"]
|
||||
|
46
candle-examples/examples/clip/README.md
Normal file
46
candle-examples/examples/clip/README.md
Normal file
@ -0,0 +1,46 @@
|
||||
Contrastive Language-Image Pre-Training
|
||||
|
||||
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||
pairs of images with related texts.
|
||||
|
||||
https://github.com/openai/CLIP
|
||||
|
||||
https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
||||
|
||||
## Running on an example on cpu
|
||||
|
||||
```
|
||||
$ cargo run --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
||||
|
||||
|
||||
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||
|
||||
INFO clip: Probability: 0.0000% Text: a cycling race
|
||||
INFO clip: Probability: 0.0000% Text: a photo of two cats
|
||||
INFO clip: Probability: 100.0000% Text: a robot holding a candle
|
||||
|
||||
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
INFO clip: Probability: 99.9999% Text: a cycling race
|
||||
INFO clip: Probability: 0.0001% Text: a photo of two cats
|
||||
INFO clip: Probability: 0.0000% Text: a robot holding a candle
|
||||
```
|
||||
|
||||
## Running on an example with metal feature (mac)
|
||||
|
||||
```
|
||||
$ cargo run --features metal --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
||||
|
||||
|
||||
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||
|
||||
INFO clip: Probability: 0.0000% Text: a cycling race
|
||||
INFO clip: Probability: 0.0000% Text: a photo of two cats
|
||||
INFO clip: Probability: 100.0000% Text: a robot holding a candle
|
||||
|
||||
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
INFO clip: Probability: 99.9999% Text: a cycling race
|
||||
INFO clip: Probability: 0.0001% Text: a photo of two cats
|
||||
INFO clip: Probability: 0.0000% Text: a robot holding a candle
|
||||
```
|
202
candle-examples/examples/clip/main.rs
Normal file
202
candle-examples/examples/clip/main.rs
Normal file
@ -0,0 +1,202 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use candle_transformers::models::clip;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long, use_value_delimiter = true)]
|
||||
images: Option<Vec<String>>,
|
||||
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long, use_value_delimiter = true)]
|
||||
sequences: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
|
||||
let img = image::io::Reader::open(path)?.decode()?;
|
||||
let (height, width) = (image_size, image_size);
|
||||
let img = img.resize_to_fill(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::Triangle,
|
||||
);
|
||||
|
||||
let img = img.to_rgb8();
|
||||
|
||||
let img = img.into_raw();
|
||||
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
|
||||
.permute((2, 0, 1))?
|
||||
.to_dtype(DType::F32)?
|
||||
.affine(2. / 255., -1.)?;
|
||||
// .unsqueeze(0)?;
|
||||
Ok(img)
|
||||
}
|
||||
|
||||
fn load_images<T: AsRef<std::path::Path>>(
|
||||
paths: &Vec<T>,
|
||||
image_size: usize,
|
||||
) -> anyhow::Result<Tensor> {
|
||||
let mut images = vec![];
|
||||
|
||||
for path in paths {
|
||||
let tensor = load_image(path, image_size)?;
|
||||
images.push(tensor);
|
||||
}
|
||||
|
||||
let images = Tensor::stack(&images, 0)?;
|
||||
|
||||
Ok(images)
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
// std::env::set_var("RUST_BACKTRACE", "full");
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
|
||||
let api = api.repo(hf_hub::Repo::with_revision(
|
||||
"openai/clip-vit-base-patch32".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/15".to_string(),
|
||||
));
|
||||
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let tokenizer = get_tokenizer(args.tokenizer)?;
|
||||
|
||||
let config = clip::ClipConfig::vit_base_patch32();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let vec_imgs = match args.images {
|
||||
Some(imgs) => imgs,
|
||||
None => vec![
|
||||
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
|
||||
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
||||
],
|
||||
};
|
||||
|
||||
// let image = load_image(args.image, config.image_size)?.to_device(&device)?;
|
||||
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
|
||||
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
|
||||
|
||||
let model = clip::ClipModel::new(vb, &config)?;
|
||||
|
||||
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
|
||||
|
||||
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
|
||||
|
||||
let softmax_image = softmax(&logits_per_image, 1)?;
|
||||
|
||||
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
|
||||
|
||||
info!("softmax_image_vec: {:?}", softmax_image_vec);
|
||||
|
||||
let probability_vec = softmax_image_vec
|
||||
.iter()
|
||||
.map(|v| v * 100.0)
|
||||
.collect::<Vec<f32>>();
|
||||
|
||||
let probability_per_image = probability_vec.len() / vec_imgs.len();
|
||||
|
||||
for (i, img) in vec_imgs.iter().enumerate() {
|
||||
let start = i * probability_per_image;
|
||||
let end = start + probability_per_image;
|
||||
let prob = &probability_vec[start..end];
|
||||
info!("\n\nResults for image: {}\n", img);
|
||||
|
||||
for (i, p) in prob.iter().enumerate() {
|
||||
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer = match tokenizer {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.repo(hf_hub::Repo::with_revision(
|
||||
"openai/clip-vit-base-patch32".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/15".to_string(),
|
||||
));
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
Some(file) => file.into(),
|
||||
};
|
||||
|
||||
Tokenizer::from_file(tokenizer).map_err(E::msg)
|
||||
}
|
||||
|
||||
pub fn tokenize_sequences(
|
||||
sequences: Option<Vec<String>>,
|
||||
tokenizer: &Tokenizer,
|
||||
device: &Device,
|
||||
) -> anyhow::Result<(Tensor, Vec<String>)> {
|
||||
let pad_id = *tokenizer
|
||||
.get_vocab(true)
|
||||
.get("<|endoftext|>")
|
||||
.ok_or(E::msg("No pad token"))?;
|
||||
|
||||
let vec_seq = match sequences {
|
||||
Some(seq) => seq,
|
||||
None => vec![
|
||||
"a cycling race".to_string(),
|
||||
"a photo of two cats".to_string(),
|
||||
"a robot holding a candle".to_string(),
|
||||
],
|
||||
};
|
||||
|
||||
let mut tokens = vec![];
|
||||
|
||||
for seq in vec_seq.clone() {
|
||||
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
|
||||
tokens.push(encoding.get_ids().to_vec());
|
||||
}
|
||||
|
||||
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
|
||||
|
||||
// Pad the sequences to have the same length
|
||||
for token_vec in tokens.iter_mut() {
|
||||
let len_diff = max_len - token_vec.len();
|
||||
if len_diff > 0 {
|
||||
token_vec.extend(vec![pad_id; len_diff]);
|
||||
}
|
||||
}
|
||||
|
||||
let input_ids = Tensor::new(tokens, device)?;
|
||||
|
||||
Ok((input_ids, vec_seq))
|
||||
}
|
@ -28,7 +28,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -93,7 +93,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -31,7 +31,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -47,7 +47,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -66,7 +66,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -13,8 +13,13 @@ cargo run --example encodec --features symphonia --release -- code-to-audio \
|
||||
```
|
||||
|
||||
This decodes the EnCodec tokens stored in `jfk-codes.safetensors` and generates
|
||||
an output wav file containing the audio data. Instead of `code-to-audio` one
|
||||
can use:
|
||||
an output wav file containing the audio data.
|
||||
|
||||
Instead of `code-to-audio` one can use:
|
||||
- `audio-to-audio in.mp3 out.wav`: encodes the input audio file then decodes it to a wav file.
|
||||
- `audio-to-code in.mp3 out.safetensors`: generates a safetensors file
|
||||
containing EnCodec tokens for the input audio file.
|
||||
|
||||
If the audio output file name is set to `-`, the audio content directly gets
|
||||
played on default audio output device. If the audio input file is set to `-`, the audio
|
||||
gets recorded from the default audio input.
|
||||
|
275
candle-examples/examples/encodec/audio_io.rs
Normal file
275
candle-examples/examples/encodec/audio_io.rs
Normal file
@ -0,0 +1,275 @@
|
||||
#![allow(unused)]
|
||||
use anyhow::{Context, Result};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub const SAMPLE_RATE: usize = 24_000;
|
||||
|
||||
pub(crate) struct AudioOutputData_ {
|
||||
resampled_data: std::collections::VecDeque<f32>,
|
||||
resampler: rubato::FastFixedIn<f32>,
|
||||
output_buffer: Vec<f32>,
|
||||
input_buffer: Vec<f32>,
|
||||
input_len: usize,
|
||||
}
|
||||
|
||||
impl AudioOutputData_ {
|
||||
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
|
||||
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
|
||||
let resampler = rubato::FastFixedIn::new(
|
||||
resample_ratio,
|
||||
f64::max(resample_ratio, 1.0),
|
||||
rubato::PolynomialDegree::Septic,
|
||||
1024,
|
||||
1,
|
||||
)?;
|
||||
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
|
||||
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
|
||||
Ok(Self {
|
||||
resampled_data,
|
||||
resampler,
|
||||
input_buffer,
|
||||
output_buffer,
|
||||
input_len: 0,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
use rubato::Resampler;
|
||||
self.output_buffer.fill(0.);
|
||||
self.input_buffer.fill(0.);
|
||||
self.resampler.reset();
|
||||
self.resampled_data.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn take_all(&mut self) -> Vec<f32> {
|
||||
let mut data = Vec::with_capacity(self.resampled_data.len());
|
||||
while let Some(elem) = self.resampled_data.pop_back() {
|
||||
data.push(elem);
|
||||
}
|
||||
data
|
||||
}
|
||||
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.resampled_data.is_empty()
|
||||
}
|
||||
|
||||
// Assumes that the input buffer is large enough.
|
||||
fn push_input_buffer(&mut self, samples: &[f32]) {
|
||||
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
|
||||
self.input_len += samples.len()
|
||||
}
|
||||
|
||||
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let mut pos_in = 0;
|
||||
loop {
|
||||
let rem = self.input_buffer.len() - self.input_len;
|
||||
let pos_end = usize::min(pos_in + rem, samples.len());
|
||||
self.push_input_buffer(&samples[pos_in..pos_end]);
|
||||
pos_in = pos_end;
|
||||
if self.input_len < self.input_buffer.len() {
|
||||
break;
|
||||
}
|
||||
let (_, out_len) = self.resampler.process_into_buffer(
|
||||
&[&self.input_buffer],
|
||||
&mut [&mut self.output_buffer],
|
||||
None,
|
||||
)?;
|
||||
for &elem in self.output_buffer[..out_len].iter() {
|
||||
self.resampled_data.push_front(elem)
|
||||
}
|
||||
self.input_len = 0;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
|
||||
|
||||
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
|
||||
println!("Setup audio output stream!");
|
||||
let host = cpal::default_host();
|
||||
let device = host
|
||||
.default_output_device()
|
||||
.context("no output device available")?;
|
||||
let mut supported_configs_range = device.supported_output_configs()?;
|
||||
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
|
||||
// On macOS, it's commonly the case that there are only stereo outputs.
|
||||
None => device
|
||||
.supported_output_configs()?
|
||||
.next()
|
||||
.context("no audio output available")?,
|
||||
Some(config_range) => config_range,
|
||||
};
|
||||
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||
config_range.min_sample_rate(),
|
||||
config_range.max_sample_rate(),
|
||||
);
|
||||
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||
let channels = config.channels as usize;
|
||||
println!(
|
||||
"cpal device: {} {} {config:?}",
|
||||
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||
config.sample_rate.0
|
||||
);
|
||||
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||
SAMPLE_RATE,
|
||||
config.sample_rate.0 as usize,
|
||||
)?));
|
||||
let ad = audio_data.clone();
|
||||
let stream = device.build_output_stream(
|
||||
&config,
|
||||
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
|
||||
data.fill(0.);
|
||||
let mut ad = ad.lock().unwrap();
|
||||
let mut last_elem = 0f32;
|
||||
for (idx, elem) in data.iter_mut().enumerate() {
|
||||
if idx % channels == 0 {
|
||||
match ad.resampled_data.pop_back() {
|
||||
None => break,
|
||||
Some(v) => {
|
||||
last_elem = v;
|
||||
*elem = v
|
||||
}
|
||||
}
|
||||
} else {
|
||||
*elem = last_elem
|
||||
}
|
||||
}
|
||||
},
|
||||
move |err| eprintln!("cpal error: {err}"),
|
||||
None, // None=blocking, Some(Duration)=timeout
|
||||
)?;
|
||||
stream.play()?;
|
||||
Ok((stream, audio_data))
|
||||
}
|
||||
|
||||
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
|
||||
println!("Setup audio input stream!");
|
||||
let host = cpal::default_host();
|
||||
let device = host
|
||||
.default_input_device()
|
||||
.context("no input device available")?;
|
||||
let mut supported_configs_range = device.supported_input_configs()?;
|
||||
let config_range = supported_configs_range
|
||||
.find(|c| c.channels() == 1)
|
||||
.context("no audio input available")?;
|
||||
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||
config_range.min_sample_rate(),
|
||||
config_range.max_sample_rate(),
|
||||
);
|
||||
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||
println!(
|
||||
"cpal device: {} {} {config:?}",
|
||||
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||
config.sample_rate.0
|
||||
);
|
||||
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||
config.sample_rate.0 as usize,
|
||||
SAMPLE_RATE,
|
||||
)?));
|
||||
let ad = audio_data.clone();
|
||||
let stream = device.build_input_stream(
|
||||
&config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
let mut ad = ad.lock().unwrap();
|
||||
if let Err(err) = ad.push_samples(data) {
|
||||
eprintln!("error processing audio input {err:?}")
|
||||
}
|
||||
},
|
||||
move |err| eprintln!("cpal error: {err}"),
|
||||
None, // None=blocking, Some(Duration)=timeout
|
||||
)?;
|
||||
stream.play()?;
|
||||
Ok((stream, audio_data))
|
||||
}
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
use symphonia::core::audio::Signal;
|
||||
use symphonia::core::conv::FromSample;
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
|
||||
let src = std::fs::File::open(path)?;
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
let mut format = probed.format;
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &Default::default())
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
||||
|
||||
pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result<Vec<f32>> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let mut pcm_out =
|
||||
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
|
||||
|
||||
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in, sr_out, 1024, 1)?;
|
||||
let mut output_buffer = resampler.output_buffer_allocate(true);
|
||||
let mut pos_in = 0;
|
||||
while pos_in + resampler.input_frames_next() < pcm_in.len() {
|
||||
let (in_len, out_len) =
|
||||
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
|
||||
pos_in += in_len;
|
||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||
}
|
||||
|
||||
if pos_in < pcm_in.len() {
|
||||
let (_in_len, out_len) = resampler.process_partial_into_buffer(
|
||||
Some(&[&pcm_in[pos_in..]]),
|
||||
&mut output_buffer,
|
||||
None,
|
||||
)?;
|
||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||
}
|
||||
|
||||
Ok(pcm_out)
|
||||
}
|
@ -11,59 +11,7 @@ use candle_transformers::models::encodec::{Config, Model};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::api::sync::Api;
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
use symphonia::core::audio::Signal;
|
||||
use symphonia::core::conv::FromSample;
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
|
||||
let src = std::fs::File::open(path)?;
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
let mut format = probed.format;
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &Default::default())
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
||||
mod audio_io;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Action {
|
||||
@ -109,14 +57,36 @@ fn main() -> Result<()> {
|
||||
let codes = match args.action {
|
||||
Action::CodeToAudio => {
|
||||
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
|
||||
codes
|
||||
codes.get("codes").expect("no codes in input file").clone()
|
||||
}
|
||||
Action::AudioToCode | Action::AudioToAudio => {
|
||||
let (pcm, sample_rate) = pcm_decode(args.in_file)?;
|
||||
if sample_rate != 24_000 {
|
||||
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}")
|
||||
}
|
||||
let pcm = if args.in_file == "-" {
|
||||
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
|
||||
let (stream, input_audio) = audio_io::setup_input_stream()?;
|
||||
let mut pcms = vec![];
|
||||
let stdin = std::thread::spawn(|| {
|
||||
let mut s = String::new();
|
||||
std::io::stdin().read_line(&mut s)
|
||||
});
|
||||
while !stdin.is_finished() {
|
||||
let input = input_audio.lock().unwrap().take_all();
|
||||
if input.is_empty() {
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
continue;
|
||||
}
|
||||
pcms.push(input)
|
||||
}
|
||||
drop(stream);
|
||||
pcms.concat()
|
||||
} else {
|
||||
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
|
||||
if sample_rate != 24_000 {
|
||||
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}, resampling...");
|
||||
audio_io::resample(&pcm, sample_rate as usize, 24_000)?
|
||||
} else {
|
||||
pcm
|
||||
}
|
||||
};
|
||||
let pcm_len = pcm.len();
|
||||
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||
println!("input pcm shape: {:?}", pcm.shape());
|
||||
@ -135,8 +105,26 @@ fn main() -> Result<()> {
|
||||
let pcm = pcm.i(0)?.i(0)?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||
if args.out_file == "-" {
|
||||
let (stream, ad) = audio_io::setup_output_stream()?;
|
||||
{
|
||||
let mut ad = ad.lock().unwrap();
|
||||
ad.push_samples(&pcm)?;
|
||||
}
|
||||
loop {
|
||||
let ad = ad.lock().unwrap();
|
||||
if ad.is_empty() {
|
||||
break;
|
||||
}
|
||||
// That's very weird, calling thread::sleep here triggers the stream to stop
|
||||
// playing (the callback doesn't seem to be called anymore).
|
||||
// std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
}
|
||||
drop(stream)
|
||||
} else {
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
@ -1,4 +1,4 @@
|
||||
# candle-mistral: 2b and 7b LLMs from Google DeepMind
|
||||
# candle-gemma: 2b and 7b LLMs from Google DeepMind
|
||||
|
||||
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
|
||||
models published by Google Deepmind with a 2b and a 7b variant.
|
||||
|
@ -16,6 +16,30 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "2b")]
|
||||
Base2B,
|
||||
#[value(name = "7b")]
|
||||
Base7B,
|
||||
#[value(name = "2b-it")]
|
||||
Instruct2B,
|
||||
#[value(name = "7b-it")]
|
||||
Instruct7B,
|
||||
#[value(name = "1.1-2b-it")]
|
||||
InstructV1_1_2B,
|
||||
#[value(name = "1.1-7b-it")]
|
||||
InstructV1_1_7B,
|
||||
#[value(name = "code-2b")]
|
||||
CodeBase2B,
|
||||
#[value(name = "code-7b")]
|
||||
CodeBase7B,
|
||||
#[value(name = "code-2b-it")]
|
||||
CodeInstruct2B,
|
||||
#[value(name = "code-7b-it")]
|
||||
CodeInstruct7B,
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
@ -165,6 +189,10 @@ struct Args {
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "2b")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -196,14 +224,19 @@ fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match &args.model_id {
|
||||
Some(model_id) => match model_id.as_str() {
|
||||
"7b-it" => "google/gemma-7b-it".to_string(),
|
||||
"7b" => "google/gemma-7b".to_string(),
|
||||
"2b-it" => "google/gemma-2b-it".to_string(),
|
||||
"2b" => "google/gemma-2b".to_string(),
|
||||
_ => model_id.to_string(),
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => match args.which {
|
||||
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
||||
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
||||
Which::Base2B => "google/gemma-2b".to_string(),
|
||||
Which::Base7B => "google/gemma-7b".to_string(),
|
||||
Which::Instruct2B => "google/gemma-2b-it".to_string(),
|
||||
Which::Instruct7B => "google/gemma-7b-it".to_string(),
|
||||
Which::CodeBase2B => "google/codegemma-2b".to_string(),
|
||||
Which::CodeBase7B => "google/codegemma-7b".to_string(),
|
||||
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
||||
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
||||
},
|
||||
None => "google/gemma-2b".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
|
@ -54,6 +54,7 @@ impl TextGeneration {
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let dtype = self.model.dtype();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
@ -66,7 +67,7 @@ impl TextGeneration {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let mut state = State::new(1, &self.config, &self.device)?;
|
||||
let mut state = State::new(1, &self.config, dtype, &self.device)?;
|
||||
let mut next_logits = None;
|
||||
for &t in tokens.iter() {
|
||||
let input = Tensor::new(&[t], &self.device)?;
|
||||
@ -84,7 +85,7 @@ impl TextGeneration {
|
||||
Some(logits) => logits,
|
||||
None => anyhow::bail!("cannot work on an empty prompt"),
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(dtype)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
@ -210,6 +211,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "f32")]
|
||||
dtype: String,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
@ -220,6 +224,7 @@ struct Args {
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use std::str::FromStr;
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
@ -279,7 +284,8 @@ fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let dtype = DType::from_str(&args.dtype)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb.pp("backbone"))?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
@ -11,6 +11,7 @@ use std::io::Write;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::models::encodec;
|
||||
use candle_transformers::models::metavoice::{adapters, gpt, tokenizers, transformer};
|
||||
use candle_transformers::models::quantized_metavoice::transformer as qtransformer;
|
||||
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
@ -26,6 +27,11 @@ enum ArgDType {
|
||||
Bf16,
|
||||
}
|
||||
|
||||
enum Transformer {
|
||||
Normal(transformer::Model),
|
||||
Quantized(qtransformer::Model),
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -40,6 +46,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// Use the quantized version of the model.
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// The guidance scale.
|
||||
#[arg(long, default_value_t = 3.0)]
|
||||
guidance_scale: f64,
|
||||
@ -116,11 +126,7 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?;
|
||||
|
||||
let first_stage_weights = match &args.first_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("first_stage.safetensors")?,
|
||||
};
|
||||
let second_stage_weights = match &args.first_stage_weights {
|
||||
let second_stage_weights = match &args.second_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("second_stage.safetensors")?,
|
||||
};
|
||||
@ -135,10 +141,27 @@ fn main() -> Result<()> {
|
||||
ArgDType::F16 => DType::F16,
|
||||
ArgDType::Bf16 => DType::BF16,
|
||||
};
|
||||
let first_stage_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
|
||||
|
||||
let first_stage_config = transformer::Config::cfg1b_v0_1();
|
||||
let mut first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
|
||||
let mut first_stage_model = if args.quantized {
|
||||
let filename = match &args.first_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("first_stage_q4k.gguf")?,
|
||||
};
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
let first_stage_model = qtransformer::Model::new(&first_stage_config, vb)?;
|
||||
Transformer::Quantized(first_stage_model)
|
||||
} else {
|
||||
let first_stage_weights = match &args.first_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("first_stage.safetensors")?,
|
||||
};
|
||||
let first_stage_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
|
||||
let first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
|
||||
Transformer::Normal(first_stage_model)
|
||||
};
|
||||
|
||||
let second_stage_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? };
|
||||
@ -178,7 +201,12 @@ fn main() -> Result<()> {
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &device)?;
|
||||
let input = Tensor::stack(&[&input, &input], 0)?;
|
||||
let logits = first_stage_model.forward(&input, &spk_emb, tokens.len() - context_size)?;
|
||||
let logits = match &mut first_stage_model {
|
||||
Transformer::Normal(m) => m.forward(&input, &spk_emb, tokens.len() - context_size)?,
|
||||
Transformer::Quantized(m) => {
|
||||
m.forward(&input, &spk_emb, tokens.len() - context_size)?
|
||||
}
|
||||
};
|
||||
let logits0 = logits.i((0, 0))?;
|
||||
let logits1 = logits.i((1, 0))?;
|
||||
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;
|
||||
|
@ -13,7 +13,7 @@ use candle_transformers::models::quantized_mistral::Model as QMistral;
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
@ -39,11 +39,26 @@ impl TextGeneration {
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
top_k: Option<usize>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
let logits_processor = {
|
||||
let temperature = temp.unwrap_or(0.);
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (top_k, top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(seed, sampling)
|
||||
};
|
||||
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
@ -122,6 +137,18 @@ impl TextGeneration {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "7b-v0.1")]
|
||||
Mistral7bV01,
|
||||
#[value(name = "7b-v0.2")]
|
||||
Mistral7bV02,
|
||||
#[value(name = "7b-instruct-v0.1")]
|
||||
Mistral7bInstructV01,
|
||||
#[value(name = "7b-instruct-v0.2")]
|
||||
Mistral7bInstructV02,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -147,6 +174,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
@ -155,6 +186,10 @@ struct Args {
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "7b-v0.1")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
@ -164,6 +199,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
@ -177,6 +215,10 @@ struct Args {
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// Use the slower dmmv cuda kernel.
|
||||
#[arg(long)]
|
||||
force_dmmv: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -184,6 +226,9 @@ fn main() -> Result<()> {
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
#[cfg(feature = "cuda")]
|
||||
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
@ -211,9 +256,17 @@ fn main() -> Result<()> {
|
||||
Some(model_id) => model_id,
|
||||
None => {
|
||||
if args.quantized {
|
||||
if args.which != Which::Mistral7bV01 {
|
||||
anyhow::bail!("only 7b-v0.1 is available as a quantized model for now")
|
||||
}
|
||||
"lmz/candle-mistral".to_string()
|
||||
} else {
|
||||
"mistralai/Mistral-7B-v0.1".to_string()
|
||||
match args.which {
|
||||
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(),
|
||||
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(),
|
||||
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(),
|
||||
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -243,7 +296,17 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
||||
let config = match args.config_file {
|
||||
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
|
||||
None => {
|
||||
if args.quantized {
|
||||
Config::config_7b_v0_1(args.use_flash_attn)
|
||||
} else {
|
||||
let config_file = repo.get("config.json")?;
|
||||
serde_json::from_slice(&std::fs::read(config_file)?)?
|
||||
}
|
||||
}
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
@ -270,6 +333,7 @@ fn main() -> Result<()> {
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.top_k,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
|
@ -63,7 +63,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
26
candle-examples/examples/moondream/README.md
Normal file
26
candle-examples/examples/moondream/README.md
Normal file
@ -0,0 +1,26 @@
|
||||
# candle-moondream
|
||||
|
||||
[Moondream](https://github.com/vikhyat/moondream) is a computer-vision model can answer real-world questions about images. It's tiny by today's models, with only 1.6B parameters. That enables it to run on a variety of devices, including mobile phones and edge devices.
|
||||
|
||||
## Running some examples
|
||||
First download an example image
|
||||
```bash
|
||||
$ wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg
|
||||
```
|
||||
|
||||
<img src="https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg" width="200">
|
||||
|
||||
Now you can run Moondream from the `candle-examples` crate:
|
||||
```bash
|
||||
$ cargo run --example moondream --release -- --prompt "What is the girl eating?" --image "./demo-1.jpg"
|
||||
|
||||
avavx: false, neon: true, simd128: false, f16c: false
|
||||
temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64
|
||||
retrieved the files in 3.395583ms
|
||||
Running on CPU, to run on GPU(metal), build this example with `--features metal`
|
||||
loaded the model in 5.485493792s
|
||||
loaded and encoded the image Tensor[dims 3, 378, 378; f32] in 4.801396417s
|
||||
starting the inference loop
|
||||
The girl is eating a hamburger.<
|
||||
9 tokens generated (0.68 token/s)
|
||||
```
|
343
candle-examples/examples/moondream/main.rs
Normal file
343
candle-examples/examples/moondream/main.rs
Normal file
@ -0,0 +1,343 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::{
|
||||
generation::LogitsProcessor,
|
||||
models::{moondream, quantized_moondream},
|
||||
};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
Moondream(moondream::Model),
|
||||
Quantized(quantized_moondream::Model),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
verbose_prompt,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, image_embeds: &Tensor, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||
if tokens.is_empty() {
|
||||
anyhow::bail!("Empty prompts are not supported in the Moondream model.")
|
||||
}
|
||||
if self.verbose_prompt {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
}
|
||||
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
|
||||
// Moondream tokenizer bos_token and eos_token is "<|endoftext|>"
|
||||
// https://huggingface.co/vikhyatk/moondream2/blob/main/special_tokens_map.json
|
||||
let special_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
None => anyhow::bail!("cannot find the special token"),
|
||||
};
|
||||
let (bos_token, eos_token) = (special_token, special_token);
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut load_t = std::time::Duration::from_secs_f64(0f64);
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = if index > 0 {
|
||||
match self.model {
|
||||
Model::Moondream(ref mut model) => model.text_model.forward(&input)?,
|
||||
Model::Quantized(ref mut model) => model.text_model.forward(&input)?,
|
||||
}
|
||||
} else {
|
||||
let bos_token = Tensor::new(&[bos_token], &self.device)?.unsqueeze(0)?;
|
||||
let logits = match self.model {
|
||||
Model::Moondream(ref mut model) => {
|
||||
model
|
||||
.text_model
|
||||
.forward_with_img(&bos_token, &input, image_embeds)?
|
||||
}
|
||||
Model::Quantized(ref mut model) => {
|
||||
model
|
||||
.text_model
|
||||
.forward_with_img(&bos_token, &input, image_embeds)?
|
||||
}
|
||||
};
|
||||
load_t = start_gen.elapsed();
|
||||
println!("load_t: {:?}", load_t);
|
||||
logits
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token || tokens.ends_with(&[27, 10619, 29] /* <END> */) {
|
||||
break;
|
||||
}
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let dt = start_gen.elapsed() - load_t;
|
||||
println!(
|
||||
"\ngenerated in {} seconds\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
dt.as_secs_f64(),
|
||||
(generated_tokens - 1) as f64 / dt.as_secs_f64()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 0)]
|
||||
seed: u64,
|
||||
|
||||
#[arg(long, default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Use f16 precision for all the computations rather than f32.
|
||||
#[arg(long)]
|
||||
f16: bool,
|
||||
|
||||
#[arg(long)]
|
||||
model_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
}
|
||||
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 378, 378).
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> candle::Result<Tensor> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
.resize_to_fill(378, 378, image::imageops::FilterType::Triangle); // Adjusted to 378x378
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (378, 378, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
let mean = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
let std = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||
.broadcast_sub(&mean)?
|
||||
.broadcast_div(&std)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = hf_hub::api::tokio::Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
"santiagomed/candle-moondream".to_string()
|
||||
} else {
|
||||
"vikhyatk/moondream2".to_string()
|
||||
}
|
||||
}
|
||||
};
|
||||
let repo = api.repo(hf_hub::Repo::with_revision(
|
||||
model_id,
|
||||
hf_hub::RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let model_file = match args.model_file {
|
||||
Some(m) => m.into(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
repo.get("model-q4_0.gguf").await?
|
||||
} else {
|
||||
repo.get("model.safetensors").await?
|
||||
}
|
||||
}
|
||||
};
|
||||
let tokenizer = match args.tokenizer_file {
|
||||
Some(m) => m.into(),
|
||||
None => repo.get("tokenizer.json").await?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let config = moondream::Config::v2();
|
||||
let dtype = if args.quantized {
|
||||
if args.f16 {
|
||||
anyhow::bail!("Quantized model does not support f16");
|
||||
}
|
||||
DType::F32
|
||||
} else if device.is_cuda() || args.f16 {
|
||||
DType::F16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let model = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&model_file,
|
||||
&device,
|
||||
)?;
|
||||
let model = quantized_moondream::Model::new(&config, vb)?;
|
||||
Model::Quantized(model)
|
||||
} else {
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
let model = moondream::Model::new(&config, vb)?;
|
||||
Model::Moondream(model)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let image = load_image(args.image)?
|
||||
.to_device(&device)?
|
||||
.to_dtype(dtype)?;
|
||||
let image_embeds = image.unsqueeze(0)?;
|
||||
let image_embeds = match model {
|
||||
Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?,
|
||||
Model::Quantized(ref m) => image_embeds.apply(m.vision_encoder())?,
|
||||
};
|
||||
println!(
|
||||
"loaded and encoded the image {image:?} in {:?}",
|
||||
start.elapsed()
|
||||
);
|
||||
|
||||
let prompt = format!("\n\nQuestion: {0}\n\nAnswer:", args.prompt);
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&prompt, &image_embeds, args.sample_len)?;
|
||||
|
||||
Ok(())
|
||||
}
|
@ -17,7 +17,7 @@ generate quantized weight files from the original safetensors file by using the
|
||||
`tensor-tools` command line utility via:
|
||||
|
||||
```bash
|
||||
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||
$ cargo run --bin tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
|
||||
```
|
||||
|
||||
## Using custom models
|
||||
|
@ -10,7 +10,7 @@ use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::Tensor;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::quantized_llama as model;
|
||||
@ -200,6 +200,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
@ -235,6 +239,10 @@ struct Args {
|
||||
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
|
||||
#[arg(long)]
|
||||
gqa: Option<usize>,
|
||||
|
||||
/// Use the slower dmmv cuda kernel.
|
||||
#[arg(long)]
|
||||
force_dmmv: bool,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -341,11 +349,10 @@ fn main() -> anyhow::Result<()> {
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let temperature = if args.temperature == 0. {
|
||||
None
|
||||
} else {
|
||||
Some(args.temperature)
|
||||
};
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
@ -492,7 +499,20 @@ fn main() -> anyhow::Result<()> {
|
||||
prompt_tokens
|
||||
};
|
||||
let mut all_tokens = vec![];
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = if !args.split_prompt {
|
||||
|
27
candle-examples/examples/qwen/README.md
Normal file
27
candle-examples/examples/qwen/README.md
Normal file
@ -0,0 +1,27 @@
|
||||
# candle-qwen: large language model series from Alibaba Cloud
|
||||
|
||||
Qwen 1.5 is a series of large language models that provide strong performances
|
||||
on English and Chinese.
|
||||
|
||||
- [Blog post](https://qwenlm.github.io/blog/qwen1.5/) introducing Qwen1.5.
|
||||
- [Model card](https://huggingface.co/Qwen/Qwen1.5-0.5B) on the HuggingFace Hub.
|
||||
- [Blog post](https://qwenlm.github.io/blog/qwen-moe/) for the
|
||||
mixture-of-experts (MoE) variant.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example qwen --release -- --prompt "Hello there "
|
||||
```
|
||||
|
||||
Various model sizes are available via the `--model` argument, including the MoE
|
||||
variant.
|
||||
|
||||
```bash
|
||||
$ cargo run --example qwen --release -- --model moe-a2.7b --prompt 'def print_prime(n: int): '
|
||||
def print_prime(n: int): # n is the number of primes to be printed
|
||||
for i in range(2, n + 1):
|
||||
if all(i % j != 0 for j in range(2, i)):
|
||||
print(i)
|
||||
```
|
||||
|
@ -7,7 +7,8 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::qwen2::{Config, Model};
|
||||
use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase};
|
||||
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -16,6 +17,20 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
Base(ModelBase),
|
||||
Moe(ModelMoe),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, xs: &Tensor, s: usize) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Moe(ref mut m) => m.forward(xs, s),
|
||||
Self::Base(ref mut m) => m.forward(xs, s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
@ -127,6 +142,8 @@ enum WhichModel {
|
||||
W14b,
|
||||
#[value(name = "72b")]
|
||||
W72b,
|
||||
#[value(name = "moe-a2.7b")]
|
||||
MoeA27b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -224,6 +241,7 @@ fn main() -> Result<()> {
|
||||
WhichModel::W7b => "7B",
|
||||
WhichModel::W14b => "14B",
|
||||
WhichModel::W72b => "72B",
|
||||
WhichModel::MoeA27b => "MoE-A2.7B",
|
||||
};
|
||||
format!("Qwen/Qwen1.5-{size}")
|
||||
}
|
||||
@ -244,7 +262,11 @@ fn main() -> Result<()> {
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.model {
|
||||
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => {
|
||||
WhichModel::W4b
|
||||
| WhichModel::W7b
|
||||
| WhichModel::W14b
|
||||
| WhichModel::W72b
|
||||
| WhichModel::MoeA27b => {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
},
|
||||
@ -254,7 +276,6 @@ fn main() -> Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config_file = repo.get("config.json")?;
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
@ -262,7 +283,16 @@ fn main() -> Result<()> {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
let model = match args.model {
|
||||
WhichModel::MoeA27b => {
|
||||
let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Moe(ModelMoe::new(&config, vb)?)
|
||||
}
|
||||
_ => {
|
||||
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
Model::Base(ModelBase::new(&config, vb)?)
|
||||
}
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
118
candle-examples/examples/reinforcement-learning/dqn.rs
Normal file
118
candle-examples/examples/reinforcement-learning/dqn.rs
Normal file
@ -0,0 +1,118 @@
|
||||
use std::collections::VecDeque;
|
||||
|
||||
use rand::distributions::Uniform;
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
use candle::{DType, Device, Module, Result, Tensor};
|
||||
use candle_nn::loss::mse;
|
||||
use candle_nn::{linear, seq, Activation, AdamW, Optimizer, VarBuilder, VarMap};
|
||||
|
||||
use crate::gym_env::GymEnv;
|
||||
|
||||
const DEVICE: Device = Device::Cpu;
|
||||
const EPISODES: usize = 200;
|
||||
const BATCH_SIZE: usize = 64;
|
||||
const GAMMA: f64 = 0.99;
|
||||
const LEARNING_RATE: f64 = 0.01;
|
||||
|
||||
pub fn run() -> Result<()> {
|
||||
let env = GymEnv::new("CartPole-v1")?;
|
||||
|
||||
// Build the model that predicts the estimated rewards given a specific state.
|
||||
let var_map = VarMap::new();
|
||||
let vb = VarBuilder::from_varmap(&var_map, DType::F32, &DEVICE);
|
||||
let observation_space = *env.observation_space().first().unwrap();
|
||||
|
||||
let model = seq()
|
||||
.add(linear(observation_space, 64, vb.pp("linear_in"))?)
|
||||
.add(Activation::Relu)
|
||||
.add(linear(64, env.action_space(), vb.pp("linear_out"))?);
|
||||
|
||||
let mut optimizer = AdamW::new_lr(var_map.all_vars(), LEARNING_RATE)?;
|
||||
|
||||
// Initialize the model's memory.
|
||||
let mut memory = VecDeque::with_capacity(10000);
|
||||
|
||||
// Start the training loop.
|
||||
let mut state = env.reset(0)?;
|
||||
let mut episode = 0;
|
||||
let mut accumulate_rewards = 0.0;
|
||||
while episode < EPISODES {
|
||||
// Given the current state, predict the estimated rewards, and take the
|
||||
// action that is expected to return the most rewards.
|
||||
let estimated_rewards = model.forward(&state.unsqueeze(0)?)?;
|
||||
let action: u32 = estimated_rewards.squeeze(0)?.argmax(0)?.to_scalar()?;
|
||||
|
||||
// Take that action in the environment, and memorize the outcome:
|
||||
// - the state for which the action was taken
|
||||
// - the action taken
|
||||
// - the new state resulting of taking that action
|
||||
// - the actual rewards of taking that action
|
||||
// - whether the environment reached a terminal state or not (e.g. game over)
|
||||
let step = env.step(action)?;
|
||||
accumulate_rewards += step.reward;
|
||||
memory.push_back((
|
||||
state,
|
||||
action,
|
||||
step.state.clone(),
|
||||
step.reward,
|
||||
step.terminated || step.truncated,
|
||||
));
|
||||
state = step.state;
|
||||
|
||||
// If there's enough entries in the memory, perform a learning step, where
|
||||
// BATCH_SIZE transitions will be sampled from the memory and will be
|
||||
// fed to the model so that it performs a backward pass.
|
||||
if memory.len() > BATCH_SIZE {
|
||||
// Sample randomly from the memory.
|
||||
let batch = thread_rng()
|
||||
.sample_iter(Uniform::from(0..memory.len()))
|
||||
.take(BATCH_SIZE)
|
||||
.map(|i| memory.get(i).unwrap().clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Group all the samples together into tensors with the appropriate shape.
|
||||
let states: Vec<_> = batch.iter().map(|e| e.0.clone()).collect();
|
||||
let states = Tensor::stack(&states, 0)?;
|
||||
|
||||
let actions = batch.iter().map(|e| e.1);
|
||||
let actions = Tensor::from_iter(actions, &DEVICE)?.unsqueeze(1)?;
|
||||
|
||||
let next_states: Vec<_> = batch.iter().map(|e| e.2.clone()).collect();
|
||||
let next_states = Tensor::stack(&next_states, 0)?;
|
||||
|
||||
let rewards = batch.iter().map(|e| e.3 as f32);
|
||||
let rewards = Tensor::from_iter(rewards, &DEVICE)?.unsqueeze(1)?;
|
||||
|
||||
let non_final_mask = batch.iter().map(|e| !e.4 as u8 as f32);
|
||||
let non_final_mask = Tensor::from_iter(non_final_mask, &DEVICE)?.unsqueeze(1)?;
|
||||
|
||||
// Get the estimated rewards for the actions that where taken at each step.
|
||||
let estimated_rewards = model.forward(&states)?;
|
||||
let x = estimated_rewards.gather(&actions, 1)?;
|
||||
|
||||
// Get the maximum expected rewards for the next state, apply them a discount rate
|
||||
// GAMMA and add them to the rewards that were actually gathered on the current state.
|
||||
// If the next state is a terminal state, just omit maximum estimated
|
||||
// rewards for that state.
|
||||
let expected_rewards = model.forward(&next_states)?.detach();
|
||||
let y = expected_rewards.max_keepdim(1)?;
|
||||
let y = (y * GAMMA * non_final_mask + rewards)?;
|
||||
|
||||
// Compare the estimated rewards with the maximum expected rewards and
|
||||
// perform the backward step.
|
||||
let loss = mse(&x, &y)?;
|
||||
optimizer.backward_step(&loss)?;
|
||||
}
|
||||
|
||||
// If we are on a terminal state, reset the environment and log how it went.
|
||||
if step.terminated || step.truncated {
|
||||
episode += 1;
|
||||
println!("Episode {episode} | Rewards {}", accumulate_rewards as i64);
|
||||
state = env.reset(0)?;
|
||||
accumulate_rewards = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -42,7 +42,7 @@ impl GymEnv {
|
||||
/// Creates a new session of the specified OpenAI Gym environment.
|
||||
pub fn new(name: &str) -> Result<GymEnv> {
|
||||
Python::with_gil(|py| {
|
||||
let gym = py.import("gymnasium")?;
|
||||
let gym = py.import_bound("gymnasium")?;
|
||||
let make = gym.getattr("make")?;
|
||||
let env = make.call1((name,))?;
|
||||
let action_space = env.getattr("action_space")?;
|
||||
@ -66,10 +66,10 @@ impl GymEnv {
|
||||
/// Resets the environment, returning the observation tensor.
|
||||
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
||||
let state: Vec<f32> = Python::with_gil(|py| {
|
||||
let kwargs = PyDict::new(py);
|
||||
let kwargs = PyDict::new_bound(py);
|
||||
kwargs.set_item("seed", seed)?;
|
||||
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
||||
state.as_ref(py).get_item(0)?.extract()
|
||||
let state = self.env.call_method_bound(py, "reset", (), Some(&kwargs))?;
|
||||
state.bind(py).get_item(0)?.extract()
|
||||
})
|
||||
.map_err(w)?;
|
||||
Tensor::new(state, &Device::Cpu)
|
||||
@ -81,8 +81,10 @@ impl GymEnv {
|
||||
action: A,
|
||||
) -> Result<Step<A>> {
|
||||
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
|
||||
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
||||
let step = step.as_ref(py);
|
||||
let step = self
|
||||
.env
|
||||
.call_method_bound(py, "step", (action.clone(),), None)?;
|
||||
let step = step.bind(py);
|
||||
let state: Vec<f32> = step.get_item(0)?.extract()?;
|
||||
let reward: f64 = step.get_item(1)?.extract()?;
|
||||
let terminated: bool = step.get_item(2)?.extract()?;
|
||||
|
@ -13,6 +13,7 @@ mod gym_env;
|
||||
mod vec_gym_env;
|
||||
|
||||
mod ddpg;
|
||||
mod dqn;
|
||||
mod policy_gradient;
|
||||
|
||||
#[derive(Parser)]
|
||||
@ -25,6 +26,7 @@ struct Args {
|
||||
enum Command {
|
||||
Pg,
|
||||
Ddpg,
|
||||
Dqn,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -32,6 +34,7 @@ fn main() -> Result<()> {
|
||||
match args.command {
|
||||
Command::Pg => policy_gradient::run()?,
|
||||
Command::Ddpg => ddpg::run()?,
|
||||
Command::Dqn => dqn::run()?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -24,13 +24,13 @@ fn w(res: PyErr) -> candle::Error {
|
||||
impl VecGymEnv {
|
||||
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
|
||||
Python::with_gil(|py| {
|
||||
let sys = py.import("sys")?;
|
||||
let sys = py.import_bound("sys")?;
|
||||
let path = sys.getattr("path")?;
|
||||
let _ = path.call_method1(
|
||||
"append",
|
||||
("candle-examples/examples/reinforcement-learning",),
|
||||
)?;
|
||||
let gym = py.import("atari_wrappers")?;
|
||||
let gym = py.import_bound("atari_wrappers")?;
|
||||
let make = gym.getattr("make")?;
|
||||
let env = make.call1((name, img_dir, nprocesses))?;
|
||||
let action_space = env.getattr("action_space")?;
|
||||
@ -60,10 +60,10 @@ impl VecGymEnv {
|
||||
|
||||
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
|
||||
let (obs, reward, is_done) = Python::with_gil(|py| {
|
||||
let step = self.env.call_method(py, "step", (action,), None)?;
|
||||
let step = step.as_ref(py);
|
||||
let step = self.env.call_method_bound(py, "step", (action,), None)?;
|
||||
let step = step.bind(py);
|
||||
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
|
||||
let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
|
||||
let obs_buffer = pyo3::buffer::PyBuffer::get_bound(&obs)?;
|
||||
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
|
||||
let reward: Vec<f32> = step.get_item(1)?.extract()?;
|
||||
let is_done: Vec<f32> = step.get_item(2)?.extract()?;
|
||||
|
@ -78,7 +78,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -45,7 +45,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -141,7 +141,7 @@ impl std::fmt::Display for Which {
|
||||
impl Which {
|
||||
fn model_id(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
|
||||
Self::Eagle7b => "RWKV/v5-Eagle-7B-HF",
|
||||
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
|
||||
Self::World3b => "RWKV/rwkv-5-world-3b",
|
||||
Self::World6_1b6 => "paperfun/rwkv",
|
||||
|
@ -5,7 +5,7 @@ use candle_transformers::models::segformer::{
|
||||
Config, ImageClassificationModel, SemanticSegmentationModel,
|
||||
};
|
||||
use clap::{Args, Parser, Subcommand};
|
||||
use image::Rgb;
|
||||
use imageproc::image::Rgb;
|
||||
use imageproc::integral_image::ArrayData;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
@ -96,6 +96,10 @@ struct Args {
|
||||
/// information.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
img2img_strength: f64,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long)]
|
||||
seed: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
|
||||
@ -288,6 +292,13 @@ fn text_embeddings(
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
if tokens.len() > sd_config.clip.max_position_embeddings {
|
||||
anyhow::bail!(
|
||||
"the prompt is too long, {} > max-tokens ({})",
|
||||
tokens.len(),
|
||||
sd_config.clip.max_position_embeddings
|
||||
)
|
||||
}
|
||||
while tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
tokens.push(pad_id)
|
||||
}
|
||||
@ -315,6 +326,13 @@ fn text_embeddings(
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
if uncond_tokens.len() > sd_config.clip.max_position_embeddings {
|
||||
anyhow::bail!(
|
||||
"the negative prompt is too long, {} > max-tokens ({})",
|
||||
uncond_tokens.len(),
|
||||
sd_config.clip.max_position_embeddings
|
||||
)
|
||||
}
|
||||
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
uncond_tokens.push(pad_id)
|
||||
}
|
||||
@ -374,6 +392,7 @@ fn run(args: Args) -> Result<()> {
|
||||
use_flash_attn,
|
||||
img2img,
|
||||
img2img_strength,
|
||||
seed,
|
||||
..
|
||||
} = args;
|
||||
|
||||
@ -427,6 +446,9 @@ fn run(args: Args) -> Result<()> {
|
||||
|
||||
let scheduler = sd_config.build_scheduler(n_steps)?;
|
||||
let device = candle_examples::device(cpu)?;
|
||||
if let Some(seed) = seed {
|
||||
device.set_seed(seed)?;
|
||||
}
|
||||
let use_guide_scale = guidance_scale > 1.0;
|
||||
|
||||
let which = match sd_version {
|
||||
|
@ -10,11 +10,6 @@ order to be able to use it.
|
||||
|
||||
Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants.
|
||||
|
||||
StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by
|
||||
Candle, so to run it you can download a somewhat compatible
|
||||
[tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
|
||||
and pass it via the --tokenizer-file argument.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
|
@ -239,14 +239,7 @@ fn main() -> Result<()> {
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => match args.which {
|
||||
Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::Code => {
|
||||
repo.get("tokenizer.json")?
|
||||
}
|
||||
Which::V2 | Which::V2Zephyr => api
|
||||
.model("lmz/candle-stablelm".to_string())
|
||||
.get("tokenizer-gpt4.json")?,
|
||||
},
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
@ -295,12 +288,12 @@ fn main() -> Result<()> {
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = if args.quantized {
|
||||
let model = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
let model = QStableLM::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
Model::Quantized(model)
|
||||
} else {
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
@ -309,7 +302,7 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = StableLM::new(&config, vb)?;
|
||||
(Model::StableLM(model), device)
|
||||
Model::StableLM(model)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
@ -12,12 +12,23 @@ use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
T5Base,
|
||||
T5Small,
|
||||
T5Large,
|
||||
T5_3B,
|
||||
Mt5Base,
|
||||
Mt5Small,
|
||||
Mt5Large,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -36,6 +47,15 @@ struct Args {
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
model_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
/// Enable decoding.
|
||||
#[arg(long)]
|
||||
decode: bool,
|
||||
@ -71,6 +91,10 @@ struct Args {
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model to be used.
|
||||
#[arg(long, default_value = "t5-small")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
struct T5ModelBuilder {
|
||||
@ -82,8 +106,17 @@ struct T5ModelBuilder {
|
||||
impl T5ModelBuilder {
|
||||
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let default_model = "t5-small".to_string();
|
||||
let default_revision = "refs/pr/15".to_string();
|
||||
let (default_model, default_revision) = match args.which {
|
||||
Which::T5Base => ("t5-base", "main"),
|
||||
Which::T5Small => ("t5-small", "refs/pr/15"),
|
||||
Which::T5Large => ("t5-large", "main"),
|
||||
Which::T5_3B => ("t5-3b", "main"),
|
||||
Which::Mt5Base => ("google/mt5-base", "refs/pr/5"),
|
||||
Which::Mt5Small => ("google/mt5-small", "refs/pr/6"),
|
||||
Which::Mt5Large => ("google/mt5-large", "refs/pr/2"),
|
||||
};
|
||||
let default_model = default_model.to_string();
|
||||
let default_revision = default_revision.to_string();
|
||||
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
@ -93,14 +126,35 @@ impl T5ModelBuilder {
|
||||
|
||||
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config_filename = api.get("config.json")?;
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
|
||||
{
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
} else {
|
||||
vec![api.get("model.safetensors")?]
|
||||
let repo = api.repo(repo);
|
||||
let config_filename = match &args.config_file {
|
||||
None => repo.get("config.json")?,
|
||||
Some(f) => f.into(),
|
||||
};
|
||||
let tokenizer_filename = match &args.tokenizer_file {
|
||||
None => match args.which {
|
||||
Which::Mt5Base => api
|
||||
.model("lmz/mt5-tokenizers".into())
|
||||
.get("mt5-base.tokenizer.json")?,
|
||||
Which::Mt5Small => api
|
||||
.model("lmz/mt5-tokenizers".into())
|
||||
.get("mt5-small.tokenizer.json")?,
|
||||
Which::Mt5Large => api
|
||||
.model("lmz/mt5-tokenizers".into())
|
||||
.get("mt5-large.tokenizer.json")?,
|
||||
_ => repo.get("tokenizer.json")?,
|
||||
},
|
||||
Some(f) => f.into(),
|
||||
};
|
||||
let weights_filename = match &args.model_file {
|
||||
Some(f) => f.split(',').map(|v| v.into()).collect::<Vec<_>>(),
|
||||
None => {
|
||||
if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
} else {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
}
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: t5::Config = serde_json::from_str(&config)?;
|
||||
|
@ -33,7 +33,7 @@ struct Args {
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
|
@ -28,7 +28,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -34,6 +34,7 @@ from the hub.
|
||||
- `--timestamps`: enable the timestamp mode where some timestamps are reported
|
||||
for each recognized audio extracts.
|
||||
- `--model`: the model to be used. Models that do not end with `-en` are
|
||||
multilingual models, other ones are English only models. The supported models
|
||||
are `tiny`, `tiny.en`, `base`, `base.en`, `small`, `small.en`, `medium`,
|
||||
`medium.en`, `large`, and `large-v2`.
|
||||
multilingual models, other ones are English only models. The supported OpenAI
|
||||
Whisper models are `tiny`, `tiny.en`, `base`, `base.en`, `small`, `small.en`,
|
||||
`medium`, `medium.en`, `large`, `large-v2` and `large-v3`. The supported
|
||||
Distil-Whisper models are `distil-medium.en`, `distil-large-v2` and `distil-large-v3`.
|
||||
|
@ -374,6 +374,8 @@ enum WhichModel {
|
||||
DistilMediumEn,
|
||||
#[value(name = "distil-large-v2")]
|
||||
DistilLargeV2,
|
||||
#[value(name = "distil-large-v3")]
|
||||
DistilLargeV3,
|
||||
}
|
||||
|
||||
impl WhichModel {
|
||||
@ -386,7 +388,8 @@ impl WhichModel {
|
||||
| Self::Large
|
||||
| Self::LargeV2
|
||||
| Self::LargeV3
|
||||
| Self::DistilLargeV2 => true,
|
||||
| Self::DistilLargeV2
|
||||
| Self::DistilLargeV3 => true,
|
||||
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
|
||||
false
|
||||
}
|
||||
@ -408,6 +411,7 @@ impl WhichModel {
|
||||
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
|
||||
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
|
||||
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
|
||||
Self::DistilLargeV3 => ("distil-whisper/distil-large-v3", "main"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
BIN
candle-examples/examples/yolo-v8/assets/bike.pp.jpg
Normal file
BIN
candle-examples/examples/yolo-v8/assets/bike.pp.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 175 KiB |
@ -99,7 +99,7 @@ pub fn report_detect(
|
||||
let h_ratio = initial_h as f32 / h as f32;
|
||||
let mut img = img.to_rgb8();
|
||||
let font = Vec::from(include_bytes!("roboto-mono-stripped.ttf") as &[u8]);
|
||||
let font = rusttype::Font::try_from_vec(font);
|
||||
let font = ab_glyph::FontRef::try_from_slice(&font).map_err(candle::Error::wrap)?;
|
||||
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
|
||||
for b in bboxes_for_class.iter() {
|
||||
println!(
|
||||
@ -119,27 +119,28 @@ pub fn report_detect(
|
||||
);
|
||||
}
|
||||
if legend_size > 0 {
|
||||
if let Some(font) = font.as_ref() {
|
||||
imageproc::drawing::draw_filled_rect_mut(
|
||||
&mut img,
|
||||
imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, legend_size),
|
||||
image::Rgb([170, 0, 0]),
|
||||
);
|
||||
let legend = format!(
|
||||
"{} {:.0}%",
|
||||
candle_examples::coco_classes::NAMES[class_index],
|
||||
100. * b.confidence
|
||||
);
|
||||
imageproc::drawing::draw_text_mut(
|
||||
&mut img,
|
||||
image::Rgb([255, 255, 255]),
|
||||
xmin,
|
||||
ymin,
|
||||
rusttype::Scale::uniform(legend_size as f32 - 1.),
|
||||
font,
|
||||
&legend,
|
||||
)
|
||||
}
|
||||
imageproc::drawing::draw_filled_rect_mut(
|
||||
&mut img,
|
||||
imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, legend_size),
|
||||
image::Rgb([170, 0, 0]),
|
||||
);
|
||||
let legend = format!(
|
||||
"{} {:.0}%",
|
||||
candle_examples::coco_classes::NAMES[class_index],
|
||||
100. * b.confidence
|
||||
);
|
||||
imageproc::drawing::draw_text_mut(
|
||||
&mut img,
|
||||
image::Rgb([255, 255, 255]),
|
||||
xmin,
|
||||
ymin,
|
||||
ab_glyph::PxScale {
|
||||
x: legend_size as f32 - 1.,
|
||||
y: legend_size as f32 - 1.,
|
||||
},
|
||||
&font,
|
||||
&legend,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.4.1"
|
||||
version = "0.5.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.1" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.5.0" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.4.1"
|
||||
version = "0.5.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -1,5 +1,8 @@
|
||||
fn main() {
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
println!("cargo:rerun-if-changed=src/compatibility.cuh");
|
||||
println!("cargo:rerun-if-changed=src/cuda_utils.cuh");
|
||||
println!("cargo:rerun-if-changed=src/binary_op_macros.cuh");
|
||||
|
||||
let builder = bindgen_cuda::Builder::default();
|
||||
println!("cargo:info={builder:?}");
|
||||
|
@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||
out[i] = x * mul + add; \
|
||||
|
@ -12,8 +12,8 @@ extern "C" __global__ void FN_NAME( \
|
||||
const size_t *dims = dims_and_strides; \
|
||||
const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \
|
||||
const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \
|
||||
bool lhs_cont = is_contiguous(num_dims, dims, lhs_strides); \
|
||||
bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \
|
||||
bool lhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, lhs_strides); \
|
||||
bool rhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, rhs_strides); \
|
||||
if (lhs_cont && rhs_cont) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = lhs[i]; \
|
||||
|
@ -11,7 +11,7 @@ __device__ void cast_(
|
||||
) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
if (is_contiguous(num_dims, dims, strides)) {
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
out[i] = inp[i];
|
||||
}
|
||||
@ -34,7 +34,7 @@ __device__ void cast_through(
|
||||
) {
|
||||
const size_t *dims = info;
|
||||
const size_t *strides = info + num_dims;
|
||||
if (is_contiguous(num_dims, dims, strides)) {
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
out[i] = static_cast<T>(static_cast<I>(inp[i]));
|
||||
}
|
||||
@ -83,6 +83,18 @@ CAST_OP(double, __nv_bfloat16, cast_f64_bf16)
|
||||
CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8)
|
||||
CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16)
|
||||
CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16)
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#if CUDA_VERSION >= 11000
|
||||
CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
|
||||
CAST_OP(float, __nv_bfloat16, cast_f32_bf16)
|
||||
CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8)
|
||||
CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16)
|
||||
CAST_THROUGH_OP(__nv_bfloat16, double, float, cast_bf16_f64)
|
||||
CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16)
|
||||
CAST_THROUGH_OP(double, __nv_bfloat16, float, cast_f64_bf16)
|
||||
CAST_THROUGH_OP(uint8_t, __nv_bfloat16, float, cast_u8_bf16)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
|
@ -14,7 +14,7 @@ __device__ bool is_contiguous(
|
||||
size_t acc = 1;
|
||||
for (unsigned int d = 0; d < num_dims; d++) {
|
||||
unsigned int dim_idx = num_dims - 1 - d;
|
||||
if (acc != strides[dim_idx]) {
|
||||
if (dims[dim_idx] > 1 && acc != strides[dim_idx]) {
|
||||
return false;
|
||||
}
|
||||
acc *= dims[dim_idx];
|
||||
|
@ -10,11 +10,39 @@ __device__ void fill_with(T *buf, T value, const size_t numel) {
|
||||
extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
|
||||
template<typename T>
|
||||
__device__ void copy2d(const T *src, T *dst, uint32_t d1, uint32_t d2, uint32_t src_s, uint32_t dst_s) {
|
||||
uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= d1 * d2) {
|
||||
return;
|
||||
}
|
||||
uint32_t idx1 = idx / d2;
|
||||
uint32_t idx2 = idx - d2 * idx1;
|
||||
dst[idx1 * dst_s + idx2] = src[idx1 * src_s + idx2];
|
||||
}
|
||||
|
||||
#define COPY2D_OP(TYPENAME, FNNAME) \
|
||||
extern "C" __global__ \
|
||||
void FNNAME(const TYPENAME *src, TYPENAME *dst, uint32_t d1, uint32_t d2, uint32_t src_s, uint32_t dst_s) { \
|
||||
copy2d(src, dst, d1, d2, src_s, dst_s); \
|
||||
} \
|
||||
|
||||
COPY2D_OP(float, copy2d_f32)
|
||||
COPY2D_OP(double, copy2d_f64)
|
||||
COPY2D_OP(uint8_t, copy2d_u8)
|
||||
COPY2D_OP(uint32_t, copy2d_u32)
|
||||
COPY2D_OP(int64_t, copy2d_i64)
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
COPY2D_OP(__half, copy2d_f16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
|
||||
#endif
|
||||
|
@ -168,8 +168,10 @@ IS_OP(__half, uint8_t, is_u8_f16)
|
||||
GATHER_OP(__half, int64_t, gather_i64_f16)
|
||||
GATHER_OP(__half, uint32_t, gather_u32_f16)
|
||||
GATHER_OP(__half, uint8_t, gather_u8_f16)
|
||||
IA_OP(__half, int64_t, ia_i64_f16)
|
||||
IA_OP(__half, uint32_t, ia_u32_f16)
|
||||
IA_OP(__half, uint8_t, ia_u8_f16)
|
||||
SA_OP(__half, int64_t, sa_i64_f16)
|
||||
SA_OP(__half, uint32_t, sa_u32_f16)
|
||||
SA_OP(__half, uint8_t, sa_u8_f16)
|
||||
#endif
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -2,6 +2,7 @@
|
||||
#include <cmath>
|
||||
#include <stdint.h>
|
||||
|
||||
#define WARP_SIZE 32
|
||||
const int BLOCK_SIZE = 1024;
|
||||
|
||||
// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32
|
||||
@ -49,6 +50,59 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
dst[dst_id] = shr[0];
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
|
||||
template <typename T>
|
||||
__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
const int block_size = blockDim.x;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = static_cast<float>(x[row*ncols + col]);
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if (block_size > WARP_SIZE) {
|
||||
__shared__ float s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
__syncthreads();
|
||||
tmp = s_sum[lane_id];
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
const float mean = tmp / ncols;
|
||||
const float scale = rsqrtf(mean + eps);
|
||||
|
||||
if (alpha == nullptr) {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = static_cast<T>(scale * static_cast<float>(x[row*ncols + col]));
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
float a = static_cast<float>(alpha[col]);
|
||||
dst[row*ncols + col] = static_cast<T>(scale * static_cast<float>(x[row*ncols + col]) * a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Softmax implementation adapted from ggml.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L4159
|
||||
template <typename T, typename ACC>
|
||||
@ -93,6 +147,65 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= bh * td) return;
|
||||
|
||||
uint32_t rope_idx = idx % (td / 2);
|
||||
T c = cos[rope_idx];
|
||||
T s = sin[rope_idx];
|
||||
|
||||
dst[2 * idx] = src[2 * idx] * c - src[2 * idx + 1] * s;
|
||||
dst[2 * idx + 1] = src[2 * idx] * s + src[2 * idx + 1] * c;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= bh * td) return;
|
||||
|
||||
uint32_t i_bh = idx / (td / 2);
|
||||
uint32_t i_td = idx - (td / 2) * i_bh;
|
||||
uint32_t i_t = i_td / (d / 2);
|
||||
uint32_t i_d = i_td - (d / 2) * i_t;
|
||||
uint32_t i1 = i_bh * td + i_t * d + i_d;
|
||||
uint32_t i2 = i1 + d / 2;
|
||||
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
|
||||
dst[i1] = src[i1] * c - src[i2] * s;
|
||||
dst[i2] = src[i1] * s + src[i2] * c;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void rope_thd(
|
||||
const T * src,
|
||||
const T * cos,
|
||||
const T * sin,
|
||||
T * dst,
|
||||
const uint32_t b,
|
||||
const uint32_t t,
|
||||
const uint32_t h,
|
||||
const uint32_t d
|
||||
) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (2 * idx >= b * t * h * d) return;
|
||||
|
||||
uint32_t i_bth = idx / (d / 2);
|
||||
uint32_t i_d = idx - (d / 2) * i_bth;
|
||||
uint32_t i_t = (i_bth / h) % t;
|
||||
uint32_t i1 = i_bth * d + i_d;
|
||||
uint32_t i2 = i1 + d / 2;
|
||||
uint32_t i_cs = i_t * (d / 2) + i_d;
|
||||
T c = cos[i_cs];
|
||||
T s = sin[i_cs];
|
||||
|
||||
dst[i1] = src[i1] * c - src[i2] * s;
|
||||
dst[i2] = src[i1] * s + src[i2] * c;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void
|
||||
fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
@ -341,14 +454,57 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
softmax<TYPENAME, ACC_TYPENAME>(src, dst, n_cols); \
|
||||
} \
|
||||
|
||||
#define RMSNORM_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
|
||||
const int n_cols, const float eps) { \
|
||||
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
|
||||
} \
|
||||
|
||||
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
|
||||
extern "C" __global__ void FN_NAME_I( \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *cos, \
|
||||
const TYPENAME *sin, \
|
||||
TYPENAME *dst, \
|
||||
const uint32_t bh, \
|
||||
const uint32_t td) { \
|
||||
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *cos, \
|
||||
const TYPENAME *sin, \
|
||||
TYPENAME *dst, \
|
||||
const uint32_t bh, \
|
||||
const uint32_t td, \
|
||||
const uint32_t d) { \
|
||||
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME_THD( \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *cos, \
|
||||
const TYPENAME *sin, \
|
||||
TYPENAME *dst, \
|
||||
const uint32_t b, \
|
||||
const uint32_t t, \
|
||||
const uint32_t h, \
|
||||
const uint32_t d) { \
|
||||
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
|
||||
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
|
||||
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
|
||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
SOFTMAX_OP(__half, float, softmax_f16)
|
||||
RMSNORM_OP(__half, rmsnorm_f16)
|
||||
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
|
||||
SUM_OP(__half, sum_f16)
|
||||
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
||||
#endif
|
||||
@ -358,6 +514,10 @@ SUM_OP(double, sum_f64)
|
||||
SUM_OP(uint32_t, sum_u32)
|
||||
SOFTMAX_OP(float, float, softmax_f32)
|
||||
SOFTMAX_OP(double, double, softmax_f64)
|
||||
RMSNORM_OP(float, rmsnorm_f32)
|
||||
RMSNORM_OP(double, rmsnorm_f64)
|
||||
ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32)
|
||||
ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)
|
||||
|
||||
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
|
||||
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
|
||||
|
@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
@ -71,7 +71,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
@ -86,6 +86,11 @@ extern "C" __global__ void FN_NAME( \
|
||||
} \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
__device__ T sign_(T t) {
|
||||
return static_cast<T>(t > static_cast<T>(0)) - static_cast<T>(t < static_cast<T>(0));
|
||||
}
|
||||
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
|
||||
@ -110,6 +115,7 @@ UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
||||
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
||||
UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -135,6 +141,7 @@ UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
||||
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
||||
UNARY_OP(__half, usilu_f16, silu_fwd(x))
|
||||
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||
UNARY_OP(__half, usign_f16, sign_(x))
|
||||
#endif
|
||||
|
||||
UNARY_OP(uint8_t, ucopy_u8, x)
|
||||
@ -184,3 +191,5 @@ UNARY_OP(float, usilu_f32, silu_fwd(x))
|
||||
UNARY_OP(double, usilu_f64, silu_fwd(x))
|
||||
UNARY_OP1(float, upowf_f32, powg(x, param))
|
||||
UNARY_OP1(double, upowf_f64, powg(x, param))
|
||||
UNARY_OP(float, usign_f32, sign_(x))
|
||||
UNARY_OP(double, usign_f64, sign_(x))
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.4.1"
|
||||
version = "0.5.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
@ -89,7 +89,7 @@ kernel void FN_NAME( \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME x = input[id]; \
|
||||
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
|
||||
output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
|
@ -60,21 +60,24 @@ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \
|
||||
BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \
|
||||
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
||||
|
||||
#define INT64_BINARY_OP(NAME, FN) \
|
||||
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
|
||||
|
||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
#define BINARY_OP_OUT(NAME, FN) \
|
||||
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
|
||||
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \
|
||||
BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
|
||||
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
||||
|
||||
#define INT64_BINARY_OP(NAME, FN) \
|
||||
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
|
||||
|
||||
#define INT64_BINARY_OP_OUT(NAME, FN) \
|
||||
BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided);
|
||||
|
||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
#define BFLOAT_BINARY_OP_OUT(NAME, FN) \
|
||||
BINARY(FN, bfloat, uint8_t, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
BINARY_OP(x + y, add)
|
||||
BINARY_OP(x - y, sub)
|
||||
BINARY_OP(x * y, mul)
|
||||
@ -112,4 +115,11 @@ BFLOAT_BINARY_OP(x * y, mul)
|
||||
BFLOAT_BINARY_OP(x / y, div)
|
||||
BFLOAT_BINARY_OP(MIN(x, y), min)
|
||||
BFLOAT_BINARY_OP(MAX(x, y), max)
|
||||
|
||||
BFLOAT_BINARY_OP_OUT(eq, x == y)
|
||||
BFLOAT_BINARY_OP_OUT(ne, x != y)
|
||||
BFLOAT_BINARY_OP_OUT(le, x <= y)
|
||||
BFLOAT_BINARY_OP_OUT(lt, x < y)
|
||||
BFLOAT_BINARY_OP_OUT(ge, x >= y)
|
||||
BFLOAT_BINARY_OP_OUT(gt, x > y)
|
||||
#endif
|
||||
|
@ -72,27 +72,60 @@ kernel void FN_NAME_STRIDED( \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
|
||||
} \
|
||||
|
||||
// u32
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
|
||||
CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
|
||||
#if __METAL_VERSION__ >= 220
|
||||
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
|
||||
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
|
||||
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
|
||||
#endif
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
|
||||
#endif
|
||||
|
||||
// u8
|
||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
||||
CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half)
|
||||
#if __METAL_VERSION__ >= 220
|
||||
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
|
||||
#endif
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
|
||||
#endif
|
||||
|
||||
// f16
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t)
|
||||
CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t)
|
||||
CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
||||
|
||||
// i64
|
||||
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
|
||||
CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t)
|
||||
CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t)
|
||||
CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float)
|
||||
#endif
|
||||
|
||||
// f32
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t)
|
||||
CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t)
|
||||
CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
#endif
|
||||
|
||||
// bf16
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
|
||||
CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t)
|
||||
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
|
||||
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
|
||||
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
|
||||
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
|
||||
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
|
||||
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
@ -1,3 +1,9 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void im2col(
|
||||
constant size_t &dst_numel,
|
||||
@ -200,14 +206,331 @@ kernel void FN_NAME( \
|
||||
upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \
|
||||
} \
|
||||
|
||||
template <typename T, typename A>
|
||||
METAL_FUNC void avg_pool2d(
|
||||
constant size_t &w_k,
|
||||
constant size_t &h_k,
|
||||
constant size_t &w_stride,
|
||||
constant size_t &h_stride,
|
||||
constant size_t *src_dims,
|
||||
constant size_t *src_strides,
|
||||
device const T *src,
|
||||
device T *dst,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
const size_t c = src_dims[1];
|
||||
const size_t w_in = src_dims[2];
|
||||
const size_t h_in = src_dims[3];
|
||||
|
||||
const size_t w_out = (w_in - w_k) / w_stride + 1;
|
||||
const size_t h_out = (h_in - h_k) / h_stride + 1;
|
||||
if (tid >= src_dims[0] * c * w_out * h_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t b_idx = tid / (w_out * h_out * c);
|
||||
const size_t c_idx = (tid / (w_out * h_out)) % c;
|
||||
const size_t dst_w = (tid / h_out) % w_out;
|
||||
const size_t dst_h = tid % h_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * src_strides[0];
|
||||
A d = 0;
|
||||
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
|
||||
size_t src_w = w_stride * dst_w + w_offset;
|
||||
if (src_w >= w_in){
|
||||
continue;
|
||||
}
|
||||
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
|
||||
size_t src_h = h_stride * dst_h + h_offset;
|
||||
if (src_h >= h_in) {
|
||||
continue;
|
||||
}
|
||||
const size_t src_idx = src_idx0 + c_idx * src_strides[1] + src_w * src_strides[2] + src_h * src_strides[3];
|
||||
d += static_cast<A>(src[src_idx]);
|
||||
}
|
||||
}
|
||||
dst[tid] = static_cast<T>(d / (w_k * h_k));
|
||||
}
|
||||
|
||||
#define AVGPOOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &w_k, \
|
||||
constant size_t &h_k, \
|
||||
constant size_t &w_s, \
|
||||
constant size_t &h_s, \
|
||||
constant size_t *src_dims, \
|
||||
constant size_t *src_s, \
|
||||
device const TYPENAME *src, \
|
||||
device TYPENAME *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
avg_pool2d<TYPENAME, TYPEACC>(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \
|
||||
} \
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void max_pool2d(
|
||||
constant size_t &w_k,
|
||||
constant size_t &h_k,
|
||||
constant size_t &w_stride,
|
||||
constant size_t &h_stride,
|
||||
constant size_t *src_dims,
|
||||
constant size_t *src_strides,
|
||||
device const T *src,
|
||||
device T *dst,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
const size_t c = src_dims[1];
|
||||
const size_t w_in = src_dims[2];
|
||||
const size_t h_in = src_dims[3];
|
||||
|
||||
const size_t w_out = (w_in - w_k) / w_stride + 1;
|
||||
const size_t h_out = (h_in - h_k) / h_stride + 1;
|
||||
if (tid >= src_dims[0] * c * w_out * h_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t b_idx = tid / (w_out * h_out * c);
|
||||
const size_t c_idx = (tid / (w_out * h_out)) % c;
|
||||
const size_t dst_w = (tid / h_out) % w_out;
|
||||
const size_t dst_h = tid % h_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * src_strides[0];
|
||||
T d = 0;
|
||||
bool set = false;
|
||||
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
|
||||
size_t src_w = w_stride * dst_w + w_offset;
|
||||
if (src_w >= w_in){
|
||||
continue;
|
||||
}
|
||||
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
|
||||
size_t src_h = h_stride * dst_h + h_offset;
|
||||
if (src_h >= h_in) {
|
||||
continue;
|
||||
}
|
||||
const size_t src_idx = src_idx0 + c_idx * src_strides[1] + src_w * src_strides[2] + src_h * src_strides[3];
|
||||
if (set) {
|
||||
d = MAX(d, src[src_idx]);
|
||||
}
|
||||
else {
|
||||
d = src[src_idx];
|
||||
set = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[tid] = d;
|
||||
}
|
||||
|
||||
#define MAXPOOL2D_OP(TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &w_k, \
|
||||
constant size_t &h_k, \
|
||||
constant size_t &w_s, \
|
||||
constant size_t &h_s, \
|
||||
constant size_t *src_dims, \
|
||||
constant size_t *src_s, \
|
||||
device const TYPENAME *src, \
|
||||
device TYPENAME *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
max_pool2d<TYPENAME>(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \
|
||||
} \
|
||||
|
||||
|
||||
// Naive implementation of conv_transpose1d.
|
||||
template <typename T, typename A>
|
||||
METAL_FUNC void conv_transpose1d(
|
||||
constant size_t &l_out,
|
||||
constant size_t &stride,
|
||||
constant size_t &padding,
|
||||
constant size_t &out_padding,
|
||||
constant size_t &dilation,
|
||||
constant size_t *src_dims,
|
||||
constant size_t *src_strides,
|
||||
constant size_t *k_dims,
|
||||
constant size_t *k_strides,
|
||||
device const T *src,
|
||||
device const T *k,
|
||||
device T *dst,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
// src: (b_size, c_in, l_in)
|
||||
// kernel: (c_in, c_out, l_k)
|
||||
const size_t l_k = k_dims[2];
|
||||
const size_t c_out = k_dims[1];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t l_in = src_dims[2];
|
||||
if (tid >= src_dims[0] * c_out * l_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t b_idx = tid / (l_out * c_out);
|
||||
const size_t dst_c_idx = (tid / l_out) % c_out;
|
||||
const size_t out_x = tid % l_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * src_strides[0];
|
||||
A d = 0;
|
||||
for (int k_x = 0; k_x < (int)l_k; ++k_x) {
|
||||
// let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
|
||||
int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
|
||||
if (inp_x_stride < 0 || inp_x_stride % stride) {
|
||||
continue;
|
||||
}
|
||||
int inp_x = inp_x_stride / stride;
|
||||
if (inp_x >= l_in) continue;
|
||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
||||
const size_t src_idx = src_idx0 + src_c_idx * src_strides[1] + inp_x * src_strides[2];
|
||||
const size_t k_idx = src_c_idx * k_strides[0] + dst_c_idx * k_strides[1] + k_x * k_strides[2];
|
||||
d += static_cast<A>(src[src_idx]) * static_cast<A>(k[k_idx]);
|
||||
}
|
||||
}
|
||||
dst[tid] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &l_out, \
|
||||
constant size_t &stride, \
|
||||
constant size_t &padding, \
|
||||
constant size_t &out_padding, \
|
||||
constant size_t &dilation, \
|
||||
constant size_t *src_dims, \
|
||||
constant size_t *src_strides, \
|
||||
constant size_t *k_dims, \
|
||||
constant size_t *k_strides, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *k, \
|
||||
device TYPENAME *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
conv_transpose1d<TYPENAME, TYPEACC>(l_out, stride, padding, out_padding, dilation, src_dims, src_strides, k_dims, k_strides, src, k, dst, tid); \
|
||||
} \
|
||||
|
||||
template <typename T, typename A>
|
||||
METAL_FUNC void conv_transpose2d(
|
||||
constant size_t &w_out,
|
||||
constant size_t &h_out,
|
||||
constant size_t &stride,
|
||||
constant size_t &padding,
|
||||
constant size_t &out_padding,
|
||||
constant size_t &dilation,
|
||||
constant size_t *input_dims,
|
||||
constant size_t *input_stride,
|
||||
constant size_t *k_dims,
|
||||
constant size_t *k_stride,
|
||||
device const T *src,
|
||||
device const T *k,
|
||||
device T *dst,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
const size_t h_k = k_dims[2];
|
||||
const size_t w_k = k_dims[3];
|
||||
const size_t c_out = k_dims[1];
|
||||
const size_t c_in = input_dims[1];
|
||||
const size_t h_in = input_dims[2];
|
||||
const size_t w_in = input_dims[3];
|
||||
|
||||
if (tid >= input_dims[0] * c_out * w_out * h_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t b_idx = tid / (w_out * h_out * c_out);
|
||||
const size_t dst_c_idx = (tid / (w_out * h_out)) % c_out;
|
||||
const size_t out_y = (tid / w_out) % h_out;
|
||||
const size_t out_x = tid % w_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * input_stride[0];
|
||||
|
||||
A d = 0;
|
||||
for (int k_x = 0; k_x < (int)w_k; ++k_x) {
|
||||
const int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
|
||||
if (inp_x_stride < 0 || inp_x_stride % stride) {
|
||||
continue;
|
||||
}
|
||||
const int inp_x = inp_x_stride / stride;
|
||||
if (inp_x >= w_in) continue;
|
||||
for (int k_y = 0; k_y < (int)h_k; ++k_y) {
|
||||
const int inp_y_stride = (int)(out_y + padding) - k_y * dilation;
|
||||
if (inp_y_stride < 0 || inp_y_stride % stride) {
|
||||
continue;
|
||||
}
|
||||
const int inp_y = inp_y_stride / stride;
|
||||
if (inp_y >= h_in) continue;
|
||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
||||
const size_t src_idx = src_idx0 + src_c_idx * input_stride[1] + inp_y * input_stride[2] + inp_x * input_stride[3];
|
||||
const size_t k_idx = src_c_idx * k_stride[0] + dst_c_idx * k_stride[1] + k_y * k_stride[2] + k_x * k_stride[3];
|
||||
d += static_cast<A>(src[src_idx]) * static_cast<A>(k[k_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[tid] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &w_out, \
|
||||
constant size_t &h_out, \
|
||||
constant size_t &stride, \
|
||||
constant size_t &padding, \
|
||||
constant size_t &out_padding, \
|
||||
constant size_t &dilation, \
|
||||
constant size_t *input_dims, \
|
||||
constant size_t *input_stride, \
|
||||
constant size_t *k_dims, \
|
||||
constant size_t *k_stride, \
|
||||
device const TYPENAME *src, \
|
||||
device const TYPENAME *k, \
|
||||
device TYPENAME *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
conv_transpose2d<TYPENAME, TYPEACC>(w_out, h_out, stride, padding, out_padding, dilation, input_dims, input_stride, k_dims, k_stride, src, k, dst, tid); \
|
||||
} \
|
||||
|
||||
IM2COL_OP(float, im2col_f32)
|
||||
IM2COL_OP(half, im2col_f16)
|
||||
IM2COL_OP(uint8_t, im2col_u8)
|
||||
IM2COL_OP(uint32_t, im2col_u32)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
IM2COL_OP(bfloat, im2col_bf16)
|
||||
#endif
|
||||
|
||||
IM2COL1D_OP(float, im2col1d_f32)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||
|
||||
UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
|
||||
UPSAMPLE_NEAREST2D_OP(half, upsample_nearest2d_f16)
|
||||
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
|
||||
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
UPSAMPLE_NEAREST2D_OP(bfloat, upsample_nearest2d_bf16)
|
||||
#endif
|
||||
|
||||
MAXPOOL2D_OP(float, max_pool2d_f32)
|
||||
MAXPOOL2D_OP(half, max_pool2d_f16)
|
||||
MAXPOOL2D_OP(uint32_t, max_pool2d_u32)
|
||||
MAXPOOL2D_OP(uint8_t, max_pool2d_u8)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
MAXPOOL2D_OP(bfloat, max_pool2d_bf16)
|
||||
#endif
|
||||
|
||||
AVGPOOL2D_OP(float, float, avg_pool2d_f32)
|
||||
AVGPOOL2D_OP(half, float, avg_pool2d_f16)
|
||||
AVGPOOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32)
|
||||
AVGPOOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
AVGPOOL2D_OP(bfloat, float, avg_pool2d_bf16)
|
||||
#endif
|
||||
|
||||
CONVT1D_OP(float, float, conv_transpose1d_f32)
|
||||
CONVT1D_OP(half, float, conv_transpose1d_f16)
|
||||
CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8)
|
||||
CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CONVT1D_OP(bfloat, float, conv_transpose1d_bf16)
|
||||
#endif
|
||||
|
||||
CONVT2D_OP(float, float, conv_transpose2d_f32)
|
||||
CONVT2D_OP(half, float, conv_transpose2d_f16)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CONVT1D_OP(bfloat, float, conv_transpose2d_bf16)
|
||||
#endif
|
@ -1,20 +1,38 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||
METAL_FUNC void index(
|
||||
constant size_t &dst_size,
|
||||
constant size_t &left_size,
|
||||
constant size_t &src_dim_size,
|
||||
constant size_t &right_size,
|
||||
constant size_t &ids_size,
|
||||
const device TYPENAME *input,
|
||||
constant size_t &ids_size,
|
||||
constant bool &contiguous,
|
||||
constant size_t *src_dims,
|
||||
constant size_t *src_strides,
|
||||
const device TYPENAME *input,
|
||||
const device INDEX_TYPENAME *input_ids,
|
||||
device TYPENAME *output,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
if (tid >= dst_size) {
|
||||
return;
|
||||
return;
|
||||
}
|
||||
const size_t id_i = (tid / right_size) % ids_size;
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
|
||||
@ -26,7 +44,8 @@ METAL_FUNC void index(
|
||||
// No need to check for zero we're only allowing unsized.
|
||||
*/
|
||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
|
||||
output[tid] = input[src_i];
|
||||
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
|
||||
output[tid] = input[strided_src_i];
|
||||
}
|
||||
|
||||
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
@ -36,12 +55,15 @@ kernel void NAME( \
|
||||
constant size_t &src_dim_size, \
|
||||
constant size_t &right_size, \
|
||||
constant size_t &ids_size, \
|
||||
constant bool &contiguous, \
|
||||
constant size_t *src_dims, \
|
||||
constant size_t *src_strides, \
|
||||
const device TYPENAME *input, \
|
||||
const device INDEX_TYPENAME *input_ids, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
|
||||
index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, contiguous, src_dims, src_strides, input, input_ids, output, tid); \
|
||||
}
|
||||
|
||||
|
||||
@ -165,37 +187,68 @@ kernel void NAME( \
|
||||
}
|
||||
|
||||
|
||||
INDEX_OP(is_u32_f32, uint, float)
|
||||
INDEX_OP(is_u32_f16, uint, half)
|
||||
GATHER_OP(gather_u32_f32, uint, float)
|
||||
GATHER_OP(gather_u32_f16, uint, half)
|
||||
SCATTER_ADD_OP(sa_u32_f32, uint, float)
|
||||
SCATTER_ADD_OP(sa_u32_f16, uint, half)
|
||||
|
||||
|
||||
INDEX_OP(is_i64_f32, int64_t, float)
|
||||
INDEX_OP(is_i64_f16, int64_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
|
||||
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
||||
|
||||
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
||||
INDEX_OP(is_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
|
||||
INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
|
||||
INDEX_OP(is_u32_f32, uint32_t, float)
|
||||
INDEX_OP(is_u32_f16, uint32_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
|
||||
#endif
|
||||
|
||||
INDEX_OP(is_u8_f32, uint8_t, float)
|
||||
INDEX_OP(is_u8_f16, uint8_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
||||
#endif
|
||||
|
||||
GATHER_OP(gather_u32_f32, uint, float)
|
||||
GATHER_OP(gather_u32_f16, uint, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
GATHER_OP(gather_u32_bf16, uint, bfloat)
|
||||
#endif
|
||||
|
||||
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
|
||||
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
|
||||
SCATTER_ADD_OP(sa_i64_f32, int64_t, float)
|
||||
SCATTER_ADD_OP(sa_u32_f16, uint32_t, half)
|
||||
SCATTER_ADD_OP(sa_u8_f16, uint8_t, half)
|
||||
SCATTER_ADD_OP(sa_i64_f16, int64_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat)
|
||||
SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
|
||||
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
// i64
|
||||
INDEX_ADD_OP(ia_i64_f16, int64_t, half)
|
||||
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
|
||||
INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
|
||||
INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t)
|
||||
INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t)
|
||||
INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
// u32
|
||||
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
|
||||
INDEX_ADD_OP(ia_u32_f32, uint32_t, float)
|
||||
INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
|
||||
INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t)
|
||||
INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t)
|
||||
INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
|
||||
#endif
|
||||
|
||||
// u8
|
||||
INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
|
||||
INDEX_ADD_OP(ia_u8_f32, uint8_t, float)
|
||||
INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
|
||||
INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
|
||||
INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t)
|
||||
INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
|
||||
INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
||||
#endif
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -123,16 +123,20 @@ template<typename T> METAL_FUNC void rand_uniform(
|
||||
return;
|
||||
}
|
||||
|
||||
// Evenly sized vectors need an offset when writing the mirror element.
|
||||
uint off = 1 - size % 2;
|
||||
float diff = abs(min - max);
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
uint s = atomic_load_explicit(seed, memory_order_relaxed);
|
||||
HybridTaus rng = HybridTaus::init({ulong(s), tid, 1, 1});
|
||||
out[tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
// Return early if tid == 0, otherwise we will write to out[size].
|
||||
return;
|
||||
// Return early if tid == 0 && off == 0, otherwise we will write to out[size].
|
||||
if (off == 0)
|
||||
return;
|
||||
}
|
||||
// Use symmetry to fill the other half of the array.
|
||||
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
out[size - off - tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
}
|
||||
|
||||
// Create Gaussian normal distribution using Box-Muller transform:
|
||||
@ -148,7 +152,10 @@ template<typename T> METAL_FUNC void normal(
|
||||
if (tid >= size) {
|
||||
return;
|
||||
}
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
// Evenly sized vectors need an offset when writing the mirror element.
|
||||
uint off = 1 - size % 2;
|
||||
uint s = atomic_load_explicit(seed, memory_order_relaxed);
|
||||
HybridTaus rng = HybridTaus::init({ulong(s), tid, 1, 1});
|
||||
float u1 = rng.rand();
|
||||
float u2 = rng.rand();
|
||||
|
||||
@ -162,11 +169,12 @@ template<typename T> METAL_FUNC void normal(
|
||||
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
// Return early if tid == 0, otherwise we will write to out[size].
|
||||
return;
|
||||
// Return early if tid == 0 && off == 0, otherwise we will write to out[size].
|
||||
if (off == 0)
|
||||
return;
|
||||
}
|
||||
// Use symmetry to fill the other half of the array.
|
||||
out[size - tid] = static_cast<T>(z1);
|
||||
out[size - off - tid] = static_cast<T>(z1);
|
||||
}
|
||||
|
||||
#define UNIFORM_OP(NAME, T) \
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user