Compare commits

..

1 Commits

Author SHA1 Message Date
101a4c8389 Moondream first bits. 2024-03-17 17:49:56 +01:00
137 changed files with 2926 additions and 12446 deletions

View File

@ -9,7 +9,6 @@ members = [
"candle-transformers",
"candle-wasm-examples/*",
"candle-wasm-tests",
"tensor-tools",
]
exclude = [
"candle-flash-attn",
@ -20,7 +19,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.5.0"
version = "0.4.2"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -29,18 +28,17 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[workspace.dependencies]
ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.5.0" }
candle-datasets = { path = "./candle-datasets", version = "0.5.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.0" }
candle-kernels = { path = "./candle-kernels", version = "0.5.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.0" }
candle-nn = { path = "./candle-nn", version = "0.5.0" }
candle-onnx = { path = "./candle-onnx", version = "0.5.0" }
candle-transformers = { path = "./candle-transformers", version = "0.5.0" }
candle = { path = "./candle-core", package = "candle-core", version = "0.4.2" }
candle-datasets = { path = "./candle-datasets", version = "0.4.2" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.2" }
candle-kernels = { path = "./candle-kernels", version = "0.4.2" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.2" }
candle-nn = { path = "./candle-nn", version = "0.4.2" }
candle-onnx = { path = "./candle-onnx", version = "0.4.2" }
candle-transformers = { path = "./candle-transformers", version = "0.4.2" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.10.0", features = ["f16"] }
@ -48,18 +46,19 @@ fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.23.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
log = "0.4"
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
num_cpus = "1.15.0"
num-traits = "0.2.15"
parquet = { version = "51.0.0" }
parquet = { version = "50.0.0" }
rand = "0.8.5"
rand_distr = "0.4.3"
rayon = "1.7.0"
rusttype = { version = "0.9", default-features = false }
safetensors = "0.4.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"

View File

@ -125,14 +125,10 @@ We also provide a some command line based examples using state of the art models
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
generate captions for an image.
- [CLIP](./candle-examples/examples/clip/): multi-model vision and language
model.
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
dedicated submodels for hand-writing and printed recognition.
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
model, generates the translated text from the input text.
- [Moondream](./candle-examples/examples/moondream/): tiny computer-vision model
that can answer real-world questions about images.
Run them using commands like:
```
@ -176,7 +172,6 @@ And then head over to
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
serving local LLMs including an OpenAI compatible API server.
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
- [`candle-coursera-ml`](https://github.com/vishpat/candle-coursera-ml): Implementation of ML algorithms from Coursera's [Machine Learning Specialization](https://www.coursera.org/specializations/machine-learning-introduction) course.
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
@ -211,7 +206,7 @@ If you have an addition to this list, please submit a pull request.
- Replit-code-v1.5-3B.
- Bert.
- Yi-6B and Yi-34B.
- Qwen1.5, Qwen1.5 MoE.
- Qwen1.5.
- RWKV v5 and v6.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.

View File

@ -5,6 +5,5 @@ criterion_main!(
benchmarks::affine::benches,
benchmarks::matmul::benches,
benchmarks::random::benches,
benchmarks::where_cond::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::where_cond::benches
);

View File

@ -1,59 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(
x: &Tensor,
k: &Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) {
x.conv_transpose2d(k, padding, output_padding, stride, dilation)
.unwrap();
}
fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let t = Tensor::arange(0.0f32, 10000.0, device)
.unwrap()
.reshape((1, 4, 50, 50))
.unwrap()
.to_dtype(dtype)
.unwrap();
let kernel = Tensor::arange(0.0f32, 100.0, device)
.unwrap()
.reshape((4, 1, 5, 5))
.unwrap()
.to_dtype(dtype)
.unwrap();
let flops = t.dims().iter().product::<usize>() * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&t), black_box(&kernel), 1, 0, 1, 2);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32");
run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16");
run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,5 +1,4 @@
pub(crate) mod affine;
pub(crate) mod conv_transpose2d;
pub(crate) mod matmul;
pub(crate) mod random;
pub(crate) mod where_cond;

View File

@ -1,5 +1,5 @@
use candle::quantized::{gguf_file, GgmlDType, QTensor};
use candle::{Device, Result};
use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
use candle_core::{Device, Result};
use clap::{Parser, Subcommand, ValueEnum};
use rayon::prelude::*;
@ -117,24 +117,6 @@ enum Command {
verbose: bool,
},
Print {
file: std::path::PathBuf,
names: Vec<String>,
/// The file format to use, if unspecified infer from the file extension.
#[arg(long, value_enum)]
format: Option<Format>,
/// Print the whole content of each tensor.
#[arg(long)]
full: bool,
/// Line width for printing the tensors.
#[arg(long)]
line_width: Option<usize>,
},
Quantize {
/// The input file(s), in safetensors format.
in_file: Vec<std::path::PathBuf>,
@ -168,105 +150,6 @@ struct Args {
command: Command,
}
fn run_print(
file: &std::path::PathBuf,
names: Vec<String>,
format: Option<Format>,
full: bool,
line_width: Option<usize>,
device: &Device,
) -> Result<()> {
if full {
candle::display::set_print_options_full();
}
if let Some(line_width) = line_width {
candle::display::set_line_width(line_width)
}
let format = match format {
Some(format) => format,
None => match Format::infer(file) {
Some(format) => format,
None => {
println!(
"{file:?}: cannot infer format from file extension, use the --format flag"
);
return Ok(());
}
},
};
match format {
Format::Npz => {
let tensors = candle::npy::NpzTensors::new(file)?;
for name in names.iter() {
println!("==== {name} ====");
match tensors.get(name)? {
Some(tensor) => println!("{tensor}"),
None => println!("not found"),
}
}
}
Format::Safetensors => {
use candle::safetensors::Load;
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
for name in names.iter() {
println!("==== {name} ====");
match tensors.get(name) {
Some(tensor_view) => {
let tensor = tensor_view.load(device)?;
println!("{tensor}")
}
None => println!("not found"),
}
}
}
Format::Pth => {
let pth_file = candle::pickle::PthTensors::new(file, None)?;
for name in names.iter() {
println!("==== {name} ====");
match pth_file.get(name)? {
Some(tensor) => {
println!("{tensor}")
}
None => println!("not found"),
}
}
}
Format::Pickle => {
candle::bail!("pickle format is not supported for print")
}
Format::Ggml => {
let mut file = std::fs::File::open(file)?;
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
for name in names.iter() {
println!("==== {name} ====");
match content.tensors.get(name) {
Some(tensor) => {
let tensor = tensor.dequantize(device)?;
println!("{tensor}")
}
None => println!("not found"),
}
}
}
Format::Gguf => {
let mut file = std::fs::File::open(file)?;
let content = gguf_file::Content::read(&mut file)?;
for name in names.iter() {
println!("==== {name} ====");
match content.tensor(&mut file, name, device) {
Ok(tensor) => {
let tensor = tensor.dequantize(device)?;
println!("{tensor}")
}
Err(_) => println!("not found"),
}
}
}
}
Ok(())
}
fn run_ls(
file: &std::path::PathBuf,
format: Option<Format>,
@ -287,7 +170,7 @@ fn run_ls(
};
match format {
Format::Npz => {
let tensors = candle::npy::NpzTensors::new(file)?;
let tensors = candle_core::npy::NpzTensors::new(file)?;
let mut names = tensors.names();
names.sort();
for name in names {
@ -299,12 +182,12 @@ fn run_ls(
}
}
Format::Safetensors => {
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
let mut tensors = tensors.tensors();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, view) in tensors.iter() {
let dtype = view.dtype();
let dtype = match candle::DType::try_from(dtype) {
let dtype = match candle_core::DType::try_from(dtype) {
Ok(dtype) => format!("{dtype:?}"),
Err(_) => format!("{dtype:?}"),
};
@ -313,7 +196,7 @@ fn run_ls(
}
}
Format::Pth => {
let mut tensors = candle::pickle::read_pth_tensor_info(file, verbose, None)?;
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?;
tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() {
println!(
@ -330,7 +213,7 @@ fn run_ls(
Format::Pickle => {
let file = std::fs::File::open(file)?;
let mut reader = std::io::BufReader::new(file);
let mut stack = candle::pickle::Stack::empty();
let mut stack = candle_core::pickle::Stack::empty();
stack.read_loop(&mut reader)?;
for (i, obj) in stack.stack().iter().enumerate() {
println!("{i} {obj:?}");
@ -338,7 +221,7 @@ fn run_ls(
}
Format::Ggml => {
let mut file = std::fs::File::open(file)?;
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, qtensor) in tensors.iter() {
@ -374,7 +257,7 @@ fn run_quantize_safetensors(
let mut out_file = std::fs::File::create(out_file)?;
let mut tensors = std::collections::HashMap::new();
for in_file in in_files.iter() {
let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?;
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
tensors.extend(in_tensors)
}
println!("tensors: {}", tensors.len());
@ -416,7 +299,7 @@ fn run_dequantize(
let tensor = tensor.dequantize(device)?;
tensors.insert(tensor_name.to_string(), tensor);
}
candle::safetensors::save(&tensors, out_file)?;
candle_core::safetensors::save(&tensors, out_file)?;
Ok(())
}
@ -428,11 +311,11 @@ fn run_quantize(
device: &Device,
) -> Result<()> {
if in_files.is_empty() {
candle::bail!("no specified input files")
candle_core::bail!("no specified input files")
}
if let Some(extension) = out_file.extension() {
if extension == "safetensors" {
candle::bail!("the generated file cannot use the safetensors extension")
candle_core::bail!("the generated file cannot use the safetensors extension")
}
}
if let Some(extension) = in_files[0].extension() {
@ -442,7 +325,7 @@ fn run_quantize(
}
if in_files.len() != 1 {
candle::bail!("only a single in-file can be used when quantizing gguf files")
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
}
// Open the out file early so as to fail directly on missing directories etc.
@ -494,13 +377,6 @@ fn main() -> anyhow::Result<()> {
run_ls(file, format.clone(), verbose, &device)?
}
}
Command::Print {
file,
names,
format,
full,
line_width,
} => run_print(&file, names, format, full, line_width, &device)?,
Command::Quantize {
in_file,
out_file,

View File

@ -127,16 +127,8 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
/// # Safety
/// This function is unsafe as it doesn't initialize the underlying data store.
/// The caller should ensure that the data is properly initialized as early as possible
/// after this call.
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;

View File

@ -1,4 +1,3 @@
/// Methods for backpropagation of gradients.
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
use crate::{Error, Result, Tensor, TensorId};
use std::collections::HashMap;
@ -112,8 +111,7 @@ impl Tensor {
}
Op::Unary(_node, UnaryOp::Ceil)
| Op::Unary(_node, UnaryOp::Floor)
| Op::Unary(_node, UnaryOp::Round)
| Op::Unary(_node, UnaryOp::Sign) => nodes,
| Op::Unary(_node, UnaryOp::Round) => nodes,
Op::Reshape(node)
| Op::UpsampleNearest1D { arg: node, .. }
| Op::UpsampleNearest2D { arg: node, .. }
@ -312,32 +310,9 @@ impl Tensor {
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose1d",
})?,
Op::ConvTranspose2D {
arg,
kernel,
padding,
stride,
dilation,
output_padding: _output_padding,
} => {
let grad_arg = grad.conv2d(kernel, *padding, *dilation, *stride, 1)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = grad
.transpose(0, 1)?
.conv2d(&arg.transpose(0, 1)?, *padding, *stride, *dilation, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0, k1) = kernel.dims4()?;
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose2d",
})?,
Op::AvgPool2D {
arg,
kernel_size,
@ -489,6 +464,7 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad)?;
}
Op::Cmp(_args, _) => {}
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
let node = broadcast_back(arg, node, reduced_dims)?;
let grad = broadcast_back(arg, &grad, reduced_dims)?;
@ -578,18 +554,20 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Floor)
| Op::Unary(_, UnaryOp::Round)
| Op::Reduce(_, ReduceOp::ArgMin, _)
| Op::Reduce(_, ReduceOp::ArgMax, _)
| Op::Unary(_, UnaryOp::Sign)
| Op::Cmp(_, _) => {}
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
Op::Unary(_, UnaryOp::Floor) => {
Err(Error::BackwardNotSupported { op: "floor" })?
}
Op::Unary(_, UnaryOp::Round) => {
Err(Error::BackwardNotSupported { op: "round" })?
}
Op::Unary(arg, UnaryOp::Gelu) => {
let sum_grad = grads.or_insert(arg)?;
let cube = arg.powf(3.)?;
@ -712,38 +690,30 @@ impl Tensor {
}
}
/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation.
#[derive(Debug)]
pub struct GradStore(HashMap<TensorId, Tensor>);
impl GradStore {
/// Create a new gradient store
fn new() -> Self {
GradStore(HashMap::new())
}
/// Get the gradient tensor corresponding to the given tensor id
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
self.0.get(&id)
}
/// Get the gradient tensor associated with the given tensor
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
self.0.get(&tensor.id())
}
/// Remove the gradient tensor associated with the given tensor, returning it if it exists
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
self.0.remove(&tensor.id())
}
/// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
self.0.insert(tensor.id(), grad)
}
/// Get the gradient tensor associated with the given tensor, or, if it does not exist,
/// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
use std::collections::hash_map::Entry;
let grad = match self.0.entry(tensor.id()) {

View File

@ -4,11 +4,6 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use half::{bf16, f16};
use rayon::prelude::*;
mod utils;
pub use utils::{
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
};
const USE_IM2COL_CONV1D: bool = true;
const USE_IM2COL_CONV1D_TR: bool = true;
const USE_IM2COL_CONV2D: bool = true;
@ -29,6 +24,102 @@ pub enum CpuStorage {
#[derive(Debug, Clone)]
pub struct CpuDevice;
pub trait Map1 {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
match vs {
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
}
}
}
pub trait Map1Any {
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
&self,
vs: &[T],
layout: &Layout,
wrap: W,
) -> Result<CpuStorage>;
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
match vs {
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
}
}
}
type C = CpuStorage;
pub trait Map2 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
fn map(
&self,
v1: &CpuStorage,
l1: &Layout,
v2: &CpuStorage,
l2: &Layout,
) -> Result<CpuStorage> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt()),
}
}
}
pub trait Map2U8 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
fn map(
&self,
v1: &CpuStorage,
l1: &Layout,
v2: &CpuStorage,
l2: &Layout,
) -> Result<CpuStorage> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt()),
}
}
}
struct Cmp(CmpOp);
impl Map2U8 for Cmp {
const OP: &'static str = "cmp";
@ -275,6 +366,275 @@ impl<'a> Map1 for ReduceSum<'a> {
}
}
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
vs: &[T],
layout: &Layout,
mut f: F,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
[start_offset..start_offset + len]
.iter()
.map(|&v| f(v))
.collect(),
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let mut result = Vec::with_capacity(layout.shape().elem_count());
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
} else {
for index in block_start_index {
for offset in 0..block_len {
let v = unsafe { vs.get_unchecked(index + offset) };
result.push(f(*v))
}
}
}
result
}
}
}
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
vs: &[T],
layout: &Layout,
mut f: F,
mut f_vec: FV,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
let mut ys: Vec<U> = Vec::with_capacity(len);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(len) };
ys
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let el_count = layout.shape().elem_count();
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
let mut result = Vec::with_capacity(el_count);
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
result
} else {
let mut ys: Vec<U> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
let mut dst_index = 0;
for src_index in block_start_index {
let vs = &vs[src_index..src_index + block_len];
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
f_vec(vs, ys);
dst_index += block_len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
}
}
}
// This function maps over two strided index sequences.
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
) -> Vec<U> {
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
.iter()
.zip(rhs[o_r1..o_r2].iter())
.map(|(&l, &r)| f(l, r))
.collect(),
(Some((o_l1, o_l2)), None) => {
// TODO: Maybe we want to avoid going through the layout twice.
match rhs_l.offsets_b() {
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
lhs[o_l1..o_l2]
.iter()
.map(|&l| {
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(l, *r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
(None, Some((o_r1, o_r2))) => {
// TODO: Maybe we want to avoid going through the layout twice.
match lhs_l.offsets_b() {
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
rhs[o_r1..o_r2]
.iter()
.map(|&r| {
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(*l, r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
_ => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
// Similar to binary_map but with vectorized variants.
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
mut f_vec: FV,
) -> Vec<T> {
let el_count = lhs_l.shape().elem_count();
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_l1..o_l2).step_by(ob.len) {
f_vec(
&lhs[src_i..src_i + ob.len],
rhs,
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys = lhs[o_l1..o_l2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
for (i, &r) in rhs.iter().enumerate() {
let start = start + i * ob.right_broadcast;
for v in ys[start..start + ob.right_broadcast].iter_mut() {
*v = f(*v, r)
}
}
}
ys
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_r1..o_r2).step_by(ob.len) {
f_vec(
lhs,
&rhs[src_i..src_i + ob.len],
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys = rhs[o_r1..o_r2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
for (i, &l) in lhs.iter().enumerate() {
let start = start + i * ob.right_broadcast;
for v in ys[start..start + ob.right_broadcast].iter_mut() {
*v = f(l, *v)
}
}
}
ys
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
_ => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
struct Affine(f64, f64);
impl Map1 for Affine {
@ -1204,30 +1564,6 @@ impl MatMul {
}))
.bt()
}
fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
let rank = lhs_stride.len();
let (_b, m, n, k) = self.0;
let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[_, stride] if lhs_l.dims()[0] == 1 => stride,
[stride, _] if lhs_l.dims()[1] == 1 => stride,
[stride] => stride,
[] => m * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[_, stride] if rhs_l.dims()[0] == 1 => stride,
[stride, _] if rhs_l.dims()[1] == 1 => stride,
[stride] => stride,
[] => n * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
Ok((a_skip, b_skip))
}
}
impl Map2 for MatMul {
@ -1261,7 +1597,18 @@ impl Map2 for MatMul {
let rhs_cs = rhs_stride[rank - 1];
let rhs_rs = rhs_stride[rank - 2];
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
let c_skip: usize = m * n;
let dst_shape: Shape = (m, n).into();
@ -1321,8 +1668,20 @@ impl Map2 for MatMul {
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
let rank = lhs_stride.len();
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
let c_skip: usize = m * n;
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
@ -1330,7 +1689,7 @@ impl Map2 for MatMul {
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
(n as i32, b'N')
} else if rhs_m1 == k && rhs_m2 == 1 {
(k as i32, b'T')
@ -1338,7 +1697,7 @@ impl Map2 for MatMul {
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
};
// The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
(k as i32, b'N')
} else if lhs_m1 == m && lhs_m2 == 1 {
(m as i32, b'T')
@ -1412,8 +1771,20 @@ impl Map2 for MatMul {
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
let rank = lhs_stride.len();
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
let c_skip: usize = m * n;
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
@ -1421,7 +1792,7 @@ impl Map2 for MatMul {
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
(n as i32, b'N')
} else if rhs_m1 == k && rhs_m2 == 1 {
(k as i32, b'T')
@ -1429,7 +1800,7 @@ impl Map2 for MatMul {
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
};
// The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
(k as i32, b'N')
} else if lhs_m1 == m && lhs_m2 == 1 {
(m as i32, b'T')
@ -2211,10 +2582,7 @@ impl BackendStorage for CpuStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe {
self.device()
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
@ -2222,7 +2590,7 @@ impl BackendStorage for CpuStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
@ -2313,10 +2681,7 @@ impl BackendStorage for CpuStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe {
self.device()
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
@ -2326,7 +2691,7 @@ impl BackendStorage for CpuStorage {
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
.transpose(1, 2)?
.transpose(1, 3)?;
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
@ -2449,10 +2814,6 @@ impl BackendDevice for CpuDevice {
Ok(s.clone())
}
fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> {
Ok(s)
}
fn new(_: usize) -> Result<Self> {
Ok(Self)
}
@ -2554,53 +2915,6 @@ impl BackendDevice for CpuDevice {
}
}
#[allow(clippy::uninit_vec)]
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
let elem_count = shape.elem_count();
// The code below is highly unsafe but hopefully not directly unsound as we only consider
// types that are Copy, not Drop, and for which all bit patterns are proper values.
// It's still pretty risky, see the following for more details:
// https://github.com/rust-lang/rust-clippy/issues/4483
let storage = match dtype {
DType::U8 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::U8(v)
}
DType::U32 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::U32(v)
}
DType::I64 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::I64(v)
}
DType::BF16 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::BF16(v)
}
DType::F16 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::F16(v)
}
DType::F32 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::F32(v)
}
DType::F64 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::F64(v)
}
};
Ok(storage)
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
let elem_count = shape.elem_count();
let storage = match dtype {

View File

@ -1,350 +0,0 @@
/// Helper functions to write CPU kernels.
use crate::backend::BackendStorage;
use crate::{Error, Layout, Result, WithDType};
type C = super::CpuStorage;
pub trait Map1 {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
match vs {
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
}
}
}
pub trait Map1Any {
fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
match vs {
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
}
}
}
pub trait Map2 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt()),
}
}
}
pub trait Map2U8 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt()),
}
}
}
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
) -> Vec<U> {
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
.iter()
.zip(rhs[o_r1..o_r2].iter())
.map(|(&l, &r)| f(l, r))
.collect(),
(Some((o_l1, o_l2)), None) => {
// TODO: Maybe we want to avoid going through the layout twice.
match rhs_l.offsets_b() {
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
lhs[o_l1..o_l2]
.iter()
.map(|&l| {
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(l, *r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
(None, Some((o_r1, o_r2))) => {
// TODO: Maybe we want to avoid going through the layout twice.
match lhs_l.offsets_b() {
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
rhs[o_r1..o_r2]
.iter()
.map(|&r| {
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(*l, r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
_ => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
// Similar to binary_map but with vectorized variants.
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
mut f_vec: FV,
) -> Vec<T> {
let el_count = lhs_l.shape().elem_count();
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_l1..o_l2).step_by(ob.len) {
f_vec(
&lhs[src_i..src_i + ob.len],
rhs,
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys = lhs[o_l1..o_l2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
for (i, &r) in rhs.iter().enumerate() {
let start = start + i * ob.right_broadcast;
for v in ys[start..start + ob.right_broadcast].iter_mut() {
*v = f(*v, r)
}
}
}
ys
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_r1..o_r2).step_by(ob.len) {
f_vec(
lhs,
&rhs[src_i..src_i + ob.len],
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys = rhs[o_r1..o_r2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
for (i, &l) in lhs.iter().enumerate() {
let start = start + i * ob.right_broadcast;
for v in ys[start..start + ob.right_broadcast].iter_mut() {
*v = f(l, *v)
}
}
}
ys
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
_ => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
vs: &[T],
layout: &Layout,
mut f: F,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
[start_offset..start_offset + len]
.iter()
.map(|&v| f(v))
.collect(),
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let mut result = Vec::with_capacity(layout.shape().elem_count());
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
} else {
for index in block_start_index {
for offset in 0..block_len {
let v = unsafe { vs.get_unchecked(index + offset) };
result.push(f(*v))
}
}
}
result
}
}
}
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
vs: &[T],
layout: &Layout,
mut f: F,
mut f_vec: FV,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
let mut ys: Vec<U> = Vec::with_capacity(len);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(len) };
ys
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let el_count = layout.shape().elem_count();
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
let mut result = Vec::with_capacity(el_count);
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
result
} else {
let mut ys: Vec<U> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
let mut dst_index = 0;
for src_index in block_start_index {
let vs = &vs[src_index..src_index + block_len];
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
f_vec(vs, ys);
dst_index += block_len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
}
}
}

View File

@ -5,41 +5,395 @@ pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{
CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
CudaFunction, CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig,
ValidAsZeroBits,
};
use half::{bf16, f16};
use std::sync::{Arc, Mutex};
#[cfg(feature = "cudnn")]
pub mod cudnn;
mod device;
mod error;
mod utils;
pub use device::{CudaDevice, DeviceId};
pub use error::{CudaError, WrapErr};
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
/// cudarc related errors
#[derive(thiserror::Error, Debug)]
pub enum CudaError {
#[error(transparent)]
Cuda(#[from] cudarc::driver::DriverError),
enum SlicePtrOrNull<T> {
Ptr(CudaSlice<T>),
Null,
#[error(transparent)]
Compiler(#[from] cudarc::nvrtc::CompileError),
#[error(transparent)]
Cublas(#[from] cudarc::cublas::result::CublasError),
#[error(transparent)]
Curand(#[from] cudarc::curand::result::CurandError),
#[error("missing kernel '{module_name}'")]
MissingKernel { module_name: String },
#[error("unsupported dtype {dtype:?} for {op}")]
UnsupportedDtype { dtype: DType, op: &'static str },
#[error("internal error '{0}'")]
InternalError(&'static str),
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Vec<usize>,
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedDType {
msg: &'static str,
expected: DType,
got: DType,
},
#[error("{cuda} when loading {module_name}")]
Load {
cuda: cudarc::driver::DriverError,
module_name: String,
},
}
unsafe impl<T: DeviceRepr> DeviceRepr for &SlicePtrOrNull<T> {
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
match self {
SlicePtrOrNull::Ptr(slice) => slice.as_kernel_param(),
SlicePtrOrNull::Null => 0usize.as_kernel_param(),
}
impl From<CudaError> for crate::Error {
fn from(val: CudaError) -> Self {
crate::Error::Cuda(Box::new(val)).bt()
}
}
impl SlicePtrOrNull<usize> {
fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
let ds = if l.is_contiguous() {
SlicePtrOrNull::Null
} else {
SlicePtrOrNull::Ptr(dev.htod_copy([l.dims(), l.stride()].concat()).w()?)
/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DeviceId(usize);
impl DeviceId {
fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
struct CudaRng(cudarc::curand::CudaRng);
unsafe impl Send for CudaRng {}
#[derive(Clone)]
pub struct CudaDevice {
id: DeviceId,
device: Arc<cudarc::driver::CudaDevice>,
blas: Arc<cudarc::cublas::CudaBlas>,
curand: Arc<Mutex<CudaRng>>,
}
impl std::fmt::Debug for CudaDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CudaDevice({:?})", self.id)
}
}
impl std::ops::Deref for CudaDevice {
type Target = Arc<cudarc::driver::CudaDevice>;
fn deref(&self) -> &Self::Target {
&self.device
}
}
pub trait WrapErr<O> {
fn w(self) -> std::result::Result<O, crate::Error>;
}
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
fn w(self) -> std::result::Result<O, crate::Error> {
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())))
}
}
impl CudaDevice {
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
self.device.clone()
}
pub fn id(&self) -> DeviceId {
self.id
}
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match dtype {
DType::U8 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
let params = (&data, v as u8, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
let params = (&data, v as u32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
let params = (&data, v as i64, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
let params = (&data, bf16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
let params = (&data, f16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
let params = (&data, v as f32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
let params = (&data, v, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F64(data)
}
};
Ok(ds)
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
if !self.has_func(module_name, module_name) {
// Leaking the string here is a bit sad but we need a &'static str and this is only
// done once per kernel name.
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
self.load_ptx(ptx.into(), module_name, &[static_module_name])
.map_err(|cuda| CudaError::Load {
cuda,
module_name: module_name.to_string(),
})
.w()?;
}
self.get_func(module_name, module_name)
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
// able to only build the error value if needed.
.ok_or(CudaError::MissingKernel {
module_name: module_name.to_string(),
})
.w()
}
}
impl BackendDevice for CudaDevice {
type Storage = CudaStorage;
fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
Ok(Self {
id: DeviceId::new(),
device,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
})
}
fn set_seed(&self, seed: u64) -> Result<()> {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
Ok(())
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(),
}
}
fn same_device(&self, rhs: &Self) -> bool {
self.id == rhs.id
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc_zeros::<u8>(elem_count).w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc_zeros::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc_zeros::<i64>(elem_count).w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc_zeros::<f16>(elem_count).w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc_zeros::<f32>(elem_count).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc_zeros::<f64>(elem_count).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F64(data)
}
};
let slice = if lo == 0. && up == 1.0 {
slice
} else {
let layout = Layout::contiguous(shape);
Affine(up - lo, lo).map(&slice, self, &layout)?
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
// curand can only generate an odd number of values.
// https://github.com/huggingface/candle/issues/734
let elem_count_round = if elem_count % 2 == 1 {
elem_count + 1
} else {
elem_count
};
let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
self.const_impl(1., shape, dtype)
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
}
@ -53,6 +407,133 @@ pub enum CudaStorageSlice {
F32(CudaSlice<f32>),
F64(CudaSlice<f64>),
}
type S = CudaStorageSlice;
pub trait Map1 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U8(s) => S::U8(self.f(s, d, l)?),
S::U32(s) => S::U32(self.f(s, d, l)?),
S::I64(s) => S::I64(self.f(s, d, l)?),
S::BF16(s) => S::BF16(self.f(s, d, l)?),
S::F16(s) => S::F16(self.f(s, d, l)?),
S::F32(s) => S::F32(self.f(s, d, l)?),
S::F64(s) => S::F64(self.f(s, d, l)?),
};
Ok(out)
}
}
pub trait Map2 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
};
Ok(out)
}
}
pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()>;
fn map(
&self,
dst: &mut S,
dst_s: &Shape,
src: &S,
src_l: &Layout,
d: &CudaDevice,
) -> Result<()> {
match (dst, src) {
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
}
}
}
pub trait Map1Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
wrap: W,
) -> Result<S>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U8(s) => self.f(s, d, l, S::U8)?,
S::U32(s) => self.f(s, d, l, S::U32)?,
S::I64(s) => self.f(s, d, l, S::I64)?,
S::BF16(s) => self.f(s, d, l, S::BF16)?,
S::F16(s) => self.f(s, d, l, S::F16)?,
S::F32(s) => self.f(s, d, l, S::F32)?,
S::F64(s) => self.f(s, d, l, S::F64)?,
};
Ok(out)
}
}
pub trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<S>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
};
Ok(out)
}
}
struct Clone;
impl Map1 for Clone {
@ -83,7 +564,7 @@ impl Map1 for Affine {
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
@ -115,7 +596,7 @@ impl Map1 for Elu {
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
@ -238,7 +719,7 @@ impl Map1 for Powf {
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
@ -371,7 +852,7 @@ impl<U: UnaryOpT> Map1 for U {
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
@ -921,14 +1402,9 @@ impl<U: crate::op::BinaryOpT> Map2 for U {
let dims = shape.dims();
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
SlicePtrOrNull::Null
} else {
SlicePtrOrNull::Ptr(
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?,
)
};
let dims_and_strides = dev
.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?;
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?;
@ -955,14 +1431,9 @@ impl Map2Any for Cmp {
let dims = shape.dims();
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
SlicePtrOrNull::Null
} else {
SlicePtrOrNull::Ptr(
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?,
)
};
let dims_and_strides = dev
.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?;
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let name = match self.0 {
@ -1070,30 +1541,26 @@ fn gemm_config<T>(
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// The a tensor has dims batching, k, n (rhs)
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
(n as i32, cublasOperation_t::CUBLAS_OP_N)
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
} else if rhs_m1 == k && rhs_m2 == 1 {
(k as i32, cublasOperation_t::CUBLAS_OP_T)
} else {
Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_l.clone(),
rhs_stride: rhs_l.clone(),
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?
};
// The b tensor has dims batching, m, k (lhs)
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
(k as i32, cublasOperation_t::CUBLAS_OP_N)
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
} else if lhs_m1 == m && lhs_m2 == 1 {
(m as i32, cublasOperation_t::CUBLAS_OP_T)
} else {
Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_l.clone(),
rhs_stride: rhs_l.clone(),
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?
};
@ -1114,25 +1581,21 @@ fn gemm_config<T>(
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[_, stride] if lhs_l.dims()[0] == 1 => stride,
[stride, _] if lhs_l.dims()[1] == 1 => stride,
[stride] => stride,
[] => m * k,
_ => Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_l.clone(),
rhs_stride: rhs_l.clone(),
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
};
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[_, stride] if rhs_l.dims()[0] == 1 => stride,
[stride, _] if rhs_l.dims()[1] == 1 => stride,
[stride] => stride,
[] => n * k,
_ => Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_l.clone(),
rhs_stride: rhs_l.clone(),
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
};
@ -1177,7 +1640,7 @@ impl BackendStorage for CudaStorage {
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let start_o = layout.start_offset();
// This returns an i64 rather than a &i64, this is useful to get around some temporary
// lifetime issue and is safe as long as self.slice does not go out of scope before inp
@ -1381,10 +1844,7 @@ impl BackendStorage for CudaStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe {
self.device()
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
@ -1392,7 +1852,7 @@ impl BackendStorage for CudaStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
@ -1449,10 +1909,7 @@ impl BackendStorage for CudaStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe {
self.device()
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
@ -1462,7 +1919,7 @@ impl BackendStorage for CudaStorage {
let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)?
.transpose(1, 3)?;
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
@ -1599,7 +2056,7 @@ impl BackendStorage for CudaStorage {
dim: usize,
) -> Result<Self> {
let device = self.device().clone();
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc)
@ -1614,7 +2071,7 @@ impl BackendStorage for CudaStorage {
dim: usize,
) -> Result<Self> {
let device = self.device().clone();
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc)
@ -1701,11 +2158,6 @@ impl BackendStorage for CudaStorage {
let dev = &self.device;
let d1 = d1 as u32;
let d2 = d2 as u32;
// Nothing to copy so we exit early to avoid launching a kernel and some potential invalid
// argument with a null pointer.
if d1 == 0 || d2 == 0 {
return Ok(());
}
let dst_s = dst_s as u32;
let src_s = src_s as u32;
let (src, dst, kname) = match (&self.slice, &mut dst.slice) {
@ -1763,7 +2215,7 @@ impl BackendStorage for CudaStorage {
}
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
let ds = SlicePtrOrNull::params_from_layout(dev, src_l)?;
let ds = dev.htod_copy([dims, src_l.stride()].concat()).w()?;
match (&self.slice, &mut dst.slice) {
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);

View File

@ -1,410 +0,0 @@
use crate::backend::BackendDevice;
use crate::{CpuStorage, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
use half::{bf16, f16};
use std::sync::{Arc, Mutex};
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DeviceId(usize);
impl DeviceId {
fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
struct CudaRng(cudarc::curand::CudaRng);
unsafe impl Send for CudaRng {}
#[derive(Clone)]
pub struct CudaDevice {
id: DeviceId,
device: Arc<cudarc::driver::CudaDevice>,
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
curand: Arc<Mutex<CudaRng>>,
}
impl std::fmt::Debug for CudaDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CudaDevice({:?})", self.id)
}
}
impl std::ops::Deref for CudaDevice {
type Target = Arc<cudarc::driver::CudaDevice>;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl CudaDevice {
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
self.device.clone()
}
pub fn id(&self) -> DeviceId {
self.id
}
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match dtype {
DType::U8 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
let params = (&data, v as u8, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
let params = (&data, v as u32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
let params = (&data, v as i64, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
let params = (&data, bf16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
let params = (&data, f16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
let params = (&data, v as f32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
let params = (&data, v, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
if !self.has_func(module_name, module_name) {
// Leaking the string here is a bit sad but we need a &'static str and this is only
// done once per kernel name.
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
self.load_ptx(ptx.into(), module_name, &[static_module_name])
.map_err(|cuda| CudaError::Load {
cuda,
module_name: module_name.to_string(),
})
.w()?;
}
self.get_func(module_name, module_name)
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
// able to only build the error value if needed.
.ok_or(CudaError::MissingKernel {
module_name: module_name.to_string(),
})
.w()
}
}
impl BackendDevice for CudaDevice {
type Storage = CudaStorage;
fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
Ok(Self {
id: DeviceId::new(),
device,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
})
}
fn set_seed(&self, seed: u64) -> Result<()> {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
Ok(())
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(),
}
}
fn same_device(&self, rhs: &Self) -> bool {
self.id == rhs.id
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc_zeros::<u8>(elem_count).w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc_zeros::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc_zeros::<i64>(elem_count).w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc_zeros::<f16>(elem_count).w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc_zeros::<f32>(elem_count).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc_zeros::<f64>(elem_count).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F64(data)
}
};
let slice = if lo == 0. && up == 1.0 {
slice
} else {
use super::utils::Map1;
let layout = Layout::contiguous(shape);
super::Affine(up - lo, lo).map(&slice, self, &layout)?
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
// curand can only generate an odd number of values.
// https://github.com/huggingface/candle/issues/734
let elem_count_round = if elem_count % 2 == 1 {
elem_count + 1
} else {
elem_count
};
let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
self.const_impl(1., shape, dtype)
}
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc::<u8>(elem_count).w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc::<i64>(elem_count).w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc::<f16>(elem_count).w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc::<f32>(elem_count).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc::<f64>(elem_count).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
}

View File

@ -1,62 +0,0 @@
use crate::{DType, Layout};
/// cudarc related errors
#[derive(thiserror::Error, Debug)]
pub enum CudaError {
#[error(transparent)]
Cuda(#[from] cudarc::driver::DriverError),
#[error(transparent)]
Compiler(#[from] cudarc::nvrtc::CompileError),
#[error(transparent)]
Cublas(#[from] cudarc::cublas::result::CublasError),
#[error(transparent)]
Curand(#[from] cudarc::curand::result::CurandError),
#[error("missing kernel '{module_name}'")]
MissingKernel { module_name: String },
#[error("unsupported dtype {dtype:?} for {op}")]
UnsupportedDtype { dtype: DType, op: &'static str },
#[error("internal error '{0}'")]
InternalError(&'static str),
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Layout,
rhs_stride: Layout,
mnk: (usize, usize, usize),
},
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedDType {
msg: &'static str,
expected: DType,
got: DType,
},
#[error("{cuda} when loading {module_name}")]
Load {
cuda: cudarc::driver::DriverError,
module_name: String,
},
}
impl From<CudaError> for crate::Error {
fn from(val: CudaError) -> Self {
crate::Error::Cuda(Box::new(val)).bt()
}
}
pub trait WrapErr<O> {
fn w(self) -> std::result::Result<O, crate::Error>;
}
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
fn w(self) -> std::result::Result<O, crate::Error> {
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt())
}
}

View File

@ -1,134 +0,0 @@
/// Helper functions to plug cuda kernels in candle.
use crate::{Layout, Result, Shape, WithDType};
pub use cudarc;
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
use super::{CudaDevice, CudaError, WrapErr};
pub type S = super::CudaStorageSlice;
pub trait Map1 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U8(s) => S::U8(self.f(s, d, l)?),
S::U32(s) => S::U32(self.f(s, d, l)?),
S::I64(s) => S::I64(self.f(s, d, l)?),
S::BF16(s) => S::BF16(self.f(s, d, l)?),
S::F16(s) => S::F16(self.f(s, d, l)?),
S::F32(s) => S::F32(self.f(s, d, l)?),
S::F64(s) => S::F64(self.f(s, d, l)?),
};
Ok(out)
}
}
pub trait Map2 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
};
Ok(out)
}
}
pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()>;
fn map(
&self,
dst: &mut S,
dst_s: &Shape,
src: &S,
src_l: &Layout,
d: &CudaDevice,
) -> Result<()> {
match (dst, src) {
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
}
}
}
pub trait Map1Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
wrap: W,
) -> Result<S>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U8(s) => self.f(s, d, l, S::U8)?,
S::U32(s) => self.f(s, d, l, S::U32)?,
S::I64(s) => self.f(s, d, l, S::I64)?,
S::BF16(s) => self.f(s, d, l, S::BF16)?,
S::F16(s) => self.f(s, d, l, S::F16)?,
S::F32(s) => self.f(s, d, l, S::F32)?,
S::F64(s) => self.f(s, d, l, S::F64)?,
};
Ok(out)
}
}
pub trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<S>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
};
Ok(out)
}
}

View File

@ -1,377 +0,0 @@
use crate::op::{BackpropOp, Op};
use crate::tensor::from_storage;
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
use std::sync::Arc;
/// Unary ops that can be defined in user-land.
pub trait CustomOp1 {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp2 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp3 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_arg3: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
impl Tensor {
/// Applies a unary custom op without backward support
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a binary custom op without backward support
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
let (storage, shape) =
self.storage()
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a ternary custom op without backward support
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c,
)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a unary custom op.
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
let (storage, shape) = self
.storage()
.apply_op1(self.layout(), c.as_ref().as_ref())?;
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
self.apply_op1_arc(Arc::new(Box::new(c)))
}
/// Applies a binary custom op.
pub fn apply_op2_arc(
&self,
rhs: &Self,
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op2(
self.layout(),
&rhs.storage(),
rhs.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
self.apply_op2_arc(r, Arc::new(Box::new(c)))
}
/// Applies a ternary custom op.
pub fn apply_op3_arc(
&self,
t2: &Self,
t3: &Self,
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
Op::CustomOp3(t1, t2, t3, c.clone())
});
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
&self,
t2: &Self,
t3: &Self,
c: C,
) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
}
// In place ops.
/// Unary ops that can be defined in user-land.
/// These ops work in place and as such back-prop is unsupported.
pub trait InplaceOp1 {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
}
pub trait InplaceOp2 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
-> Result<()>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &mut MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<()> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
}
pub trait InplaceOp3 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &mut CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<()>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &mut CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<()> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &mut MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<()> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
}
impl Tensor {
/// Applies a unary custom op in place.
pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
self.storage_mut().inplace_op1(self.layout(), c)
}
/// Applies a unary custom op in place (for the first tensor).
pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
self.storage_mut()
.inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
}
/// Applies a ternary custom op in place (for the first tensor).
pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
self.storage_mut().inplace_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c,
)
}
}

View File

@ -289,34 +289,17 @@ impl Device {
}
}
pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.alloc_uninit(shape, dtype)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.alloc_uninit(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.alloc_uninit(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
Device::Cuda(device) => {
let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage_owned(storage)?;
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage_owned(storage)?;
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
}
@ -327,12 +310,12 @@ impl Device {
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
Device::Cuda(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage_owned(storage)?;
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage_owned(storage)?;
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
}

View File

@ -210,18 +210,10 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -222,18 +222,10 @@ impl crate::backend::BackendDevice for MetalDevice {
Err(Error::NotCompiledWithMetalSupport)
}
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}

View File

@ -14,7 +14,7 @@
//!
//! ## Features
//!
//! - Simple syntax (looks and feels like PyTorch)
//! - Simple syntax (looks and like PyTorch)
//! - CPU and Cuda backends (and M1 support)
//! - Enable serverless (CPU) small and fast deployments
//! - Model training
@ -37,13 +37,14 @@
mod accelerate;
pub mod backend;
pub mod backprop;
pub mod conv;
mod conv;
mod convert;
pub mod cpu;
pub mod cpu_backend;
#[cfg(feature = "cuda")]
pub mod cuda_backend;
mod custom_op;
#[cfg(feature = "cudnn")]
pub mod cudnn;
mod device;
pub mod display;
mod dtype;
@ -57,7 +58,7 @@ pub mod metal_backend;
#[cfg(feature = "mkl")]
mod mkl;
pub mod npy;
pub mod op;
mod op;
pub mod pickle;
pub mod quantized;
pub mod safetensors;
@ -71,16 +72,13 @@ pub mod test_utils;
pub mod utils;
mod variable;
#[cfg(feature = "cudnn")]
pub use cuda_backend::cudnn;
pub use cpu_backend::CpuStorage;
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
pub use dtype::{DType, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
pub use indexer::IndexOp;
pub use layout::Layout;
pub use op::{CustomOp1, CustomOp2, CustomOp3};
pub use shape::{Shape, D};
pub use storage::Storage;
pub use strided_index::{StridedBlocks, StridedIndex};

View File

@ -2,15 +2,14 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels::CallConvTranspose2dCfg;
use candle_metal_kernels;
use candle_metal_kernels::Kernels;
use metal::{Buffer, MTLResourceOptions, NSUInteger};
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::{Arc, Mutex, RwLock, TryLockError};
mod device;
pub use device::{DeviceId, MetalDevice};
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard, TryLockError};
/// Simple way to catch lock error without
/// depending on T
@ -38,6 +37,13 @@ pub enum MetalError {
Message(String),
#[error(transparent)]
KernelError(#[from] candle_metal_kernels::MetalKernelError),
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Vec<usize>,
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
#[error("{0:?}")]
LockError(LockError),
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
@ -54,6 +60,263 @@ impl From<String> for MetalError {
}
}
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
#[derive(Clone)]
pub struct MetalDevice {
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
device: metal::Device,
/// Single command queue for the entire device.
command_queue: CommandQueue,
/// One command buffer at a time.
/// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
/// to start to work).
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there)
command_buffer: Arc<RwLock<CommandBuffer>>,
/// Keeps track of the current amount of compute command encoders on the current
/// command buffer
/// Arc, RwLock because of the interior mutability.
command_buffer_index: Arc<RwLock<usize>>,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
compute_per_buffer: usize,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`]
kernels: Arc<Kernels>,
/// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
/// (could be linked to FFI communication overhead).
///
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
/// to be reused. However, in order for this to work, we need to guarantee the order of
/// operation, so that this buffer is not being used by another kernel at the same time.
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
///
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
/// (strong_count = 1).
buffers: AllocatedBuffers,
/// Seed for random number generation.
seed: Arc<Mutex<Buffer>>,
}
impl std::fmt::Debug for MetalDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalDevice({:?})", self.device.registry_id())
}
}
impl std::ops::Deref for MetalDevice {
type Target = metal::DeviceRef;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl MetalDevice {
pub fn id(&self) -> NSUInteger {
self.registry_id()
}
pub fn metal_device(&self) -> &metal::Device {
&self.device
}
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}
pub fn command_buffer(&self) -> Result<CommandBuffer> {
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
let mut command_buffer = command_buffer_lock.to_owned();
let mut index = self
.command_buffer_index
.try_write()
.map_err(MetalError::from)?;
if *index > self.compute_per_buffer {
command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
*command_buffer_lock = command_buffer.clone();
*index = 0;
self.drop_unused_buffers()?;
}
*index += 1;
Ok(command_buffer)
}
pub fn wait_until_completed(&self) -> Result<()> {
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Already committed");
}
_ => {}
}
command_buffer.commit();
command_buffer.wait_until_completed();
*command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(())
}
pub fn kernels(&self) -> &Kernels {
&self.kernels
}
pub fn device(&self) -> &metal::Device {
&self.device
}
/// Creates a new buffer (not necessarily zeroed).
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
/// This means the buffer data cannot be read on the CPU directly.
///
/// [`name`] is only used to keep track of the resource origin in case of bugs
pub fn new_buffer(
&self,
element_count: usize,
dtype: DType,
name: &str,
) -> Result<Arc<Buffer>> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
}
/// Creates a new buffer (not necessarily zeroed).
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
/// This means the buffer can be read on the CPU but will require manual
/// synchronization when the CPU memory is modified
/// Used as a bridge to gather data back from the GPU
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
}
/// Creates a new buffer from data.
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
///
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
let size = core::mem::size_of_val(data) as NSUInteger;
let new_buffer = self.device.new_buffer_with_data(
data.as_ptr() as *const c_void,
size,
MTLResourceOptions::StorageModeManaged,
);
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeManaged))
.or_insert(vec![]);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
let buffer = self.allocate_buffer(
size_in_bytes as NSUInteger,
MTLResourceOptions::StorageModePrivate,
"allocate_zeros",
)?;
let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder();
blit.fill_buffer(
&buffer,
metal::NSRange {
location: 0,
length: buffer.length(),
},
0,
);
blit.end_encoding();
Ok(buffer)
}
fn find_available_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
buffers: &RwLockWriteGuard<BufferMap>,
) -> Option<Arc<Buffer>> {
let mut best_buffer: Option<&Arc<Buffer>> = None;
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
return best_buffer.map(|b| b.clone());
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
}
/// The critical allocator algorithm
fn allocate_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
_name: &str,
) -> Result<Arc<Buffer>> {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
// Cloning also ensures we increment the strong count
return Ok(b.clone());
}
let size = buf_size(size);
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
/// Create a metal GPU capture trace on [`path`].
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let capture = metal::CaptureManager::shared();
let descriptor = metal::CaptureDescriptor::new();
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
descriptor.set_capture_device(self);
descriptor.set_output_url(path);
capture
.start_capture(&descriptor)
.map_err(MetalError::from)?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MetalStorage {
/// The actual buffer containing the data.
@ -346,41 +609,28 @@ impl BackendStorage for MetalStorage {
let command_buffer = device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U32, DType::F16) => "cast_u32_f16",
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::I64) => "cast_u32_i64",
(DType::U32, DType::U8) => "cast_u32_u8",
(DType::U32, DType::I64) => "cast_u32_i64",
(DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U8, DType::BF16) => "cast_u8_bf16",
(DType::U8, DType::F16) => "cast_u8_f16",
(DType::U8, DType::U32) => "cast_u8_u32",
(DType::U8, DType::F32) => "cast_u8_f32",
(DType::U8, DType::I64) => "cast_u8_i64",
(DType::U8, DType::U32) => "cast_u8_u32",
(DType::U8, DType::BF16) => "cast_u8_bf16",
(DType::F32, DType::BF16) => "cast_f32_bf16",
(DType::F32, DType::F16) => "cast_f32_f16",
(DType::F32, DType::I64) => "cast_f32_i64",
(DType::F32, DType::U32) => "cast_f32_u32",
(DType::F32, DType::U8) => "cast_f32_u8",
(DType::F32, DType::BF16) => "cast_f32_bf16",
(DType::I64, DType::BF16) => "cast_i64_bf16",
(DType::I64, DType::F16) => "cast_i64_f16",
(DType::I64, DType::F32) => "cast_i64_f32",
(DType::I64, DType::U32) => "cast_i64_u32",
(DType::I64, DType::U8) => "cast_i64_u8",
(DType::F16, DType::BF16) => "cast_f16_bf16",
(DType::F16, DType::F32) => "cast_f16_f32",
(DType::F16, DType::I64) => "cast_f16_i64",
(DType::F16, DType::U32) => "cast_f16_u32",
(DType::F16, DType::U8) => "cast_f16_u8",
(DType::BF16, DType::U8) => "cast_bf16_u8",
(DType::BF16, DType::U32) => "cast_bf16_u32",
(DType::BF16, DType::F16) => "cast_bf16_f16",
(DType::BF16, DType::F32) => "cast_bf16_f32",
(DType::BF16, DType::I64) => "cast_bf16_i64",
(DType::BF16, DType::U32) => "cast_bf16_u32",
(DType::BF16, DType::U8) => "cast_bf16_u8",
(left, right) => {
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
@ -443,64 +693,42 @@ impl BackendStorage for MetalStorage {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("uabs", DType::F16) => contiguous::abs::HALF,
("uabs", DType::F32) => contiguous::abs::FLOAT,
("uabs", DType::BF16) => contiguous::abs::BFLOAT,
("uceil", DType::F16) => contiguous::ceil::HALF,
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("uceil", DType::BF16) => contiguous::ceil::BFLOAT,
("ucos", DType::F16) => contiguous::cos::HALF,
("ucos", DType::F32) => contiguous::cos::FLOAT,
("ucos", DType::BF16) => contiguous::cos::BFLOAT,
("uerf", DType::F16) => contiguous::erf::HALF,
("uerf", DType::F32) => contiguous::erf::FLOAT,
("uerf", DType::BF16) => contiguous::erf::BFLOAT,
("uexp", DType::F16) => contiguous::exp::HALF,
("uexp", DType::F32) => contiguous::exp::FLOAT,
("uexp", DType::BF16) => contiguous::exp::BFLOAT,
("ufloor", DType::F16) => contiguous::floor::HALF,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
("ufloor", DType::BF16) => contiguous::floor::BFLOAT,
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT,
("ugelu", DType::F16) => contiguous::gelu::HALF,
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
("ugelu", DType::BF16) => contiguous::gelu::BFLOAT,
("ulog", DType::F16) => contiguous::log::HALF,
("ulog", DType::F32) => contiguous::log::FLOAT,
("ulog", DType::BF16) => contiguous::log::BFLOAT,
("uneg", DType::F16) => contiguous::neg::HALF,
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uneg", DType::BF16) => contiguous::neg::BFLOAT,
("urecip", DType::F16) => contiguous::recip::HALF,
("urecip", DType::F32) => contiguous::recip::FLOAT,
("urecip", DType::BF16) => contiguous::recip::BFLOAT,
("urelu", DType::F16) => contiguous::relu::HALF,
("urelu", DType::F32) => contiguous::relu::FLOAT,
("urelu", DType::BF16) => contiguous::relu::BFLOAT,
("uround", DType::F16) => contiguous::round::HALF,
("uround", DType::F32) => contiguous::round::FLOAT,
("uround", DType::BF16) => contiguous::round::BFLOAT,
("usilu", DType::F16) => contiguous::silu::HALF,
("usilu", DType::F32) => contiguous::silu::FLOAT,
("usilu", DType::BF16) => contiguous::silu::BFLOAT,
("usin", DType::F16) => contiguous::sin::HALF,
("usin", DType::F32) => contiguous::sin::FLOAT,
("usin", DType::BF16) => contiguous::sin::BFLOAT,
("usqr", DType::F16) => contiguous::sqr::HALF,
("usqr", DType::F32) => contiguous::sqr::FLOAT,
("usqr", DType::BF16) => contiguous::sqr::BFLOAT,
("usqrt", DType::F16) => contiguous::sqrt::HALF,
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT,
("utanh", DType::F16) => contiguous::tanh::HALF,
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uexp", DType::F32) => contiguous::exp::FLOAT,
("ulog", DType::F32) => contiguous::log::FLOAT,
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
("uerf", DType::F32) => contiguous::erf::FLOAT,
("usilu", DType::F32) => contiguous::silu::FLOAT,
("uabs", DType::F32) => contiguous::abs::FLOAT,
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
("uround", DType::F32) => contiguous::round::FLOAT,
("urecip", DType::F32) => contiguous::recip::FLOAT,
("utanh", DType::F32) => contiguous::tanh::FLOAT,
("utanh", DType::BF16) => contiguous::tanh::BFLOAT,
("usign", DType::F16) => contiguous::sign::HALF,
("usign", DType::F32) => contiguous::sign::FLOAT,
("usign", DType::BF16) => contiguous::sign::BFLOAT,
("usign", DType::I64) => contiguous::sign::I64,
("urelu", DType::F32) => contiguous::relu::FLOAT,
("ucos", DType::F16) => contiguous::cos::HALF,
("usin", DType::F16) => contiguous::sin::HALF,
("usqr", DType::F16) => contiguous::sqr::HALF,
("usqrt", DType::F16) => contiguous::sqrt::HALF,
("uneg", DType::F16) => contiguous::neg::HALF,
("uexp", DType::F16) => contiguous::exp::HALF,
("ulog", DType::F16) => contiguous::log::HALF,
("ugelu", DType::F16) => contiguous::gelu::HALF,
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
("uerf", DType::F16) => contiguous::erf::HALF,
("usilu", DType::F16) => contiguous::silu::HALF,
("uabs", DType::F16) => contiguous::abs::HALF,
("uceil", DType::F16) => contiguous::ceil::HALF,
("ufloor", DType::F16) => contiguous::floor::HALF,
("uround", DType::F16) => contiguous::round::HALF,
("urecip", DType::F16) => contiguous::recip::HALF,
("utanh", DType::F16) => contiguous::tanh::HALF,
("urelu", DType::F16) => contiguous::relu::HALF,
(name, dtype) => {
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
}
@ -707,50 +935,12 @@ impl BackendStorage for MetalStorage {
fn conv_transpose1d(
&self,
layout: &Layout,
k: &Self,
k_layout: &Layout,
params: &ParamsConvTranspose1D,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConvTranspose1D,
) -> Result<Self> {
let l_out = params.l_out();
let dst_el = params.c_out * l_out * params.b_size;
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "conv_transpose1d_f32",
DType::F16 => "conv_transpose1d_f16",
DType::BF16 => "conv_transpose1d_bf16",
DType::U32 => "conv_transpose1d_u32",
DType::U8 => "conv_transpose1d_u8",
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
};
candle_metal_kernels::call_conv_transpose1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
params.dilation,
params.stride,
params.padding,
params.output_padding,
params.c_out,
l_out,
params.b_size,
layout.dims(),
layout.stride(),
k_layout.dims(),
k_layout.stride(),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&k.buffer,
k_layout.start_offset() * k.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
crate::bail!("Metal conv_transpose1d not implemented")
}
fn conv2d(
@ -781,10 +971,6 @@ impl BackendStorage for MetalStorage {
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "im2col_f32",
DType::F16 => "im2col_f16",
DType::BF16 => "im2col_bf16",
DType::U8 => "im2col_u8",
DType::U32 => "im2col_u32",
dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"),
};
candle_metal_kernels::call_im2col_strided(
@ -837,150 +1023,20 @@ impl BackendStorage for MetalStorage {
fn conv_transpose2d(
&self,
l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &ParamsConvTranspose2D,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConvTranspose2D,
) -> Result<Self> {
// Kernel shape: (c_in_k, c_out, h_k, w_k)
// Input shape: (b_size, c_in, h_in, w_in)
let (out_w, out_h) = (params.out_w(), params.out_h());
let dst_el = params.c_out * out_w * out_h * params.b_size;
let dims = l.dims();
if dims.len() != 4 {
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}, expected 4")
}
let k_dims = kernel_l.dims();
if k_dims.len() != 4 {
crate::bail!("unexpected kernel shape for conv_transpose2d {k_dims:?}, expected 4")
}
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose2d")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "conv_transpose2d_f32",
DType::F16 => "conv_transpose2d_f16",
DType::BF16 => "conv_transpose2d_bf16",
dtype => crate::bail!("Metal conv_transpose2d {dtype:?} not implemented"),
};
candle_metal_kernels::call_conv_transpose2d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
CallConvTranspose2dCfg {
dilation: params.dilation,
stride: params.stride,
padding: params.padding,
output_padding: params.output_padding,
c_out: params.c_out,
out_h,
out_w,
b_size: params.b_size,
input_dims: l.dims(),
input_stride: l.stride(),
kernel_dims: kernel_l.dims(),
kernel_stride: kernel_l.stride(),
input_offset: l.start_offset() * self.dtype.size_in_bytes(),
kernel_offset: kernel_l.start_offset() * kernel.dtype.size_in_bytes(),
},
&self.buffer,
&kernel.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
crate::bail!("Metal conv_tranpose2d not implemented")
}
fn avg_pool2d(
&self,
inp_l: &Layout,
(w_k, h_k): (usize, usize),
(w_stride, h_stride): (usize, usize),
) -> Result<Self> {
let shape = inp_l.shape();
let (b_size, channels, width, height) = shape.dims4()?;
let strides = inp_l.stride();
let name = match self.dtype {
DType::F32 => "avg_pool2d_f32",
DType::F16 => "avg_pool2d_f16",
DType::BF16 => "avg_pool2d_bf16",
DType::U8 => "avg_pool2d_u8",
DType::U32 => "avg_pool2d_u32",
dtype => crate::bail!("Metal avg_pool2d {dtype:?} not implemented"),
};
let out_w = (width - w_k) / w_stride + 1;
let out_h = (height - h_k) / h_stride + 1;
let dst_el = out_w * out_h * b_size * channels;
let buffer = self.device.new_buffer(dst_el, self.dtype, "avg_pool2d")?;
let command_buffers = self.device.command_buffer()?;
candle_metal_kernels::call_pool2d(
&self.device.device,
&command_buffers,
&self.device.kernels,
name,
inp_l.dims(),
strides,
out_w,
out_h,
w_k,
h_k,
w_stride,
h_stride,
&self.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
crate::bail!("Metal avg_pool2d not implemented")
}
fn max_pool2d(
&self,
inp_l: &Layout,
(w_k, h_k): (usize, usize),
(w_stride, h_stride): (usize, usize),
) -> Result<Self> {
let shape = inp_l.shape();
let (b_size, channels, width, height) = shape.dims4()?;
let strides = inp_l.stride();
let name = match self.dtype {
DType::F32 => "max_pool2d_f32",
DType::F16 => "max_pool2d_f16",
DType::BF16 => "max_pool2d_bf16",
DType::U8 => "max_pool2d_u8",
DType::U32 => "max_pool2d_u32",
dtype => crate::bail!("Metal max_pool2d {dtype:?} not implemented"),
};
let out_w = (width - w_k) / w_stride + 1;
let out_h = (height - h_k) / h_stride + 1;
let dst_el = out_w * out_h * b_size * channels;
let buffer = self.device.new_buffer(dst_el, self.dtype, "max_pool2d")?;
let command_buffers = self.device.command_buffer()?;
candle_metal_kernels::call_pool2d(
&self.device.device,
&command_buffers,
&self.device.kernels,
name,
inp_l.dims(),
strides,
out_w,
out_h,
w_k,
h_k,
w_stride,
h_stride,
&self.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
crate::bail!("Metal max_pool2d not implemented")
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
@ -997,10 +1053,6 @@ impl BackendStorage for MetalStorage {
}
let name = match self.dtype {
DType::F32 => "upsample_nearest2d_f32",
DType::F16 => "upsample_nearest2d_f16",
DType::BF16 => "upsample_nearest2d_bf16",
DType::U8 => "upsample_nearest2d_u8",
DType::U32 => "upsample_nearest2d_u32",
dtype => crate::bail!("Metal upsample_nearest2d {dtype:?} not implemented"),
};
@ -1115,8 +1167,12 @@ impl BackendStorage for MetalStorage {
}
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
if !ids_l.is_contiguous() {
crate::bail!("Metal index_select requires contiguous ids")
if !(src_l.is_contiguous()
&& src_l.start_offset() == 0
&& ids_l.is_contiguous()
&& ids_l.start_offset() == 0)
{
crate::bail!("Metal strided index_select not implemented");
}
let left_size: usize = src_l.dims()[..dim].iter().product();
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
@ -1127,17 +1183,11 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::BF16) => "is_u8_bf16",
(DType::U8, DType::F32) => "is_u8_f32",
(DType::U8, DType::F16) => "is_u8_f16",
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
(DType::U32, DType::BF16) => "is_u32_bf16",
(DType::I64, DType::F32) => "is_i64_f32",
(DType::I64, DType::F16) => "is_i64_f16",
(DType::I64, DType::BF16) => "is_i64_bf16",
(left, right) => {
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
}
@ -1151,13 +1201,8 @@ impl BackendStorage for MetalStorage {
src_l.dims(),
ids_el,
dim,
src_l.is_contiguous(),
src_l.dims(),
src_l.stride(),
&self.buffer,
src_l.start_offset() * dtype.size_in_bytes(),
&ids.buffer,
ids_l.start_offset() * ids.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
@ -1184,29 +1229,9 @@ impl BackendStorage for MetalStorage {
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
};
let name = match (ids.dtype, self.dtype) {
(DType::I64, DType::BF16) => "ia_i64_bf16",
(DType::I64, DType::F16) => "ia_i64_f16",
(DType::I64, DType::F32) => "ia_i64_f32",
(DType::I64, DType::I64) => "ia_i64_i64",
(DType::I64, DType::U32) => "ia_i64_u32",
(DType::I64, DType::U8) => "ia_i64_u8",
(DType::U32, DType::BF16) => "ia_u32_bf16",
(DType::U32, DType::F16) => "ia_u32_f16",
(DType::U32, DType::F32) => "ia_u32_f32",
(DType::U32, DType::I64) => "ia_u32_i64",
(DType::U32, DType::U32) => "ia_u32_u32",
(DType::U32, DType::U8) => "ia_u32_u8",
(DType::U8, DType::BF16) => "ia_u8_bf16",
(DType::U8, DType::F16) => "ia_u8_f16",
(DType::U8, DType::F32) => "ia_u8_f32",
(DType::U8, DType::I64) => "ia_u8_i64",
(DType::U8, DType::U32) => "ia_u8_u32",
(DType::U8, DType::U8) => "ia_u8_u8",
_ => Err(MetalError::UnexpectedDType {
msg: "index-add ids should be u8/u32/i64",
msg: "index-add ids should be u32",
expected: DType::U32,
got: ids.dtype(),
})?,
@ -1639,7 +1664,6 @@ impl BackendDevice for MetalDevice {
MTLResourceOptions::StorageModeManaged,
)));
Ok(Self {
id: DeviceId::new(),
device,
command_queue,
command_buffer,
@ -1658,17 +1682,7 @@ impl BackendDevice for MetalDevice {
}
fn same_device(&self, rhs: &Self) -> bool {
self.id == rhs.id
}
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-uninit")?;
Ok(MetalStorage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
self.device.registry_id() == rhs.device.registry_id()
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
@ -1706,10 +1720,6 @@ impl BackendDevice for MetalDevice {
))
}
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<Self::Storage> {
self.storage_from_cpu_storage(&storage)
}
fn rand_uniform(
&self,
shape: &Shape,
@ -1798,6 +1808,10 @@ impl BackendDevice for MetalDevice {
}
}
fn buf_size(size: NSUInteger) -> NSUInteger {
(size - 1).next_power_of_two() as NSUInteger
}
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
assert!(!ptr.is_null());

View File

@ -1,287 +0,0 @@
use crate::{DType, Result};
use candle_metal_kernels::Kernels;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
use super::MetalError;
/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DeviceId(usize);
impl DeviceId {
pub(crate) fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
#[derive(Clone)]
pub struct MetalDevice {
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
/// the device itself.
pub(crate) id: DeviceId,
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
pub(crate) device: metal::Device,
/// Single command queue for the entire device.
pub(crate) command_queue: CommandQueue,
/// One command buffer at a time.
/// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
/// to start to work).
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there)
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
/// Keeps track of the current amount of compute command encoders on the current
/// command buffer
/// Arc, RwLock because of the interior mutability.
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
pub(crate) compute_per_buffer: usize,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`]
pub(crate) kernels: Arc<Kernels>,
/// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
/// (could be linked to FFI communication overhead).
///
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
/// to be reused. However, in order for this to work, we need to guarantee the order of
/// operation, so that this buffer is not being used by another kernel at the same time.
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
///
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
/// (strong_count = 1).
pub(crate) buffers: AllocatedBuffers,
/// Seed for random number generation.
pub(crate) seed: Arc<Mutex<Buffer>>,
}
impl std::fmt::Debug for MetalDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalDevice({:?})", self.id)
}
}
impl std::ops::Deref for MetalDevice {
type Target = metal::DeviceRef;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl MetalDevice {
pub fn id(&self) -> DeviceId {
self.id
}
pub fn metal_device(&self) -> &metal::Device {
&self.device
}
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}
pub fn command_buffer(&self) -> Result<CommandBuffer> {
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
let mut command_buffer = command_buffer_lock.to_owned();
let mut index = self
.command_buffer_index
.try_write()
.map_err(MetalError::from)?;
if *index > self.compute_per_buffer {
command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
*command_buffer_lock = command_buffer.clone();
*index = 0;
self.drop_unused_buffers()?;
}
*index += 1;
Ok(command_buffer)
}
pub fn wait_until_completed(&self) -> Result<()> {
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Already committed");
}
_ => {}
}
command_buffer.commit();
command_buffer.wait_until_completed();
*command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(())
}
pub fn kernels(&self) -> &Kernels {
&self.kernels
}
pub fn device(&self) -> &metal::Device {
&self.device
}
/// Creates a new buffer (not necessarily zeroed).
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
/// This means the buffer data cannot be read on the CPU directly.
///
/// [`name`] is only used to keep track of the resource origin in case of bugs
pub fn new_buffer(
&self,
element_count: usize,
dtype: DType,
name: &str,
) -> Result<Arc<Buffer>> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
}
/// Creates a new buffer (not necessarily zeroed).
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
/// This means the buffer can be read on the CPU but will require manual
/// synchronization when the CPU memory is modified
/// Used as a bridge to gather data back from the GPU
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
}
/// Creates a new buffer from data.
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
///
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
let size = core::mem::size_of_val(data) as NSUInteger;
let new_buffer = self.device.new_buffer_with_data(
data.as_ptr() as *const c_void,
size,
MTLResourceOptions::StorageModeManaged,
);
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeManaged))
.or_insert(vec![]);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
let buffer = self.allocate_buffer(
size_in_bytes as NSUInteger,
MTLResourceOptions::StorageModePrivate,
"allocate_zeros",
)?;
let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder();
blit.fill_buffer(
&buffer,
metal::NSRange {
location: 0,
length: buffer.length(),
},
0,
);
blit.end_encoding();
Ok(buffer)
}
fn find_available_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
buffers: &RwLockWriteGuard<BufferMap>,
) -> Option<Arc<Buffer>> {
let mut best_buffer: Option<&Arc<Buffer>> = None;
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
best_buffer.cloned()
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
}
/// The critical allocator algorithm
fn allocate_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
_name: &str,
) -> Result<Arc<Buffer>> {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
// Cloning also ensures we increment the strong count
return Ok(b.clone());
}
let size = buf_size(size);
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
/// Create a metal GPU capture trace on [`path`].
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let capture = metal::CaptureManager::shared();
let descriptor = metal::CaptureDescriptor::new();
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
descriptor.set_capture_device(self);
descriptor.set_output_url(path);
capture
.start_capture(&descriptor)
.map_err(MetalError::from)?;
Ok(())
}
}
fn buf_size(size: NSUInteger) -> NSUInteger {
(size - 1).next_power_of_two() as NSUInteger
}

View File

@ -1,5 +1,5 @@
#![allow(clippy::redundant_closure_call)]
use crate::Tensor;
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
use half::{bf16, f16};
use num_traits::float::Float;
@ -66,7 +66,6 @@ pub enum UnaryOp {
Floor,
Ceil,
Round,
Sign,
}
#[derive(Clone)]
@ -162,23 +161,168 @@ pub enum Op {
Permute(Tensor, Vec<usize>),
Elu(Tensor, f64),
Powf(Tensor, f64),
CustomOp1(
Tensor,
std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
),
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
CustomOp2(
Tensor,
Tensor,
std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
),
CustomOp3(
Tensor,
Tensor,
Tensor,
std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
),
}
/// Unary ops that can be defined in user-land.
pub trait CustomOp1 {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp2 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp3 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_arg3: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait UnaryOpT {
const NAME: &'static str;
const KERNEL: &'static str;
@ -255,7 +399,6 @@ pub(crate) struct Tanh;
pub(crate) struct Floor;
pub(crate) struct Ceil;
pub(crate) struct Round;
pub(crate) struct Sign;
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
@ -459,13 +602,6 @@ unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
// Hardcode the value for sqrt(2/pi)
// https://github.com/huggingface/candle/issues/1982
#[allow(clippy::excessive_precision)]
const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
#[allow(clippy::excessive_precision)]
const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
/// Tanh based approximation of the `gelu` operation
/// GeluErf is the more precise one.
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
@ -478,7 +614,7 @@ impl UnaryOpT for Gelu {
* v
* (bf16::ONE
+ bf16::tanh(
bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
* v
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
))
@ -489,18 +625,22 @@ impl UnaryOpT for Gelu {
* v
* (f16::ONE
+ f16::tanh(
f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
(f16::from_f32_const(2.0) / f16::PI).sqrt()
* v
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
0.5 * v
* (1.0
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
#[inline(always)]
fn f64(v: f64) -> f64 {
0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
0.5 * v
* (1.0
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
#[inline(always)]
fn u8(_: u8) -> u8 {
@ -927,37 +1067,3 @@ impl std::ops::Deref for BackpropOp {
&self.0
}
}
impl UnaryOpT for Sign {
const NAME: &'static str = "sign";
const KERNEL: &'static str = "usign";
const V: Self = Sign;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
}
#[inline(always)]
fn f32(v: f32) -> f32 {
f32::from(v > 0.) - f32::from(v < 0.)
}
#[inline(always)]
fn f64(v: f64) -> f64 {
f64::from(v > 0.) - f64::from(v < 0.)
}
#[inline(always)]
fn u8(v: u8) -> u8 {
u8::min(1, v)
}
#[inline(always)]
fn u32(v: u32) -> u32 {
u32::min(1, v)
}
#[inline(always)]
fn i64(v: i64) -> i64 {
(v > 0) as i64 - (v < 0) as i64
}
}

View File

@ -1,62 +1,22 @@
use super::{GgmlDType, QStorage};
use crate::quantized::k_quants::GgmlType;
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result};
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
use cudarc::driver::{CudaSlice, DeviceSlice};
#[derive(Clone, Debug)]
pub struct QCudaStorage {
data: CudaSlice<u8>,
dtype: GgmlDType,
device: CudaDevice,
}
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
pub fn set_force_dmmv(f: bool) {
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
}
pub const WARP_SIZE: usize = 32;
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
pub const NWARPS_Q4_0_AMPERE: usize = 4;
pub const GGML_CUDA_MMV_X: usize = 32;
pub const GGML_CUDA_MMV_Y: usize = 1;
pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256;
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
pub const MATRIX_ROW_PADDING: usize = 512;
fn ceil_div(p: usize, q: usize) -> usize {
(p + q - 1) / q
}
fn pad(p: usize, q: usize) -> usize {
ceil_div(p, q) * q
}
fn quantize_q8_1(
src: &CudaView<f32>,
dst: &mut CudaSlice<u8>,
elem_count: usize,
dev: &CudaDevice,
) -> Result<()> {
use cudarc::driver::LaunchAsync;
let kx = elem_count;
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
shared_mem_bytes: 0,
};
let params = (src, dst, kx as i32, kx_padded as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(())
}
fn dequantize(
data: &CudaSlice<u8>,
@ -70,18 +30,26 @@ fn dequantize(
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
GgmlDType::Q5_0 => (
"dequantize_block_q5_0",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q5_1 => (
"dequantize_block_q5_1",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q5_0 => {
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
(
"dequantize_block_q5_0",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
nb,
)
}
GgmlDType::Q5_1 => {
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
(
"dequantize_block_q5_1",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
nb,
)
}
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
@ -92,7 +60,7 @@ fn dequantize(
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
@ -115,9 +83,9 @@ fn dequantize(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_mul_mat_vec(
fn dequantize_mut_mal_vec(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
y: &cudarc::driver::CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
@ -125,13 +93,6 @@ fn dequantize_mul_mat_vec(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
if y.len() != ncols {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
}
let kernel_name = match dtype {
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
@ -146,8 +107,8 @@ fn dequantize_mul_mat_vec(
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
let dst = dev.alloc_zeros::<f32>(nrows).w()?;
let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (block_num_y as u32, 1, 1),
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
@ -159,66 +120,9 @@ fn dequantize_mul_mat_vec(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn mul_mat_vec_via_q8_1(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
if y.len() != ncols {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
}
// Start by quantizing y
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, ncols, dev)?;
let kernel_name = match dtype {
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda",
GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda",
GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda",
GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda",
GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda",
GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda",
GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda",
GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda",
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (nrows as u32, 1, 1),
block_dim: (WARP_SIZE as u32, 4, 1),
shared_mem_bytes: 0,
};
let params = (
data,
&y_q8_1,
&dst,
/* ncols_x */ ncols as i32,
/* nrows_x */ nrows as i32,
/* nrows_y */ ncols as i32,
/* nrows_dst */ nrows as i32,
);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
let size_in_bytes = el_count * dtype.type_size() / dtype.block_size();
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
Ok(QCudaStorage {
data,
@ -236,12 +140,6 @@ impl QCudaStorage {
}
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
fn deq<T: GgmlType>(buffer: &[u8], n: usize, dst: &mut [f32]) -> Result<()> {
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
let vec = slice.to_vec();
T::to_float(&vec, dst)
}
let fast_kernel = matches!(
self.dtype,
GgmlDType::Q4_0
@ -260,25 +158,69 @@ impl QCudaStorage {
return dequantize(&self.data, self.dtype, elem_count, self.device());
}
// Run the dequantization on cpu.
use crate::quantized::k_quants::GgmlType;
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q5_1 => deq::<crate::quantized::BlockQ5_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q8_0 => deq::<crate::quantized::BlockQ8_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q8_1 => deq::<crate::quantized::BlockQ8_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q2K => deq::<crate::quantized::BlockQ2K>(&buffer, block_len, &mut out)?,
GgmlDType::Q3K => deq::<crate::quantized::BlockQ3K>(&buffer, block_len, &mut out)?,
GgmlDType::Q4K => deq::<crate::quantized::BlockQ4K>(&buffer, block_len, &mut out)?,
GgmlDType::Q5K => deq::<crate::quantized::BlockQ5K>(&buffer, block_len, &mut out)?,
GgmlDType::Q6K => deq::<crate::quantized::BlockQ6K>(&buffer, block_len, &mut out)?,
GgmlDType::Q8K => deq::<crate::quantized::BlockQ8K>(&buffer, block_len, &mut out)?,
GgmlDType::F32 => {
let slice =
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, block_len) };
out.copy_from_slice(slice)
}
GgmlDType::F16 => {
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
half::f16::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_0 => {
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_1 => {
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_0 => {
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_1 => {
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_0 => {
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_1 => {
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q2K => {
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
}
GgmlDType::Q3K => {
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
}
GgmlDType::Q4K => {
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
}
GgmlDType::Q5K => {
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
}
GgmlDType::Q6K => {
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
}
GgmlDType::Q8K => {
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
}
}
self.device
@ -343,11 +285,8 @@ impl QCudaStorage {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
}
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
} else {
mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
};
let out =
dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?;
let out_shape = if with_batch {
vec![1, 1, nrows]
} else {
@ -383,6 +322,11 @@ impl QCudaStorage {
}
}
fn read_to_vec<T: Clone>(buffer: &[u8], n: usize) -> Vec<T> {
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
slice.to_vec()
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
device: &CudaDevice,
data: &[T],
@ -397,60 +341,3 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
dtype: T::DTYPE,
}))
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn cuda_quantize_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let el = 256;
let el_padded = pad(el, MATRIX_ROW_PADDING);
let y_size_in_bytes =
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, &dev)?;
Ok(())
}
#[test]
fn cuda_mmv_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_vec_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
// Q8 means 1/256 precision.
assert_eq!(vs[0], 5561664.5);
let cuda_storage = dequantize_mul_mat_vec(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
assert_eq!(vs[0], 5561851.0);
Ok(())
}
}

View File

@ -171,7 +171,7 @@ impl Shape {
}
let mut acc = 1;
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
if dim > 1 && stride != acc {
if stride != acc {
return false;
}
acc *= dim;
@ -186,7 +186,7 @@ impl Shape {
}
let mut acc = 1;
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
if dim > 1 && stride != acc {
if stride != acc {
return false;
}
acc *= dim;

View File

@ -1,7 +1,6 @@
use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, ReduceOp};
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
@ -44,19 +43,9 @@ impl Storage {
}
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
let lhs_device = self.device();
let rhs_device = rhs.device();
let lhs = lhs_device.location();
let rhs = rhs_device.location();
let same_device = if self.device().is_metal() {
// On metal, we require the device to be exactly the same rather than
// having the same location. In cuda this is not necessary as all CudaDevice on the
// same GPU will use the same cuda stream.
lhs_device.same_device(&rhs_device)
} else {
lhs == rhs
};
if !same_device {
let lhs = self.device().location();
let rhs = rhs.device().location();
if lhs != rhs {
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
} else {
Ok(())
@ -263,51 +252,6 @@ impl Storage {
}
}
pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
match self {
Self::Cpu(storage) => c.cpu_fwd(storage, l),
Self::Cuda(storage) => c.cuda_fwd(storage, l),
Self::Metal(storage) => c.metal_fwd(storage, l),
}
}
pub(crate) fn inplace_op2(
&mut self,
l1: &Layout,
t2: &Self,
l2: &Layout,
c: &dyn InplaceOp2,
) -> Result<()> {
self.same_device(t2, c.name())?;
match (self, t2) {
(Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
(Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
(Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
_ => unreachable!(),
}
}
pub(crate) fn inplace_op3(
&mut self,
l1: &Layout,
t2: &Self,
l2: &Layout,
t3: &Self,
l3: &Layout,
c: &dyn InplaceOp3,
) -> Result<()> {
self.same_device(t2, c.name())?;
self.same_device(t3, c.name())?;
match (self, t2, t3) {
(Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
(Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
c.metal_fwd(s1, l1, s2, l2, s3, l3)
}
_ => unreachable!(),
}
}
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
match self {
Storage::Cpu(storage) => {

View File

@ -1,7 +1,9 @@
//! Tensors are N-dimensional matrixes of elements using a single data type.
#![allow(clippy::redundant_closure_call)]
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
use crate::op::{
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
};
use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
@ -510,7 +512,6 @@ impl Tensor {
unary_op!(ceil, Ceil);
unary_op!(floor, Floor);
unary_op!(round, Round);
unary_op!(sign, Sign);
/// Round element of the input tensor to the nearest integer.
///
@ -1350,7 +1351,7 @@ impl Tensor {
}
.bt())?
}
let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let offset = start * src.dims()[1..].iter().product::<usize>();
@ -2000,7 +2001,7 @@ impl Tensor {
Ok(self.clone())
} else {
let shape = self.shape();
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
let mut storage = self.device().zeros(shape, self.dtype())?;
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let op = BackpropOp::new1(self, Op::Copy);
@ -2008,21 +2009,11 @@ impl Tensor {
}
}
/// Returns a tensor that is in row major order. This always makes a copy.
pub fn force_contiguous(&self) -> Result<Tensor> {
let shape = self.shape();
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let op = BackpropOp::new1(self, Op::Copy);
Ok(from_storage(storage, shape.clone(), op, false))
}
/// Create a variable based on the values currently stored in a tensor. The storage is always
/// copied.
pub(crate) fn make_var(&self) -> Result<Tensor> {
let shape = self.shape().clone();
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
let mut storage = self.device().zeros(&shape, self.dtype())?;
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(storage, shape, BackpropOp::none(), true))
@ -2075,7 +2066,7 @@ impl Tensor {
};
Ok(Tensor(Arc::new(tensor_)))
} else {
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
let mut storage = self.device().zeros(&shape, self.dtype())?;
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(storage, shape, op, false))
@ -2102,19 +2093,8 @@ impl Tensor {
let dim = dim.to_index(self.shape(), "squeeze")?;
if dims[dim] == 1 {
let mut dims = dims.to_vec();
let mut strides = self.stride().to_vec();
dims.remove(dim);
strides.remove(dim);
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
op: BackpropOp::new1(self, Op::Reshape),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
self.reshape(dims)
} else {
Ok(self.clone())
}
@ -2135,24 +2115,10 @@ impl Tensor {
/// ```
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
let mut dims = self.dims().to_vec();
let mut strides = self.stride().to_vec();
let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
// Cannot panic because to_index_plus_one already checks dimensions
dims.insert(dim, 1);
// Any stride would work here, but we pick one so as to maximize the probability to remain
// C contiguous.
let stride = if dim < strides.len() { strides[dim] } else { 1 };
strides.insert(dim, stride);
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
op: BackpropOp::new1(self, Op::Reshape),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
self.reshape(dims)
}
/// Stacks two or more tensors along a particular dimension.
@ -2265,10 +2231,6 @@ impl Tensor {
self.storage.read().unwrap()
}
pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
self.storage.write().unwrap()
}
// If we extend the visibility of this function to be usable outside of this crate, we should
// make it unsafe.
pub(crate) fn storage_mut_and_layout(
@ -2290,6 +2252,96 @@ impl Tensor {
std::ptr::eq(lhs, rhs)
}
/// Applies a unary custom op without backward support
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a binary custom op without backward support
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
let (storage, shape) =
self.storage()
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a ternary custom op without backward support
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c,
)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a unary custom op.
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
let (storage, shape) = self
.storage()
.apply_op1(self.layout(), c.as_ref().as_ref())?;
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
self.apply_op1_arc(Arc::new(Box::new(c)))
}
/// Applies a binary custom op.
pub fn apply_op2_arc(
&self,
rhs: &Self,
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op2(
self.layout(),
&rhs.storage(),
rhs.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
self.apply_op2_arc(r, Arc::new(Box::new(c)))
}
/// Applies a ternary custom op.
pub fn apply_op3_arc(
&self,
t2: &Self,
t3: &Self,
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
Op::CustomOp3(t1, t2, t3, c.clone())
});
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
&self,
t2: &Self,
t3: &Self,
c: C,
) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
/// Normalize a 'relative' axis value: positive values are kept, negative
/// values means counting the dimensions from the back.
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {

View File

@ -58,18 +58,20 @@ impl Tensor {
}
}
}
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
if all_contiguous {
Self::cat_contiguous(args, dim)
} else if dim == 0 {
if dim == 0 {
Self::cat0(args)
} else {
let args: Vec<Tensor> = args
.iter()
.map(|a| a.as_ref().transpose(0, dim))
.collect::<Result<Vec<_>>>()?;
let cat = Self::cat0(&args)?;
cat.transpose(0, dim)
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
if all_contiguous {
Self::cat_contiguous(args, dim)
} else {
let args: Vec<Tensor> = args
.iter()
.map(|a| a.as_ref().transpose(0, dim))
.collect::<Result<Vec<_>>>()?;
let cat = Self::cat0(&args)?;
cat.transpose(0, dim)
}
}
}
@ -139,7 +141,7 @@ impl Tensor {
}
let shape = Shape::from(cat_dims);
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
let mut storage = device.zeros(&shape, dtype)?;
for (arg, &offset) in args.iter().zip(offsets.iter()) {
let arg = arg.as_ref();
arg.storage()
@ -213,7 +215,7 @@ impl Tensor {
let block_size: usize = cat_dims.iter().skip(1 + dim).product();
let shape = Shape::from(cat_dims);
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
let mut storage = device.zeros(&shape, dtype)?;
let mut dst_o = 0;
for arg in args.iter() {
let arg = arg.as_ref();

View File

@ -54,6 +54,11 @@ fn conv1d(dev: &Device) -> Result<()> {
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
// conv-transposes are not implemented for metal.
if dev.is_metal() {
return Ok(());
}
let w = w.transpose(0, 1)?;
// The CPU kernels applied in the contiguous and non contiguous cases are different.
for w in [w.clone(), w.contiguous()?] {
@ -135,7 +140,7 @@ fn conv2d(dev: &Device) -> Result<()> {
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
],
dev,
)?;
@ -163,34 +168,33 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7, 7]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
[
if !dev.is_metal() {
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7, 7]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
[
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
],
[
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
[
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
],
[
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
]
]
]
);
);
}
// Dilations.
let res = t.conv2d(&w, 0, 1, 2, 1)?;
assert_eq!(res.dims(), [1, 2, 1, 1]);
@ -199,37 +203,44 @@ fn conv2d(dev: &Device) -> Result<()> {
[2.45, -2.3504],
);
// Transpose and dilations.
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
assert_eq!(res.dims(), [1, 2, 9, 9]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
[
if !dev.is_metal() {
// Transpose and dilations.
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
assert_eq!(res.dims(), [1, 2, 9, 9]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
[
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
[-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
],
[
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
[-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
[
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
[
-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51,
-3.5024
],
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
],
[
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
[
-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827,
1.0171
]
]
]
]
);
);
}
Ok(())
}
@ -276,13 +287,19 @@ fn conv2d_small(dev: &Device) -> Result<()> {
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1640,
-0.0111, -0.1742, 0.0, 0.0, 0.0, 0.0, 2.6437, -2.0268, 1.1823, 0.0, 0.0, 0.0, 0.0,
3.2855, -1.0324, 0.2539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000,
0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855,
-1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
]
);
// conv-transposes are not implemented for metal
if dev.is_metal() {
return Ok(());
}
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
@ -385,6 +402,9 @@ print(w.grad[0])
*/
fn conv2d_grad(dev: &Device) -> Result<()> {
// conv-transposes are not implemented for metal
if dev.is_metal() {
return Ok(());
}
use candle_core::Var;
let t = Var::from_slice(
&[
@ -397,7 +417,7 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
],
(1, 4, 5, 5),
dev,
@ -582,154 +602,6 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
]
);
// Conv Transpose 2d Test
//tested against following python
// import torch
// torch.manual_seed(4242)
// padding = 4
// outpadding = 2
// dilation = 3
// stride = 3
// input = torch.randn((1, 4, 7, 5), requires_grad=True)
// kernel = torch.randn((4, 2, 3, 5), requires_grad=True)
// print("input", input.flatten())
// print("kernel", kernel.flatten())
// res = torch.nn.functional.conv_transpose2d(
// input,
// kernel,
// stride=stride,
// padding=padding,
// dilation=dilation,
// output_padding=outpadding,
// )
// res.retain_grad()
// print(res.shape)
// loss = (res**2).sum()
// print(loss)
// loss.backward()
// print(input.grad.shape)
// print("input grad", torch.round(input.grad, decimals=1))
// print(kernel.grad.shape)
// print("kernel grad", torch.round(kernel.grad.flatten(), decimals=1))
let padding = 4;
let outpadding = 2;
let dilation = 3;
let stride = 3;
let t = Var::from_slice(
&[
0.4056_f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997,
3.0616, 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843,
0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013,
-0.6836, 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130,
1.3123, 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071,
1.1586, 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090,
0.2049, 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323,
-1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742,
0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.5411, -1.1715, -2.4997, 2.3249, -0.8912,
-0.4733, -0.5701, -2.8888, -1.4112, -0.5471, -0.9234, -1.1660, 0.4189, -0.7465,
-0.6473, 0.1402, 0.7875, 0.5377, -0.6779, -0.8088, -0.4864, -0.2312, 0.9279, 0.1264,
1.5480, 0.8265, -0.1025, 0.5138, -0.2512, 0.1576, 1.2705, 0.3641, -0.9325, 0.6451,
-0.8537, 0.2378, 0.1794, 0.2752, -0.3687, -1.1149, -0.1410, -0.5829, -0.0892, 1.4258,
-2.2789, 0.5270, 0.1825, 1.7007, -0.5263, -0.2954, 0.4440, 0.5537, 0.3492, 0.6186,
1.6475, 0.2219,
],
(1, 4, 7, 5),
dev,
)?;
#[rustfmt::skip]
let w = Var::from_slice(
&[
-1.1744_f32, 0.3266, 2.5893, 1.0142, 0.1763, 0.7752, 0.6604, 0.2029, -0.2145, 0.7234,
-0.3441, -1.5400, -0.6333, 0.6613, 0.2083, 0.6230, -1.7002, 0.3393, 0.4049, 1.0762,
0.2723, 1.4181, 0.0029, -0.2122, 1.7668, 1.4168, 0.3320, -0.2719, 0.7932, -0.7204,
0.4447, 0.1211, 0.5908, 1.0089, -0.1646, 1.8033, -0.6286, 0.2016, -0.3370, 1.2555,
0.8009, -0.6488, -0.4652, -1.5685, 1.5860, 0.5583, 0.4623, 0.6026, 0.8828, 2.4990,
0.6811, -0.3369, 1.3320, 1.7669, -1.1067, 1.2958, -0.9415, -0.9655, -0.4462, 0.7181,
0.5181, -1.1658, -1.8467, -0.7763, 1.2769, 0.8651, 0.9890, 1.5092, 0.7207, -0.8481,
0.7417, 0.3375, -1.2685, 1.4572, 1.0915, 0.1093, -0.8550, -0.5831, -0.6309, -0.2509,
0.5220, -0.0914, 0.7900, 0.1096, 0.3258, 0.2723, -1.0942, -0.3393, -0.1653, 0.5732,
-0.8014, 1.8194, -1.9023, 0.2127, 1.8636, -0.8979, 0.1927, -0.2778, 0.3105, 0.0071,
-1.1823, 0.2476, -0.7178, -1.3821, 1.0769, -0.4376, -0.9967, -0.1227, 1.6197, -1.0604,
0.1372, 0.8141, -0.6163, 0.7304, -0.8285, 2.0636, -0.7176, 0.2495, -0.2581, -0.4478,
],
(4, 2, 3, 5),
dev,
)?;
let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 2904.0);
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 7, 5]);
assert_eq!(grad_w.dims(), [4, 2, 3, 5]);
assert_eq!(
test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?,
[
// torch gets 89.1
-89.0, -135.3, 136.7, 102.0, -53.4, 117.9, 118.6, -43.9, -218.0, -58.5, -114.3, -150.0,
-15.6, 172.1, 66.3, -64.3, -27.9, -19.8, 31.7, 62.1, 5.5, 92.6, 28.2, -29.6, 55.9,
52.7, -72.7, -119.8, 53.8, -25.5, 128.8, 19.3, 68.0, 190.9, -64.1, -86.2, -111.2,
106.6, -67.7, 37.8, 115.9, 50.4, -77.7, -54.9, 22.3, -4.6, 89.8, 61.7, 122.4, 192.6,
-27.8, -104.6, 57.0, 166.4, 27.1, 6.1, 18.7, -93.2, 31.5, 168.2, -3.7, -99.5, -55.5,
-10.8, 17.5, 20.8, 16.9, 43.8, 42.0, -89.2, 18.8, -9.6, -84.1, 212.6, 19.7, -50.0,
-52.0, -40.0, -166.6, -73.2, -10.8, -73.3, 31.5, -23.4, -79.3, -27.0, -84.4, -42.9,
-20.3, 51.8, -16.7, 76.3, -120.5, -65.8, 96.5, -10.7, -45.9, -88.1, 65.4, -7.0, -1.5,
92.8, -25.1, -114.2, -5.8, -14.8, -51.2, -20.7, 54.2, -79.8, 47.7, -29.2, -8.8, 53.5,
-28.4, 85.0, -18.3, 107.0, 28.3, -71.8
]
);
assert_eq!(
test_utils::to_vec3_round(&grad_t.i(0)?, 1)?,
[
[
[32.3, -41.6, -24.0, 14.1, 17.6],
[-11.8, 72.5, 87.6, 46.4, 61.5],
[115.0, 108.5, -48.6, -63.4, -50.0],
[51.3, 5.4, 31.3, 91.1, -30.9],
[52.7, 92.8, -68.0, -47.0, 83.0],
// pytorch gets -107.1
[-10.2, -107.0, -5.4, 213.1, -31.4],
[-2.4, 65.1, 9.2, -146.2, -24.2]
],
[
[-72.6, -63.9, -61.9, 45.3, 33.0],
[79.3, -0.5, -26.2, 78.2, 42.7],
[90.9, 141.6, 40.1, -62.7, 37.0],
[32.8, 198.2, -0.8, -31.1, 27.3],
// torch gets 48.0
[34.5, 34.9, -47.9, 127.6, -12.3],
[-61.4, -3.2, -2.9, -10.9, -16.6],
[74.6, 60.1, -68.9, 34.5, -50.4]
],
[
[37.5, -56.9, -43.6, -13.5, -9.9],
[40.0, 97.3, 28.6, 14.2, -30.1],
[-22.3, -126.3, -68.8, -8.2, 26.1],
[-32.9, 37.3, 108.5, -54.8, 29.6],
[34.9, -176.9, -125.0, -28.3, -13.9],
[-54.9, 142.6, 62.1, -80.4, -65.6],
[7.4, -91.1, -67.6, 35.0, 39.7]
],
[
[-57.2, -40.9, -10.1, 32.6, 29.4],
[18.7, -18.0, 29.5, -1.2, 59.2],
[-14.0, -74.4, 19.8, -117.0, 58.2],
[-21.8, 163.5, -71.1, -99.0, 80.9],
[-58.9, -10.9, 93.8, -139.6, 98.0],
// torch gets 54.5
[-54.4, 135.3, 6.0, -79.1, 134.6],
[27.5, -76.0, 43.4, -2.8, -7.8]
]
]
);
Ok(())
}

View File

@ -112,34 +112,3 @@ fn custom_op1_with_backward() -> Result<()> {
Ok(())
}
impl candle_core::InplaceOp1 for Elu {
fn name(&self) -> &'static str {
"elu"
}
fn cpu_fwd(&self, s: &mut CpuStorage, _l: &Layout) -> Result<()> {
let alpha = self.alpha;
match s {
CpuStorage::BF16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
CpuStorage::F16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
CpuStorage::F32(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
CpuStorage::F64(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
_ => candle_core::bail!("unsupported dtype for inplace elu"),
}
Ok(())
}
}
#[test]
fn inplace_op1() -> Result<()> {
let cpu = &Device::Cpu;
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
let t = (t - 5.)?;
t.inplace_op1(&Elu { alpha: 1. })?;
assert_eq!(
to_vec1_round(&t, 4)?,
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
);
Ok(())
}

View File

@ -88,7 +88,7 @@ fn strided_blocks() -> Result<()> {
}
};
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
let tensor = tensor.i((.., 1))?.contiguous()?;
let tensor = tensor.i((.., 1))?;
match tensor.strided_blocks() {
candle::StridedBlocks::SingleBlock { start_offset, len } => {
assert_eq!(start_offset, 0);
@ -100,20 +100,6 @@ fn strided_blocks() -> Result<()> {
}
};
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
let tensor = tensor.i((.., 1))?;
match tensor.strided_blocks() {
candle::StridedBlocks::SingleBlock { .. } => {
panic!("unexpected block structure")
}
candle::StridedBlocks::MultipleBlocks {
block_len,
block_start_index,
} => {
assert_eq!(block_len, 4);
assert_eq!(block_start_index.collect::<Vec<_>>(), &[4, 16])
}
};
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
match tensor.t()?.strided_blocks() {
candle::StridedBlocks::SingleBlock { .. } => {
panic!("unexpected block structure")

View File

@ -1,106 +0,0 @@
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
fn matmul(device: &Device) -> Result<()> {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let a = Tensor::from_slice(&data, (2, 2), device)?;
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let b = Tensor::from_slice(&data, (2, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
let data = vec![1.0f32, 2.0];
let a = Tensor::from_slice(&data, (2, 1), device)?;
let data = vec![3.0f32, 4.0];
let b = Tensor::from_slice(&data, (1, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 3), device)?;
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (3, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
let c = a.matmul(&b)?;
assert_eq!(c.to_vec3::<f32>()?, &expected);
// Also perform the matmul on contiguous transposed versions.
let a_tt = a.t()?.contiguous()?.t()?;
assert!(!a_tt.is_contiguous());
assert_eq!(a.dims(), a_tt.dims());
assert_eq!(a_tt.stride(), &[6, 1, 2]);
let b_tt = b.t()?.contiguous()?.t()?;
assert!(!b_tt.is_contiguous());
assert_eq!(b.dims(), b_tt.dims());
assert_eq!(b_tt.stride(), &[6, 1, 3]);
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
Ok(())
}
fn broadcast_matmul(device: &Device) -> Result<()> {
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
let out = lhs.broadcast_matmul(&rhs)?;
assert_eq!(out.dims(), &[3, 6, 4, 2]);
for idx1 in 0..3 {
for idx2 in 0..6 {
let out = out.i((idx1, idx2))?;
let lhs = lhs.i((idx1, 0))?;
let rhs = rhs.i(idx2)?;
let out2 = lhs.matmul(&rhs);
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
// With cuda, we see errors of up to ~1e-12.
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
}
}
Ok(())
}
// https://github.com/huggingface/candle/issues/1948
fn squeeze_mm(device: &Device) -> Result<()> {
let seq_len = 8_usize;
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
let x = a.i((.., seq_len - 1, ..))?;
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
let x = x.matmul(&w)?;
assert_eq!(x.dims(), &[1, 32]);
Ok(())
}
// https://github.com/huggingface/candle/issues/1992
fn mm_layout(device: &Device) -> Result<()> {
let a = Tensor::arange(0f32, 16f32, device)?.reshape((1, 1, 4, 4))?;
let b = Tensor::arange(0f32, 8f32, device)?.reshape((1, 1, 4, 2))?;
let mm1 = a.matmul(&b)?;
// Forces the layout to be:
// shape: [1, 1, 4, 2], stride: [8, 2, 2, 1], start_offset: 0
// This is still a contiguous matrix but matmul checks are only the two last dimensions have
// non 1 sizes but matmul check may be reluctant to handle it.
let b = b.transpose(1, 2)?.force_contiguous()?.transpose(1, 2)?;
let mm2 = a.matmul(&b)?;
let diff = (mm1 - mm2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
Ok(())
}
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
broadcast_matmul,
broadcast_matmul_cpu,
broadcast_matmul_gpu,
broadcast_matmul_metal
);
test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal);
test_device!(mm_layout, mm_layout_cpu, mm_layout_gpu, mm_layout_metal);

View File

@ -2,6 +2,9 @@ use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor};
// https://github.com/huggingface/candle/issues/364
fn avg_pool2d(dev: &Device) -> Result<()> {
if dev.is_metal() {
return Ok(());
}
let data: Vec<f32> = vec![
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
@ -19,6 +22,9 @@ fn avg_pool2d(dev: &Device) -> Result<()> {
}
fn max_pool2d(dev: &Device) -> Result<()> {
if dev.is_metal() {
return Ok(());
}
let data: Vec<f32> = vec![
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
];

View File

@ -106,9 +106,6 @@ fn unary_op(device: &Device) -> Result<()> {
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
]
);
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?;
assert!(max_diff.to_vec0::<f32>()? < 5e-3);
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
[
@ -151,14 +148,6 @@ fn unary_op(device: &Device) -> Result<()> {
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
[3000.0, 300.]
);
let tensor = Tensor::new(
&[-1.01f32, -0.9, -0.1, 0.0, -0.0, 0.1, 0.9, 1.0, 1.1],
device,
)?;
assert_eq!(
tensor.sign()?.to_vec1::<f32>()?,
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
);
Ok(())
}
@ -718,8 +707,6 @@ fn embeddings(device: &Device) -> Result<()> {
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
Ok(())
}
@ -747,47 +734,44 @@ fn index_select(device: &Device) -> Result<()> {
[9.0, 10.0, 11.0]
]
);
for dtype in [DType::U8, DType::U32, DType::I64] {
let ids = ids.to_dtype(dtype)?;
let hs = t.index_select(&ids, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 2.0, 1.0],
[3.0, 5.0, 4.0],
[6.0, 8.0, 7.0],
[9.0, 11.0, 10.0]
]
);
let hs = t.index_select(&ids, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
);
// Prior to https://github.com/huggingface/candle/pull/1022
// There would be a bug where the last values in the result tensor would be set to 0.
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
let hs = t.index_select(&ids, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
]
);
let hs = t.index_select(&ids, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 2.0, 1.0],
[3.0, 5.0, 4.0],
[6.0, 8.0, 7.0],
[9.0, 11.0, 10.0]
]
);
let hs = t.index_select(&ids, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
);
// Prior to https://github.com/huggingface/candle/pull/1022
// There would be a bug where the last values in the result tensor would be set to 0.
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
let hs = t.index_select(&ids, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
]
);
// Test when selecting dim > 0 with ids size different from elem count of
// target dim in source/input.
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
let hs = t.index_select(&ids, 1)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
}
// Test when selecting dim > 0 with ids size different from elem count of
// target dim in source/input.
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
let hs = t.index_select(&ids, 1)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
Ok(())
}
@ -949,6 +933,74 @@ fn gather(device: &Device) -> Result<()> {
Ok(())
}
fn matmul(device: &Device) -> Result<()> {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let a = Tensor::from_slice(&data, (2, 2), device)?;
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let b = Tensor::from_slice(&data, (2, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
let data = vec![1.0f32, 2.0];
let a = Tensor::from_slice(&data, (2, 1), device)?;
let data = vec![3.0f32, 4.0];
let b = Tensor::from_slice(&data, (1, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 3), device)?;
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (3, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
let c = a.matmul(&b)?;
assert_eq!(c.to_vec3::<f32>()?, &expected);
// Also perform the matmul on contiguous transposed versions.
let a_tt = a.t()?.contiguous()?.t()?;
assert!(!a_tt.is_contiguous());
assert_eq!(a.dims(), a_tt.dims());
assert_eq!(a_tt.stride(), &[6, 1, 2]);
let b_tt = b.t()?.contiguous()?.t()?;
assert!(!b_tt.is_contiguous());
assert_eq!(b.dims(), b_tt.dims());
assert_eq!(b_tt.stride(), &[6, 1, 3]);
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
Ok(())
}
fn broadcast_matmul(device: &Device) -> Result<()> {
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
let out = lhs.broadcast_matmul(&rhs)?;
assert_eq!(out.dims(), &[3, 6, 4, 2]);
for idx1 in 0..3 {
for idx2 in 0..6 {
let out = out.i((idx1, idx2))?;
let lhs = lhs.i((idx1, 0))?;
let rhs = rhs.i(idx2)?;
let out2 = lhs.matmul(&rhs);
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
// With cuda, we see errors of up to ~1e-12.
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
}
}
Ok(())
}
fn broadcasting(device: &Device) -> Result<()> {
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
let t2 = Tensor::new(&[100f32, 200f32], device)?;
@ -1102,6 +1154,13 @@ test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
broadcast_matmul,
broadcast_matmul_cpu,
broadcast_matmul_gpu,
broadcast_matmul_metal
);
test_device!(
broadcasting,
broadcasting_cpu,
@ -1258,8 +1317,8 @@ fn pow() -> Result<()> {
let rhs = (&lhs - 2.)?;
let res = lhs.pow(&rhs)?;
assert_eq!(
test_utils::to_vec2_round(&res, 3)?,
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]]
test_utils::to_vec2_round(&res, 4)?,
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
);
Ok(())
}

View File

@ -25,9 +25,8 @@ hf-hub = { workspace = true, features = ["tokio"] }
image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true }
rubato = { version = "0.15.0", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
@ -42,7 +41,7 @@ clap = { workspace = true }
imageproc = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }
ab_glyph = { workspace = true }
rusttype = { workspace = true }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
@ -64,7 +63,6 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
encodec = ["cpal", "symphonia", "rubato"]
[[example]]
name = "llama_multiprocess"
@ -100,4 +98,6 @@ required-features = ["candle-datasets"]
[[example]]
name = "encodec"
required-features = ["encodec"]
required-features = ["symphonia"]

View File

@ -1,46 +0,0 @@
Contrastive Language-Image Pre-Training
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
pairs of images with related texts.
https://github.com/openai/CLIP
https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
## Running on an example on cpu
```
$ cargo run --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
INFO clip: Probability: 0.0000% Text: a cycling race
INFO clip: Probability: 0.0000% Text: a photo of two cats
INFO clip: Probability: 100.0000% Text: a robot holding a candle
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
INFO clip: Probability: 99.9999% Text: a cycling race
INFO clip: Probability: 0.0001% Text: a photo of two cats
INFO clip: Probability: 0.0000% Text: a robot holding a candle
```
## Running on an example with metal feature (mac)
```
$ cargo run --features metal --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
INFO clip: Probability: 0.0000% Text: a cycling race
INFO clip: Probability: 0.0000% Text: a photo of two cats
INFO clip: Probability: 100.0000% Text: a robot holding a candle
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
INFO clip: Probability: 99.9999% Text: a cycling race
INFO clip: Probability: 0.0001% Text: a photo of two cats
INFO clip: Probability: 0.0000% Text: a robot holding a candle
```

View File

@ -1,202 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::clip;
use tokenizers::Tokenizer;
use tracing::info;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,
#[arg(long)]
cpu: bool,
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,
}
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
let img = image::io::Reader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
// .unsqueeze(0)?;
Ok(img)
}
fn load_images<T: AsRef<std::path::Path>>(
paths: &Vec<T>,
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = load_image(path, image_size)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
pub fn main() -> anyhow::Result<()> {
// std::env::set_var("RUST_BACKTRACE", "full");
let args = Args::parse();
tracing_subscriber::fmt::init();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.repo(hf_hub::Repo::with_revision(
"openai/clip-vit-base-patch32".to_string(),
hf_hub::RepoType::Model,
"refs/pr/15".to_string(),
));
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let tokenizer = get_tokenizer(args.tokenizer)?;
let config = clip::ClipConfig::vit_base_patch32();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
// let image = load_image(args.image, config.image_size)?.to_device(&device)?;
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let model = clip::ClipModel::new(vb, &config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
info!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
info!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
}
}
Ok(())
}
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
let tokenizer = match tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.repo(hf_hub::Repo::with_revision(
"openai/clip-vit-base-patch32".to_string(),
hf_hub::RepoType::Model,
"refs/pr/15".to_string(),
));
api.get("tokenizer.json")?
}
Some(file) => file.into(),
};
Tokenizer::from_file(tokenizer).map_err(E::msg)
}
pub fn tokenize_sequences(
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
let pad_id = *tokenizer
.get_vocab(true)
.get("<|endoftext|>")
.ok_or(E::msg("No pad token"))?;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
"a cycling race".to_string(),
"a photo of two cats".to_string(),
"a robot holding a candle".to_string(),
],
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
if len_diff > 0 {
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -13,13 +13,8 @@ cargo run --example encodec --features symphonia --release -- code-to-audio \
```
This decodes the EnCodec tokens stored in `jfk-codes.safetensors` and generates
an output wav file containing the audio data.
Instead of `code-to-audio` one can use:
an output wav file containing the audio data. Instead of `code-to-audio` one
can use:
- `audio-to-audio in.mp3 out.wav`: encodes the input audio file then decodes it to a wav file.
- `audio-to-code in.mp3 out.safetensors`: generates a safetensors file
containing EnCodec tokens for the input audio file.
If the audio output file name is set to `-`, the audio content directly gets
played on default audio output device. If the audio input file is set to `-`, the audio
gets recorded from the default audio input.

View File

@ -1,275 +0,0 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};
pub const SAMPLE_RATE: usize = 24_000;
pub(crate) struct AudioOutputData_ {
resampled_data: std::collections::VecDeque<f32>,
resampler: rubato::FastFixedIn<f32>,
output_buffer: Vec<f32>,
input_buffer: Vec<f32>,
input_len: usize,
}
impl AudioOutputData_ {
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
use rubato::Resampler;
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
let resampler = rubato::FastFixedIn::new(
resample_ratio,
f64::max(resample_ratio, 1.0),
rubato::PolynomialDegree::Septic,
1024,
1,
)?;
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
Ok(Self {
resampled_data,
resampler,
input_buffer,
output_buffer,
input_len: 0,
})
}
pub fn reset(&mut self) {
use rubato::Resampler;
self.output_buffer.fill(0.);
self.input_buffer.fill(0.);
self.resampler.reset();
self.resampled_data.clear();
}
pub(crate) fn take_all(&mut self) -> Vec<f32> {
let mut data = Vec::with_capacity(self.resampled_data.len());
while let Some(elem) = self.resampled_data.pop_back() {
data.push(elem);
}
data
}
pub(crate) fn is_empty(&self) -> bool {
self.resampled_data.is_empty()
}
// Assumes that the input buffer is large enough.
fn push_input_buffer(&mut self, samples: &[f32]) {
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
self.input_len += samples.len()
}
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
use rubato::Resampler;
let mut pos_in = 0;
loop {
let rem = self.input_buffer.len() - self.input_len;
let pos_end = usize::min(pos_in + rem, samples.len());
self.push_input_buffer(&samples[pos_in..pos_end]);
pos_in = pos_end;
if self.input_len < self.input_buffer.len() {
break;
}
let (_, out_len) = self.resampler.process_into_buffer(
&[&self.input_buffer],
&mut [&mut self.output_buffer],
None,
)?;
for &elem in self.output_buffer[..out_len].iter() {
self.resampled_data.push_front(elem)
}
self.input_len = 0;
}
Ok(())
}
}
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio output stream!");
let host = cpal::default_host();
let device = host
.default_output_device()
.context("no output device available")?;
let mut supported_configs_range = device.supported_output_configs()?;
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
// On macOS, it's commonly the case that there are only stereo outputs.
None => device
.supported_output_configs()?
.next()
.context("no audio output available")?,
Some(config_range) => config_range,
};
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
let channels = config.channels as usize;
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
SAMPLE_RATE,
config.sample_rate.0 as usize,
)?));
let ad = audio_data.clone();
let stream = device.build_output_stream(
&config,
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
data.fill(0.);
let mut ad = ad.lock().unwrap();
let mut last_elem = 0f32;
for (idx, elem) in data.iter_mut().enumerate() {
if idx % channels == 0 {
match ad.resampled_data.pop_back() {
None => break,
Some(v) => {
last_elem = v;
*elem = v
}
}
} else {
*elem = last_elem
}
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio input stream!");
let host = cpal::default_host();
let device = host
.default_input_device()
.context("no input device available")?;
let mut supported_configs_range = device.supported_input_configs()?;
let config_range = supported_configs_range
.find(|c| c.channels() == 1)
.context("no audio input available")?;
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
config.sample_rate.0 as usize,
SAMPLE_RATE,
)?));
let ad = audio_data.clone();
let stream = device.build_input_stream(
&config,
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let mut ad = ad.lock().unwrap();
if let Err(err) = ad.push_samples(data) {
eprintln!("error processing audio input {err:?}")
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result<Vec<f32>> {
use rubato::Resampler;
let mut pcm_out =
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in, sr_out, 1024, 1)?;
let mut output_buffer = resampler.output_buffer_allocate(true);
let mut pos_in = 0;
while pos_in + resampler.input_frames_next() < pcm_in.len() {
let (in_len, out_len) =
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
pos_in += in_len;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
if pos_in < pcm_in.len() {
let (_in_len, out_len) = resampler.process_partial_into_buffer(
Some(&[&pcm_in[pos_in..]]),
&mut output_buffer,
None,
)?;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
Ok(pcm_out)
}

View File

@ -11,7 +11,59 @@ use candle_transformers::models::encodec::{Config, Model};
use clap::{Parser, ValueEnum};
use hf_hub::api::sync::Api;
mod audio_io;
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Action {
@ -57,36 +109,14 @@ fn main() -> Result<()> {
let codes = match args.action {
Action::CodeToAudio => {
let codes = candle::safetensors::load(args.in_file, &device)?;
codes.get("codes").expect("no codes in input file").clone()
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
codes
}
Action::AudioToCode | Action::AudioToAudio => {
let pcm = if args.in_file == "-" {
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
let (stream, input_audio) = audio_io::setup_input_stream()?;
let mut pcms = vec![];
let stdin = std::thread::spawn(|| {
let mut s = String::new();
std::io::stdin().read_line(&mut s)
});
while !stdin.is_finished() {
let input = input_audio.lock().unwrap().take_all();
if input.is_empty() {
std::thread::sleep(std::time::Duration::from_millis(100));
continue;
}
pcms.push(input)
}
drop(stream);
pcms.concat()
} else {
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
if sample_rate != 24_000 {
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}, resampling...");
audio_io::resample(&pcm, sample_rate as usize, 24_000)?
} else {
pcm
}
};
let (pcm, sample_rate) = pcm_decode(args.in_file)?;
if sample_rate != 24_000 {
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}")
}
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
@ -105,26 +135,8 @@ fn main() -> Result<()> {
let pcm = pcm.i(0)?.i(0)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
if args.out_file == "-" {
let (stream, ad) = audio_io::setup_output_stream()?;
{
let mut ad = ad.lock().unwrap();
ad.push_samples(&pcm)?;
}
loop {
let ad = ad.lock().unwrap();
if ad.is_empty() {
break;
}
// That's very weird, calling thread::sleep here triggers the stream to stop
// playing (the callback doesn't seem to be called anymore).
// std::thread::sleep(std::time::Duration::from_millis(100));
}
drop(stream)
} else {
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
}
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
}
}
Ok(())

View File

@ -13,7 +13,7 @@ use candle_transformers::models::quantized_mistral::Model as QMistral;
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
@ -39,26 +39,11 @@ impl TextGeneration {
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
top_k: Option<usize>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = {
let temperature = temp.unwrap_or(0.);
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (top_k, top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
@ -137,18 +122,6 @@ impl TextGeneration {
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "7b-v0.1")]
Mistral7bV01,
#[value(name = "7b-v0.2")]
Mistral7bV02,
#[value(name = "7b-instruct-v0.1")]
Mistral7bInstructV01,
#[value(name = "7b-instruct-v0.2")]
Mistral7bInstructV02,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -174,10 +147,6 @@ struct Args {
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@ -186,10 +155,6 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
/// The model size to use.
#[arg(long, default_value = "7b-v0.1")]
which: Which,
#[arg(long)]
model_id: Option<String>,
@ -199,9 +164,6 @@ struct Args {
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
@ -215,10 +177,6 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// Use the slower dmmv cuda kernel.
#[arg(long)]
force_dmmv: bool,
}
fn main() -> Result<()> {
@ -226,9 +184,6 @@ fn main() -> Result<()> {
use tracing_subscriber::prelude::*;
let args = Args::parse();
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
@ -256,17 +211,9 @@ fn main() -> Result<()> {
Some(model_id) => model_id,
None => {
if args.quantized {
if args.which != Which::Mistral7bV01 {
anyhow::bail!("only 7b-v0.1 is available as a quantized model for now")
}
"lmz/candle-mistral".to_string()
} else {
match args.which {
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(),
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(),
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(),
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(),
}
"mistralai/Mistral-7B-v0.1".to_string()
}
}
};
@ -296,17 +243,7 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = match args.config_file {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
if args.quantized {
Config::config_7b_v0_1(args.use_flash_attn)
} else {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
}
};
let config = Config::config_7b_v0_1(args.use_flash_attn);
let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized {
let filename = &filenames[0];
@ -333,7 +270,6 @@ fn main() -> Result<()> {
args.seed,
args.temperature,
args.top_p,
args.top_k,
args.repeat_penalty,
args.repeat_last_n,
&device,

View File

@ -1,26 +0,0 @@
# candle-moondream
[Moondream](https://github.com/vikhyat/moondream) is a computer-vision model can answer real-world questions about images. It's tiny by today's models, with only 1.6B parameters. That enables it to run on a variety of devices, including mobile phones and edge devices.
## Running some examples
First download an example image
```bash
$ wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg
```
<img src="https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg" width="200">
Now you can run Moondream from the `candle-examples` crate:
```bash
$ cargo run --example moondream --release -- --prompt "What is the girl eating?" --image "./demo-1.jpg"
avavx: false, neon: true, simd128: false, f16c: false
temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64
retrieved the files in 3.395583ms
Running on CPU, to run on GPU(metal), build this example with `--features metal`
loaded the model in 5.485493792s
loaded and encoded the image Tensor[dims 3, 378, 378; f32] in 4.801396417s
starting the inference loop
The girl is eating a hamburger.<
9 tokens generated (0.68 token/s)
```

View File

@ -1,343 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
generation::LogitsProcessor,
models::{moondream, quantized_moondream},
};
use tokenizers::Tokenizer;
enum Model {
Moondream(moondream::Model),
Quantized(quantized_moondream::Model),
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, image_embeds: &Tensor, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
if tokens.is_empty() {
anyhow::bail!("Empty prompts are not supported in the Moondream model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
// Moondream tokenizer bos_token and eos_token is "<|endoftext|>"
// https://huggingface.co/vikhyatk/moondream2/blob/main/special_tokens_map.json
let special_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the special token"),
};
let (bos_token, eos_token) = (special_token, special_token);
let start_gen = std::time::Instant::now();
let mut load_t = std::time::Duration::from_secs_f64(0f64);
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = if index > 0 {
match self.model {
Model::Moondream(ref mut model) => model.text_model.forward(&input)?,
Model::Quantized(ref mut model) => model.text_model.forward(&input)?,
}
} else {
let bos_token = Tensor::new(&[bos_token], &self.device)?.unsqueeze(0)?;
let logits = match self.model {
Model::Moondream(ref mut model) => {
model
.text_model
.forward_with_img(&bos_token, &input, image_embeds)?
}
Model::Quantized(ref mut model) => {
model
.text_model
.forward_with_img(&bos_token, &input, image_embeds)?
}
};
load_t = start_gen.elapsed();
println!("load_t: {:?}", load_t);
logits
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed() - load_t;
println!(
"\ngenerated in {} seconds\n{generated_tokens} tokens generated ({:.2} token/s)",
dt.as_secs_f64(),
(generated_tokens - 1) as f64 / dt.as_secs_f64()
);
Ok(())
}
}
#[derive(Parser)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long)]
prompt: String,
#[arg(long)]
image: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 0)]
seed: u64,
#[arg(long, default_value_t = 5000)]
sample_len: usize,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
quantized: bool,
/// Use f16 precision for all the computations rather than f32.
#[arg(long)]
f16: bool,
#[arg(long)]
model_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
}
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 378, 378).
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> candle::Result<Tensor> {
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(378, 378, image::imageops::FilterType::Triangle); // Adjusted to 378x378
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (378, 378, 3), &Device::Cpu)?.permute((2, 0, 1))?;
let mean = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = hf_hub::api::tokio::Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => {
if args.quantized {
"santiagomed/candle-moondream".to_string()
} else {
"vikhyatk/moondream2".to_string()
}
}
};
let repo = api.repo(hf_hub::Repo::with_revision(
model_id,
hf_hub::RepoType::Model,
args.revision,
));
let model_file = match args.model_file {
Some(m) => m.into(),
None => {
if args.quantized {
repo.get("model-q4_0.gguf").await?
} else {
repo.get("model.safetensors").await?
}
}
};
let tokenizer = match args.tokenizer_file {
Some(m) => m.into(),
None => repo.get("tokenizer.json").await?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let config = moondream::Config::v2();
let dtype = if args.quantized {
if args.f16 {
anyhow::bail!("Quantized model does not support f16");
}
DType::F32
} else if device.is_cuda() || args.f16 {
DType::F16
} else {
DType::F32
};
let model = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&model_file,
&device,
)?;
let model = quantized_moondream::Model::new(&config, vb)?;
Model::Quantized(model)
} else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let model = moondream::Model::new(&config, vb)?;
Model::Moondream(model)
};
println!("loaded the model in {:?}", start.elapsed());
let start = std::time::Instant::now();
let image = load_image(args.image)?
.to_device(&device)?
.to_dtype(dtype)?;
let image_embeds = image.unsqueeze(0)?;
let image_embeds = match model {
Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?,
Model::Quantized(ref m) => image_embeds.apply(m.vision_encoder())?,
};
println!(
"loaded and encoded the image {image:?} in {:?}",
start.elapsed()
);
let prompt = format!("\n\nQuestion: {0}\n\nAnswer:", args.prompt);
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
);
pipeline.run(&prompt, &image_embeds, args.sample_len)?;
Ok(())
}

View File

@ -17,7 +17,7 @@ generate quantized weight files from the original safetensors file by using the
`tensor-tools` command line utility via:
```bash
$ cargo run --bin tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
```
## Using custom models

View File

@ -10,7 +10,7 @@ use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file};
use candle::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::generation::LogitsProcessor;
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_llama as model;
@ -200,10 +200,6 @@ struct Args {
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@ -239,10 +235,6 @@ struct Args {
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
#[arg(long)]
gqa: Option<usize>,
/// Use the slower dmmv cuda kernel.
#[arg(long)]
force_dmmv: bool,
}
impl Args {
@ -349,10 +341,11 @@ fn main() -> anyhow::Result<()> {
use tracing_subscriber::prelude::*;
let args = Args::parse();
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
let temperature = if args.temperature == 0. {
None
} else {
Some(args.temperature)
};
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
@ -499,20 +492,7 @@ fn main() -> anyhow::Result<()> {
prompt_tokens
};
let mut all_tokens = vec![];
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt {

View File

@ -1,27 +0,0 @@
# candle-qwen: large language model series from Alibaba Cloud
Qwen 1.5 is a series of large language models that provide strong performances
on English and Chinese.
- [Blog post](https://qwenlm.github.io/blog/qwen1.5/) introducing Qwen1.5.
- [Model card](https://huggingface.co/Qwen/Qwen1.5-0.5B) on the HuggingFace Hub.
- [Blog post](https://qwenlm.github.io/blog/qwen-moe/) for the
mixture-of-experts (MoE) variant.
## Running the example
```bash
$ cargo run --example qwen --release -- --prompt "Hello there "
```
Various model sizes are available via the `--model` argument, including the MoE
variant.
```bash
$ cargo run --example qwen --release -- --model moe-a2.7b --prompt 'def print_prime(n: int): '
def print_prime(n: int): # n is the number of primes to be printed
for i in range(2, n + 1):
if all(i % j != 0 for j in range(2, i)):
print(i)
```

View File

@ -7,8 +7,7 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase};
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
use candle_transformers::models::qwen2::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -17,20 +16,6 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
Base(ModelBase),
Moe(ModelMoe),
}
impl Model {
fn forward(&mut self, xs: &Tensor, s: usize) -> candle::Result<Tensor> {
match self {
Self::Moe(ref mut m) => m.forward(xs, s),
Self::Base(ref mut m) => m.forward(xs, s),
}
}
}
struct TextGeneration {
model: Model,
device: Device,
@ -142,8 +127,6 @@ enum WhichModel {
W14b,
#[value(name = "72b")]
W72b,
#[value(name = "moe-a2.7b")]
MoeA27b,
}
#[derive(Parser, Debug)]
@ -241,7 +224,6 @@ fn main() -> Result<()> {
WhichModel::W7b => "7B",
WhichModel::W14b => "14B",
WhichModel::W72b => "72B",
WhichModel::MoeA27b => "MoE-A2.7B",
};
format!("Qwen/Qwen1.5-{size}")
}
@ -262,11 +244,7 @@ fn main() -> Result<()> {
.collect::<Vec<_>>(),
None => match args.model {
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
WhichModel::W4b
| WhichModel::W7b
| WhichModel::W14b
| WhichModel::W72b
| WhichModel::MoeA27b => {
WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
@ -276,6 +254,7 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config_file = repo.get("config.json")?;
let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
@ -283,16 +262,7 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = match args.model {
WhichModel::MoeA27b => {
let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Moe(ModelMoe::new(&config, vb)?)
}
_ => {
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Base(ModelBase::new(&config, vb)?)
}
};
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());

View File

@ -1,118 +0,0 @@
use std::collections::VecDeque;
use rand::distributions::Uniform;
use rand::{thread_rng, Rng};
use candle::{DType, Device, Module, Result, Tensor};
use candle_nn::loss::mse;
use candle_nn::{linear, seq, Activation, AdamW, Optimizer, VarBuilder, VarMap};
use crate::gym_env::GymEnv;
const DEVICE: Device = Device::Cpu;
const EPISODES: usize = 200;
const BATCH_SIZE: usize = 64;
const GAMMA: f64 = 0.99;
const LEARNING_RATE: f64 = 0.01;
pub fn run() -> Result<()> {
let env = GymEnv::new("CartPole-v1")?;
// Build the model that predicts the estimated rewards given a specific state.
let var_map = VarMap::new();
let vb = VarBuilder::from_varmap(&var_map, DType::F32, &DEVICE);
let observation_space = *env.observation_space().first().unwrap();
let model = seq()
.add(linear(observation_space, 64, vb.pp("linear_in"))?)
.add(Activation::Relu)
.add(linear(64, env.action_space(), vb.pp("linear_out"))?);
let mut optimizer = AdamW::new_lr(var_map.all_vars(), LEARNING_RATE)?;
// Initialize the model's memory.
let mut memory = VecDeque::with_capacity(10000);
// Start the training loop.
let mut state = env.reset(0)?;
let mut episode = 0;
let mut accumulate_rewards = 0.0;
while episode < EPISODES {
// Given the current state, predict the estimated rewards, and take the
// action that is expected to return the most rewards.
let estimated_rewards = model.forward(&state.unsqueeze(0)?)?;
let action: u32 = estimated_rewards.squeeze(0)?.argmax(0)?.to_scalar()?;
// Take that action in the environment, and memorize the outcome:
// - the state for which the action was taken
// - the action taken
// - the new state resulting of taking that action
// - the actual rewards of taking that action
// - whether the environment reached a terminal state or not (e.g. game over)
let step = env.step(action)?;
accumulate_rewards += step.reward;
memory.push_back((
state,
action,
step.state.clone(),
step.reward,
step.terminated || step.truncated,
));
state = step.state;
// If there's enough entries in the memory, perform a learning step, where
// BATCH_SIZE transitions will be sampled from the memory and will be
// fed to the model so that it performs a backward pass.
if memory.len() > BATCH_SIZE {
// Sample randomly from the memory.
let batch = thread_rng()
.sample_iter(Uniform::from(0..memory.len()))
.take(BATCH_SIZE)
.map(|i| memory.get(i).unwrap().clone())
.collect::<Vec<_>>();
// Group all the samples together into tensors with the appropriate shape.
let states: Vec<_> = batch.iter().map(|e| e.0.clone()).collect();
let states = Tensor::stack(&states, 0)?;
let actions = batch.iter().map(|e| e.1);
let actions = Tensor::from_iter(actions, &DEVICE)?.unsqueeze(1)?;
let next_states: Vec<_> = batch.iter().map(|e| e.2.clone()).collect();
let next_states = Tensor::stack(&next_states, 0)?;
let rewards = batch.iter().map(|e| e.3 as f32);
let rewards = Tensor::from_iter(rewards, &DEVICE)?.unsqueeze(1)?;
let non_final_mask = batch.iter().map(|e| !e.4 as u8 as f32);
let non_final_mask = Tensor::from_iter(non_final_mask, &DEVICE)?.unsqueeze(1)?;
// Get the estimated rewards for the actions that where taken at each step.
let estimated_rewards = model.forward(&states)?;
let x = estimated_rewards.gather(&actions, 1)?;
// Get the maximum expected rewards for the next state, apply them a discount rate
// GAMMA and add them to the rewards that were actually gathered on the current state.
// If the next state is a terminal state, just omit maximum estimated
// rewards for that state.
let expected_rewards = model.forward(&next_states)?.detach();
let y = expected_rewards.max_keepdim(1)?;
let y = (y * GAMMA * non_final_mask + rewards)?;
// Compare the estimated rewards with the maximum expected rewards and
// perform the backward step.
let loss = mse(&x, &y)?;
optimizer.backward_step(&loss)?;
}
// If we are on a terminal state, reset the environment and log how it went.
if step.terminated || step.truncated {
episode += 1;
println!("Episode {episode} | Rewards {}", accumulate_rewards as i64);
state = env.reset(0)?;
accumulate_rewards = 0.0;
}
}
Ok(())
}

View File

@ -42,7 +42,7 @@ impl GymEnv {
/// Creates a new session of the specified OpenAI Gym environment.
pub fn new(name: &str) -> Result<GymEnv> {
Python::with_gil(|py| {
let gym = py.import_bound("gymnasium")?;
let gym = py.import("gymnasium")?;
let make = gym.getattr("make")?;
let env = make.call1((name,))?;
let action_space = env.getattr("action_space")?;
@ -66,10 +66,10 @@ impl GymEnv {
/// Resets the environment, returning the observation tensor.
pub fn reset(&self, seed: u64) -> Result<Tensor> {
let state: Vec<f32> = Python::with_gil(|py| {
let kwargs = PyDict::new_bound(py);
let kwargs = PyDict::new(py);
kwargs.set_item("seed", seed)?;
let state = self.env.call_method_bound(py, "reset", (), Some(&kwargs))?;
state.bind(py).get_item(0)?.extract()
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
state.as_ref(py).get_item(0)?.extract()
})
.map_err(w)?;
Tensor::new(state, &Device::Cpu)
@ -81,10 +81,8 @@ impl GymEnv {
action: A,
) -> Result<Step<A>> {
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
let step = self
.env
.call_method_bound(py, "step", (action.clone(),), None)?;
let step = step.bind(py);
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
let step = step.as_ref(py);
let state: Vec<f32> = step.get_item(0)?.extract()?;
let reward: f64 = step.get_item(1)?.extract()?;
let terminated: bool = step.get_item(2)?.extract()?;

View File

@ -13,7 +13,6 @@ mod gym_env;
mod vec_gym_env;
mod ddpg;
mod dqn;
mod policy_gradient;
#[derive(Parser)]
@ -26,7 +25,6 @@ struct Args {
enum Command {
Pg,
Ddpg,
Dqn,
}
fn main() -> Result<()> {
@ -34,7 +32,6 @@ fn main() -> Result<()> {
match args.command {
Command::Pg => policy_gradient::run()?,
Command::Ddpg => ddpg::run()?,
Command::Dqn => dqn::run()?,
}
Ok(())
}

View File

@ -24,13 +24,13 @@ fn w(res: PyErr) -> candle::Error {
impl VecGymEnv {
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
Python::with_gil(|py| {
let sys = py.import_bound("sys")?;
let sys = py.import("sys")?;
let path = sys.getattr("path")?;
let _ = path.call_method1(
"append",
("candle-examples/examples/reinforcement-learning",),
)?;
let gym = py.import_bound("atari_wrappers")?;
let gym = py.import("atari_wrappers")?;
let make = gym.getattr("make")?;
let env = make.call1((name, img_dir, nprocesses))?;
let action_space = env.getattr("action_space")?;
@ -60,10 +60,10 @@ impl VecGymEnv {
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
let (obs, reward, is_done) = Python::with_gil(|py| {
let step = self.env.call_method_bound(py, "step", (action,), None)?;
let step = step.bind(py);
let step = self.env.call_method(py, "step", (action,), None)?;
let step = step.as_ref(py);
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
let obs_buffer = pyo3::buffer::PyBuffer::get_bound(&obs)?;
let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
let reward: Vec<f32> = step.get_item(1)?.extract()?;
let is_done: Vec<f32> = step.get_item(2)?.extract()?;

View File

@ -5,7 +5,7 @@ use candle_transformers::models::segformer::{
Config, ImageClassificationModel, SemanticSegmentationModel,
};
use clap::{Args, Parser, Subcommand};
use imageproc::image::Rgb;
use image::Rgb;
use imageproc::integral_image::ArrayData;
use std::collections::HashMap;
use std::path::PathBuf;

View File

@ -292,13 +292,6 @@ fn text_embeddings(
.map_err(E::msg)?
.get_ids()
.to_vec();
if tokens.len() > sd_config.clip.max_position_embeddings {
anyhow::bail!(
"the prompt is too long, {} > max-tokens ({})",
tokens.len(),
sd_config.clip.max_position_embeddings
)
}
while tokens.len() < sd_config.clip.max_position_embeddings {
tokens.push(pad_id)
}
@ -326,13 +319,6 @@ fn text_embeddings(
.map_err(E::msg)?
.get_ids()
.to_vec();
if uncond_tokens.len() > sd_config.clip.max_position_embeddings {
anyhow::bail!(
"the negative prompt is too long, {} > max-tokens ({})",
uncond_tokens.len(),
sd_config.clip.max_position_embeddings
)
}
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
uncond_tokens.push(pad_id)
}

View File

@ -288,12 +288,12 @@ fn main() -> Result<()> {
};
let device = candle_examples::device(args.cpu)?;
let model = if args.quantized {
let (model, device) = if args.quantized {
let filename = &filenames[0];
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QStableLM::new(&config, vb)?;
Model::Quantized(model)
(Model::Quantized(model), Device::Cpu)
} else {
let dtype = if device.is_cuda() {
DType::BF16
@ -302,7 +302,7 @@ fn main() -> Result<()> {
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = StableLM::new(&config, vb)?;
Model::StableLM(model)
(Model::StableLM(model), device)
};
println!("loaded the model in {:?}", start.elapsed());

View File

@ -12,23 +12,12 @@ use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use clap::{Parser, ValueEnum};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
const DTYPE: DType = DType::F32;
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
T5Base,
T5Small,
T5Large,
T5_3B,
Mt5Base,
Mt5Small,
Mt5Large,
}
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -47,15 +36,6 @@ struct Args {
#[arg(long)]
revision: Option<String>,
#[arg(long)]
model_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// Enable decoding.
#[arg(long)]
decode: bool,
@ -91,10 +71,6 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model to be used.
#[arg(long, default_value = "t5-small")]
which: Which,
}
struct T5ModelBuilder {
@ -106,17 +82,8 @@ struct T5ModelBuilder {
impl T5ModelBuilder {
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
let device = candle_examples::device(args.cpu)?;
let (default_model, default_revision) = match args.which {
Which::T5Base => ("t5-base", "main"),
Which::T5Small => ("t5-small", "refs/pr/15"),
Which::T5Large => ("t5-large", "main"),
Which::T5_3B => ("t5-3b", "main"),
Which::Mt5Base => ("google/mt5-base", "refs/pr/5"),
Which::Mt5Small => ("google/mt5-small", "refs/pr/6"),
Which::Mt5Large => ("google/mt5-large", "refs/pr/2"),
};
let default_model = default_model.to_string();
let default_revision = default_revision.to_string();
let default_model = "t5-small".to_string();
let default_revision = "refs/pr/15".to_string();
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
@ -126,35 +93,14 @@ impl T5ModelBuilder {
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
let api = Api::new()?;
let repo = api.repo(repo);
let config_filename = match &args.config_file {
None => repo.get("config.json")?,
Some(f) => f.into(),
};
let tokenizer_filename = match &args.tokenizer_file {
None => match args.which {
Which::Mt5Base => api
.model("lmz/mt5-tokenizers".into())
.get("mt5-base.tokenizer.json")?,
Which::Mt5Small => api
.model("lmz/mt5-tokenizers".into())
.get("mt5-small.tokenizer.json")?,
Which::Mt5Large => api
.model("lmz/mt5-tokenizers".into())
.get("mt5-large.tokenizer.json")?,
_ => repo.get("tokenizer.json")?,
},
Some(f) => f.into(),
};
let weights_filename = match &args.model_file {
Some(f) => f.split(',').map(|v| v.into()).collect::<Vec<_>>(),
None => {
if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
} else {
vec![repo.get("model.safetensors")?]
}
}
let api = api.repo(repo);
let config_filename = api.get("config.json")?;
let tokenizer_filename = api.get("tokenizer.json")?;
let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
{
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
} else {
vec![api.get("model.safetensors")?]
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: t5::Config = serde_json::from_str(&config)?;

View File

@ -34,7 +34,6 @@ from the hub.
- `--timestamps`: enable the timestamp mode where some timestamps are reported
for each recognized audio extracts.
- `--model`: the model to be used. Models that do not end with `-en` are
multilingual models, other ones are English only models. The supported OpenAI
Whisper models are `tiny`, `tiny.en`, `base`, `base.en`, `small`, `small.en`,
`medium`, `medium.en`, `large`, `large-v2` and `large-v3`. The supported
Distil-Whisper models are `distil-medium.en`, `distil-large-v2` and `distil-large-v3`.
multilingual models, other ones are English only models. The supported models
are `tiny`, `tiny.en`, `base`, `base.en`, `small`, `small.en`, `medium`,
`medium.en`, `large`, and `large-v2`.

View File

@ -374,8 +374,6 @@ enum WhichModel {
DistilMediumEn,
#[value(name = "distil-large-v2")]
DistilLargeV2,
#[value(name = "distil-large-v3")]
DistilLargeV3,
}
impl WhichModel {
@ -388,8 +386,7 @@ impl WhichModel {
| Self::Large
| Self::LargeV2
| Self::LargeV3
| Self::DistilLargeV2
| Self::DistilLargeV3 => true,
| Self::DistilLargeV2 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
false
}
@ -411,7 +408,6 @@ impl WhichModel {
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
Self::DistilLargeV3 => ("distil-whisper/distil-large-v3", "main"),
}
}
}

View File

@ -99,7 +99,7 @@ pub fn report_detect(
let h_ratio = initial_h as f32 / h as f32;
let mut img = img.to_rgb8();
let font = Vec::from(include_bytes!("roboto-mono-stripped.ttf") as &[u8]);
let font = ab_glyph::FontRef::try_from_slice(&font).map_err(candle::Error::wrap)?;
let font = rusttype::Font::try_from_vec(font);
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
println!(
@ -119,28 +119,27 @@ pub fn report_detect(
);
}
if legend_size > 0 {
imageproc::drawing::draw_filled_rect_mut(
&mut img,
imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, legend_size),
image::Rgb([170, 0, 0]),
);
let legend = format!(
"{} {:.0}%",
candle_examples::coco_classes::NAMES[class_index],
100. * b.confidence
);
imageproc::drawing::draw_text_mut(
&mut img,
image::Rgb([255, 255, 255]),
xmin,
ymin,
ab_glyph::PxScale {
x: legend_size as f32 - 1.,
y: legend_size as f32 - 1.,
},
&font,
&legend,
)
if let Some(font) = font.as_ref() {
imageproc::drawing::draw_filled_rect_mut(
&mut img,
imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, legend_size),
image::Rgb([170, 0, 0]),
);
let legend = format!(
"{} {:.0}%",
candle_examples::coco_classes::NAMES[class_index],
100. * b.confidence
);
imageproc::drawing::draw_text_mut(
&mut img,
image::Rgb([255, 255, 255]),
xmin,
ymin,
rusttype::Scale::uniform(legend_size as f32 - 1.),
font,
&legend,
)
}
}
}
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.5.0"
version = "0.4.2"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.5.0" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.2" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.5.0"
version = "0.4.2"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -1,8 +1,5 @@
fn main() {
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=src/compatibility.cuh");
println!("cargo:rerun-if-changed=src/cuda_utils.cuh");
println!("cargo:rerun-if-changed=src/binary_op_macros.cuh");
let builder = bindgen_cuda::Builder::default();
println!("cargo:info={builder:?}");

View File

@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \
) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
if (is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
TYPENAME x = inp ? inp[i] : out[i]; \
out[i] = x * mul + add; \

View File

@ -12,8 +12,8 @@ extern "C" __global__ void FN_NAME( \
const size_t *dims = dims_and_strides; \
const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \
const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \
bool lhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, lhs_strides); \
bool rhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, rhs_strides); \
bool lhs_cont = is_contiguous(num_dims, dims, lhs_strides); \
bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \
if (lhs_cont && rhs_cont) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
TYPENAME x = lhs[i]; \

View File

@ -11,7 +11,7 @@ __device__ void cast_(
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
if (is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = inp[i];
}
@ -34,7 +34,7 @@ __device__ void cast_through(
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (info == nullptr || is_contiguous(num_dims, dims, strides)) {
if (is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = static_cast<T>(static_cast<I>(inp[i]));
}
@ -83,18 +83,6 @@ CAST_OP(double, __nv_bfloat16, cast_f64_bf16)
CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8)
CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16)
CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16)
#else
#include <cuda.h>
#if CUDA_VERSION >= 11000
CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
CAST_OP(float, __nv_bfloat16, cast_f32_bf16)
CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8)
CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16)
CAST_THROUGH_OP(__nv_bfloat16, double, float, cast_bf16_f64)
CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16)
CAST_THROUGH_OP(double, __nv_bfloat16, float, cast_f64_bf16)
CAST_THROUGH_OP(uint8_t, __nv_bfloat16, float, cast_u8_bf16)
#endif
#endif
#if __CUDA_ARCH__ >= 530

View File

@ -14,7 +14,7 @@ __device__ bool is_contiguous(
size_t acc = 1;
for (unsigned int d = 0; d < num_dims; d++) {
unsigned int dim_idx = num_dims - 1 - d;
if (dims[dim_idx] > 1 && acc != strides[dim_idx]) {
if (acc != strides[dim_idx]) {
return false;
}
acc *= dims[dim_idx];

View File

@ -168,10 +168,8 @@ IS_OP(__half, uint8_t, is_u8_f16)
GATHER_OP(__half, int64_t, gather_i64_f16)
GATHER_OP(__half, uint32_t, gather_u32_f16)
GATHER_OP(__half, uint8_t, gather_u8_f16)
IA_OP(__half, int64_t, ia_i64_f16)
IA_OP(__half, uint32_t, ia_u32_f16)
IA_OP(__half, uint8_t, ia_u8_f16)
SA_OP(__half, int64_t, sa_i64_f16)
SA_OP(__half, uint32_t, sa_u32_f16)
SA_OP(__half, uint8_t, sa_u8_f16)
#endif

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,6 @@
#include <cmath>
#include <stdint.h>
#define WARP_SIZE 32
const int BLOCK_SIZE = 1024;
// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32
@ -50,59 +49,6 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
dst[dst_id] = shr[0];
}
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
}
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
template <typename T>
__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int block_size = blockDim.x;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = static_cast<float>(x[row*ncols + col]);
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
if (alpha == nullptr) {
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = static_cast<T>(scale * static_cast<float>(x[row*ncols + col]));
}
}
else {
for (int col = tid; col < ncols; col += block_size) {
float a = static_cast<float>(alpha[col]);
dst[row*ncols + col] = static_cast<T>(scale * static_cast<float>(x[row*ncols + col]) * a);
}
}
}
// Softmax implementation adapted from ggml.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L4159
template <typename T, typename ACC>
@ -147,65 +93,6 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
}
}
template <typename T>
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= bh * td) return;
uint32_t rope_idx = idx % (td / 2);
T c = cos[rope_idx];
T s = sin[rope_idx];
dst[2 * idx] = src[2 * idx] * c - src[2 * idx + 1] * s;
dst[2 * idx + 1] = src[2 * idx] * s + src[2 * idx + 1] * c;
}
template <typename T>
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= bh * td) return;
uint32_t i_bh = idx / (td / 2);
uint32_t i_td = idx - (td / 2) * i_bh;
uint32_t i_t = i_td / (d / 2);
uint32_t i_d = i_td - (d / 2) * i_t;
uint32_t i1 = i_bh * td + i_t * d + i_d;
uint32_t i2 = i1 + d / 2;
uint32_t i_cs = i_t * (d / 2) + i_d;
T c = cos[i_cs];
T s = sin[i_cs];
dst[i1] = src[i1] * c - src[i2] * s;
dst[i2] = src[i1] * s + src[i2] * c;
}
template <typename T>
__device__ void rope_thd(
const T * src,
const T * cos,
const T * sin,
T * dst,
const uint32_t b,
const uint32_t t,
const uint32_t h,
const uint32_t d
) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (2 * idx >= b * t * h * d) return;
uint32_t i_bth = idx / (d / 2);
uint32_t i_d = idx - (d / 2) * i_bth;
uint32_t i_t = (i_bth / h) % t;
uint32_t i1 = i_bth * d + i_d;
uint32_t i2 = i1 + d / 2;
uint32_t i_cs = i_t * (d / 2) + i_d;
T c = cos[i_cs];
T s = sin[i_cs];
dst[i1] = src[i1] * c - src[i2] * s;
dst[i2] = src[i1] * s + src[i2] * c;
}
template <typename T>
__device__ void
fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
@ -454,57 +341,14 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
softmax<TYPENAME, ACC_TYPENAME>(src, dst, n_cols); \
} \
#define RMSNORM_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
const int n_cols, const float eps) { \
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
} \
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
extern "C" __global__ void FN_NAME_I( \
const TYPENAME *src, \
const TYPENAME *cos, \
const TYPENAME *sin, \
TYPENAME *dst, \
const uint32_t bh, \
const uint32_t td) { \
ropei<TYPENAME>(src, cos, sin, dst, bh, td); \
} \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, \
const TYPENAME *cos, \
const TYPENAME *sin, \
TYPENAME *dst, \
const uint32_t bh, \
const uint32_t td, \
const uint32_t d) { \
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
} \
extern "C" __global__ void FN_NAME_THD( \
const TYPENAME *src, \
const TYPENAME *cos, \
const TYPENAME *sin, \
TYPENAME *dst, \
const uint32_t b, \
const uint32_t t, \
const uint32_t h, \
const uint32_t d) { \
rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
} \
#if __CUDA_ARCH__ >= 800
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
SUM_OP(__nv_bfloat16, sum_bf16)
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
#endif
#if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16)
RMSNORM_OP(__half, rmsnorm_f16)
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
SUM_OP(__half, sum_f16)
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
#endif
@ -514,10 +358,6 @@ SUM_OP(double, sum_f64)
SUM_OP(uint32_t, sum_u32)
SOFTMAX_OP(float, float, softmax_f32)
SOFTMAX_OP(double, double, softmax_f64)
RMSNORM_OP(float, rmsnorm_f32)
RMSNORM_OP(double, rmsnorm_f64)
ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32)
ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)

View File

@ -13,7 +13,7 @@ extern "C" __global__ void FN_NAME( \
) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
if (is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
TYPENAME x = inp ? inp[i] : out[i]; \
out[i] = FUNC; \
@ -71,7 +71,7 @@ extern "C" __global__ void FN_NAME( \
) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
if (is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
TYPENAME x = inp ? inp[i] : out[i]; \
out[i] = FUNC; \
@ -86,11 +86,6 @@ extern "C" __global__ void FN_NAME( \
} \
} \
template<typename T>
__device__ T sign_(T t) {
return static_cast<T>(t > static_cast<T>(0)) - static_cast<T>(t < static_cast<T>(0));
}
#if __CUDA_ARCH__ >= 800
UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
@ -115,7 +110,6 @@ UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
#endif
#if __CUDA_ARCH__ >= 530
@ -141,7 +135,6 @@ UNARY_OP(__half, urelu_f16, relu_fwd(x))
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
UNARY_OP(__half, usilu_f16, silu_fwd(x))
UNARY_OP1(__half, upowf_f16, powg(x, param))
UNARY_OP(__half, usign_f16, sign_(x))
#endif
UNARY_OP(uint8_t, ucopy_u8, x)
@ -191,5 +184,3 @@ UNARY_OP(float, usilu_f32, silu_fwd(x))
UNARY_OP(double, usilu_f64, silu_fwd(x))
UNARY_OP1(float, upowf_f32, powg(x, param))
UNARY_OP1(double, upowf_f64, powg(x, param))
UNARY_OP(float, usign_f32, sign_(x))
UNARY_OP(double, usign_f64, sign_(x))

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.5.0"
version = "0.4.2"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -60,24 +60,21 @@ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \
BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
#define INT64_BINARY_OP(NAME, FN) \
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
#define BINARY_OP_OUT(NAME, FN) \
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \
BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
#define INT64_BINARY_OP(NAME, FN) \
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
#define INT64_BINARY_OP_OUT(NAME, FN) \
BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
#define BFLOAT_BINARY_OP_OUT(NAME, FN) \
BINARY(FN, bfloat, uint8_t, NAME##_bf16, NAME##_bf16_strided);
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
BINARY_OP(x * y, mul)
@ -115,11 +112,4 @@ BFLOAT_BINARY_OP(x * y, mul)
BFLOAT_BINARY_OP(x / y, div)
BFLOAT_BINARY_OP(MIN(x, y), min)
BFLOAT_BINARY_OP(MAX(x, y), max)
BFLOAT_BINARY_OP_OUT(eq, x == y)
BFLOAT_BINARY_OP_OUT(ne, x != y)
BFLOAT_BINARY_OP_OUT(le, x <= y)
BFLOAT_BINARY_OP_OUT(lt, x < y)
BFLOAT_BINARY_OP_OUT(ge, x >= y)
BFLOAT_BINARY_OP_OUT(gt, x > y)
#endif

View File

@ -72,60 +72,27 @@ kernel void FN_NAME_STRIDED( \
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
} \
// u32
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
#if __METAL_VERSION__ >= 220
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
#endif
#if defined(__HAVE_BFLOAT__)
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
#endif
// u8
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
#if __METAL_VERSION__ >= 220
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
#endif
#if defined(__HAVE_BFLOAT__)
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
#endif
// f16
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t)
CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t)
CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t)
#if defined(__HAVE_BFLOAT__)
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
#endif
// i64
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t)
CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t)
CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half)
#if defined(__HAVE_BFLOAT__)
CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float)
#endif
// f32
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t)
CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t)
CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t)
#if defined(__HAVE_BFLOAT__)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
#endif
// bf16
#if defined(__HAVE_BFLOAT__)
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t)
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
#endif

View File

@ -1,9 +1,3 @@
#include <metal_stdlib>
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
template <typename T>
METAL_FUNC void im2col(
constant size_t &dst_numel,
@ -206,331 +200,14 @@ kernel void FN_NAME( \
upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \
} \
template <typename T, typename A>
METAL_FUNC void avg_pool2d(
constant size_t &w_k,
constant size_t &h_k,
constant size_t &w_stride,
constant size_t &h_stride,
constant size_t *src_dims,
constant size_t *src_strides,
device const T *src,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
const size_t c = src_dims[1];
const size_t w_in = src_dims[2];
const size_t h_in = src_dims[3];
const size_t w_out = (w_in - w_k) / w_stride + 1;
const size_t h_out = (h_in - h_k) / h_stride + 1;
if (tid >= src_dims[0] * c * w_out * h_out) {
return;
}
const size_t b_idx = tid / (w_out * h_out * c);
const size_t c_idx = (tid / (w_out * h_out)) % c;
const size_t dst_w = (tid / h_out) % w_out;
const size_t dst_h = tid % h_out;
const size_t src_idx0 = b_idx * src_strides[0];
A d = 0;
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
size_t src_w = w_stride * dst_w + w_offset;
if (src_w >= w_in){
continue;
}
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
size_t src_h = h_stride * dst_h + h_offset;
if (src_h >= h_in) {
continue;
}
const size_t src_idx = src_idx0 + c_idx * src_strides[1] + src_w * src_strides[2] + src_h * src_strides[3];
d += static_cast<A>(src[src_idx]);
}
}
dst[tid] = static_cast<T>(d / (w_k * h_k));
}
#define AVGPOOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \
kernel void FN_NAME( \
constant size_t &w_k, \
constant size_t &h_k, \
constant size_t &w_s, \
constant size_t &h_s, \
constant size_t *src_dims, \
constant size_t *src_s, \
device const TYPENAME *src, \
device TYPENAME *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
avg_pool2d<TYPENAME, TYPEACC>(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \
} \
template <typename T>
METAL_FUNC void max_pool2d(
constant size_t &w_k,
constant size_t &h_k,
constant size_t &w_stride,
constant size_t &h_stride,
constant size_t *src_dims,
constant size_t *src_strides,
device const T *src,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
const size_t c = src_dims[1];
const size_t w_in = src_dims[2];
const size_t h_in = src_dims[3];
const size_t w_out = (w_in - w_k) / w_stride + 1;
const size_t h_out = (h_in - h_k) / h_stride + 1;
if (tid >= src_dims[0] * c * w_out * h_out) {
return;
}
const size_t b_idx = tid / (w_out * h_out * c);
const size_t c_idx = (tid / (w_out * h_out)) % c;
const size_t dst_w = (tid / h_out) % w_out;
const size_t dst_h = tid % h_out;
const size_t src_idx0 = b_idx * src_strides[0];
T d = 0;
bool set = false;
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
size_t src_w = w_stride * dst_w + w_offset;
if (src_w >= w_in){
continue;
}
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
size_t src_h = h_stride * dst_h + h_offset;
if (src_h >= h_in) {
continue;
}
const size_t src_idx = src_idx0 + c_idx * src_strides[1] + src_w * src_strides[2] + src_h * src_strides[3];
if (set) {
d = MAX(d, src[src_idx]);
}
else {
d = src[src_idx];
set = true;
}
}
}
dst[tid] = d;
}
#define MAXPOOL2D_OP(TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &w_k, \
constant size_t &h_k, \
constant size_t &w_s, \
constant size_t &h_s, \
constant size_t *src_dims, \
constant size_t *src_s, \
device const TYPENAME *src, \
device TYPENAME *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
max_pool2d<TYPENAME>(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \
} \
// Naive implementation of conv_transpose1d.
template <typename T, typename A>
METAL_FUNC void conv_transpose1d(
constant size_t &l_out,
constant size_t &stride,
constant size_t &padding,
constant size_t &out_padding,
constant size_t &dilation,
constant size_t *src_dims,
constant size_t *src_strides,
constant size_t *k_dims,
constant size_t *k_strides,
device const T *src,
device const T *k,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
// src: (b_size, c_in, l_in)
// kernel: (c_in, c_out, l_k)
const size_t l_k = k_dims[2];
const size_t c_out = k_dims[1];
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
if (tid >= src_dims[0] * c_out * l_out) {
return;
}
const size_t b_idx = tid / (l_out * c_out);
const size_t dst_c_idx = (tid / l_out) % c_out;
const size_t out_x = tid % l_out;
const size_t src_idx0 = b_idx * src_strides[0];
A d = 0;
for (int k_x = 0; k_x < (int)l_k; ++k_x) {
// let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
if (inp_x_stride < 0 || inp_x_stride % stride) {
continue;
}
int inp_x = inp_x_stride / stride;
if (inp_x >= l_in) continue;
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
const size_t src_idx = src_idx0 + src_c_idx * src_strides[1] + inp_x * src_strides[2];
const size_t k_idx = src_c_idx * k_strides[0] + dst_c_idx * k_strides[1] + k_x * k_strides[2];
d += static_cast<A>(src[src_idx]) * static_cast<A>(k[k_idx]);
}
}
dst[tid] = static_cast<T>(d);
}
#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \
kernel void FN_NAME( \
constant size_t &l_out, \
constant size_t &stride, \
constant size_t &padding, \
constant size_t &out_padding, \
constant size_t &dilation, \
constant size_t *src_dims, \
constant size_t *src_strides, \
constant size_t *k_dims, \
constant size_t *k_strides, \
device const TYPENAME *src, \
device const TYPENAME *k, \
device TYPENAME *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
conv_transpose1d<TYPENAME, TYPEACC>(l_out, stride, padding, out_padding, dilation, src_dims, src_strides, k_dims, k_strides, src, k, dst, tid); \
} \
template <typename T, typename A>
METAL_FUNC void conv_transpose2d(
constant size_t &w_out,
constant size_t &h_out,
constant size_t &stride,
constant size_t &padding,
constant size_t &out_padding,
constant size_t &dilation,
constant size_t *input_dims,
constant size_t *input_stride,
constant size_t *k_dims,
constant size_t *k_stride,
device const T *src,
device const T *k,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
const size_t h_k = k_dims[2];
const size_t w_k = k_dims[3];
const size_t c_out = k_dims[1];
const size_t c_in = input_dims[1];
const size_t h_in = input_dims[2];
const size_t w_in = input_dims[3];
if (tid >= input_dims[0] * c_out * w_out * h_out) {
return;
}
const size_t b_idx = tid / (w_out * h_out * c_out);
const size_t dst_c_idx = (tid / (w_out * h_out)) % c_out;
const size_t out_y = (tid / w_out) % h_out;
const size_t out_x = tid % w_out;
const size_t src_idx0 = b_idx * input_stride[0];
A d = 0;
for (int k_x = 0; k_x < (int)w_k; ++k_x) {
const int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
if (inp_x_stride < 0 || inp_x_stride % stride) {
continue;
}
const int inp_x = inp_x_stride / stride;
if (inp_x >= w_in) continue;
for (int k_y = 0; k_y < (int)h_k; ++k_y) {
const int inp_y_stride = (int)(out_y + padding) - k_y * dilation;
if (inp_y_stride < 0 || inp_y_stride % stride) {
continue;
}
const int inp_y = inp_y_stride / stride;
if (inp_y >= h_in) continue;
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
const size_t src_idx = src_idx0 + src_c_idx * input_stride[1] + inp_y * input_stride[2] + inp_x * input_stride[3];
const size_t k_idx = src_c_idx * k_stride[0] + dst_c_idx * k_stride[1] + k_y * k_stride[2] + k_x * k_stride[3];
d += static_cast<A>(src[src_idx]) * static_cast<A>(k[k_idx]);
}
}
}
dst[tid] = static_cast<T>(d);
}
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
kernel void FN_NAME( \
constant size_t &w_out, \
constant size_t &h_out, \
constant size_t &stride, \
constant size_t &padding, \
constant size_t &out_padding, \
constant size_t &dilation, \
constant size_t *input_dims, \
constant size_t *input_stride, \
constant size_t *k_dims, \
constant size_t *k_stride, \
device const TYPENAME *src, \
device const TYPENAME *k, \
device TYPENAME *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
conv_transpose2d<TYPENAME, TYPEACC>(w_out, h_out, stride, padding, out_padding, dilation, input_dims, input_stride, k_dims, k_stride, src, k, dst, tid); \
} \
IM2COL_OP(float, im2col_f32)
IM2COL_OP(half, im2col_f16)
IM2COL_OP(uint8_t, im2col_u8)
IM2COL_OP(uint32_t, im2col_u32)
#if defined(__HAVE_BFLOAT__)
IM2COL_OP(bfloat, im2col_bf16)
#endif
IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)
UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
UPSAMPLE_NEAREST2D_OP(half, upsample_nearest2d_f16)
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
#if defined(__HAVE_BFLOAT__)
UPSAMPLE_NEAREST2D_OP(bfloat, upsample_nearest2d_bf16)
#endif
MAXPOOL2D_OP(float, max_pool2d_f32)
MAXPOOL2D_OP(half, max_pool2d_f16)
MAXPOOL2D_OP(uint32_t, max_pool2d_u32)
MAXPOOL2D_OP(uint8_t, max_pool2d_u8)
#if defined(__HAVE_BFLOAT__)
MAXPOOL2D_OP(bfloat, max_pool2d_bf16)
#endif
AVGPOOL2D_OP(float, float, avg_pool2d_f32)
AVGPOOL2D_OP(half, float, avg_pool2d_f16)
AVGPOOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32)
AVGPOOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8)
#if defined(__HAVE_BFLOAT__)
AVGPOOL2D_OP(bfloat, float, avg_pool2d_bf16)
#endif
CONVT1D_OP(float, float, conv_transpose1d_f32)
CONVT1D_OP(half, float, conv_transpose1d_f16)
CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8)
CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32)
#if defined(__HAVE_BFLOAT__)
CONVT1D_OP(bfloat, float, conv_transpose1d_bf16)
#endif
CONVT2D_OP(float, float, conv_transpose2d_f32)
CONVT2D_OP(half, float, conv_transpose2d_f16)
#if defined(__HAVE_BFLOAT__)
CONVT1D_OP(bfloat, float, conv_transpose2d_bf16)
#endif

View File

@ -1,21 +1,6 @@
#include <metal_stdlib>
using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void index(
constant size_t &dst_size,
@ -23,9 +8,6 @@ METAL_FUNC void index(
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &ids_size,
constant bool &contiguous,
constant size_t *src_dims,
constant size_t *src_strides,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
@ -44,8 +26,7 @@ METAL_FUNC void index(
// No need to check for zero we're only allowing unsized.
*/
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
output[tid] = input[strided_src_i];
output[tid] = input[src_i];
}
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
@ -55,15 +36,12 @@ kernel void NAME( \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &ids_size, \
constant bool &contiguous, \
constant size_t *src_dims, \
constant size_t *src_strides, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, contiguous, src_dims, src_strides, input, input_ids, output, tid); \
index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
}
@ -187,23 +165,8 @@ kernel void NAME( \
}
INDEX_OP(is_i64_f32, int64_t, float)
INDEX_OP(is_i64_f16, int64_t, half)
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_i64_bf16, int64_t, bfloat)
#endif
INDEX_OP(is_u32_f32, uint32_t, float)
INDEX_OP(is_u32_f16, uint32_t, half)
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
#endif
INDEX_OP(is_u8_f32, uint8_t, float)
INDEX_OP(is_u8_f16, uint8_t, half)
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
#endif
INDEX_OP(is_u32_f32, uint, float)
INDEX_OP(is_u32_f16, uint, half)
GATHER_OP(gather_u32_f32, uint, float)
GATHER_OP(gather_u32_f16, uint, half)
@ -214,38 +177,34 @@ SCATTER_ADD_OP(sa_i64_f32, int64_t, float)
SCATTER_ADD_OP(sa_u32_f16, uint32_t, half)
SCATTER_ADD_OP(sa_u8_f16, uint8_t, half)
SCATTER_ADD_OP(sa_i64_f16, int64_t, half)
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat)
SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
#endif
// i64
INDEX_ADD_OP(ia_i64_f16, int64_t, half)
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t)
INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t)
INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
#if defined(__HAVE_BFLOAT__)
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
#endif
// u32
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
INDEX_ADD_OP(ia_u32_f32, uint32_t, float)
INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t)
INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t)
INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
#if defined(__HAVE_BFLOAT__)
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
#endif
// u8
INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
INDEX_ADD_OP(ia_u8_f32, uint8_t, float)
INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t)
INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
#if defined(__HAVE_BFLOAT__)
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
#endif
INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t)

View File

@ -193,7 +193,7 @@ macro_rules! ops{
pub mod unary {
ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
tanh, recip, silu, sign
tanh, recip, silu
);
}
pub mod binary {
@ -406,7 +406,7 @@ pub fn call_copy2d(
);
let width: usize = d1 * d2;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width / 4);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
@ -750,193 +750,6 @@ pub fn call_last_softmax(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rms_norm(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
elements_to_sum: usize,
eps: f32,
input: &Buffer,
input_offset: usize,
alpha: &Buffer,
alpha_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
length,
elements_to_sum,
(input, input_offset),
output,
(alpha, alpha_offset),
eps
)
);
let out_length = length / elements_to_sum;
let thread_group_count = MTLSize {
width: out_length as u64,
height: 1,
depth: 1,
};
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
elements_to_sum as u64,
)
.next_power_of_two();
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rope_i(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
bh: usize,
td: usize,
src: &Buffer,
src_offset: usize,
cos: &Buffer,
cos_offset: usize,
sin: &Buffer,
sin_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
bh,
td,
(src, src_offset),
(cos, cos_offset),
(sin, sin_offset),
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2);
encoder.use_resource(src, metal::MTLResourceUsage::Read);
encoder.use_resource(cos, metal::MTLResourceUsage::Read);
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rope_thd(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
b: usize,
t: usize,
h: usize,
d: usize,
src: &Buffer,
src_offset: usize,
cos: &Buffer,
cos_offset: usize,
sin: &Buffer,
sin_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
b,
t,
h,
d,
(src, src_offset),
(cos, cos_offset),
(sin, sin_offset),
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, (b * t * h * d) / 2);
encoder.use_resource(src, metal::MTLResourceUsage::Read);
encoder.use_resource(cos, metal::MTLResourceUsage::Read);
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rope(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
bh: usize,
td: usize,
d: usize,
src: &Buffer,
src_offset: usize,
cos: &Buffer,
cos_offset: usize,
sin: &Buffer,
sin_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
bh,
td,
d,
(src, src_offset),
(cos, cos_offset),
(sin, sin_offset),
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2);
encoder.use_resource(src, metal::MTLResourceUsage::Read);
encoder.use_resource(cos, metal::MTLResourceUsage::Read);
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_affine(
device: &Device,
@ -1196,13 +1009,8 @@ pub fn call_index_select(
shape: &[usize],
ids_size: usize,
dim: usize,
contiguous: bool,
src_dims: &[usize],
src_strides: &[usize],
input: &Buffer,
src_offset: usize,
ids: &Buffer,
ids_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
@ -1224,11 +1032,8 @@ pub fn call_index_select(
src_dim_size,
right_size,
ids_size,
contiguous,
src_dims,
src_strides,
(input, src_offset),
(ids, ids_offset),
input,
ids,
output
)
);
@ -1496,12 +1301,9 @@ pub fn call_gemm(
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// lhs has shape b, m, k
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let a_trans = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
false
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
} else if lhs_m1 == m && lhs_m2 == 1 {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {
@ -1510,10 +1312,9 @@ pub fn call_gemm(
mnk: (m, n, k),
})?;
};
// rhs has shape b, k, n
let b_trans = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
false
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
} else if rhs_m1 == k && rhs_m2 == 1 {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {
@ -2025,149 +1826,5 @@ fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}
#[allow(clippy::too_many_arguments)]
pub fn call_pool2d(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
strides: &[usize],
out_w: usize,
out_h: usize,
w_k: usize,
h_k: usize,
w_stride: usize,
h_stride: usize,
input: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let dst_el = out_w * out_h * shape[0] * shape[1];
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(w_k, h_k, w_stride, h_stride, shape, strides, input, output)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_conv_transpose1d(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
dilation: usize,
stride: usize,
padding: usize,
out_padding: usize,
c_out: usize,
l_out: usize,
b_size: usize,
src_shape: &[usize],
src_strides: &[usize],
kernel_shape: &[usize],
kernel_strides: &[usize],
input: &Buffer,
input_offset: usize,
kernel: &Buffer,
kernel_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let dst_el = c_out * l_out * b_size;
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
l_out,
stride,
padding,
out_padding,
dilation,
src_shape,
src_strides,
kernel_shape,
kernel_strides,
(input, input_offset),
(kernel, kernel_offset),
output
)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub struct CallConvTranspose2dCfg<'a> {
pub dilation: usize,
pub stride: usize,
pub padding: usize,
pub output_padding: usize,
pub c_out: usize,
pub out_w: usize,
pub out_h: usize,
pub b_size: usize,
pub input_dims: &'a [usize],
pub input_stride: &'a [usize],
pub kernel_dims: &'a [usize],
pub kernel_stride: &'a [usize],
pub input_offset: usize,
pub kernel_offset: usize,
}
pub fn call_conv_transpose2d(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
cfg: CallConvTranspose2dCfg,
input: &Buffer,
kernel: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size;
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
cfg.out_w,
cfg.out_h,
cfg.stride,
cfg.padding,
cfg.output_padding,
cfg.dilation,
cfg.input_dims,
cfg.input_stride,
cfg.kernel_dims,
cfg.kernel_stride,
(input, cfg.input_offset),
(kernel, cfg.kernel_offset),
output
)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[cfg(test)]
mod tests;

View File

@ -21,52 +21,6 @@ METAL_FUNC uint get_strided_index(
constant int THREADGROUP_SIZE = 2048;
template<typename T>
METAL_FUNC void argmin(
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides,
constant size_t &el_to_sum_per_block,
device const T *src,
device uint *dst,
uint id,
uint tid,
uint dst_id,
uint block_dim,
threadgroup T *shared_memory,
threadgroup uint *shared_indices
) {
bool notset = true;
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
if (notset || src[strided_i] < shared_memory[tid]) {
shared_memory[tid] = src[strided_i];
/* Assume that the reduction takes place over the last dimension which is contiguous. */
shared_indices[tid] = idx % dims[num_dims - 1];
notset = false;
}
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_none);
// reduction in shared memory
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) {
shared_indices[tid] = shared_indices[tid + s];
shared_memory[tid] = shared_memory[tid + s];
} \
threadgroup_barrier(mem_flags::mem_none);
}
if (tid == 0) {
dst[dst_id] = shared_indices[0];
}
}
#define ARGMIN(NAME, T, MAXVALUE) \
kernel void NAME( \
@ -81,63 +35,53 @@ kernel void NAME( \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
shared_memory[tid] = MAXVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
argmin<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MAXVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
bool notset = true; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || src[strided_i] < shared_memory[tid]) { \
shared_memory[tid] = src[strided_i]; \
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
template<typename T>
METAL_FUNC void argmax(
constant size_t & num_dims,
constant size_t * dims,
constant size_t * strides,
constant size_t & el_to_sum_per_block,
device const T * src,
device uint * dst,
uint id,
uint tid,
uint dst_id,
uint block_dim,
threadgroup T * shared_memory,
threadgroup uint * shared_indices
) {
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
bool notset = true;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
if (notset || shared_memory[tid] < src[strided_i]) {
shared_memory[tid] = src[strided_i];
shared_indices[tid] = idx % dims[num_dims - 1];
notset = false;
}
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_none);
// reduction in shared memory
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) {
shared_indices[tid] = shared_indices[tid + s];
shared_memory[tid] = shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_none);
}
// Thread 0 writes the result of the reduction
if (tid == 0) {
dst[dst_id] = shared_indices[0];
}
}
#define ARGMAX(NAME, T, MINVALUE) \
kernel void NAME( \
constant size_t &num_dims, \
@ -151,337 +95,170 @@ kernel void NAME( \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MINVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
argmax<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
bool notset = true; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || shared_memory[tid] < src[strided_i]) { \
shared_memory[tid] = src[strided_i]; \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
template<typename T>
METAL_FUNC void reduce(
constant size_t & num_dims,
constant size_t * dims,
constant size_t * strides,
constant size_t & el_to_sum_per_block,
device const T * src,
device T * dst,
uint id,
uint tid,
uint dst_id,
uint block_dim,
threadgroup T * shared_memory,
T (*fn)(T, T)
) {
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
T x = shared_memory[tid];
T y = src[strided_i];
shared_memory[tid] = fn(x, y);
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_none);
// reduction in shared memory
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) {
T x = shared_memory[tid];
T y = shared_memory[tid + s];
shared_memory[tid] = fn(x, y);
}
threadgroup_barrier(mem_flags::mem_none);
}
if (tid == 0) {
dst[dst_id] = shared_memory[0];
}
}
#define REDUCE(FN, NAME, T, START) \
METAL_FUNC T NAME##_##op(T x, T y) { return FN; } \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device const T *src, \
device T *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup T shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = START; \
reduce<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, NAME##_##op); \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
\
shared_memory[tid] = START; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
T x = shared_memory[tid]; \
T y = src[strided_i]; \
shared_memory[tid] = FN; \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
T x = shared_memory[tid]; \
T y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
dst[dst_id] = shared_memory[0]; \
} \
template<typename T>
METAL_FUNC void softmax(
constant size_t & src_numel,
constant size_t & el_to_sum_per_block,
device const T * src,
device T * dst,
uint id,
uint tid,
uint dst_id,
uint block_dim,
threadgroup float * shared_memory
) {
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
float tmp = -INFINITY;
while (idx < stop_idx) {
tmp = MAX(tmp, float(src[idx]));
idx += block_dim;
}
shared_memory[tid] = tmp;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]);\
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
/* wait for shared_memory[0] to be filled */
threadgroup_barrier(mem_flags::mem_threadgroup);
float _max = shared_memory[0];
/* prevent tid=0 from overwriting _max before other threads have written it */
threadgroup_barrier(mem_flags::mem_threadgroup);
shared_memory[tid] = 0;
idx = start_idx + tid;
while (idx < stop_idx) {
const float val = exp(float(src[idx]) - _max);
dst[idx] = T(val);
shared_memory[tid] += val;
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] += shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
const T inv_acc = T(1.0 / shared_memory[0]);
idx = start_idx + tid;
while (idx < stop_idx) {
dst[idx] *= inv_acc;
idx += block_dim;
}
}
#define SOFTMAX(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = -INFINITY; \
softmax<T>(src_numel, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory); \
} \
template<typename T>
METAL_FUNC void rmsnorm(
constant size_t & src_numel,
constant size_t & el_to_sum_per_block,
device const T * src,
device T * dst,
device const T * alpha,
constant float & eps,
uint id,
uint tid,
uint dst_id,
uint block_dim,
threadgroup float * shared_memory
) {
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
float tmp = 0;
while (idx < stop_idx) {
tmp = tmp + float(src[idx]) * float(src[idx]);
idx += block_dim;
}
shared_memory[tid] = tmp;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
/* wait for shared_memory[0] to be filled */
threadgroup_barrier(mem_flags::mem_threadgroup);
float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps);
float inv_norm = 1.0f / norm;
idx = start_idx + tid;
while (idx < stop_idx) {
float val = float(src[idx]) * inv_norm;
if (alpha != nullptr) {
val *= float(alpha[idx - start_idx]);
}
dst[idx] = T(val);
idx += block_dim;
}
}
#define RMSNORM(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
device const T *alpha, \
constant float &eps, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = 0; \
rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
} \
template<typename T>
METAL_FUNC void ropei(
constant size_t &bh,
constant size_t &td,
device const T *src,
device const T *cos,
device const T *sin,
device T *dst,
uint tid
) {
if (2 * tid >= bh * td) {
return;
}
size_t rope_idx = tid % (td / 2);
T c = cos[rope_idx];
T s = sin[rope_idx];
dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s;
dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c;
}
template<typename T>
METAL_FUNC void rope(
constant size_t &bh,
constant size_t &td,
constant size_t &d,
device const T *src,
device const T *cos,
device const T *sin,
device T *dst,
uint idx
) {
if (2 * idx >= bh * td) {
return;
}
size_t i_bh = idx / (td / 2);
size_t i_td = idx - (td / 2) * i_bh;
size_t i_t = i_td / (d / 2);
size_t i_d = i_td - (d / 2) * i_t;
size_t i1 = i_bh * td + i_t * d + i_d;
size_t i2 = i1 + d / 2;
size_t i_cs = i_t * (d / 2) + i_d;
T c = cos[i_cs];
T s = sin[i_cs];
dst[i1] = src[i1] * c - src[i2] * s;
dst[i2] = src[i1] * s + src[i2] * c;
}
template<typename T>
METAL_FUNC void rope_thd(
constant size_t &b,
constant size_t &t,
constant size_t &h,
constant size_t &d,
device const T *src,
device const T *cos,
device const T *sin,
device T *dst,
uint idx
) {
if (2 * idx >= b * t * h * d) {
return;
}
const size_t i_bth = idx / (d / 2);
const size_t i_d = idx - (d / 2) * i_bth;
const size_t i_t = (i_bth / h) % t;
const size_t i1 = i_bth * d + i_d;
const size_t i2 = i1 + d / 2;
const size_t i_cs = i_t * (d / 2) + i_d;
T c = cos[i_cs];
T s = sin[i_cs];
dst[i1] = src[i1] * c - src[i2] * s;
dst[i2] = src[i1] * s + src[i2] * c;
}
#define ROPE(FN_NAME, FN_NAME_I, FN_NAME_THD, TYPENAME) \
kernel void FN_NAME_I( \
constant size_t &bh, \
constant size_t &td, \
device const TYPENAME *src, \
device const TYPENAME *cos, \
device const TYPENAME *sin, \
device TYPENAME *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
ropei<TYPENAME>(bh, td, src, cos, sin, dst, tid); \
}\
kernel void FN_NAME( \
constant size_t &bh, \
constant size_t &td, \
constant size_t &d, \
device const TYPENAME *src, \
device const TYPENAME *cos, \
device const TYPENAME *sin, \
device TYPENAME *dst, \
uint idx [[ thread_position_in_grid ]] \
) { \
rope<TYPENAME>(bh, td, d, src, cos, sin, dst, idx); \
}\
kernel void FN_NAME_THD( \
constant size_t &b, \
constant size_t &t, \
constant size_t &h, \
constant size_t &d, \
device const TYPENAME *src, \
device const TYPENAME *cos, \
device const TYPENAME *sin, \
device TYPENAME *dst, \
uint idx [[ thread_position_in_grid ]] \
) { \
rope_thd<TYPENAME>(b, t, h, d, src, cos, sin, dst, idx); \
}\
#define SOFTMAX(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
\
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = -INFINITY; \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
\
\
float tmp = -INFINITY; \
while (idx < stop_idx) { \
tmp = MAX(tmp, float(src[idx])); \
idx += block_dim; \
} \
shared_memory[tid] = tmp; \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
/* wait for shared_memory[0] to be filled */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
float _max = shared_memory[0]; \
\
/* prevent tid=0 from overwriting _max before other threads have written it */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
shared_memory[tid] = 0; \
\
idx = start_idx + tid; \
while (idx < stop_idx) { \
const float val = exp(float(src[idx]) - _max); \
dst[idx] = T(val); \
shared_memory[tid] += val; \
idx += block_dim; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] += shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
const T inv_acc = T(1.0/shared_memory[0]); \
idx = start_idx + tid; \
while (idx < stop_idx) { \
dst[idx] *= inv_acc; \
idx += block_dim; \
} \
} \
REDUCE(x + y, fast_sum_f32_strided, float, 0)
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
@ -509,10 +286,6 @@ ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
RMSNORM(rmsnorm_f32, float)
RMSNORM(rmsnorm_f16, half)
ROPE(rope_f32, rope_i_f32, rope_thd_f32, float)
ROPE(rope_f16, rope_i_f16, rope_thd_f16, half)
#if __METAL_VERSION__ >= 220
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
@ -524,16 +297,10 @@ ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
#if defined(__HAVE_BFLOAT__)
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x + y, fast_sum_bf16_strided, half, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF)
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
RMSNORM(rmsnorm_bf16, bfloat)
ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat)
#endif

File diff suppressed because it is too large Load Diff

View File

@ -112,7 +112,6 @@ kernel void FN_NAME( \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
tid *= 4; \
if (tid >= d1 * d2) { \
return; \
} \
@ -121,9 +120,6 @@ kernel void FN_NAME( \
size_t src_idx = idx1 * src_s + idx2; \
size_t dst_idx = idx1 * dst_s + idx2; \
output[dst_idx] = input[src_idx]; \
output[dst_idx+1] = input[src_idx+1]; \
output[dst_idx+2] = input[src_idx+2]; \
output[dst_idx+3] = input[src_idx+3]; \
}
COPY2D(copy2d_f32, float)
@ -149,7 +145,6 @@ UNARY_OP(erf)
UNARY_OP(tanh)
UNARY_OP(recip)
UNARY_OP(relu)
UNARY_OP(sign)
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
@ -179,7 +174,6 @@ BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
BFLOAT_UNARY_OP(relu)
BFLOAT_UNARY_OP(sign)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)

View File

@ -25,8 +25,6 @@ candle-metal-kernels = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
clap = { workspace = true }
rand = { workspace = true }
criterion = { workspace = true }
[features]
default = []
@ -34,7 +32,3 @@ accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]
[[bench]]
name = "bench_main"
harness = false

View File

@ -1,4 +0,0 @@
mod benchmarks;
use criterion::criterion_main;
criterion_main!(benchmarks::layer_norm::benches, benchmarks::conv::benches);

View File

@ -1,54 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle::{DType, Device, Module, Tensor};
use candle_nn::{Conv2d, Conv2dConfig};
use criterion::{black_box, criterion_group, Criterion};
use std::time::Instant;
const B: usize = 1;
const C: usize = 1;
const M: usize = 128;
const K: usize = 128;
const K_SIZE: usize = 3;
fn run(input: Tensor, weight: Tensor, bias: Tensor, config: Conv2dConfig) {
Conv2d::new(weight, Some(bias), config)
.forward(&input)
.unwrap();
}
fn run_conv2d_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let weight = Tensor::ones((1, 1, K_SIZE, K_SIZE), dtype, device)
.unwrap()
.to_dtype(dtype)
.unwrap();
let bias = Tensor::zeros(K, dtype, device).unwrap();
let input = Tensor::ones((B, C, M, K), dtype, device).unwrap();
let mut group = c.benchmark_group(device.bench_name(name));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(
black_box(input.clone()),
black_box(weight.clone()),
black_box(bias.clone()),
Default::default(),
);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let device = BenchDeviceHandler::new().unwrap();
for d in device.devices {
run_conv2d_benchmark(c, &d, DType::F32, "conv2d_f32");
run_conv2d_benchmark(c, &d, DType::F16, "conv2d_f16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,48 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle::{DType, Device, Module, Tensor};
use candle_nn::LayerNorm;
use criterion::{black_box, criterion_group, Criterion};
use std::time::Instant;
fn run(input: &Tensor, weight: &Tensor, bias: &Tensor) {
let _ = LayerNorm::new(weight.clone(), bias.clone(), 1e-5).forward(&input);
}
const B: usize = 1;
const M: usize = 1024;
const K: usize = 1024;
fn run_layer_norm_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let elements = B * M * K;
let weight = Tensor::arange(0.0, elements as f32, device)
.unwrap()
.to_dtype(dtype)
.unwrap();
let bias = weight.ones_like().unwrap();
let input = weight.ones_like().unwrap();
let mut group = c.benchmark_group(device.bench_name(name));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&input), black_box(&weight), black_box(&bias));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let device = BenchDeviceHandler::new().unwrap();
for d in device.devices {
run_layer_norm_benchmark(c, &d, DType::F32, "layer_norm_f32");
run_layer_norm_benchmark(c, &d, DType::BF16, "layer_norm_bf16");
run_layer_norm_benchmark(c, &d, DType::F16, "layer_norm_f16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,64 +0,0 @@
pub(crate) mod conv;
pub(crate) mod layer_norm;
use candle::{Device, Result};
pub(crate) trait BenchDevice {
fn sync(&self) -> Result<()>;
fn bench_name<S: Into<String>>(&self, name: S) -> String;
}
impl BenchDevice for Device {
fn sync(&self) -> Result<()> {
match self {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
return Ok(device.synchronize()?);
#[cfg(not(feature = "cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
}
Device::Metal(device) => {
#[cfg(feature = "metal")]
return Ok(device.wait_until_completed()?);
#[cfg(not(feature = "metal"))]
panic!("Metal device without metal feature enabled: {:?}", device)
}
}
}
fn bench_name<S: Into<String>>(&self, name: S) -> String {
match self {
Device::Cpu => {
let cpu_type = if cfg!(feature = "accelerate") {
"accelerate"
} else if cfg!(feature = "mkl") {
"mkl"
} else {
"cpu"
};
format!("{}_{}", cpu_type, name.into())
}
Device::Cuda(_) => format!("cuda_{}", name.into()),
Device::Metal(_) => format!("metal_{}", name.into()),
}
}
}
struct BenchDeviceHandler {
devices: Vec<Device>,
}
impl BenchDeviceHandler {
pub fn new() -> Result<Self> {
let mut devices = Vec::new();
if cfg!(feature = "metal") {
devices.push(Device::new_metal(0)?);
} else if cfg!(feature = "cuda") {
devices.push(Device::new_cuda(0)?);
}
devices.push(Device::Cpu);
Ok(Self { devices })
}
}

View File

@ -12,7 +12,6 @@ pub mod loss;
pub mod ops;
pub mod optim;
pub mod rnn;
pub mod rotary_emb;
pub mod sequential;
pub mod var_builder;
pub mod var_map;

View File

@ -1,4 +1,4 @@
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
use rayon::prelude::*;
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
@ -180,10 +180,11 @@ impl candle::CustomOp1 for SoftmaxLastDim {
block_dim: (1, 32, 1),
shared_mem_bytes: 0,
};
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &dst, n_cols as i32);
let params = (src, &dst, n_cols as i32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
@ -206,7 +207,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
use candle::{backend::BackendStorage, DType};
let device = storage.device();
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
@ -236,7 +237,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
layout.start_offset() * storage.dtype().size_in_bytes(),
&output,
)
.map_err(candle::Error::wrap)?;
.unwrap();
let newstorage =
candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
Ok((newstorage, layout.shape().clone()))
@ -247,215 +248,6 @@ pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
xs.apply_op1_no_bwd(&SoftmaxLastDim)
}
#[derive(Debug, Clone)]
struct RmsNorm {
eps: f32,
}
impl candle::CustomOp2 for RmsNorm {
fn name(&self) -> &'static str {
"rms-norm"
}
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
) -> Result<(CpuStorage, Shape)> {
use candle::backend::BackendStorage;
let eps = self.eps;
fn inner<
T: candle::WithDType
+ num_traits::Float
+ num_traits::AsPrimitive<f32>
+ num_traits::FromPrimitive,
>(
src: &[T],
layout: &Layout,
alpha: &[T],
alpha_layout: &Layout,
eps: f32,
) -> Result<(CpuStorage, Shape)> {
let src = match layout.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => &src[o1..o2],
};
let alpha = match alpha_layout.contiguous_offsets() {
None => candle::bail!("alpha has to be contiguous"),
Some((o1, o2)) => &alpha[o1..o2],
};
let el_count = layout.shape().elem_count();
let dims = layout.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let mut dst = vec![T::zero(); el_count];
src.par_chunks(dim_m1)
.zip(dst.par_chunks_mut(dim_m1))
.for_each(|(src, dst)| {
let sum2 = src
.iter()
.map(|&v| {
let v = v.as_();
v * v
})
.sum::<f32>();
let m = (sum2 / dim_m1 as f32 + eps).sqrt();
let m = T::from_f32(m).unwrap_or_else(T::nan);
for ((d, s), alpha) in dst.iter_mut().zip(src.iter()).zip(alpha) {
*d = *s / m * *alpha
}
});
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, Shape::from_dims(dims)))
}
use CpuStorage as C;
match (s1, s2) {
(C::BF16(s1), C::BF16(s2)) => inner::<half::bf16>(s1, l1, s2, l2, eps),
(C::F16(s1), C::F16(s2)) => inner::<half::f16>(s1, l1, s2, l2, eps),
(C::F32(s1), C::F32(s2)) => inner::<f32>(s1, l1, s2, l2, eps),
_ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
}
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
l1: &Layout,
s2: &candle::CudaStorage,
l2: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
use candle::{CudaDevice, WithDType};
struct S {
eps: f32,
}
impl Map2 for S {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
layout: &Layout,
alpha: &CudaSlice<T>,
alpha_layout: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let src = match layout.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let alpha = match alpha_layout.contiguous_offsets() {
None => candle::bail!("alpha has to be contiguous"),
Some((o1, o2)) => alpha.slice(o1..o2),
};
let el = layout.shape().elem_count();
let dims = layout.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
let cfg = LaunchConfig {
grid_dim: (n_rows as u32, 1, 1),
block_dim: (1024, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &dst, &alpha, n_cols as i32, self.eps);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
use candle::backend::BackendStorage;
let dev = s1.device();
let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, dev)?;
let dst = candle::cuda_backend::CudaStorage {
slice,
device: dev.clone(),
};
Ok((dst, l1.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
s1: &candle::MetalStorage,
l1: &Layout,
s2: &candle::MetalStorage,
l2: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
let device = s1.device();
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
let name = match (s1.dtype(), s2.dtype()) {
(DType::F32, DType::F32) => "rmsnorm_f32",
(DType::F16, DType::F16) => "rmsnorm_f16",
(DType::BF16, DType::BF16) => "rmsnorm_bf16",
(dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"),
};
if !(l1.is_contiguous() && l2.is_contiguous()) {
candle::bail!("Non contiguous rmsnorm is not implemented");
}
let last_dim = l1.dims()[l1.shape().rank() - 1];
let elem_count = l1.shape().elem_count();
let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?;
candle_metal_kernels::call_rms_norm(
device.metal_device(),
&command_buffer,
kernels,
name,
elem_count,
last_dim,
self.eps,
s1.buffer(),
l1.start_offset() * s1.dtype().size_in_bytes(),
s2.buffer(),
l2.start_offset() * s2.dtype().size_in_bytes(),
&output,
)
.map_err(candle::Error::wrap)?;
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
Ok((newstorage, l1.shape().clone()))
}
}
pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(candle::D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
}
pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
let hidden_size_xs = xs.dim(candle::D::Minus1)?;
let hidden_size_alpha = alpha.dims1()?;
if hidden_size_xs != hidden_size_alpha {
candle::bail!(
"shape mismatch in rms-norm {:?} {:?}",
xs.shape(),
alpha.shape()
)
}
xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
}
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
let (b_size, c, h, w) = xs.dims4()?;

View File

@ -31,7 +31,7 @@ pub trait RNN {
let (_b_size, seq_len, _features) = input.dims3()?;
let mut output = Vec::with_capacity(seq_len);
for seq_index in 0..seq_len {
let input = input.i((.., seq_index, ..))?.contiguous()?;
let input = input.i((.., seq_index, ..))?;
let state = if seq_index == 0 {
self.step(&input, init_state)?
} else {

View File

@ -1,730 +0,0 @@
use candle::{CpuStorage, Layout, Result, Shape, Tensor, D};
use rayon::prelude::*;
/// Interleaved variant of rotary embeddings.
/// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
/// The resulting y0 and y1 are also interleaved with:
/// y0 = x0*cos - x1*sin
/// y1 = x0*sin + x1*cos
#[derive(Debug, Clone)]
struct RotaryEmbI;
impl candle::CustomOp3 for RotaryEmbI {
fn name(&self) -> &'static str {
"rotary-emb-int"
}
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)> {
fn inner<T: candle::WithDType + num_traits::Float>(
src: &[T],
l_src: &Layout,
cos: &[T],
l_cos: &Layout,
sin: &[T],
l_sin: &Layout,
) -> Result<(CpuStorage, Shape)> {
let src = match l_src.contiguous_offsets() {
None => candle::bail!("input src has to be contiguous"),
Some((o1, o2)) => &src[o1..o2],
};
let cos = match l_cos.contiguous_offsets() {
None => candle::bail!("input cos has to be contiguous"),
Some((o1, o2)) => &cos[o1..o2],
};
let sin = match l_sin.contiguous_offsets() {
None => candle::bail!("input sin has to be contiguous"),
Some((o1, o2)) => &sin[o1..o2],
};
let (b, h, t, d) = l_src.shape().dims4()?;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * d)
.zip(dst.par_chunks_mut(t * d))
.for_each(|(src, dst)| {
for i_over_2 in 0..t * d / 2 {
let i = 2 * i_over_2;
dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2];
dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2];
}
});
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, (b, h, t, d).into()))
}
use candle::backend::BackendStorage;
use CpuStorage::{BF16, F16, F32, F64};
match (s1, s2, s3) {
(BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
(F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
(F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
(F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
_ => candle::bail!(
"unsupported dtype for rope {:?} {:?} {:?}",
s1.dtype(),
s2.dtype(),
s3.dtype()
),
}
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
l1: &Layout,
s2: &candle::CudaStorage,
l2: &Layout,
s3: &candle::CudaStorage,
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};
fn inner<T: DeviceRepr + WithDType>(
src: &CudaSlice<T>,
l_src: &Layout,
cos: &CudaSlice<T>,
l_cos: &Layout,
sin: &CudaSlice<T>,
l_sin: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let src = match l_src.contiguous_offsets() {
None => candle::bail!("src input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let cos = match l_cos.contiguous_offsets() {
None => candle::bail!("cos input has to be contiguous"),
Some((o1, o2)) => cos.slice(o1..o2),
};
let sin = match l_sin.contiguous_offsets() {
None => candle::bail!("sin input has to be contiguous"),
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
use candle::backend::BackendStorage;
use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
let dev = s1.device();
let slice = match (&s1.slice, &s2.slice, &s3.slice) {
(BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
(F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
(F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
(F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
_ => candle::bail!(
"unsupported dtype for rope {:?} {:?} {:?}",
s1.dtype(),
s2.dtype(),
s3.dtype()
),
};
let dst = candle::cuda_backend::CudaStorage {
slice,
device: dev.clone(),
};
Ok((dst, l1.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
src: &candle::MetalStorage,
l_src: &Layout,
cos: &candle::MetalStorage,
l_cos: &Layout,
sin: &candle::MetalStorage,
l_sin: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
let device = src.device();
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
candle::bail!(
"dtype mismatch in rope-i {:?} {:?} {:?}",
src.dtype(),
cos.dtype(),
sin.dtype()
)
}
let name = match src.dtype() {
candle::DType::F32 => "rope_i_f32",
candle::DType::F16 => "rope_i_f16",
candle::DType::BF16 => "rope_i_bf16",
dtype => candle::bail!("rope-i is not implemented for {dtype:?}"),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
candle_metal_kernels::call_rope_i(
device.metal_device(),
&command_buffer,
kernels,
name,
b * h,
t * d,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
l_cos.start_offset() * cos.dtype().size_in_bytes(),
sin.buffer(),
l_sin.start_offset() * sin.dtype().size_in_bytes(),
&output,
)
.map_err(candle::Error::wrap)?;
let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
Ok((out, l_src.shape().clone()))
}
}
pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = cos.dims2()?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len
|| seq_len > sin_seq_len
{
candle::bail!(
"inconsistent last dim size in rope {:?} {:?} {:?}",
xs.shape(),
cos.shape(),
sin.shape()
)
}
if !xs.is_contiguous() {
candle::bail!("xs has to be contiguous in rope")
}
if !cos.is_contiguous() {
candle::bail!("cos has to be contiguous in rope")
}
if !sin.is_contiguous() {
candle::bail!("sin has to be contiguous in rope")
}
xs.apply_op3_no_bwd(cos, sin, &RotaryEmbI)
}
pub fn rope_i_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
let cos = cos
.narrow(0, 0, seq_len)?
.reshape((seq_len, n_embd / 2, 1))?;
let sin = sin
.narrow(0, 0, seq_len)?
.reshape((seq_len, n_embd / 2, 1))?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
let rope = rope.flatten_from(D::Minus2)?;
Ok(rope)
}
/// Contiguous variant of rope embeddings.
#[derive(Debug, Clone)]
struct RotaryEmb;
impl candle::CustomOp3 for RotaryEmb {
fn name(&self) -> &'static str {
"rotary-emb"
}
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)> {
fn inner<T: candle::WithDType + num_traits::Float>(
src: &[T],
l_src: &Layout,
cos: &[T],
l_cos: &Layout,
sin: &[T],
l_sin: &Layout,
) -> Result<(CpuStorage, Shape)> {
let src = match l_src.contiguous_offsets() {
None => candle::bail!("input src has to be contiguous"),
Some((o1, o2)) => &src[o1..o2],
};
let cos = match l_cos.contiguous_offsets() {
None => candle::bail!("input cos has to be contiguous"),
Some((o1, o2)) => &cos[o1..o2],
};
let sin = match l_sin.contiguous_offsets() {
None => candle::bail!("input sin has to be contiguous"),
Some((o1, o2)) => &sin[o1..o2],
};
let (b, h, t, d) = l_src.shape().dims4()?;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * d)
.zip(dst.par_chunks_mut(t * d))
.for_each(|(src, dst)| {
for i_t in 0..t {
for i_d in 0..d / 2 {
let i1 = i_t * d + i_d;
let i2 = i1 + d / 2;
let i_cs = i_t * (d / 2) + i_d;
dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
}
}
});
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, (b, h, t, d).into()))
}
use candle::backend::BackendStorage;
use CpuStorage::{BF16, F16, F32, F64};
match (s1, s2, s3) {
(BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
(F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
(F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
(F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
_ => candle::bail!(
"unsupported dtype for rope {:?} {:?} {:?}",
s1.dtype(),
s2.dtype(),
s3.dtype()
),
}
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
l1: &Layout,
s2: &candle::CudaStorage,
l2: &Layout,
s3: &candle::CudaStorage,
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};
fn inner<T: DeviceRepr + WithDType>(
src: &CudaSlice<T>,
l_src: &Layout,
cos: &CudaSlice<T>,
l_cos: &Layout,
sin: &CudaSlice<T>,
l_sin: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let src = match l_src.contiguous_offsets() {
None => candle::bail!("src input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let cos = match l_cos.contiguous_offsets() {
None => candle::bail!("cos input has to be contiguous"),
Some((o1, o2)) => cos.slice(o1..o2),
};
let sin = match l_sin.contiguous_offsets() {
None => candle::bail!("sin input has to be contiguous"),
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src,
&cos,
&sin,
&dst,
(b * h) as u32,
(t * d) as u32,
d as u32,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
use candle::backend::BackendStorage;
use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
let dev = s1.device();
let slice = match (&s1.slice, &s2.slice, &s3.slice) {
(BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
(F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
(F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
(F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
_ => candle::bail!(
"unsupported dtype for rope {:?} {:?} {:?}",
s1.dtype(),
s2.dtype(),
s3.dtype()
),
};
let dst = candle::cuda_backend::CudaStorage {
slice,
device: dev.clone(),
};
Ok((dst, l1.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
src: &candle::MetalStorage,
l_src: &Layout,
cos: &candle::MetalStorage,
l_cos: &Layout,
sin: &candle::MetalStorage,
l_sin: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
let device = src.device();
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
candle::bail!(
"dtype mismatch in rope {:?} {:?} {:?}",
src.dtype(),
cos.dtype(),
sin.dtype()
)
}
let name = match src.dtype() {
candle::DType::F32 => "rope_f32",
candle::DType::F16 => "rope_f16",
candle::DType::BF16 => "rope_bf16",
dtype => candle::bail!("rope is not implemented for {dtype:?}"),
};
let (b, h, t, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-i")?;
candle_metal_kernels::call_rope(
device.metal_device(),
&command_buffer,
kernels,
name,
b * h,
t * d,
d,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
l_cos.start_offset() * cos.dtype().size_in_bytes(),
sin.buffer(),
l_sin.start_offset() * sin.dtype().size_in_bytes(),
&output,
)
.map_err(candle::Error::wrap)?;
let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
Ok((out, l_src.shape().clone()))
}
}
pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len
|| seq_len > sin_seq_len
{
candle::bail!(
"inconsistent last dim size in rope {:?} {:?} {:?}",
xs.shape(),
cos.shape(),
sin.shape()
)
}
if !xs.is_contiguous() {
candle::bail!("xs has to be contiguous in rope")
}
if !cos.is_contiguous() {
candle::bail!("cos has to be contiguous in rope")
}
if !sin.is_contiguous() {
candle::bail!("sin has to be contiguous in rope")
}
xs.apply_op3_no_bwd(cos, sin, &RotaryEmb)
}
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}
pub fn rope_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, _h, seq_len, _n_embd) = x.dims4()?;
let cos = Tensor::cat(&[cos, cos], D::Minus1)?;
let sin = Tensor::cat(&[sin, sin], D::Minus1)?;
let cos = cos.narrow(0, 0, seq_len)?;
let sin = sin.narrow(0, 0, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?;
let sin = sin.unsqueeze(0)?.unsqueeze(0)?;
x.broadcast_mul(&cos)? + rotate_half(x)?.broadcast_mul(&sin)?
}
/// T (seqlen)/H (num-heads)/D (head-dim) contiguous variant of rope embeddings.
#[derive(Debug, Clone)]
struct RotaryEmbThd;
impl candle::CustomOp3 for RotaryEmbThd {
fn name(&self) -> &'static str {
"rotary-emb"
}
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)> {
fn inner<T: candle::WithDType + num_traits::Float>(
src: &[T],
l_src: &Layout,
cos: &[T],
l_cos: &Layout,
sin: &[T],
l_sin: &Layout,
) -> Result<(CpuStorage, Shape)> {
let src = match l_src.contiguous_offsets() {
None => candle::bail!("input src has to be contiguous"),
Some((o1, o2)) => &src[o1..o2],
};
let cos = match l_cos.contiguous_offsets() {
None => candle::bail!("input cos has to be contiguous"),
Some((o1, o2)) => &cos[o1..o2],
};
let sin = match l_sin.contiguous_offsets() {
None => candle::bail!("input sin has to be contiguous"),
Some((o1, o2)) => &sin[o1..o2],
};
let (b, t, h, d) = l_src.shape().dims4()?;
let el_count = b * h * t * d;
let mut dst = vec![T::zero(); el_count];
src.par_chunks(t * h * d)
.zip(dst.par_chunks_mut(t * h * d))
.for_each(|(src, dst)| {
for i_t in 0..t {
for i_d in 0..d / 2 {
let i_cs = i_t * (d / 2) + i_d;
for i_h in 0..h {
let i1 = i_t * h * d + i_h * d + i_d;
let i2 = i1 + d / 2;
dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
}
}
}
});
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, (b, t, h, d).into()))
}
use candle::backend::BackendStorage;
use CpuStorage::{BF16, F16, F32, F64};
match (s1, s2, s3) {
(BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
(F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
(F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
(F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
_ => candle::bail!(
"unsupported dtype for rope {:?} {:?} {:?}",
s1.dtype(),
s2.dtype(),
s3.dtype()
),
}
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
l1: &Layout,
s2: &candle::CudaStorage,
l2: &Layout,
s3: &candle::CudaStorage,
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};
fn inner<T: DeviceRepr + WithDType>(
src: &CudaSlice<T>,
l_src: &Layout,
cos: &CudaSlice<T>,
l_cos: &Layout,
sin: &CudaSlice<T>,
l_sin: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let src = match l_src.contiguous_offsets() {
None => candle::bail!("src input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let cos = match l_cos.contiguous_offsets() {
None => candle::bail!("cos input has to be contiguous"),
Some((o1, o2)) => cos.slice(o1..o2),
};
let sin = match l_sin.contiguous_offsets() {
None => candle::bail!("sin input has to be contiguous"),
Some((o1, o2)) => sin.slice(o1..o2),
};
let (b, t, h, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
use candle::backend::BackendStorage;
use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
let dev = s1.device();
let slice = match (&s1.slice, &s2.slice, &s3.slice) {
(BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
(F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
(F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
(F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
_ => candle::bail!(
"unsupported dtype for rope {:?} {:?} {:?}",
s1.dtype(),
s2.dtype(),
s3.dtype()
),
};
let dst = candle::cuda_backend::CudaStorage {
slice,
device: dev.clone(),
};
Ok((dst, l1.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
src: &candle::MetalStorage,
l_src: &Layout,
cos: &candle::MetalStorage,
l_cos: &Layout,
sin: &candle::MetalStorage,
l_sin: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
let device = src.device();
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
candle::bail!(
"dtype mismatch in rope {:?} {:?} {:?}",
src.dtype(),
cos.dtype(),
sin.dtype()
)
}
let name = match src.dtype() {
candle::DType::F32 => "rope_thd_f32",
candle::DType::F16 => "rope_thd_f16",
candle::DType::BF16 => "rope_thd_bf16",
dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"),
};
let (b, t, h, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let output = device.new_buffer(el, src.dtype(), "rope-thd")?;
candle_metal_kernels::call_rope_thd(
device.metal_device(),
&command_buffer,
kernels,
name,
b,
t,
h,
d,
src.buffer(),
l_src.start_offset() * src.dtype().size_in_bytes(),
cos.buffer(),
l_cos.start_offset() * cos.dtype().size_in_bytes(),
sin.buffer(),
l_sin.start_offset() * sin.dtype().size_in_bytes(),
&output,
)
.map_err(candle::Error::wrap)?;
let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
Ok((out, l_src.shape().clone()))
}
}
pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
let (sin_seq_len, sin_n_embd) = sin.dims2()?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len
|| seq_len > sin_seq_len
{
candle::bail!(
"inconsistent last dim size in rope {:?} {:?} {:?}",
xs.shape(),
cos.shape(),
sin.shape()
)
}
if !xs.is_contiguous() {
candle::bail!("xs has to be contiguous in rope")
}
if !cos.is_contiguous() {
candle::bail!("cos has to be contiguous in rope")
}
if !sin.is_contiguous() {
candle::bail!("sin has to be contiguous in rope")
}
xs.apply_op3_no_bwd(cos, sin, &RotaryEmbThd)
}

View File

@ -178,27 +178,16 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
name: &str,
hints: B::Hints,
) -> Result<Tensor> {
self.get_with_hints_dtype(s, name, hints, self.data.dtype)
let path = self.path(name);
self.data
.backend
.get(s.into(), &path, hints, self.data.dtype, &self.data.device)
}
/// Retrieve the tensor associated with the given name at the current path.
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor> {
self.get_with_hints(s, name, Default::default())
}
/// Retrieve the tensor associated with the given name & dtype at the current path.
pub fn get_with_hints_dtype<S: Into<Shape>>(
&self,
s: S,
name: &str,
hints: B::Hints,
dtype: DType,
) -> Result<Tensor> {
let path = self.path(name);
self.data
.backend
.get(s.into(), &path, hints, dtype, &self.data.device)
}
}
struct Zeros;

View File

@ -4,9 +4,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
use candle::{test_utils::to_vec3_round, Device, Result, Tensor};
fn softmax(device: &Device) -> Result<()> {
#[test]
fn softmax() -> Result<()> {
let device = &Device::Cpu;
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
let t0 = candle_nn::ops::softmax(&tensor.log()?, 0)?;
@ -52,31 +54,6 @@ fn softmax(device: &Device) -> Result<()> {
Ok(())
}
fn rms_norm(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?;
let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?;
assert_eq!(
to_vec3_round(&t, 4)?,
&[
[[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]],
[[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]]
]
);
let t2 = candle_nn::ops::rms_norm_slow(&tensor, &alpha, 1e-5)?;
assert_eq!(
to_vec3_round(&t2, 4)?,
&[
[[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]],
[[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]]
]
);
let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert!(diff < 1e-5);
Ok(())
}
#[test]
fn softmax_numerical_stability() -> Result<()> {
let dev = &Device::Cpu;
@ -85,93 +62,3 @@ fn softmax_numerical_stability() -> Result<()> {
assert_eq!(softmax.to_vec1::<f32>()?, &[1f32, 0.]);
Ok(())
}
fn ropei(device: &Device) -> Result<()> {
use rand::{rngs::StdRng, Rng, SeedableRng};
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
let el_count = b_size * num_head * seq_len * head_dim;
let mut rng = StdRng::seed_from_u64(299792458);
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.gen::<f32>())
.collect();
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.gen::<f32>())
.collect();
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?;
let rope1 = candle_nn::rotary_emb::rope_i(&src, &cos, &sin)?;
let rope2 = candle_nn::rotary_emb::rope_i_slow(&src, &cos, &sin)?;
let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if device.is_cpu() {
assert_eq!(sum_diff, 0.);
} else {
assert!(sum_diff < 1e-4);
}
Ok(())
}
fn rope(device: &Device) -> Result<()> {
use rand::{rngs::StdRng, Rng, SeedableRng};
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
let el_count = b_size * num_head * seq_len * head_dim;
let mut rng = StdRng::seed_from_u64(299792458);
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.gen::<f32>())
.collect();
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.gen::<f32>())
.collect();
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?;
let rope1 = candle_nn::rotary_emb::rope(&src, &cos, &sin)?;
let rope2 = candle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?;
let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if device.is_cpu() {
assert_eq!(sum_diff, 0.);
} else {
assert!(sum_diff < 1e-4);
}
Ok(())
}
fn rope_thd(device: &Device) -> Result<()> {
use rand::{rngs::StdRng, Rng, SeedableRng};
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
let el_count = b_size * num_head * seq_len * head_dim;
let mut rng = StdRng::seed_from_u64(299792458);
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.gen::<f32>())
.collect();
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
.map(|_| rng.gen::<f32>())
.collect();
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?;
let rope1 = {
let src = src.transpose(1, 2)?.contiguous()?;
candle_nn::rotary_emb::rope_thd(&src, &cos, &sin)?.transpose(1, 2)?
};
let rope2 = candle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?;
let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if device.is_cpu() {
assert_eq!(sum_diff, 0.);
} else {
assert!(sum_diff < 1e-4);
}
Ok(())
}
test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.5.0"
version = "0.4.2"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.5.0" }
candle-nn = { path = "../candle-nn", version = "0.5.0" }
candle = { path = "../candle-core", package = "candle-core", version = "0.4.2" }
candle-nn = { path = "../candle-nn", version = "0.4.2" }
prost = "0.12.1"
[build-dependencies]

View File

@ -776,11 +776,6 @@ pub fn simple_eval(
let output = input.reshape(new_shape)?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#identity
"Identity" => {
let input = get(&node.input[0])?;
values.insert(node.output[0].clone(), input.clone());
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}

View File

@ -20,10 +20,10 @@ candle-nn = { workspace = true }
candle-onnx = { workspace = true, optional = true }
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.21.0", features = ["extension-module", "abi3-py38"] }
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
[build-dependencies]
pyo3-build-config = "0.21"
pyo3-build-config = "0.20"
[features]
default = []

View File

@ -324,12 +324,6 @@ class Tensor:
"""
pass
def gather(self, index, dim):
"""
Gathers values along an axis specified by dim.
"""
pass
def get(self, index: int) -> Tensor:
"""
Gets the value at the specified index.

View File

@ -1,19 +0,0 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
from candle import Tensor, DType, QTensor
@staticmethod
def silu(tensor: Tensor) -> Tensor:
"""
Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
"""
pass
@staticmethod
def softmax(tensor: Tensor, dim: int) -> Tensor:
"""
Applies the Softmax function to a given tensor.#
"""
pass

View File

@ -60,8 +60,8 @@ impl PyDType {
impl PyDType {
fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult<Self> {
use std::str::FromStr;
if let Ok(dtype) = ob.extract::<String>(py) {
let dtype = DType::from_str(&dtype)
if let Ok(dtype) = ob.extract::<&str>(py) {
let dtype = DType::from_str(dtype)
.map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
Ok(Self(dtype))
} else {
@ -116,8 +116,8 @@ impl PyDevice {
impl<'source> FromPyObject<'source> for PyDevice {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let device: String = ob.extract()?;
let device = match device.as_str() {
let device: &str = ob.extract()?;
let device = match device {
"cpu" => PyDevice::Cpu,
"cuda" => PyDevice::Cuda,
_ => Err(PyTypeError::new_err(format!("invalid device '{device}'")))?,
@ -265,7 +265,7 @@ impl PyTensor {
} else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
return PyTensor::new(py, numpy);
} else {
let ty = data.bind(py).get_type();
let ty = data.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(
"incorrect type {ty} for tensor"
)))?
@ -322,7 +322,7 @@ impl PyTensor {
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
let candle_values = self.values(py)?;
let torch_tensor: PyObject = py
.import_bound("torch")?
.import("torch")?
.getattr("tensor")?
.call1((candle_values,))?
.extract()?;
@ -333,7 +333,7 @@ impl PyTensor {
/// Gets the tensor's shape.
/// &RETURNS&: Tuple[int]
fn shape(&self, py: Python<'_>) -> PyObject {
PyTuple::new_bound(py, self.0.dims()).to_object(py)
PyTuple::new(py, self.0.dims()).to_object(py)
}
#[getter]
@ -347,7 +347,7 @@ impl PyTensor {
/// Gets the tensor's strides.
/// &RETURNS&: Tuple[int]
fn stride(&self, py: Python<'_>) -> PyObject {
PyTuple::new_bound(py, self.0.stride()).to_object(py)
PyTuple::new(py, self.0.stride()).to_object(py)
}
#[getter]
@ -448,12 +448,6 @@ impl PyTensor {
Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?))
}
/// Gathers values along an axis specified by dim.
fn gather(&self, index: &Self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.gather(index, dim).map_err(wrap_err)?))
}
#[pyo3(text_signature = "(self, rhs:Tensor)")]
/// Performs a matrix multiplication between the two tensors.
/// &RETURNS&: Tensor
@ -527,7 +521,7 @@ impl PyTensor {
}
fn extract_indexer(
py_indexer: &Bound<PyAny>,
py_indexer: &PyAny,
current_dim: usize,
dims: &[usize],
index_argument_count: usize,
@ -567,7 +561,7 @@ impl PyTensor {
),
current_dim + 1,
))
} else if py_indexer.is(&py_indexer.py().Ellipsis()) {
} else if py_indexer.is_ellipsis() {
// Handle '...' e.g. tensor[..., 0]
if current_dim > 0 {
return Err(PyTypeError::new_err(
@ -586,7 +580,7 @@ impl PyTensor {
}
}
if let Ok(tuple) = idx.downcast_bound::<pyo3::types::PyTuple>(py) {
if let Ok(tuple) = idx.downcast::<pyo3::types::PyTuple>(py) {
let not_none_count: usize = tuple.iter().filter(|x| !x.is_none()).count();
if not_none_count > dims.len() {
@ -596,12 +590,12 @@ impl PyTensor {
let mut current_dim = 0;
for item in tuple.iter() {
let (indexer, new_current_dim) =
extract_indexer(&item, current_dim, dims, not_none_count)?;
extract_indexer(item, current_dim, dims, not_none_count)?;
current_dim = new_current_dim;
indexers.push(indexer);
}
} else {
let (indexer, _) = extract_indexer(idx.downcast_bound::<PyAny>(py)?, 0, dims, 1)?;
let (indexer, _) = extract_indexer(idx.downcast::<PyAny>(py)?, 0, dims, 1)?;
indexers.push(indexer);
}
@ -652,7 +646,7 @@ impl PyTensor {
/// Add two tensors.
/// &RETURNS&: Tensor
fn __add__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
self.0.broadcast_add(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
@ -663,13 +657,13 @@ impl PyTensor {
Ok(Self(tensor))
}
fn __radd__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
fn __radd__(&self, rhs: &PyAny) -> PyResult<Self> {
self.__add__(rhs)
}
/// Multiply two tensors.
/// &RETURNS&: Tensor
fn __mul__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
self.0.broadcast_mul(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
@ -680,13 +674,13 @@ impl PyTensor {
Ok(Self(tensor))
}
fn __rmul__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
fn __rmul__(&self, rhs: &PyAny) -> PyResult<Self> {
self.__mul__(rhs)
}
/// Subtract two tensors.
/// &RETURNS&: Tensor
fn __sub__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
fn __sub__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
self.0.broadcast_sub(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
@ -699,7 +693,7 @@ impl PyTensor {
/// Divide two tensors.
/// &RETURNS&: Tensor
fn __truediv__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
self.0.broadcast_div(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
@ -711,7 +705,7 @@ impl PyTensor {
}
/// Rich-compare two tensors.
/// &RETURNS&: Tensor
fn __richcmp__(&self, rhs: &Bound<PyAny>, op: CompareOp) -> PyResult<Self> {
fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult<Self> {
let compare = |lhs: &Tensor, rhs: &Tensor| {
let t = match op {
CompareOp::Eq => lhs.eq(rhs),
@ -957,7 +951,7 @@ impl PyTensor {
#[pyo3(signature = (*args, **kwargs), text_signature = "(self, *args, **kwargs)")]
/// Performs Tensor dtype and/or device conversion.
/// &RETURNS&: Tensor
fn to(&self, args: &Bound<PyTuple>, kwargs: Option<&Bound<PyDict>>) -> PyResult<Self> {
fn to(&self, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult<Self> {
let mut device: Option<PyDevice> = None;
let mut dtype: Option<PyDType> = None;
let mut other: Option<PyTensor> = None;
@ -1227,7 +1221,7 @@ impl PyQTensor {
///Gets the shape of the tensor.
/// &RETURNS&: Tuple[int]
fn shape(&self, py: Python<'_>) -> PyObject {
PyTuple::new_bound(py, self.0.shape().dims()).to_object(py)
PyTuple::new(py, self.0.shape().dims()).to_object(py)
}
fn __repr__(&self) -> String {
@ -1265,7 +1259,7 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
.into_iter()
.map(|(key, value)| (key, PyTensor(value).into_py(py)))
.collect::<Vec<_>>();
Ok(res.into_py_dict_bound(py).to_object(py))
Ok(res.into_py_dict(py).to_object(py))
}
#[pyfunction]
@ -1303,7 +1297,7 @@ fn load_ggml(
.map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))))
.collect::<::candle::Result<Vec<_>>>()
.map_err(wrap_err)?;
let tensors = tensors.into_py_dict_bound(py).to_object(py);
let tensors = tensors.into_py_dict(py).to_object(py);
let hparams = [
("n_vocab", ggml.hparams.n_vocab),
("n_embd", ggml.hparams.n_embd),
@ -1313,7 +1307,7 @@ fn load_ggml(
("n_rot", ggml.hparams.n_rot),
("ftype", ggml.hparams.ftype),
];
let hparams = hparams.into_py_dict_bound(py).to_object(py);
let hparams = hparams.into_py_dict(py).to_object(py);
let vocab = ggml
.vocab
.token_score_pairs
@ -1351,7 +1345,7 @@ fn load_gguf(
gguf_file::Value::Bool(x) => x.into_py(py),
gguf_file::Value::String(x) => x.into_py(py),
gguf_file::Value::Array(x) => {
let list = pyo3::types::PyList::empty_bound(py);
let list = pyo3::types::PyList::empty(py);
for elem in x.iter() {
list.append(gguf_value_to_pyobject(elem, py)?)?;
}
@ -1371,13 +1365,13 @@ fn load_gguf(
})
.collect::<::candle::Result<Vec<_>>>()
.map_err(wrap_err)?;
let tensors = tensors.into_py_dict_bound(py).to_object(py);
let tensors = tensors.into_py_dict(py).to_object(py);
let metadata = gguf
.metadata
.iter()
.map(|(key, value)| Ok((key, gguf_value_to_pyobject(value, py)?)))
.collect::<PyResult<Vec<_>>>()?
.into_py_dict_bound(py)
.into_py_dict(py)
.to_object(py);
Ok((tensors, metadata))
}
@ -1390,7 +1384,7 @@ fn load_gguf(
fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> {
use ::candle::quantized::gguf_file;
fn pyobject_to_gguf_value(v: &Bound<PyAny>, py: Python<'_>) -> PyResult<gguf_file::Value> {
fn pyobject_to_gguf_value(v: &PyAny, py: Python<'_>) -> PyResult<gguf_file::Value> {
let v: gguf_file::Value = if let Ok(x) = v.extract::<u8>() {
gguf_file::Value::U8(x)
} else if let Ok(x) = v.extract::<i8>() {
@ -1418,7 +1412,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
} else if let Ok(x) = v.extract::<Vec<PyObject>>() {
let x = x
.into_iter()
.map(|f| pyobject_to_gguf_value(f.bind(py), py))
.map(|f| pyobject_to_gguf_value(f.as_ref(py), py))
.collect::<PyResult<Vec<_>>>()?;
gguf_file::Value::Array(x)
} else {
@ -1450,7 +1444,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
Ok((
key.extract::<String>()
.map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?,
pyobject_to_gguf_value(&value.as_borrowed(), py)?,
pyobject_to_gguf_value(value, py)?,
))
})
.collect::<PyResult<Vec<_>>>()?;
@ -1498,7 +1492,7 @@ fn get_num_threads() -> usize {
::candle::utils::get_num_threads()
}
fn candle_utils(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
@ -1579,7 +1573,7 @@ fn tanh(tensor: PyTensor) -> PyResult<PyTensor> {
Ok(PyTensor(s))
}
fn candle_functional_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(silu, m)?)?;
m.add_function(wrap_pyfunction!(softmax, m)?)?;
m.add_function(wrap_pyfunction!(max_pool2d, m)?)?;
@ -1591,7 +1585,7 @@ fn candle_functional_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()>
}
#[cfg(feature = "onnx")]
fn candle_onnx_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
use onnx::{PyONNXModel, PyONNXTensorDescriptor};
m.add_class::<PyONNXModel>()?;
m.add_class::<PyONNXTensorDescriptor>()?;
@ -1599,18 +1593,18 @@ fn candle_onnx_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
}
#[pymodule]
fn candle(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
let utils = PyModule::new_bound(py, "utils")?;
candle_utils(py, &utils)?;
m.add_submodule(&utils)?;
let nn = PyModule::new_bound(py, "functional")?;
candle_functional_m(py, &nn)?;
m.add_submodule(&nn)?;
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let utils = PyModule::new(py, "utils")?;
candle_utils(py, utils)?;
m.add_submodule(utils)?;
let nn = PyModule::new(py, "functional")?;
candle_functional_m(py, nn)?;
m.add_submodule(nn)?;
#[cfg(feature = "onnx")]
{
let onnx = PyModule::new_bound(py, "onnx")?;
candle_onnx_m(py, &onnx)?;
m.add_submodule(&onnx)?;
let onnx = PyModule::new(py, "onnx")?;
candle_onnx_m(py, onnx)?;
m.add_submodule(onnx)?;
}
m.add_class::<PyTensor>()?;
m.add_class::<PyQTensor>()?;

View File

@ -39,7 +39,7 @@ impl PyONNXTensorDescriptor {
/// The shape of the tensor.
/// &RETURNS&: Tuple[Union[int,str,Any]]
fn shape(&self, py: Python) -> PyResult<Py<PyTuple>> {
let shape = PyList::empty_bound(py);
let shape = PyList::empty(py);
if let Some(d) = &self.0.shape {
for dim in d.dim.iter() {
if let Some(value) = &dim.value {

View File

@ -206,8 +206,6 @@ def write(module, directory, origin, check=False):
if check:
with open(filename, "r") as f:
data = f.read()
print("generated content")
print(pyi_content)
assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
else:
with open(filename, "w") as f:
@ -231,8 +229,6 @@ def write(module, directory, origin, check=False):
if check:
with open(filename, "r") as f:
data = f.read()
print("generated content")
print(py_content)
assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
else:
with open(filename, "w") as f:

View File

@ -1,36 +1,24 @@
use candle::{DType, Error, Result, Tensor};
use rand::{distributions::Distribution, SeedableRng};
#[derive(Clone, PartialEq, Debug)]
pub enum Sampling {
ArgMax,
All { temperature: f64 },
TopK { k: usize, temperature: f64 },
TopP { p: f64, temperature: f64 },
TopKThenTopP { k: usize, p: f64, temperature: f64 },
}
pub struct LogitsProcessor {
rng: rand::rngs::StdRng,
sampling: Sampling,
temperature: Option<f64>,
top_p: Option<f64>,
}
impl LogitsProcessor {
pub fn from_sampling(seed: u64, sampling: Sampling) -> Self {
let rng = rand::rngs::StdRng::seed_from_u64(seed);
Self { rng, sampling }
}
pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
let temperature = temperature.and_then(|v| if v < 1e-7 { None } else { Some(v) });
let sampling = match temperature {
None => Sampling::ArgMax,
Some(temperature) => match top_p {
None => Sampling::All { temperature },
Some(p) => Sampling::TopP { p, temperature },
},
let temperature = if temperature.map_or(true, |v| v < 1e-7) {
None
} else {
temperature
};
Self::from_sampling(seed, sampling)
Self {
rng: rand::rngs::StdRng::seed_from_u64(seed),
temperature,
top_p,
}
}
fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> {
@ -50,14 +38,14 @@ impl LogitsProcessor {
Ok(next_token)
}
/// top-p sampling (or "nucleus sampling") samples from the smallest set of tokens that exceed
/// probability top_p. This way we never sample tokens that have very low probabilities and are
/// less likely to go "off the rails".
fn sample_topp(&mut self, prs: &mut Vec<f32>, top_p: f32) -> Result<u32> {
// top-p sampling (or "nucleus sampling") samples from the smallest set of
// tokens that exceed probability top_p. This way we never sample tokens that
// have very low probabilities and are less likely to go "off the rails".
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
// Sort by descending probability.
argsort_indices.sort_by(|&i, &j| prs[j].total_cmp(&prs[i]));
argsort_indices.sort_by(|&i, &j| prs[j].partial_cmp(&prs[i]).unwrap());
// Clamp smaller probabilities to zero.
let mut cumsum = 0.;
@ -72,78 +60,23 @@ impl LogitsProcessor {
self.sample_multinomial(prs)
}
// top-k sampling samples from the k tokens with the largest probabilities.
fn sample_topk(&mut self, prs: &mut Vec<f32>, top_k: usize) -> Result<u32> {
if top_k >= prs.len() {
self.sample_multinomial(prs)
} else {
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
let (indices, _, _) =
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
let index = self.sample_multinomial(&prs)?;
Ok(indices[index as usize] as u32)
}
}
// top-k sampling samples from the k tokens with the largest probabilities.
// then top-p sampling.
fn sample_topk_topp(&mut self, prs: &mut Vec<f32>, top_k: usize, top_p: f32) -> Result<u32> {
if top_k >= prs.len() {
self.sample_topp(prs, top_p)
} else {
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
let (indices, _, _) =
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
let mut prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
let sum_p = prs.iter().sum::<f32>();
let index = if top_p <= 0.0 || top_p >= sum_p {
self.sample_multinomial(&prs)?
} else {
self.sample_topp(&mut prs, top_p)?
};
Ok(indices[index as usize] as u32)
}
}
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
self.sample_f(logits, |_| {})
}
pub fn sample_f(&mut self, logits: &Tensor, f: impl FnOnce(&mut [f32])) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
let prs = |temperature: f64| -> Result<Vec<f32>> {
let logits = (&logits / temperature)?;
let prs = candle_nn::ops::softmax_last_dim(&logits)?;
let mut prs = prs.to_vec1()?;
f(&mut prs);
Ok(prs)
};
let next_token = match &self.sampling {
Sampling::ArgMax => self.sample_argmax(logits)?,
Sampling::All { temperature } => {
let prs = prs(*temperature)?;
self.sample_multinomial(&prs)?
}
Sampling::TopP { p, temperature } => {
let mut prs = prs(*temperature)?;
if *p <= 0.0 || *p >= 1.0 {
let next_token = match self.temperature {
None => self.sample_argmax(logits)?,
Some(temperature) => {
let logits = &(&logits / temperature)?;
let prs = candle_nn::ops::softmax_last_dim(logits)?;
let mut prs: Vec<f32> = prs.to_vec1()?;
let top_p = self.top_p.unwrap_or(1.);
if top_p <= 0.0 || top_p >= 1.0 {
// simply sample from the predicted probability distribution
self.sample_multinomial(&prs)?
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
self.sample_topp(&mut prs, *p as f32)?
self.sample_topp(&mut prs, top_p as f32)?
}
}
Sampling::TopK { k, temperature } => {
let mut prs = prs(*temperature)?;
self.sample_topk(&mut prs, *k)?
}
Sampling::TopKThenTopP { k, p, temperature } => {
let mut prs = prs(*temperature)?;
self.sample_topk_topp(&mut prs, *k, *p as f32)?
}
};
Ok(next_token)
}

View File

@ -1,154 +0,0 @@
//! Contrastive Language-Image Pre-Training
//!
//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
//! pairs of images with related texts.
//!
//! https://github.com/openai/CLIP
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
use self::{
text_model::{Activation, ClipTextTransformer},
vision_model::ClipVisionTransformer,
};
use candle::{Result, Tensor, D};
pub mod text_model;
pub mod vision_model;
#[derive(Clone, Debug)]
pub struct ClipModel {
text_model: ClipTextTransformer,
vision_model: ClipVisionTransformer,
visual_projection: candle_nn::Linear,
text_projection: candle_nn::Linear,
logit_scale: Tensor,
}
#[derive(Clone, Debug)]
pub enum EncoderConfig {
Text(text_model::ClipTextConfig),
Vision(vision_model::ClipVisionConfig),
}
impl EncoderConfig {
pub fn embed_dim(&self) -> usize {
match self {
Self::Text(c) => c.embed_dim,
Self::Vision(c) => c.embed_dim,
}
}
pub fn num_attention_heads(&self) -> usize {
match self {
Self::Text(c) => c.num_attention_heads,
Self::Vision(c) => c.num_attention_heads,
}
}
pub fn intermediate_size(&self) -> usize {
match self {
Self::Text(c) => c.intermediate_size,
Self::Vision(c) => c.intermediate_size,
}
}
pub fn num_hidden_layers(&self) -> usize {
match self {
Self::Text(c) => c.num_hidden_layers,
Self::Vision(c) => c.num_hidden_layers,
}
}
pub fn activation(&self) -> Activation {
match self {
Self::Text(_c) => Activation::QuickGelu,
Self::Vision(c) => c.activation,
}
}
}
#[derive(Clone, Debug)]
pub struct ClipConfig {
pub text_config: text_model::ClipTextConfig,
pub vision_config: vision_model::ClipVisionConfig,
pub logit_scale_init_value: f32,
pub image_size: usize,
}
impl ClipConfig {
// base image size is 224, model size is 600Mb
pub fn vit_base_patch32() -> Self {
let text_config = text_model::ClipTextConfig::vit_base_patch32();
let vision_config = vision_model::ClipVisionConfig::vit_base_patch32();
Self {
text_config,
vision_config,
logit_scale_init_value: 2.6592,
image_size: 224,
}
}
}
impl ClipModel {
pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result<Self> {
let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?;
let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?;
let visual_projection = candle_nn::linear_no_bias(
c.vision_config.embed_dim,
c.vision_config.projection_dim,
vs.pp("visual_projection"),
)?;
let text_projection = candle_nn::linear_no_bias(
c.text_config.embed_dim,
c.text_config.projection_dim,
vs.pp("text_projection"),
)?;
// originally nn.Parameter
let logit_scale = if vs.contains_tensor("logit_scale") {
vs.get(&[], "logit_scale")?
} else {
Tensor::new(&[c.logit_scale_init_value], vs.device())?
};
Ok(Self {
text_model,
vision_model,
visual_projection,
text_projection,
logit_scale,
})
}
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
input_ids
.apply(&self.text_model)?
.apply(&self.text_projection)
}
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
pixel_values
.apply(&self.vision_model)?
.apply(&self.visual_projection)
}
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
let image_features = self.get_image_features(pixel_values)?;
let text_features = self.get_text_features(input_ids)?;
let image_features_normalized = div_l2_norm(&image_features)?;
let text_features_normalized = div_l2_norm(&text_features)?;
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
let logit_scale = self.logit_scale.exp()?;
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
let logits_per_image = logits_per_text.t()?;
Ok((logits_per_text, logits_per_image))
}
}
pub fn div_l2_norm(v: &Tensor) -> Result<Tensor> {
let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
v.broadcast_div(&l2_norm)
}

View File

@ -1,333 +0,0 @@
//! Contrastive Language-Image Pre-Training
//!
//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
//! pairs of images with related texts.
//!
//! https://github.com/openai/CLIP
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
use super::EncoderConfig;
#[derive(Debug, Clone, Copy)]
pub enum Activation {
QuickGelu,
}
impl Module for Activation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
}
}
}
#[derive(Debug, Clone)]
pub struct ClipTextConfig {
pub vocab_size: usize,
pub embed_dim: usize,
pub activation: Activation,
pub intermediate_size: usize,
pub max_position_embeddings: usize,
pub pad_with: Option<String>,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
#[allow(dead_code)]
pub projection_dim: usize,
}
impl ClipTextConfig {
// The config details can be found in the "text_config" section of this json file:
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
pub fn vit_base_patch32() -> Self {
Self {
vocab_size: 49408,
embed_dim: 512,
intermediate_size: 2048,
max_position_embeddings: 77,
pad_with: None,
num_hidden_layers: 12,
num_attention_heads: 8,
projection_dim: 512,
activation: Activation::QuickGelu,
}
}
}
// ClipTextEmbeddings mostly based on the existing implementation in the stable diffision model.
// TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142
#[derive(Clone, Debug)]
struct ClipTextEmbeddings {
token_embedding: candle_nn::Embedding,
position_embedding: candle_nn::Embedding,
position_ids: Tensor,
}
impl ClipTextEmbeddings {
fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
let token_embedding =
candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
let position_embedding: nn::Embedding = candle_nn::embedding(
c.max_position_embeddings,
c.embed_dim,
vs.pp("position_embedding"),
)?;
let position_ids =
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
Ok(ClipTextEmbeddings {
token_embedding,
position_embedding,
position_ids,
})
}
}
impl Module for ClipTextEmbeddings {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let seq_length = input_ids.dim(D::Minus1)?;
let inputs_embeds = self.token_embedding.forward(input_ids)?;
let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
let position_embedding = self.position_embedding.forward(&position_ids)?;
inputs_embeds.broadcast_add(&position_embedding)
}
}
#[derive(Clone, Debug)]
struct ClipAttention {
k_proj: candle_nn::Linear,
v_proj: candle_nn::Linear,
q_proj: candle_nn::Linear,
out_proj: candle_nn::Linear,
head_dim: usize,
scale: f64,
num_attention_heads: usize,
}
impl ClipAttention {
fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
let embed_dim = c.embed_dim();
let num_attention_heads = c.num_attention_heads();
let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
let head_dim = embed_dim / num_attention_heads;
let scale = (head_dim as f64).powf(-0.5);
Ok(ClipAttention {
k_proj,
v_proj,
q_proj,
out_proj,
head_dim,
scale,
num_attention_heads,
})
}
fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()
}
fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
let in_dtype = xs.dtype();
let (bsz, seq_len, embed_dim) = xs.dims3()?;
let query_states = (self.q_proj.forward(xs)? * self.scale)?;
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
let query_states = self
.shape(&query_states, seq_len, bsz)?
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let key_states = self
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let value_states = self
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let src_len = key_states.dim(1)?;
let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
attn_weights
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
.broadcast_add(causal_attention_mask)?
.reshape((bsz * self.num_attention_heads, seq_len, src_len))?
} else {
attn_weights
};
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
let attn_output = attn_output
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
.transpose(1, 2)?
.reshape((bsz, seq_len, embed_dim))?;
self.out_proj.forward(&attn_output)
}
}
#[derive(Clone, Debug)]
struct ClipMlp {
fc1: candle_nn::Linear,
fc2: candle_nn::Linear,
activation: Activation,
}
impl ClipMlp {
fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
let fc1 = candle_nn::linear(c.embed_dim(), c.intermediate_size(), vs.pp("fc1"))?;
let fc2 = candle_nn::linear(c.intermediate_size(), c.embed_dim(), vs.pp("fc2"))?;
Ok(ClipMlp {
fc1,
fc2,
activation: c.activation(),
})
}
}
impl ClipMlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?;
self.fc2.forward(&self.activation.forward(&xs)?)
}
}
#[derive(Clone, Debug)]
struct ClipEncoderLayer {
self_attn: ClipAttention,
layer_norm1: candle_nn::LayerNorm,
mlp: ClipMlp,
layer_norm2: candle_nn::LayerNorm,
}
impl ClipEncoderLayer {
fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
let layer_norm1 = candle_nn::layer_norm(c.embed_dim(), 1e-5, vs.pp("layer_norm1"))?;
let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
let layer_norm2 = candle_nn::layer_norm(c.embed_dim(), 1e-5, vs.pp("layer_norm2"))?;
Ok(ClipEncoderLayer {
self_attn,
layer_norm1,
mlp,
layer_norm2,
})
}
fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
let residual = xs;
let xs = self.layer_norm1.forward(xs)?;
let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self.layer_norm2.forward(&xs)?;
let xs = self.mlp.forward(&xs)?;
xs + residual
}
}
#[derive(Clone, Debug)]
pub struct ClipEncoder {
layers: Vec<ClipEncoderLayer>,
}
impl ClipEncoder {
pub fn new(vs: candle_nn::VarBuilder, c: &EncoderConfig) -> Result<Self> {
let vs = vs.pp("layers");
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
for index in 0..c.num_hidden_layers() {
let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
layers.push(layer)
}
Ok(ClipEncoder { layers })
}
pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, causal_attention_mask)?;
}
Ok(xs)
}
}
/// A CLIP transformer based model.
#[derive(Clone, Debug)]
pub struct ClipTextTransformer {
embeddings: ClipTextEmbeddings,
encoder: ClipEncoder,
final_layer_norm: candle_nn::LayerNorm,
}
impl ClipTextTransformer {
pub fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?;
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
Ok(ClipTextTransformer {
embeddings,
encoder,
final_layer_norm,
})
}
// TODO: rewrrite to newer version
fn build_causal_attention_mask(
bsz: usize,
seq_len: usize,
mask_after: usize,
device: &Device,
) -> Result<Tensor> {
let mask: Vec<_> = (0..seq_len)
.flat_map(|i| {
(0..seq_len).map(move |j| {
if j > i || j > mask_after {
f32::MIN
} else {
0.
}
})
})
.collect();
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
mask.broadcast_as((bsz, 1, seq_len, seq_len))
}
pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
let (bsz, seq_len) = input_ids.dims2()?;
let input_ids = self.embeddings.forward(input_ids)?;
let causal_attention_mask =
Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
let input_ids = self
.encoder
.forward(&input_ids, Some(&causal_attention_mask))?;
self.final_layer_norm.forward(&input_ids)
}
}
impl Module for ClipTextTransformer {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let output = self.forward_with_mask(input_ids, usize::MAX)?;
let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
let mut indices = Vec::new();
for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
indices.push(index);
}
Tensor::cat(&indices, 0)
}
}

Some files were not shown because too many files have changed in this diff Show More