Compare commits

..

3 Commits

Author SHA1 Message Date
ec6d7ca773 Cudarc static-linking enabled. 2025-03-29 09:27:53 +01:00
2c0f6b008e Fixing order. 2025-03-28 11:43:33 +01:00
9862cd3ba2 Splitting the features to enable different mkl linking. 2025-03-28 10:13:13 +01:00
117 changed files with 2432 additions and 5070 deletions

40
.github/workflows/book-cd.yml vendored Normal file
View File

@ -0,0 +1,40 @@
name: Deploy Rust book
on:
push:
branches:
- main
jobs:
deploy:
runs-on: ubuntu-latest
permissions:
contents: write # To push a branch
pull-requests: write # To create a PR from that branch
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Install latest mdbook
run: |
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
mkdir mdbook
curl -sSL $url | tar -xz --directory=./mdbook
echo `pwd`/mdbook >> $GITHUB_PATH
- name: Deploy GitHub Pages
run: |
# This assumes your book is in the root of your repository.
# Just add a `cd` here if you need to change to another directory.
cd candle-book
mdbook build
git worktree add gh-pages
git config user.name "Deploy from CI"
git config user.email ""
cd gh-pages
# Delete the ref to avoid keeping history.
git update-ref -d refs/heads/gh-pages
rm -rf *
mv ../book/* .
git add .
git commit -m "Deploy $GITHUB_SHA to gh-pages"
git push --force --set-upstream origin gh-pages

29
.github/workflows/book.yml vendored Normal file
View File

@ -0,0 +1,29 @@
name: CI
on:
pull_request:
jobs:
test:
name: Test candle-book
runs-on: ubuntu-latest
permissions:
contents: write # To push a branch
pull-requests: write # To create a PR from that branch
steps:
- uses: actions/checkout@master
- name: Install Rust
run: |
rustup set profile minimal
rustup toolchain install stable
rustup default stable
- name: Install latest mdbook
run: |
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
mkdir bin
curl -sSL $url | tar -xz --directory=bin
echo "$(pwd)/bin" >> $GITHUB_PATH
- name: Run tests
run: cd candle-book && cargo build && mdbook test -L ../target/debug/deps/

View File

@ -3,6 +3,7 @@ members = [
"candle-core",
"candle-datasets",
"candle-examples",
"candle-book",
"candle-nn",
"candle-pyo3",
"candle-transformers",
@ -11,7 +12,6 @@ members = [
"tensor-tools",
]
exclude = [
"candle-book",
"candle-flash-attn",
"candle-kernels",
"candle-metal-kernels",
@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.9.0-alpha.2"
version = "0.8.4"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.2" }
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.2" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.2" }
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.2" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.2" }
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.2" }
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.2" }
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.2" }
candle = { path = "./candle-core", package = "candle-core", version = "0.8.4" }
candle-datasets = { path = "./candle-datasets", version = "0.8.4" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.4" }
candle-kernels = { path = "./candle-kernels", version = "0.8.4" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.4" }
candle-nn = { path = "./candle-nn", version = "0.8.4" }
candle-onnx = { path = "./candle-onnx", version = "0.8.4" }
candle-transformers = { path = "./candle-transformers", version = "0.8.4" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.15.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
cudarc = { version = "0.13.5", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.4.1"
@ -51,7 +51,7 @@ half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_di
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
intel-mkl-src = { version = "0.8.1" }
libc = { version = "0.2.147" }
log = "0.4"
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
ug = "0.3.1"
ug-cuda = "0.3.1"
ug-metal = "0.3.1"
ug = "0.1.0"
ug-cuda = "0.1.0"
ug-metal = "0.1.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}

View File

@ -15,7 +15,7 @@ byteorder = { workspace = true }
candle-kernels = { workspace = true, optional = true }
candle-metal-kernels = { workspace = true, optional = true }
metal = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true}
gemm = { workspace = true }
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
@ -43,9 +43,11 @@ criterion = { workspace = true }
[features]
default = []
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
_cuda = ["dep:cudarc", "dep:candle-kernels", "dep:ug-cuda"]
# cuda = ["_cuda", "cudarc?/cuda-version-from-build-system", "cudarc?/dynamic-linking"]
cudnn = ["_cuda", "cudarc?/cudnn"]
_mkl = ["dep:libc", "dep:intel-mkl-src"]
mkl = ["_mkl", "intel-mkl-src?/mkl-static-lp64-iomp"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"]
@ -56,7 +58,3 @@ harness = false
[[example]]
name = "metal_basics"
required-features = ["metal"]
[[example]]
name = "cuda_basics"
required-features = ["cuda"]

View File

@ -20,11 +20,9 @@ impl BenchDevice for Device {
match self {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
return Ok(device
.synchronize()
.map_err(|e| candle_core::Error::Cuda(Box::new(e)))?);
#[cfg(not(feature = "cuda"))]
#[cfg(feature = "_cuda")]
return Ok(device.synchronize()?);
#[cfg(not(feature = "_cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
}
Device::Metal(device) => {
@ -41,7 +39,7 @@ impl BenchDevice for Device {
Device::Cpu => {
let cpu_type = if cfg!(feature = "accelerate") {
"accelerate"
} else if cfg!(feature = "mkl") {
} else if cfg!(feature = "_mkl") {
"mkl"
} else {
"cpu"
@ -63,7 +61,7 @@ impl BenchDeviceHandler {
let mut devices = Vec::new();
if cfg!(feature = "metal") {
devices.push(Device::new_metal(0)?);
} else if cfg!(feature = "cuda") {
} else if cfg!(feature = "_cuda") {
devices.push(Device::new_cuda(0)?);
}
devices.push(Device::Cpu);

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,7 +1,7 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
use anyhow::Result;

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,7 +1,7 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
use anyhow::Result;

View File

@ -14,7 +14,6 @@ pub struct ParamsConv1D {
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
}
impl ParamsConv1D {
@ -175,7 +174,6 @@ impl Tensor {
padding,
stride,
dilation,
cudnn_fwd_algo: Some(CudnnFwdAlgo::ImplicitGemm),
};
if groups == 1 {
self.conv1d_single_group(kernel, &params)

View File

@ -1246,7 +1246,7 @@ impl MatMul {
impl Map2 for MatMul {
const OP: &'static str = "mat_mul";
#[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
#[cfg(all(not(feature = "_mkl"), not(feature = "accelerate")))]
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,
lhs: &[T],
@ -1289,15 +1289,6 @@ impl Map2 for MatMul {
} else {
Parallelism::None
};
let (b, m, n, k) = if b_skip == 0 && a_skip == m * k {
// a_skip and c_skip should be updated but step is always 0 so
// it wouldn't matter.
(1, b * m, n, k)
} else if a_skip == 0 && b_skip == n * k {
(1, m, b * n, k)
} else {
(b, m, n, k)
};
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
@ -1420,7 +1411,7 @@ impl Map2 for MatMul {
Ok(dst)
}
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,
lhs: &[T],

View File

@ -43,7 +43,7 @@ pub(crate) fn launch_conv2d<
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
return Ok(cudnn.clone());
}
let c = Cudnn::new(dev.cuda_stream());
let c = Cudnn::new(dev.cuda_device());
if let Ok(c) = &c {
cudnn.borrow_mut().insert(device_id, c.clone());
}
@ -109,7 +109,7 @@ pub(crate) fn launch_conv2d<
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
};
let workspace_size = conv2d.get_workspace_size(alg)?;
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
unsafe {
conv2d.launch::<CudaSlice<u8>, _, _, _>(
alg,
@ -122,104 +122,3 @@ pub(crate) fn launch_conv2d<
}
Ok(())
}
pub(crate) fn launch_conv1d<
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
Y: cudarc::cudnn::CudnnDataType,
>(
src: &CudaView<T>,
src_l: &crate::Layout,
filter: &CudaView<T>,
dst: &mut CudaSlice<T>,
params: &crate::conv::ParamsConv1D,
dev: &crate::cuda_backend::CudaDevice,
) -> crate::Result<()> {
use crate::conv::CudnnFwdAlgo as CandleAlgo;
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
let device_id = dev.id();
let cudnn = CUDNN.with(|cudnn| {
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
return Ok(cudnn.clone());
}
let c = Cudnn::new(dev.cuda_stream());
if let Ok(c) = &c {
cudnn.borrow_mut().insert(device_id, c.clone());
}
c
})?;
let conv = cudnn.create_conv2d::<Y>(
/* pad */ [params.padding as i32, 0],
/* stride */ [params.stride as i32, 1],
/* dilation */ [params.dilation as i32, 1],
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
)?;
// https://docs.nvidia.com/deeplearning/cudnn/backend/latest/api/cudnn-ops-library.html#cudnnsettensornddescriptor
// > Tensors are restricted to having at least 4 dimensions, and at most CUDNN_DIM_MAX
// > dimensions (defined in cudnn.h). When working with lower dimensional data, it is
// > recommended that the user create a 4D tensor, and set the size along unused dimensions
// > to 1.
let x_shape = [
params.b_size as i32,
params.c_in as i32,
params.l_in as i32,
1,
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
x_shape,
)?
} else {
let s = src_l.stride();
cudnn.create_4d_tensor_ex::<T>(x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, 1i32])?
};
let w = cudnn.create_4d_filter::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[
params.c_out as i32,
params.c_in as i32,
params.k_size as i32,
1,
],
)?;
let l_out = params.l_out() as i32;
let y = cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, l_out, 1],
)?;
let conv1d = ConvForward {
conv: &conv,
x: &x,
w: &w,
y: &y,
};
let alg = match params.cudnn_fwd_algo {
None => conv1d.pick_algorithm()?,
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
Some(CandleAlgo::ImplicitPrecompGemm) => {
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
}
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
};
let workspace_size = conv1d.get_workspace_size(alg)?;
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
unsafe {
conv1d.launch::<CudaSlice<u8>, _, _, _>(
alg,
Some(&mut workspace),
(T::one(), T::zero()),
src,
filter,
dst,
)?;
}
Ok(())
}

View File

@ -2,9 +2,8 @@ use crate::backend::BackendDevice;
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
use half::{bf16, f16};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
@ -25,17 +24,10 @@ impl DeviceId {
struct CudaRng(cudarc::curand::CudaRng);
unsafe impl Send for CudaRng {}
pub struct ModuleStore {
mdls: [Option<Arc<cudarc::driver::CudaModule>>; kernels::ALL_IDS.len()],
}
#[derive(Clone)]
pub struct CudaDevice {
id: DeviceId,
context: Arc<cudarc::driver::CudaContext>,
modules: Arc<std::sync::RwLock<ModuleStore>>,
custom_modules: Arc<std::sync::RwLock<HashMap<String, Arc<cudarc::driver::CudaModule>>>>,
stream: Arc<cudarc::driver::CudaStream>,
device: Arc<cudarc::driver::CudaDevice>,
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
curand: Arc<Mutex<CudaRng>>,
}
@ -46,102 +38,17 @@ impl std::fmt::Debug for CudaDevice {
}
}
impl CudaDevice {
#[allow(clippy::missing_safety_doc)]
pub unsafe fn alloc<T: cudarc::driver::DeviceRepr>(
&self,
len: usize,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.alloc::<T>(len).w()
}
pub fn alloc_zeros<T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits>(
&self,
len: usize,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.alloc_zeros::<T>(len).w()
}
pub fn memcpy_htod<
T: cudarc::driver::DeviceRepr,
Src: cudarc::driver::HostSlice<T> + ?Sized,
Dst: cudarc::driver::DevicePtrMut<T>,
>(
&self,
src: &Src,
dst: &mut Dst,
) -> Result<()> {
self.stream.memcpy_htod(src, dst).w()
}
pub fn memcpy_dtov<T: cudarc::driver::DeviceRepr, Src: cudarc::driver::DevicePtr<T>>(
&self,
src: &Src,
) -> Result<Vec<T>> {
self.stream.memcpy_dtov(src).w()
}
pub fn memcpy_dtod<
T,
Src: cudarc::driver::DevicePtr<T>,
Dst: cudarc::driver::DevicePtrMut<T>,
>(
&self,
src: &Src,
dst: &mut Dst,
) -> Result<()> {
self.stream.memcpy_dtod(src, dst).w()
}
pub fn memcpy_stod<
T: cudarc::driver::DeviceRepr,
Src: cudarc::driver::HostSlice<T> + ?Sized,
>(
&self,
src: &Src,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.memcpy_stod(src).w()
}
}
pub struct CudaFunc {
func: CudaFunction,
stream: Arc<cudarc::driver::CudaStream>,
}
impl std::ops::Deref for CudaFunc {
type Target = CudaFunction;
impl std::ops::Deref for CudaDevice {
type Target = Arc<cudarc::driver::CudaDevice>;
fn deref(&self) -> &Self::Target {
&self.func
}
}
impl CudaFunc {
pub fn into_cuda_function(self) -> CudaFunction {
self.func
}
}
#[macro_export]
macro_rules! builder_arg {
($b:ident, $($arg:expr),*) => {
$(
let __arg = $arg;
$b.arg(&__arg);
)*
};
}
impl CudaFunc {
pub fn builder(&self) -> cudarc::driver::LaunchArgs<'_> {
self.stream.launch_builder(&self.func)
&self.device
}
}
impl CudaDevice {
pub fn cuda_stream(&self) -> Arc<cudarc::driver::CudaStream> {
self.stream.clone()
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
self.device.clone()
}
#[cfg(not(target_arch = "wasm32"))]
@ -149,7 +56,7 @@ impl CudaDevice {
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<CudaFunc> {
) -> Result<CudaFunction> {
let mut buf = vec![];
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
let cuda_code = String::from_utf8(buf)?;
@ -158,12 +65,12 @@ impl CudaDevice {
..Default::default()
};
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
let module = self.context.load_module(ptx).w()?;
let func = module.load_function(func_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
let func = match self.device.get_func("ug", func_name) {
Some(func) => func,
None => crate::bail!("unknown function ug::{func_name}"),
};
Ok(func)
}
pub fn id(&self) -> DeviceId {
@ -176,85 +83,58 @@ impl CudaDevice {
let slice = match dtype {
DType::U8 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count)? };
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u8;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
let params = (&data, v as u8, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count)? };
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u32;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
let params = (&data, v as u32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count)? };
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as i64;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
let params = (&data, v as i64, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count)? };
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = bf16::from_f64(v);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
let params = (&data, bf16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count)? };
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = f16::from_f64(v);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
let params = (&data, f16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count)? };
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as f32;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
let params = (&data, v as f32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }?;
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
let params = (&data, v, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F64(data)
}
};
@ -264,69 +144,38 @@ impl CudaDevice {
})
}
pub fn get_or_load_custom_func(
&self,
fn_name: &str,
module_name: &str,
ptx: &str,
) -> Result<CudaFunc> {
let ms = self.custom_modules.read().unwrap();
if let Some(mdl) = ms.get(module_name).as_ref() {
let func = mdl.load_function(fn_name).w()?;
return Ok(CudaFunc {
func,
stream: self.stream.clone(),
});
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
if !self.has_func(module_name, module_name) {
// Leaking the string here is a bit sad but we need a &'static str and this is only
// done once per kernel name.
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
self.load_ptx(ptx.into(), module_name, &[static_module_name])
.map_err(|cuda| CudaError::Load {
cuda,
module_name: module_name.to_string(),
})
.w()?;
}
drop(ms);
let mut ms = self.custom_modules.write().unwrap();
let cuda_module = self.context.load_module(ptx.into()).w()?;
ms.insert(module_name.to_string(), cuda_module.clone());
let func = cuda_module.load_function(fn_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
pub fn get_or_load_func(&self, fn_name: &str, mdl: &kernels::Module) -> Result<CudaFunc> {
let ms = self.modules.read().unwrap();
if let Some(mdl) = ms.mdls[mdl.index()].as_ref() {
let func = mdl.load_function(fn_name).w()?;
return Ok(CudaFunc {
func,
stream: self.stream.clone(),
});
}
drop(ms);
let mut ms = self.modules.write().unwrap();
let cuda_module = self.context.load_module(mdl.ptx().into()).w()?;
ms.mdls[mdl.index()] = Some(cuda_module.clone());
let func = cuda_module.load_function(fn_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
self.get_func(module_name, module_name)
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
// able to only build the error value if needed.
.ok_or(CudaError::MissingKernel {
module_name: module_name.to_string(),
})
.w()
}
}
impl CudaDevice {
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
let stream = context.new_stream().w()?;
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
let module_store = ModuleStore {
mdls: [const { None }; kernels::ALL_IDS.len()],
};
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
Ok(Self {
id: DeviceId::new(),
context,
stream,
device,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
modules: Arc::new(std::sync::RwLock::new(module_store)),
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
})
}
}
@ -335,21 +184,14 @@ impl BackendDevice for CudaDevice {
type Storage = CudaStorage;
fn new(ordinal: usize) -> Result<Self> {
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
let stream = context.default_stream();
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
let module_store = ModuleStore {
mdls: [const { None }; kernels::ALL_IDS.len()],
};
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
Ok(Self {
id: DeviceId::new(),
context,
stream,
device,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
modules: Arc::new(std::sync::RwLock::new(module_store)),
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
})
}
@ -357,13 +199,13 @@ impl BackendDevice for CudaDevice {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?;
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
Ok(())
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.context.ordinal(),
gpu_id: self.device.ordinal(),
}
}
@ -375,31 +217,31 @@ impl BackendDevice for CudaDevice {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc_zeros::<u8>(elem_count)?;
let data = self.alloc_zeros::<u8>(elem_count).w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc_zeros::<u32>(elem_count)?;
let data = self.alloc_zeros::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc_zeros::<i64>(elem_count)?;
let data = self.alloc_zeros::<i64>(elem_count).w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count)?;
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc_zeros::<f16>(elem_count)?;
let data = self.alloc_zeros::<f16>(elem_count).w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc_zeros::<f32>(elem_count)?;
let data = self.alloc_zeros::<f32>(elem_count).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc_zeros::<f64>(elem_count)?;
let data = self.alloc_zeros::<f64>(elem_count).w()?;
CudaStorageSlice::F64(data)
}
};
@ -423,12 +265,12 @@ impl BackendDevice for CudaDevice {
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count)? };
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count)? };
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F64(data)
}
@ -467,7 +309,7 @@ impl BackendDevice for CudaDevice {
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count_round)? };
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
@ -475,7 +317,7 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count_round)? };
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
@ -494,31 +336,31 @@ impl BackendDevice for CudaDevice {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc::<u8>(elem_count)?;
let data = self.alloc::<u8>(elem_count).w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc::<u32>(elem_count)?;
let data = self.alloc::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc::<i64>(elem_count)?;
let data = self.alloc::<i64>(elem_count).w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc::<bf16>(elem_count)?;
let data = self.alloc::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc::<f16>(elem_count)?;
let data = self.alloc::<f16>(elem_count).w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc::<f32>(elem_count)?;
let data = self.alloc::<f32>(elem_count).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc::<f64>(elem_count)?;
let data = self.alloc::<f64>(elem_count).w()?;
CudaStorageSlice::F64(data)
}
};
@ -531,31 +373,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let slice = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorageRef::U32(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorageRef::I64(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorageRef::BF16(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorageRef::F16(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorageRef::F32(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorageRef::F64(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
@ -568,31 +410,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
@ -605,31 +447,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.memcpy_stod(&storage)?;
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.memcpy_stod(&storage)?;
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.memcpy_stod(&storage)?;
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.memcpy_stod(&storage)?;
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.memcpy_stod(&storage)?;
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.memcpy_stod(&storage)?;
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.memcpy_stod(&storage)?;
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
@ -640,7 +482,7 @@ impl BackendDevice for CudaDevice {
}
fn synchronize(&self) -> Result<()> {
self.stream.synchronize().map_err(crate::Error::wrap)?;
self.device.synchronize().map_err(crate::Error::wrap)?;
Ok(())
}
}

File diff suppressed because it is too large Load Diff

View File

@ -378,7 +378,7 @@ impl Tensor {
pub struct UgIOp1 {
name: &'static str,
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
func: cudarc::driver::CudaFunction,
#[cfg(feature = "metal")]
func: metal::ComputePipelineState,
@ -392,14 +392,11 @@ impl UgIOp1 {
kernel: ug::lang::ssa::Kernel,
device: &crate::Device,
) -> Result<Self> {
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
{
let device = device.as_cuda_device()?;
let func = device.compile(name, kernel)?;
Ok(Self {
name,
func: func.into_cuda_function(),
})
Ok(Self { name, func })
}
#[cfg(feature = "metal")]
{
@ -407,7 +404,7 @@ impl UgIOp1 {
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
#[cfg(not(any(feature = "cuda", feature = "metal")))]
#[cfg(not(any(feature = "_cuda", feature = "metal")))]
{
Ok(Self { name })
}
@ -459,19 +456,19 @@ impl InplaceOp1 for UgIOp1 {
Ok(())
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
use crate::cuda_backend::WrapErr;
use cudarc::driver::PushKernelArg;
use cudarc::driver::LaunchAsync;
let elem_count = layout.shape().elem_count();
let stream = sto.device.cuda_stream();
// TODO: support more dtypes.
let sto = sto.as_cuda_slice::<f32>()?;
let sto = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => sto.slice(o1..o2),
};
let params = (&sto,);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
@ -482,9 +479,7 @@ impl InplaceOp1 for UgIOp1 {
block_dim: (b as u32, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&self.func);
builder.arg(&sto);
unsafe { builder.launch(cfg) }.w()?;
unsafe { self.func.clone().launch(cfg, params) }.w()?;
Ok(())
}
}

View File

@ -55,7 +55,7 @@ pub mod conv;
mod convert;
pub mod cpu;
pub mod cpu_backend;
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
pub mod cuda_backend;
mod custom_op;
mod device;
@ -68,7 +68,7 @@ mod indexer;
pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
mod mkl;
pub mod npy;
pub mod op;
@ -104,10 +104,10 @@ pub use strided_index::{StridedBlocks, StridedIndex};
pub use tensor::{Tensor, TensorId};
pub use variable::Var;
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
pub use cuda_backend as cuda;
#[cfg(not(feature = "cuda"))]
#[cfg(not(feature = "_cuda"))]
pub use dummy_cuda_backend as cuda;
pub use cuda::{CudaDevice, CudaStorage};
@ -118,7 +118,7 @@ pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(not(feature = "metal"))]
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -294,16 +294,16 @@ macro_rules! bin_op {
$e(v1, v2)
}
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
crate::mkl::$f32_vec(xs1, xs2, ys)
}
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
crate::mkl::$f64_vec(xs1, xs2, ys)
@ -418,16 +418,16 @@ macro_rules! unary_op {
todo!("no unary function for i64")
}
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::$f32_vec(xs, ys)
}
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::$f64_vec(xs, ys)
@ -518,19 +518,19 @@ impl UnaryOpT for Gelu {
}
const KERNEL: &'static str = "ugelu";
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::vs_gelu(xs, ys)
}
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_gelu(xs, ys)
@ -625,19 +625,19 @@ impl UnaryOpT for Silu {
}
const KERNEL: &'static str = "usilu";
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::vs_silu(xs, ys)
}
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_silu(xs, ys)

View File

@ -816,7 +816,7 @@ impl PthTensors {
/// # Arguments
/// * `path` - Path to the pth file.
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
/// contains multiple objects and the state_dict is the one we are interested in.
/// contains multiple objects and the state_dict is the one we are interested in.
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
path: P,
key: Option<&str>,

View File

@ -1,10 +1,10 @@
use super::{GgmlDType, QStorage};
use crate::quantized::k_quants::GgmlType;
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{builder_arg as barg, CudaDevice, CudaStorage, Result};
use crate::{CudaDevice, CudaStorage, Result};
use half::f16;
use cudarc::driver::{CudaSlice, CudaView, PushKernelArg};
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
#[derive(Clone, Debug)]
struct PaddedCudaSlice {
@ -50,20 +50,19 @@ fn quantize_q8_1(
ky: usize,
dev: &CudaDevice,
) -> Result<()> {
use cudarc::driver::LaunchAsync;
let kx = elem_count;
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?;
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, ky as u32, 1),
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = func.builder();
builder.arg(src);
builder.arg(dst);
barg!(builder, kx as i32, kx_padded as i32);
unsafe { builder.launch(cfg) }.w()?;
let params = (src, dst, kx as i32, kx_padded as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(())
}
@ -73,7 +72,9 @@ fn dequantize_f32(
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
let nb = elem_count.div_ceil(256);
use cudarc::driver::LaunchAsync;
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
@ -98,8 +99,8 @@ fn dequantize_f32(
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(elem_count)? };
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
@ -109,20 +110,15 @@ fn dequantize_f32(
};
if is_k {
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
unsafe { builder.launch(cfg) }.w()?;
let params = (&data.inner, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
barg!(builder, nb32 as i32);
unsafe { builder.launch(cfg) }.w()?;
let params = (&data.inner, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
@ -133,7 +129,9 @@ fn dequantize_f16(
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
let nb = elem_count.div_ceil(256);
use cudarc::driver::LaunchAsync;
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
@ -158,8 +156,8 @@ fn dequantize_f16(
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f16>(elem_count)? };
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
@ -169,20 +167,15 @@ fn dequantize_f16(
};
if is_k {
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
unsafe { builder.launch(cfg) }.w()?;
let params = (&data.inner, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
barg!(builder, nb32 as i32);
unsafe { builder.launch(cfg) }.w()?;
let params = (&data.inner, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
@ -195,6 +188,8 @@ fn dequantize_mul_mat_vec(
nrows: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
@ -215,8 +210,8 @@ fn dequantize_mul_mat_vec(
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows)? };
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (block_num_y as u32, 1, 1),
@ -224,12 +219,8 @@ fn dequantize_mul_mat_vec(
shared_mem_bytes: 0,
};
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(y);
builder.arg(&dst);
barg!(builder, ncols as i32, nrows as i32);
unsafe { builder.launch(cfg) }.w()?;
let params = (&data.inner, y, &dst, ncols as i32, nrows as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
@ -242,6 +233,8 @@ fn mul_mat_vec_via_q8_1(
b_size: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
@ -256,7 +249,7 @@ fn mul_mat_vec_via_q8_1(
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
let y_size_in_bytes =
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
let kernel_name = match dtype {
@ -273,13 +266,13 @@ fn mul_mat_vec_via_q8_1(
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let kernel_name = format!("{kernel_name}{b_size}");
let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows * b_size)? };
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
let (nblocks, nwarps) = match b_size {
1 => (nrows as u32, 4),
2..=4 => ((nrows as u32).div_ceil(2), 4),
5..=8 => ((nrows as u32).div_ceil(2), 2),
2..=4 => ((nrows as u32 + 1) / 2, 4),
5..=8 => ((nrows as u32 + 1) / 2, 2),
_ => crate::bail!("unexpected bsize {b_size}"),
};
let cfg = cudarc::driver::LaunchConfig {
@ -288,18 +281,16 @@ fn mul_mat_vec_via_q8_1(
shared_mem_bytes: 0,
};
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&y_q8_1);
builder.arg(&dst);
barg!(
builder,
let params = (
&data.inner,
&y_q8_1,
&dst,
/* ncols_x */ ncols as i32,
/* nrows_x */ nrows as i32,
/* nrows_y */ ncols_padded as i32,
/* nrows_dst */ nrows as i32
/* nrows_dst */ nrows as i32,
);
unsafe { builder.launch(cfg) }.w()?;
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
@ -314,6 +305,8 @@ fn mul_mat_via_q8_1(
y_cols: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < x_rows * x_cols {
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
@ -329,7 +322,7 @@ fn mul_mat_via_q8_1(
let k_padded = pad(k, MATRIX_ROW_PADDING);
let y_size_in_bytes =
k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
let (kernel_name, mmq_x, mmq_y) = match dtype {
@ -345,8 +338,8 @@ fn mul_mat_via_q8_1(
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols)? };
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (
ceil_div(x_rows, mmq_y) as u32,
@ -357,19 +350,17 @@ fn mul_mat_via_q8_1(
shared_mem_bytes: 0,
};
let mut builder = func.builder();
builder.arg(/* vx */ &data.inner);
builder.arg(/* vy */ &y_q8_1);
builder.arg(/* dst */ &dst);
barg!(
builder,
let params = (
/* vx */ &data.inner,
/* vy */ &y_q8_1,
/* dst */ &dst,
/* ncols_x */ x_cols as i32,
/* nrows_x */ x_rows as i32,
/* ncols_y */ y_cols as i32,
/* nrows_y */ k_padded as i32,
/* nrows_dst */ x_rows as i32
/* nrows_dst */ x_rows as i32,
);
unsafe { builder.launch(cfg) }.w()?;
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
@ -378,7 +369,7 @@ impl QCudaStorage {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
let padded_size_in_bytes =
ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size();
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes)?;
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes).w()?;
Ok(QCudaStorage {
data: PaddedCudaSlice {
inner,
@ -425,7 +416,8 @@ impl QCudaStorage {
let buffer = self
.device
.memcpy_dtov(&self.data.inner.slice(..self.data.len))?;
.dtoh_sync_copy(&self.data.inner.slice(..self.data.len))
.w()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
@ -456,7 +448,9 @@ impl QCudaStorage {
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu.
let src = match &src.slice {
crate::cuda_backend::CudaStorageSlice::F32(data) => self.device.memcpy_dtov(data)?,
crate::cuda_backend::CudaStorageSlice::F32(data) => {
self.device.dtoh_sync_copy(data).w()?
}
_ => crate::bail!("only f32 can be quantized"),
};
let src_len = src.len();
@ -466,9 +460,10 @@ impl QCudaStorage {
let data = qcpu_storage.data()?;
let padded_len =
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
let mut inner = unsafe { self.device.alloc::<u8>(padded_len)? };
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
self.device
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?;
.htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len()))
.w()?;
self.data = PaddedCudaSlice {
inner,
len: data.len(),
@ -602,8 +597,10 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
};
let dtype = T::DTYPE;
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
let mut inner = unsafe { device.alloc::<u8>(padded_len)? };
device.memcpy_htod(data, &mut inner.slice_mut(..data.len()))?;
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
device
.htod_sync_copy_into(data, &mut inner.slice_mut(..data.len()))
.w()?;
Ok(QStorage::Cuda(QCudaStorage {
data: PaddedCudaSlice {
inner,
@ -625,9 +622,9 @@ mod test {
let el_padded = pad(el, MATRIX_ROW_PADDING);
let y_size_in_bytes =
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.memcpy_stod(&vs)?;
let y = dev.htod_sync_copy(&vs).w()?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
Ok(())
}
@ -637,7 +634,7 @@ mod test {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
let y = dev.memcpy_stod(&vs)?;
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_vec_via_q8_1(
@ -650,7 +647,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.memcpy_dtov(&vs.slice(..))?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
// Q8 means 1/256 precision.
@ -665,7 +662,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.memcpy_dtov(&vs.slice(..))?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
assert_eq!(vs[0], 5561851.0);
Ok(())
@ -676,7 +673,7 @@ mod test {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
let y = dev.memcpy_stod(&vs)?;
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
@ -690,7 +687,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.memcpy_dtov(&vs.slice(..))?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
/*
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
@ -717,7 +714,7 @@ mod test {
let dev = CudaDevice::new(0)?;
let (x_rows, ncols, y_cols) = (4, 16, 2048);
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
let y = dev.memcpy_stod(&vs)?;
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
@ -731,7 +728,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let _vs = dev.memcpy_dtov(&vs.slice(..))?;
let _vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
Ok(())
}
}

View File

@ -16,9 +16,9 @@ pub mod metal;
mod metal {
pub use super::dummy_metal::*;
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
pub mod cuda;
#[cfg(not(feature = "cuda"))]
#[cfg(not(feature = "_cuda"))]
mod cuda {
pub use super::dummy_cuda::*;
}

View File

@ -52,11 +52,11 @@ impl ArgSort {
}
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
mod cuda {
use super::*;
use crate::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits,
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
};
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
use crate::{CudaDevice, WithDType};
@ -69,33 +69,27 @@ mod cuda {
layout: &crate::Layout,
_wrap: W,
) -> Result<S> {
use cudarc::driver::PushKernelArg;
let slice = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<u32>(elem_count)? };
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
let func = if self.asc {
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
} else {
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), &kernels::SORT)?
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
};
let ncols = self.last_dim;
let nrows = elem_count / ncols;
let ncols_pad = next_power_of_2(ncols);
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
let cfg = LaunchConfig {
grid_dim: (1, nrows as u32, 1),
block_dim: (ncols_pad as u32, 1, 1),
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
};
let stream = dev.cuda_stream();
let mut builder = stream.launch_builder(&func);
let ncols = ncols as i32;
let ncols_pad = ncols_pad as i32;
builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad);
unsafe { builder.launch(cfg) }.w()?;
unsafe { func.launch(cfg, params) }.w()?;
Ok(S::U32(dst))
}
}
@ -124,7 +118,7 @@ impl crate::CustomOp1 for ArgSort {
Ok((sort_indexes, layout.shape().into()))
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
storage: &crate::CudaStorage,

View File

@ -2580,28 +2580,6 @@ impl Tensor {
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.broadcast_mul(&self.log()?)?.exp()
}
/// Returns a new tensor with the order of elements reversed along the specified dimensions.
/// This function makes a copy of the tensors data.
///
/// ```rust
/// # use candle_core::{Tensor, Device};
/// let t = Tensor::arange(0., 6., &Device::Cpu)?.reshape((2, 3))?;
/// assert_eq!(t.to_vec2::<f64>()?, &[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
/// let t_flipped = t.flip(&[0])?;
/// assert_eq!(t_flipped.to_vec2::<f64>()?, &[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
let mut result = self.clone();
for &dim in dims.iter() {
let size = result.dim(dim)?;
let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
result = result.index_select(&indices_tensor, dim)?;
}
Ok(result)
}
}
macro_rules! bin_trait {

View File

@ -10,7 +10,7 @@ macro_rules! test_device {
$fn_name(&Device::Cpu)
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
#[test]
fn $test_cuda() -> Result<()> {
$fn_name(&Device::new_cuda(0)?)
@ -24,15 +24,6 @@ macro_rules! test_device {
};
}
pub fn assert_tensor_eq(t1: &Tensor, t2: &Tensor) -> Result<()> {
assert_eq!(t1.shape(), t2.shape());
// Default U8 may not be large enough to hold the sum (`t.sum_all` defaults to the dtype of `t`)
let eq_tensor = t1.eq(t2)?.to_dtype(crate::DType::U32)?;
let all_equal = eq_tensor.sum_all()?;
assert_eq!(all_equal.to_scalar::<u32>()?, eq_tensor.elem_count() as u32);
Ok(())
}
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
let b = 10f32.powi(digits);
let t = t.to_vec0::<f32>()?;

View File

@ -17,11 +17,11 @@ pub fn has_accelerate() -> bool {
}
pub fn has_mkl() -> bool {
cfg!(feature = "mkl")
cfg!(feature = "_mkl")
}
pub fn cuda_is_available() -> bool {
cfg!(feature = "cuda")
cfg!(feature = "_cuda")
}
pub fn metal_is_available() -> bool {

View File

@ -53,20 +53,6 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
let res = {
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?
};
assert_eq!(res.dims(), [3, 2, 5]);
// Same as pytorch default padding: use zeros.
assert_eq!(
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
);
assert_eq!(
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
let w = w.transpose(0, 1)?;
// The CPU kernels applied in the contiguous and non contiguous cases are different.
@ -177,22 +163,6 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = {
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
t.conv2d(&w, 0, 1, 1, 1)?
};
assert_eq!(res.dims(), [3, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
);
assert_eq!(
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
[
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;

View File

@ -144,7 +144,7 @@ fn inplace_op1() -> Result<()> {
Ok(())
}
#[cfg(any(feature = "cuda", feature = "metal"))]
#[cfg(any(feature = "_cuda", feature = "metal"))]
#[allow(clippy::approx_constant)]
#[test]
fn ug_op() -> Result<()> {

View File

@ -1,6 +1,6 @@
#![allow(clippy::approx_constant)]
use anyhow::{Context, Result};
use candle_core::{test_device, test_utils, DType, Device, Shape, Tensor, Var};
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
fn simple_grad(device: &Device) -> Result<()> {
let x = Var::new(&[3f32, 1., 4.], device)?;
@ -505,36 +505,6 @@ fn binary_grad(device: &Device) -> Result<()> {
Ok(())
}
#[test]
fn test_flip_backprop() -> Result<()> {
let device = &Device::Cpu;
// Create a tensor (leaf node) that requires gradients
let x = Var::ones((2, 2), DType::F64, device)?;
let weights = Tensor::arange(1.0, 5.0, device)?.reshape((2, 2))?;
let y = x.matmul(&weights)?;
let expected_y = Tensor::from_vec(vec![4.0, 6.0, 4.0, 6.0], (2, 2), device)?;
candle_core::test_utils::assert_tensor_eq(&y, &expected_y)?;
let z = y.flip(&[1])?;
let expected_z = Tensor::from_vec(vec![6.0, 4.0, 6.0, 4.0], (2, 2), device)?;
candle_core::test_utils::assert_tensor_eq(&z, &expected_z)?;
let loss = z.sum_all()?;
let grad_store = loss.backward()?;
let grad_x = grad_store.get_id(x.id()).unwrap();
let flipped_weights = weights.flip(&[1])?;
let dloss_dy = Tensor::ones((2, 2), DType::F64, device)?;
// dloss/dx = dloss/dy @ dy/dx = ones @ weight.flip.T
let expected_grad = dloss_dy.matmul(&flipped_weights.t()?)?;
candle_core::test_utils::assert_tensor_eq(grad_x, &expected_grad)?;
Ok(())
}
test_device!(
simple_grad,
simple_grad_cpu,

View File

@ -1682,54 +1682,3 @@ fn pow() -> Result<()> {
);
Ok(())
}
#[test]
fn test_flip_1d() -> Result<()> {
// 1D: [0, 1, 2, 3, 4]
let t = Tensor::arange(0.0, 5.0, &Device::Cpu)?.reshape((5,))?;
let flipped = t.flip(&[0])?;
// Expected: [4, 3, 2, 1, 0]
let expected = Tensor::from_vec(vec![4.0, 3.0, 2.0, 1.0, 0.0], (5,), &Device::Cpu)?;
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
Ok(())
}
#[test]
fn test_flip_2d() -> Result<()> {
// 2D:
// [[0, 1, 2],
// [3, 4, 5]]
let t = Tensor::arange(0.0, 6.0, &Device::Cpu)?.reshape((2, 3))?;
let flipped = t.flip(&[0, 1])?;
// Expected:
// [[5, 4, 3],
// [2, 1, 0]]
let expected = Tensor::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0], (2, 3), &Device::Cpu)?;
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
Ok(())
}
#[test]
fn test_flip_3d_channels() -> Result<()> {
// 3D:
// [[[0,1,2],
// [3,4,5]],
//
// [[6,7,8],
// [9,10,11]]]
let t = Tensor::arange(0.0, 12.0, &Device::Cpu)?.reshape((2, 2, 3))?;
let flipped = t.flip(&[2])?;
// Expected:
// [[[2,1,0],
// [5,4,3]],
//
// [[8,7,6],
// [11,10,9]]]
let expected = Tensor::from_vec(
vec![2.0, 1.0, 0.0, 5.0, 4.0, 3.0, 8.0, 7.0, 6.0, 11.0, 10.0, 9.0],
(2, 2, 3),
&Device::Cpu,
)?;
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
Ok(())
}

View File

@ -72,8 +72,6 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
// image-rs crate convention is to load in (width, height, channels) order
// See: https://docs.rs/image/latest/image/trait.ImageDecoder.html#tymethod.dimensions
let image = image::load_from_memory(value.data()).unwrap();
buffer_images.extend(image.to_rgb8().as_raw());
}
@ -83,10 +81,8 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
}
}
}
// Reorder image-rs convention (width, height, channels) to candle/pytorch convolution convention (channels, height, width)
let images = (Tensor::from_vec(buffer_images, (samples, 32, 32, 3), &Device::Cpu)?
.to_dtype(DType::F32)?
.permute((0, 3, 2, 1))?
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
.to_dtype(DType::U8)?
/ 255.)?;
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
Ok((images, labels))

View File

@ -60,7 +60,7 @@ bindgen_cuda = { version = "0.1.1", optional = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
cudnn = ["candle/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
@ -69,7 +69,6 @@ metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal", "rubato"]
encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
snac = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
[[example]]
@ -108,10 +107,6 @@ required-features = ["candle-datasets"]
name = "mimi"
required-features = ["mimi"]
[[example]]
name = "snac"
required-features = ["snac"]
[[example]]
name = "encodec"
required-features = ["encodec"]

View File

@ -1,13 +0,0 @@
# candle-chatglm
Uses `THUDM/chatglm3-6b` to generate chinese text. Will not generate text for english (usually).
## Text Generation
```bash
cargo run --example chatglm --release -- --prompt "部署门槛较低等众多优秀特 "
> 部署门槛较低等众多优秀特 点使得其成为了一款备受欢迎的AI助手。
>
> 作为一款人工智能助手ChatGLM3-6B
```

View File

@ -1,42 +0,0 @@
# candle-chinese-clip
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
pairs of images with related texts. This one is trained using in chinese instead of english.
## Running on cpu
```bash
$ cargo run --example chinese_clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "一场自行车比赛","两只猫的照片","一个机器人拿着蜡烛"
> Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
>
> 2025-03-25T19:22:01.325177Z INFO chinese_clip: Probability: 0.0000% Text: 一场自行车比赛
> 2025-03-25T19:22:01.325179Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
> 2025-03-25T19:22:01.325181Z INFO chinese_clip: Probability: 100.0000% Text: 一个机器人拿着蜡烛
> 2025-03-25T19:22:01.325183Z INFO chinese_clip:
>
> Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
>
> 2025-03-25T19:22:01.325184Z INFO chinese_clip: Probability: 100.0000% Text: 一场自行车比赛
> 2025-03-25T19:22:01.325186Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
> 2025-03-25T19:22:01.325187Z INFO chinese_clip: Probability: 0.0000% Text: 一个机器人拿着蜡烛
```
## Running on metal
```bash
$ cargo run --features metal --example chinese_clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "一场自行车比赛","两只猫的照片","一个机器人拿着蜡烛"
> Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
>
> 2025-03-25T19:22:01.325177Z INFO chinese_clip: Probability: 0.0000% Text: 一场自行车比赛
> 2025-03-25T19:22:01.325179Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
> 2025-03-25T19:22:01.325181Z INFO chinese_clip: Probability: 100.0000% Text: 一个机器人拿着蜡烛
> 2025-03-25T19:22:01.325183Z INFO chinese_clip:
>
> Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
>
> 2025-03-25T19:22:01.325184Z INFO chinese_clip: Probability: 100.0000% Text: 一场自行车比赛
> 2025-03-25T19:22:01.325186Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
> 2025-03-25T19:22:01.325187Z INFO chinese_clip: Probability: 0.0000% Text: 一个机器人拿着蜡烛
```

View File

@ -1,17 +0,0 @@
# candle-convmixer
A lightweight CNN architecture that processes image patches similar to a vision transformer, with separate spatial and channel convolutions.
ConvMixer from [Patches Are All You Need?](https://arxiv.org/pdf/2201.09792) and [ConvMixer](https://github.com/locuslab/convmixer).
## Running an example
```bash
$ cargo run --example convmixer --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
> mountain bike, all-terrain bike, off-roader: 61.75%
> unicycle, monocycle : 5.73%
> moped : 3.66%
> bicycle-built-for-two, tandem bicycle, tandem: 3.51%
> crash helmet : 0.85%
```

View File

@ -1,14 +0,0 @@
# Conversational Speech Model (CSM)
CSM is a speech generation model from Sesame,
[SesameAILabs/csm](https://github.com/SesameAILabs/csm).
It can generate a conversational speech between two different speakers.
The speakers turn are delimited by the `|` character in the prompt.
```bash
cargo run --example csm --features cuda -r -- \
--voices candle-examples/examples/csm/voices.safetensors \
--prompt "Hey how are you doing?|Pretty good, pretty good. How about you?"
```

View File

@ -1,243 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::csm::{Config, Model};
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "1b")]
Csm1b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
/// The prompt to be used for the generation, use a | to separate the speakers.
#[arg(long, default_value = "Hey how are you doing today?")]
prompt: String,
/// The voices to be used, in safetensors format.
#[arg(long)]
voices: String,
/// The output file using the wav format.
#[arg(long, default_value = "out.wav")]
out_file: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.7)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
/// The model size to use.
#[arg(long, default_value = "1b")]
which: Which,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
config: Option<String>,
#[arg(long)]
weights: Option<String>,
/// The mimi model weight file, in safetensor format.
#[arg(long)]
mimi_weights: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => {
let name = match args.which {
Which::Csm1b => "sesame/csm-1b",
};
name.to_string()
}
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let filenames = match args.weights {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => vec![repo.get("model.safetensors")?],
};
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("meta-llama/Llama-3.2-1B".to_string())
.get("tokenizer.json")?,
};
let mimi_filename = match args.mimi_weights {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("kyutai/mimi".to_string())
.get("model.safetensors")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: Config = match args.config {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
};
let device = candle_examples::device(args.cpu)?;
let (mut model, device) = {
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
(model, device)
};
let mut mimi_model = {
use candle_transformers::models::mimi;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[mimi_filename], DType::F32, &device)? };
let config = mimi::Config::v0_1(Some(32));
mimi::Model::new(config, vb)?
};
let cb = config.audio_num_codebooks;
println!("loaded the model in {:?}", start.elapsed());
let voices = candle::safetensors::load(args.voices, &device)?;
let mut lp = candle_transformers::generation::LogitsProcessor::new(
args.seed,
Some(args.temperature),
None,
);
let tokens = voices
.get("tokens")
.expect("no tokens in prompt")
.to_dtype(DType::U32)?;
let mask = voices.get("mask").expect("no mask in prompt").clone();
let mut pos = 0;
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
pos += tokens.dim(1)?;
let mut all_pcms = vec![];
for (turn_idx, prompt) in args.prompt.split('|').enumerate() {
println!("{prompt:?}");
let speaker_idx = turn_idx % 2;
let prompt = format!("[{speaker_idx}]{}<|end_of_text|>", prompt);
let prompt = tokenizer.encode(prompt, true).map_err(E::msg)?;
let (mut tokens, mut mask) = model.text_tokens_and_mask(prompt.get_ids())?;
let mut generated_tokens = vec![];
loop {
let frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
pos += tokens.dim(1)?;
let is_done = frame.iter().all(|&x| x == 0);
(tokens, mask) = model.audio_tokens_and_mask(frame)?;
print!("\rframe {pos}");
if is_done {
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
pos += tokens.dim(1)?;
break;
}
generated_tokens.push(tokens.clone());
}
println!();
let generated_tokens = Tensor::cat(&generated_tokens, 1)?.narrow(2, 0, cb)?.t()?;
let pcm = mimi_model.decode(&generated_tokens)?;
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
all_pcms.push(pcm);
}
let pcm = Tensor::cat(&all_pcms, 0)?;
let pcm = pcm.to_vec1::<f32>()?;
println!("writing output file {}", args.out_file);
let mut output = std::fs::File::create(args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
Ok(())
}

View File

@ -1,17 +0,0 @@
# candle-custom-ops
This example illustrates how to implement forward and backward passes for custom operations on the CPU and GPU.
The custom op in this example implements RMS normalization for the CPU and CUDA.
## Running an example
```bash
$ cargo run --example custom-ops
> [[ 0., 1., 2., 3., 4., 5., 6.],
> [ 7., 8., 9., 10., 11., 12., 13.]]
> Tensor[[2, 7], f32]
> [[0.0000, 0.2773, 0.5547, 0.8320, 1.1094, 1.3867, 1.6641],
> [0.6864, 0.7845, 0.8825, 0.9806, 1.0786, 1.1767, 1.2748]]
> Tensor[[2, 7], f32]
```

View File

@ -56,7 +56,7 @@ impl CustomOp1 for LayerNorm {
layout: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::{LaunchConfig, PushKernelArg};
use candle::cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig};
use candle::cuda_backend::WrapErr;
let (d1, d2) = layout.shape().dims2()?;
let d1 = d1 as u32;
@ -68,19 +68,15 @@ impl CustomOp1 for LayerNorm {
Some((o1, o2)) => slice.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<f32>(elem_count) }?;
let func =
dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?;
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
let params = (&dst, &slice, self.eps, d1, d2);
let cfg = LaunchConfig {
grid_dim: (d1, 1, 1),
block_dim: (d2, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = func.builder();
builder.arg(&dst);
builder.arg(&slice);
candle::builder_arg!(builder, self.eps, d1, d2);
unsafe { builder.launch(cfg) }.w()?;
unsafe { func.launch(cfg, params) }.w()?;
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
Ok((dst, layout.shape().clone()))

View File

@ -8,7 +8,7 @@ DistilBert is used to compute the sentence embeddings for a prompt. The model we
are downloaded from the hub on the first run.
```bash
$ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
@ -20,25 +20,3 @@ $ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
> Tensor[[1, 7, 768], f32]
```
## Masked Token
DistilBert is used to compute the top K choices for a masked token.
```bash
$ cargo run --example distilbert -- --prompt "The capital of France is [MASK]." --top-k 10
> Input: The capital of France is [MASK].
> Predictions for [MASK] at position 6:
> 1: marseille (probability: 12.14%)
> 2: paris (probability: 10.84%)
> 3: toulouse (probability: 8.57%)
> 4: lyon (probability: 7.61%)
> 5: montpellier (probability: 5.18%)
> 6: bordeaux (probability: 4.88%)
> 7: nantes (probability: 4.82%)
> 8: lille (probability: 4.07%)
> 9: strasbourg (probability: 3.12%)
> 10: cannes (probability: 3.04%)
```

View File

@ -3,48 +3,15 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::distilbert::{
Config, DistilBertForMaskedLM, DistilBertModel, DTYPE,
};
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
use anyhow::{Context, Error as E, Result};
use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
use clap::{Parser, ValueEnum};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::path::PathBuf;
use tokenizers::Tokenizer;
enum ModelType {
Masked(DistilBertForMaskedLM),
UnMasked(DistilBertModel),
}
impl ModelType {
fn device(&self) -> &Device {
match self {
ModelType::Masked(model) => &model.bert.device,
ModelType::UnMasked(model) => &model.device,
}
}
fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
match self {
ModelType::Masked(model) => Ok(model.forward(input_ids, attention_mask)?),
ModelType::UnMasked(model) => Ok(model.forward(input_ids, attention_mask)?),
}
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "distilbert")]
DistilBert,
#[value(name = "distilbertformaskedlm")]
DistilbertForMaskedLM,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -56,14 +23,10 @@ struct Args {
#[arg(long)]
tracing: bool,
#[arg(long, default_value = "distilbert")]
model: Which,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
/// Revision or branch
#[arg(long)]
revision: Option<String>,
@ -79,246 +42,94 @@ struct Args {
#[arg(long, default_value = "1")]
n: usize,
/// Number of top predictions to show for each mask
#[arg(long, default_value = "5")]
top_k: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
}
impl Args {
fn build_model_and_tokenizer(&self) -> Result<(ModelType, Tokenizer)> {
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
let device = candle_examples::device(self.cpu)?;
let (model_id, revision) = self.resolve_model_and_revision();
let (config_path, tokenizer_path, weights_path) =
self.download_model_files(&model_id, &revision)?;
let config = std::fs::read_to_string(config_path)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
let vb = self.load_variables(&weights_path, &device)?;
let model = self.create_model(&config, vb)?;
Ok((model, tokenizer))
}
fn resolve_model_and_revision(&self) -> (String, String) {
let default_model = "distilbert-base-uncased".to_string();
let default_revision = "main".to_string();
match (self.model_id.clone(), self.revision.clone()) {
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, default_revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
}
}
fn download_model_files(
&self,
model_id: &str,
revision: &str,
) -> Result<(PathBuf, PathBuf, PathBuf)> {
let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string());
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
api.get("model.safetensors")?
};
Ok((config, tokenizer, weights))
}
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
api.get("model.safetensors")?
};
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
fn load_variables(&self, weights_path: &PathBuf, device: &Device) -> Result<VarBuilder> {
if self.use_pth {
Ok(VarBuilder::from_pth(weights_path, DTYPE, device)?)
let vb = if self.use_pth {
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
} else {
Ok(unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, device)? })
}
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
let model = DistilBertModel::load(vb, &config)?;
Ok((model, tokenizer))
}
}
fn create_model(&self, config: &Config, vb: VarBuilder) -> Result<ModelType> {
match self.model {
Which::DistilbertForMaskedLM => {
Ok(ModelType::Masked(DistilBertForMaskedLM::load(vb, config)?))
}
Which::DistilBert => Ok(ModelType::UnMasked(DistilBertModel::load(vb, config)?)),
}
}
fn get_mask(size: usize, device: &Device) -> Tensor {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Tensor::from_slice(&mask, (size, size), device).unwrap()
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = setup_tracing(&args);
let (model, tokenizer) = args.build_model_and_tokenizer()?;
let device = model.device();
let (token_ids, mask) = prepare_inputs(&args, &tokenizer, device)?;
let output = model.forward(&token_ids, &mask)?;
process_output(&model, &output, &token_ids, &tokenizer, &args)?;
Ok(())
}
fn setup_tracing(args: &Args) -> Option<impl Drop> {
if args.tracing {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
}
}
};
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
fn prepare_inputs(args: &Args, tokenizer: &Tokenizer, device: &Device) -> Result<(Tensor, Tensor)> {
let mut binding = tokenizer.clone();
let tokenizer_configured = binding
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer_configured
.encode(args.prompt.clone(), true)
let tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mask = get_mask(tokens.len(), device);
let mask = match args.model {
Which::DistilbertForMaskedLM => attention_mask_maskedlm(tokenizer, &args.prompt, device)?,
Which::DistilBert => attention_mask(tokens.len(), device)?,
};
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
println!("mask: {:?}", mask.to_vec2::<u8>());
println!("token_ids: {:?}", token_ids.to_vec2::<u32>()?);
Ok((token_ids, mask))
}
fn process_output(
model: &ModelType,
output: &Tensor,
token_ids: &Tensor,
tokenizer: &Tokenizer,
args: &Args,
) -> Result<()> {
match model {
ModelType::UnMasked(_) => {
println!("embeddings");
println!("{output}");
}
ModelType::Masked(_) => {
process_masked_output(output, token_ids, tokenizer, args)?;
}
}
let ys = model.forward(&token_ids, &mask)?;
println!("{ys}");
Ok(())
}
fn process_masked_output(
output: &Tensor,
token_ids: &Tensor,
tokenizer: &Tokenizer,
args: &Args,
) -> Result<()> {
let input_ids_vec = token_ids.to_vec2::<u32>()?;
let mask_token_id = tokenizer
.token_to_id("[MASK]")
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
println!("\nInput: {}", args.prompt);
for (token_idx, &token_id) in input_ids_vec[0].iter().enumerate() {
if token_id == mask_token_id {
println!("Predictions for [MASK] at position {}:", token_idx);
let pos_logits = output.get(0)?.get(token_idx)?;
let probs = candle_nn::ops::softmax(&pos_logits, 0)?;
let (top_values, top_indices) = get_top_k(&probs, args.top_k)?;
let values = top_values.to_vec1::<f32>()?;
let indices = top_indices.to_vec1::<u32>()?;
for (i, (&token_id, &prob)) in indices.iter().zip(values.iter()).enumerate() {
let token = tokenizer.decode(&[token_id], false).map_err(E::msg)?;
println!(
" {}: {:15} (probability: {:.2}%)",
i + 1,
token,
prob * 100.0
);
}
}
}
Ok(())
}
fn get_top_k(tensor: &Tensor, k: usize) -> Result<(Tensor, Tensor)> {
let n = tensor.dims().iter().product::<usize>();
let k = std::cmp::min(k, n);
let values = tensor.to_vec1::<f32>()?;
let mut value_indices: Vec<(f32, usize)> = values
.into_iter()
.enumerate()
.map(|(idx, val)| (val, idx))
.collect();
value_indices.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let top_k_values: Vec<f32> = value_indices.iter().take(k).map(|(val, _)| *val).collect();
let top_k_indices: Vec<u32> = value_indices
.iter()
.take(k)
.map(|(_, idx)| *idx as u32)
.collect();
let device = tensor.device();
let top_values = Tensor::from_vec(top_k_values, (k,), device)?;
let top_indices = Tensor::from_vec(top_k_indices, (k,), device)?;
Ok((top_values, top_indices))
}
fn attention_mask(size: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Ok(Tensor::from_slice(&mask, (size, size), device)?)
}
fn attention_mask_maskedlm(tokenizer: &Tokenizer, input: &str, device: &Device) -> Result<Tensor> {
let tokens = tokenizer.encode(input, true).map_err(E::msg)?;
let seq_len = tokens.get_attention_mask().to_vec().len();
let mask_token_id = tokenizer
.token_to_id("[MASK]")
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
let mut attention_mask_vec = Vec::with_capacity(seq_len * seq_len);
let ids = tokens.get_ids();
for _ in 0..seq_len {
for id in ids.iter() {
let mask_value = if id == &mask_token_id { 1u8 } else { 0u8 };
attention_mask_vec.push(mask_value);
}
}
let shape = (1, 1, seq_len, seq_len);
let mask = Tensor::from_vec(attention_mask_vec, shape, device)?;
Ok(mask)
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}

View File

@ -1,15 +0,0 @@
# candle-efficientnet
Demonstrates a Candle implementation of EfficientNet for image classification based on ImageNet classes.
## Running an example
```bash
$ cargo run --example efficientnet --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which b1
> bicycle-built-for-two, tandem bicycle, tandem: 45.85%
> mountain bike, all-terrain bike, off-roader: 30.45%
> crash helmet : 2.58%
> unicycle, monocycle : 2.21%
> tricycle, trike, velocipede: 1.53%
```

View File

@ -1,10 +1,3 @@
# candle-falcon
Falcon is a general large language model.
## Running an example
Make sure to include the `--use-f32` flag if using CPU, because there isn't a BFloat16 implementation yet.
```
cargo run --example falcon --release -- --prompt "Flying monkeys are" --use-f32
```

View File

@ -12,7 +12,7 @@ GLM-4-9B is the open-source version of the latest generation of pre-trained mode
** Running with ~cpu~
#+begin_src shell
cargo run --example glm4 --release -- --cpu --prompt "Hello world"
cargo run --example glm4 --release -- --cpu--prompt "Hello world"
#+end_src
** Output Example

View File

@ -1,11 +0,0 @@
# candle-llama
Candle implementations of various Llama based architectures.
## Running an example
```bash
$ cargo run --example llama -- --prompt "Machine learning is " --which v32-3b-instruct
> Machine learning is the part of computer science which deals with the development of algorithms and
```

View File

@ -21,7 +21,7 @@ impl Config {
}
fn dt_rank(&self) -> usize {
self.d_model.div_ceil(16)
(self.d_model + 15) / 16
}
fn d_conv(&self) -> usize {

View File

@ -12,6 +12,6 @@ would only work for inference.
## Running the example
```bash
$ cargo run --example mamba --release -- --prompt "Mamba is the"
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
```

View File

@ -18,19 +18,21 @@ I know you are waiting for me. I will go through the forest, I will go through t
mountain. I cannot stay far from you any longer.</s>
```
### Changing model and language pairs
```bash
$ cargo run --example marian-mt --release -- --text "hello, how are you." --which base --language-pair en-zh
你好,你好吗?
```
## Generating the tokenizer.json files
The tokenizer for each `marian-mt` model was trained independently,
meaning each new model needs unique tokenizer encoders and decoders.
You can use the `./python/convert_slow_tokenizer.py` script in this directory to generate
the `tokenizer.json` config files from the hf-hub repos.
The script requires all the packages in `./python/requirements.txt` or `./python/uv.lock`
to be installed, and has only been tested for `python 3.12.7`.
You can use the following script to generate the `tokenizer.json` config files
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
packages to be install and use the `convert_slow_tokenizer.py` script from this
directory.
```python
from convert_slow_tokenizer import MarianConverter
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
```

File diff suppressed because it is too large Load Diff

View File

@ -20,22 +20,6 @@ enum Which {
Big,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum LanguagePair {
#[value(name = "fr-en")]
FrEn,
#[value(name = "en-zh")]
EnZh,
#[value(name = "en-hi")]
EnHi,
#[value(name = "en-es")]
EnEs,
#[value(name = "en-fr")]
EnFr,
#[value(name = "en-ru")]
EnRu,
}
// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
@ -52,10 +36,6 @@ struct Args {
#[arg(long, default_value = "big")]
which: Which,
// Choose which language pair to use
#[arg(long, default_value = "fr-en")]
language_pair: LanguagePair,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
@ -73,43 +53,21 @@ pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
let config = match (args.which, args.language_pair) {
(Which::Base, LanguagePair::FrEn) => marian::Config::opus_mt_fr_en(),
(Which::Big, LanguagePair::FrEn) => marian::Config::opus_mt_tc_big_fr_en(),
(Which::Base, LanguagePair::EnZh) => marian::Config::opus_mt_en_zh(),
(Which::Base, LanguagePair::EnHi) => marian::Config::opus_mt_en_hi(),
(Which::Base, LanguagePair::EnEs) => marian::Config::opus_mt_en_es(),
(Which::Base, LanguagePair::EnFr) => marian::Config::opus_mt_fr_en(),
(Which::Base, LanguagePair::EnRu) => marian::Config::opus_mt_en_ru(),
(Which::Big, lp) => anyhow::bail!("big is not supported for language pair {lp:?}"),
};
let tokenizer_default_repo = match args.language_pair {
LanguagePair::FrEn => "lmz/candle-marian",
LanguagePair::EnZh
| LanguagePair::EnHi
| LanguagePair::EnEs
| LanguagePair::EnFr
| LanguagePair::EnRu => "KeighBee/candle-marian",
let config = match args.which {
Which::Base => marian::Config::opus_mt_fr_en(),
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
};
let tokenizer = {
let tokenizer = match args.tokenizer {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
let filename = match (args.which, args.language_pair) {
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-fr.json",
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-fr.json",
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-en.json",
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-en.json",
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-en.json",
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-en.json",
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-en.json",
(Which::Big, lp) => {
anyhow::bail!("big is not supported for language pair {lp:?}")
}
let name = match args.which {
Which::Base => "tokenizer-marian-base-fr.json",
Which::Big => "tokenizer-marian-fr.json",
};
Api::new()?
.model(tokenizer_default_repo.to_string())
.get(filename)?
.model("lmz/candle-marian".to_string())
.get(name)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
@ -119,21 +77,13 @@ pub fn main() -> anyhow::Result<()> {
let tokenizer = match args.tokenizer_dec {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
let filename = match (args.which, args.language_pair) {
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-en.json",
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-en.json",
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-zh.json",
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-hi.json",
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-es.json",
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-fr.json",
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-ru.json",
(Which::Big, lp) => {
anyhow::bail!("big is not supported for language pair {lp:?}")
}
let name = match args.which {
Which::Base => "tokenizer-marian-base-en.json",
Which::Big => "tokenizer-marian-en.json",
};
Api::new()?
.model(tokenizer_default_repo.to_string())
.get(filename)?
.model("lmz/candle-marian".to_string())
.get(name)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
@ -144,48 +94,18 @@ pub fn main() -> anyhow::Result<()> {
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => {
let api = Api::new()?;
let api = match (args.which, args.language_pair) {
(Which::Base, LanguagePair::FrEn) => api.repo(hf_hub::Repo::with_revision(
None => match args.which {
Which::Base => Api::new()?
.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-fr-en".to_string(),
hf_hub::RepoType::Model,
"refs/pr/4".to_string(),
)),
(Which::Big, LanguagePair::FrEn) => {
api.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
}
(Which::Base, LanguagePair::EnZh) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-en-zh".to_string(),
hf_hub::RepoType::Model,
"refs/pr/13".to_string(),
)),
(Which::Base, LanguagePair::EnHi) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-en-hi".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
)),
(Which::Base, LanguagePair::EnEs) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-en-es".to_string(),
hf_hub::RepoType::Model,
"refs/pr/4".to_string(),
)),
(Which::Base, LanguagePair::EnFr) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-en-fr".to_string(),
hf_hub::RepoType::Model,
"refs/pr/9".to_string(),
)),
(Which::Base, LanguagePair::EnRu) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-en-ru".to_string(),
hf_hub::RepoType::Model,
"refs/pr/7".to_string(),
)),
(Which::Big, lp) => {
anyhow::bail!("big is not supported for language pair {lp:?}")
}
};
api.get("model.safetensors")?
}
))
.get("model.safetensors")?,
Which::Big => Api::new()?
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
.get("model.safetensors")?,
},
};
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
};

View File

@ -1,53 +0,0 @@
from pathlib import Path
import warnings
from transformers import AutoTokenizer
from transformers.convert_slow_tokenizer import SpmConverter, requires_backends, import_protobuf
class MarianConverter(SpmConverter):
def __init__(self, *args, index: int = 0):
requires_backends(self, "protobuf")
super(SpmConverter, self).__init__(*args)
# from .utils import sentencepiece_model_pb2 as model_pb2
model_pb2 = import_protobuf()
m = model_pb2.ModelProto()
print(self.original_tokenizer.spm_files)
with open(self.original_tokenizer.spm_files[index], "rb") as f:
m.ParseFromString(f.read())
self.proto = m
print(self.original_tokenizer)
#with open(self.original_tokenizer.vocab_path, "r") as f:
dir_path = Path(self.original_tokenizer.spm_files[0]).parents[0]
with open(dir_path / "vocab.json", "r") as f:
import json
self._vocab = json.load(f)
if self.proto.trainer_spec.byte_fallback:
if not getattr(self, "handle_byte_fallback", None):
warnings.warn(
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
"unknown tokens into a sequence of byte tokens matching the original piece of text."
)
def vocab(self, proto):
vocab_size = max(self._vocab.values()) + 1
vocab = [("<NIL>", -100) for _ in range(vocab_size)]
for piece in proto.pieces:
try:
index = self._vocab[piece.piece]
except Exception:
print(f"Ignored missing piece {piece.piece}")
vocab[index] = (piece.piece, piece.score)
return vocab
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
fast_tokenizer.save("tokenizer-marian-base-fr.json")
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
fast_tokenizer.save("tokenizer-marian-base-en.json")

View File

@ -1,22 +0,0 @@
certifi==2025.1.31
charset-normalizer==3.4.1
click==8.1.8
filelock==3.18.0
fsspec==2025.3.2
huggingface-hub==0.30.1
idna==3.10
joblib==1.4.2
numpy==2.2.4
packaging==24.2
protobuf==6.30.2
pyyaml==6.0.2
regex==2024.11.6
requests==2.32.3
sacremoses==0.1.1
safetensors==0.5.3
sentencepiece==0.2.0
tokenizers==0.21.1
tqdm==4.67.1
transformers==4.50.3
typing-extensions==4.13.0
urllib3==2.3.0

View File

@ -13,6 +13,6 @@ Note that the current candle implementation suffers from some limitations as of
## Run an example
```bash
cargo run --example metavoice --release -- \
cargo run --example metavoice --release -- \\
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
```

View File

@ -1,16 +0,0 @@
# candle-mnist-training
Training a 2 layer MLP on mnist in Candle.
## Running an example
```bash
$ cargo run --example mnist-training --features candle-datasets
> train-images: [60000, 784]
> train-labels: [60000]
> test-images: [10000, 784]
> test-labels: [10000]
> 1 train loss: 2.30265 test acc: 68.08%
> 2 train loss: 1.50815 test acc: 60.77%
```

View File

@ -12,7 +12,7 @@ $ wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jp
Now you can run Moondream from the `candle-examples` crate:
```bash
$ cargo run --example moondream --release -- --prompt "Describe the people behind the bikers?" --image "candle-examples/examples/yolo-v8/assets/bike.jpg"
$ cargo run --example moondream --release -- --prompt "What is the girl eating?" --image "./demo-1.jpg"
avavx: false, neon: true, simd128: false, f16c: false
temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64

View File

@ -1,20 +0,0 @@
# candle-musicgen
Candle implementation of musicgen from [Simple and Controllable Music Generation](https://arxiv.org/pdf/2306.05284).
## Running an example
```bash
$ cargo run --example musicgen -- --prompt "90s rock song with loud guitars and heavy drums"
> tokens: [2777, 7, 2480, 2324, 28, 8002, 5507, 7, 11, 2437, 5253, 7, 1]
> Tensor[dims 1, 13; u32]
> [[[ 0.0902, 0.1256, -0.0585, ..., 0.1057, -0.5141, -0.4675],
> [ 0.1972, -0.0268, -0.3368, ..., -0.0495, -0.3597, -0.3940],
> [-0.0855, -0.0007, 0.2225, ..., -0.2804, -0.5360, -0.2436],
> ...
> [ 0.0515, 0.0235, -0.3855, ..., -0.4728, -0.6858, -0.2923],
> [-0.3728, -0.1442, -0.1179, ..., -0.4388, -0.0287, -0.3242],
> [ 0.0163, 0.0012, -0.0020, ..., 0.0142, 0.0173, -0.0103]]]
> Tensor[[1, 13, 768], f32]
```

View File

@ -1,14 +0,0 @@
# Orpheus
Orpheus is a 3B text-to-speech model based on Llama.
- Weights on HuggingFace
[canopylabs/orpheus-3b-0.1-ft](https://huggingface.co/canopylabs/orpheus-3b-0.1-ft).
- Code on GitHub [canopyai/Orpheus-TTS](https://github.com/canopyai/Orpheus-TTS).
```bash
cargo run --example orpheus --features cuda -r
```

View File

@ -1,329 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::llama::{Cache, Llama, LlamaConfig};
use candle_transformers::models::snac::{Config as SnacConfig, Model as SnacModel};
use tokenizers::Tokenizer;
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/realtime_streaming_example/main.py#L43
const STOP_TOKEN_ID: u32 = 128258;
#[derive(Parser)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long, default_value = "Hey, how are you doing today?")]
prompt: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.6)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
model_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// The output wav file.
#[arg(long, default_value = "out.wav")]
out_file: String,
#[arg(long, default_value = "3b-0.1-ft")]
which: Which,
#[arg(long, default_value = "tara")]
voice: Voice,
#[arg(long)]
use_flash_attn: bool,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Voice {
#[value(name = "tara")]
Tara,
#[value(name = "leah")]
Leah,
#[value(name = "jess")]
Jess,
#[value(name = "leo")]
Leo,
#[value(name = "dan")]
Dan,
#[value(name = "mia")]
Mia,
#[value(name = "zac")]
Zac,
#[value(name = "zoe")]
Zoe,
}
impl Voice {
fn as_str(&self) -> &'static str {
match self {
Voice::Tara => "tara",
Voice::Leah => "leah",
Voice::Jess => "jess",
Voice::Leo => "leo",
Voice::Dan => "dan",
Voice::Mia => "mia",
Voice::Zac => "zac",
Voice::Zoe => "zoe",
}
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "3b-0.1-ft")]
ThreeB0_1Ft,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let prompt = args.prompt.clone();
let mut model = Model::load(args)?;
model.run(&prompt)?;
Ok(())
}
struct Model {
model: Llama,
tokenizer: Tokenizer,
logits_processor: candle_transformers::generation::LogitsProcessor,
cache: Cache,
device: Device,
verbose_prompt: bool,
snac: SnacModel,
out_file: String,
voice: Voice,
}
fn load_snac(device: &Device) -> Result<SnacModel> {
let api = hf_hub::api::sync::Api::new()?;
let m = api.model("hubertsiuzdak/snac_24khz".to_string());
let config = m.get("config.json")?;
let config: SnacConfig = serde_json::from_reader(std::fs::File::open(config)?)?;
let m = api.model("lmz/candle-snac".to_string());
let model = m.get("snac_24khz.safetensors")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, device)? };
let model = SnacModel::new(&config, vb)?;
Ok(model)
}
impl Model {
fn load(args: Args) -> Result<Self> {
let start = std::time::Instant::now();
let api = hf_hub::api::sync::Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::ThreeB0_1Ft => "canopylabs/orpheus-3b-0.1-ft".to_string(),
},
};
let revision = match args.revision {
Some(r) => r,
None => "main".to_string(),
};
let repo = api.repo(hf_hub::Repo::with_revision(
model_id,
hf_hub::RepoType::Model,
revision,
));
let model_files = match args.model_file {
Some(m) => vec![m.into()],
None => match args.which {
Which::ThreeB0_1Ft => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
};
let config = match args.config_file {
Some(m) => m.into(),
None => repo.get("config.json")?,
};
let tokenizer = match args.tokenizer_file {
Some(m) => m.into(),
None => repo.get("tokenizer.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, dtype, &device)? };
let config: LlamaConfig = serde_json::from_reader(std::fs::File::open(config)?)?;
let config = config.into_config(args.use_flash_attn);
let model = Llama::load(vb, &config)?;
let logits_processor = {
use candle_transformers::generation::{LogitsProcessor, Sampling};
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k.as_ref(), args.top_p.as_ref()) {
(None, None) => Sampling::All { temperature },
(Some(&k), None) => Sampling::TopK { k, temperature },
(None, Some(&p)) => Sampling::TopP { p, temperature },
(Some(&k), Some(&p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
println!("loaded the model in {:?}", start.elapsed());
let cache = Cache::new(true, dtype, &config, &device)?;
let snac = load_snac(&device)?;
Ok(Self {
model,
tokenizer,
logits_processor,
cache,
device,
verbose_prompt: args.verbose_prompt,
snac,
voice: args.voice,
out_file: args.out_file,
})
}
fn run(&mut self, prompt: &str) -> Result<()> {
println!("running the model on '{}'", prompt);
let device = &self.device;
let prompt = format!("{voice}: {prompt}", voice = self.voice.as_str());
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/engine_class.py#L82
let mut tokens = [
&[128259],
tokens.get_ids(),
&[128009, 128260, 128261, 128257],
]
.concat();
if self.verbose_prompt {
println!("{:?}", tokens);
}
let mut cache = self.cache.clone();
println!("starting the inference loop");
let mut index_pos = 0;
let mut audio_tokens = vec![];
for index in 0..2000 {
let (context_size, context_index) = if index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, context_index, &mut cache)?;
let logits = logits.squeeze(0)?;
index_pos += ctxt.len();
let next_token = self.logits_processor.sample(&logits)?;
if let Some(tok) = self.tokenizer.id_to_token(next_token) {
match tok.strip_prefix("<custom_token_") {
Some(tok) => match tok.strip_suffix('>') {
Some(tok) => {
let tok = tok.parse::<u32>()?;
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/decoder.py#L86C35-L86C63
let tok = tok - 10 - ((audio_tokens.len() as u32 % 7) * 4096);
audio_tokens.push(tok);
}
None => {
println!("{index}: unexpected custom token {next_token} {tok}");
}
},
None => {
println!("{index}: unexpected token {next_token} {tok}");
}
}
}
if next_token == STOP_TOKEN_ID {
println!("reached stop token");
break;
}
tokens.push(next_token);
}
println!("generated {} audio tokens", audio_tokens.len());
let mut codes0 = vec![];
let mut codes1 = vec![];
let mut codes2 = vec![];
for audio_tokens in audio_tokens.chunks_exact(7) {
codes0.push(audio_tokens[0]);
for i in [1, 4] {
codes1.push(audio_tokens[i]);
}
for i in [2, 3, 5, 6] {
codes2.push(audio_tokens[i]);
}
}
let codes0 = Tensor::new(codes0, device)?.unsqueeze(0)?;
let codes1 = Tensor::new(codes1, device)?.unsqueeze(0)?;
let codes2 = Tensor::new(codes2, device)?.unsqueeze(0)?;
let pcm = self.snac.decode(&[&codes0, &codes1, &codes2])?;
println!("decoded to pcm {pcm:?}");
let mut output = std::fs::File::create(&self.out_file)?;
let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24000)?;
Ok(())
}
}

View File

@ -1,20 +0,0 @@
# candle-quantized-phi
Candle implementation of various quantized Phi models.
## Running an example
```bash
$ cargo run --example quantized-phi --release -- --prompt "The best thing about coding in rust is "
> - it's memory safe (without you having to worry too much)
> - the borrow checker is really smart and will catch your mistakes for free, making them show up as compile errors instead of segfaulting in runtime.
>
> This alone make me prefer using rust over c++ or go, python/Cython etc.
>
> The major downside I can see now:
> - it's slower than other languages (viz: C++) and most importantly lack of libraries to leverage existing work done by community in that language. There are so many useful machine learning libraries available for c++, go, python etc but none for Rust as far as I am aware of on the first glance.
> - there aren't a lot of production ready projects which also makes it very hard to start new one (given my background)
>
> Another downside:
```

View File

@ -27,8 +27,6 @@ enum Which {
W2_7b,
#[value(name = "72b")]
W2_72b,
#[value(name = "deepseekr1-qwen7b")]
DeepseekR1Qwen7B,
}
#[derive(Parser, Debug)]
@ -104,7 +102,6 @@ impl Args {
Which::W2_1_5b => "Qwen/Qwen2-1.5B-Instruct",
Which::W2_7b => "Qwen/Qwen2-7B-Instruct",
Which::W2_72b => "Qwen/Qwen2-72B-Instruct",
Which::DeepseekR1Qwen7B => "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
@ -138,11 +135,6 @@ impl Args {
"qwen2-72b-instruct-q4_0.gguf",
"main",
),
Which::DeepseekR1Qwen7B => (
"unsloth/DeepSeek-R1-Distill-Qwen-7B-GGUF",
"DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf",
"main",
),
};
let api = hf_hub::api::sync::Api::new()?;
api.repo(hf_hub::Repo::with_revision(
@ -219,15 +211,11 @@ fn main() -> anyhow::Result<()> {
let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
let prompt_str = args
.prompt
.clone()
.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
let prompt_str = match args.which {
Which::DeepseekR1Qwen7B => format!("<User>{prompt_str}<Assistant>"),
_ => format!("<|im_start|>user\n{prompt_str}<|im_end|>\n<|im_start|>assistant\n"),
};
let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
let prompt_str = format!(
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
prompt_str
);
print!("formatted instruct prompt: {}", &prompt_str);
let tokens = tos
.tokenizer()
@ -272,13 +260,7 @@ fn main() -> anyhow::Result<()> {
print!("{t}");
std::io::stdout().flush()?;
}
let eos_token = match args.which {
Which::DeepseekR1Qwen7B => "<end▁of▁sentence>",
_ => "<|im_end|>",
};
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {

View File

@ -1,7 +1,5 @@
# candle-quantized-t5
Candle implementation for quantizing and running T5 translation models.
## Seq2Seq example
This example uses a quantized version of the t5 model.

View File

@ -75,8 +75,6 @@ enum Which {
SmolLM2_360MInstruct,
#[value(name = "SmoLM2-1.7B-Instruct")]
SmolLM2_1BInstruct,
#[value(name = "deepseekr1-llama8b")]
DeepseekR1Llama8b,
}
impl Which {
@ -96,8 +94,7 @@ impl Which {
| Self::L8b
| Self::Phi3
| Self::SmolLM2_1BInstruct
| Self::SmolLM2_360MInstruct
| Self::DeepseekR1Llama8b => false,
| Self::SmolLM2_360MInstruct => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35
@ -135,8 +132,7 @@ impl Which {
| Self::L8b
| Self::SmolLM2_1BInstruct
| Self::SmolLM2_360MInstruct
| Self::Phi3
| Self::DeepseekR1Llama8b => false,
| Self::Phi3 => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
}
@ -164,41 +160,11 @@ impl Which {
| Self::L8b
| Self::SmolLM2_1BInstruct
| Self::SmolLM2_360MInstruct
| Self::Phi3
| Self::DeepseekR1Llama8b => false,
| Self::Phi3 => false,
Self::OpenChat35 | Self::Starling7bAlpha => true,
}
}
fn is_deepseek(&self) -> bool {
match self {
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode
| Self::Leo7b
| Self::Leo13b
| Self::Mixtral
| Self::MixtralInstruct
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::L8b
| Self::SmolLM2_1BInstruct
| Self::SmolLM2_360MInstruct
| Self::Phi3
| Self::OpenChat35
| Self::Starling7bAlpha => false,
Self::DeepseekR1Llama8b => true,
}
}
fn tokenizer_repo(&self) -> &'static str {
match self {
Self::L7b
@ -225,7 +191,6 @@ impl Which {
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
Self::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
Self::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
Self::DeepseekR1Llama8b => "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
}
}
}
@ -398,10 +363,6 @@ impl Args {
"HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF",
"smollm2-1.7b-instruct-q4_k_m.gguf",
),
Which::DeepseekR1Llama8b => (
"unsloth/DeepSeek-R1-Distill-Llama-8B-GGUF",
"DeepSeek-R1-Distill-Llama-8B-Q4_K_M.gguf",
),
};
let revision = if self.which == Which::Phi3 {
"5eef2ce24766d31909c0b269fe90c817a8f263fb"
@ -516,7 +477,6 @@ fn main() -> anyhow::Result<()> {
| Which::L8b
| Which::SmolLM2_1BInstruct
| Which::SmolLM2_360MInstruct
| Which::DeepseekR1Llama8b
| Which::Phi3 => 1,
Which::Mixtral
| Which::MixtralInstruct
@ -570,8 +530,6 @@ fn main() -> anyhow::Result<()> {
}
} else if args.which.is_mistral() {
format!("[INST] {prompt} [/INST]")
} else if args.which.is_deepseek() {
format!("<User>{prompt}<Assistant>")
} else {
prompt
}
@ -639,7 +597,6 @@ fn main() -> anyhow::Result<()> {
let eos_token = match args.which {
Which::SmolLM2_360MInstruct | Which::SmolLM2_1BInstruct => "<|endoftext|>",
Which::L8b => "<|end_of_text|>",
Which::DeepseekR1Llama8b => "<end▁of▁sentence>",
_ => match args.which.is_open_chat() {
true => "<|end_of_turn|>",
false => "</s>",

View File

@ -2,11 +2,6 @@
Reinforcement Learning examples for candle.
> [!WARNING]
> uv is not currently compatible with pyo3 as of 2025/3/28.
## System wide python
This has been tested with `gymnasium` version `0.29.1`. You can install the
Python package with:
```bash

View File

@ -7,7 +7,7 @@ probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example resnet --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
$ cargo run --example resnet --release -- --image tiger.jpg
loaded image Tensor[dims 3, 224, 224; f32]
model built

View File

@ -10,11 +10,9 @@ If you want you can use the example images from this [pull request][pr], downloa
```bash
# run the image classification task
cargo run --example segformer classify candle-examples/examples/yolo-v8/assets/bike.jpg
cargo run --example segformer classify <path-to-image>
# run the segmentation task
cargo run --example segformer segment candle-examples/examples/yolo-v8/assets/bike.jpg
cargo run --example segformer segment <path-to-image>
```
Example output for classification:

View File

@ -14,8 +14,8 @@ based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
```bash
cargo run --example segment-anything --release -- \
--image candle-examples/examples/yolo-v8/assets/bike.jpg \
--use-tiny \
--image candle-examples/examples/yolo-v8/assets/bike.jpg
--use-tiny
--point 0.6,0.6 --point 0.6,0.55
```

View File

@ -5,7 +5,7 @@ SigLIP is multi-modal text-vision model that improves over CLIP by using a sigmo
### Running an example
```
$ cargo run --features cuda -r --example siglip
$ cargo run --features cuda -r --example siglip -
softmax_image_vec: [2.1912122e-14, 2.3624872e-14, 1.0, 1.0, 2.4787932e-8, 3.2784535e-12]

View File

@ -6,14 +6,7 @@ This example uses the models available in the hugging face [onnx-community/siler
## Running the example
### using arecord
```bash
$ arecord -t raw -f S16_LE -r 16000 -c 1 -d 5 - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
```
### using SoX
```bash
$ rec -t raw -r 48000 -b 16 -c 1 -e signed-integer - trim 0 5 | sox -t raw -r 48000 -b 16 -c 1 -e signed-integer - -t raw -r 16000 -b 16 -c 1 -e signed-integer - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
```

View File

@ -1,275 +0,0 @@
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};
pub const SAMPLE_RATE: usize = 24_000;
pub(crate) struct AudioOutputData_ {
resampled_data: std::collections::VecDeque<f32>,
resampler: rubato::FastFixedIn<f32>,
output_buffer: Vec<f32>,
input_buffer: Vec<f32>,
input_len: usize,
}
impl AudioOutputData_ {
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
use rubato::Resampler;
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
let resampler = rubato::FastFixedIn::new(
resample_ratio,
f64::max(resample_ratio, 1.0),
rubato::PolynomialDegree::Septic,
1024,
1,
)?;
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
Ok(Self {
resampled_data,
resampler,
input_buffer,
output_buffer,
input_len: 0,
})
}
pub fn reset(&mut self) {
use rubato::Resampler;
self.output_buffer.fill(0.);
self.input_buffer.fill(0.);
self.resampler.reset();
self.resampled_data.clear();
}
pub(crate) fn take_all(&mut self) -> Vec<f32> {
let mut data = Vec::with_capacity(self.resampled_data.len());
while let Some(elem) = self.resampled_data.pop_back() {
data.push(elem);
}
data
}
pub(crate) fn is_empty(&self) -> bool {
self.resampled_data.is_empty()
}
// Assumes that the input buffer is large enough.
fn push_input_buffer(&mut self, samples: &[f32]) {
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
self.input_len += samples.len()
}
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
use rubato::Resampler;
let mut pos_in = 0;
loop {
let rem = self.input_buffer.len() - self.input_len;
let pos_end = usize::min(pos_in + rem, samples.len());
self.push_input_buffer(&samples[pos_in..pos_end]);
pos_in = pos_end;
if self.input_len < self.input_buffer.len() {
break;
}
let (_, out_len) = self.resampler.process_into_buffer(
&[&self.input_buffer],
&mut [&mut self.output_buffer],
None,
)?;
for &elem in self.output_buffer[..out_len].iter() {
self.resampled_data.push_front(elem)
}
self.input_len = 0;
}
Ok(())
}
}
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio output stream!");
let host = cpal::default_host();
let device = host
.default_output_device()
.context("no output device available")?;
let mut supported_configs_range = device.supported_output_configs()?;
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
// On macOS, it's commonly the case that there are only stereo outputs.
None => device
.supported_output_configs()?
.next()
.context("no audio output available")?,
Some(config_range) => config_range,
};
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
let channels = config.channels as usize;
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
SAMPLE_RATE,
config.sample_rate.0 as usize,
)?));
let ad = audio_data.clone();
let stream = device.build_output_stream(
&config,
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
data.fill(0.);
let mut ad = ad.lock().unwrap();
let mut last_elem = 0f32;
for (idx, elem) in data.iter_mut().enumerate() {
if idx % channels == 0 {
match ad.resampled_data.pop_back() {
None => break,
Some(v) => {
last_elem = v;
*elem = v
}
}
} else {
*elem = last_elem
}
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio input stream!");
let host = cpal::default_host();
let device = host
.default_input_device()
.context("no input device available")?;
let mut supported_configs_range = device.supported_input_configs()?;
let config_range = supported_configs_range
.find(|c| c.channels() == 1)
.context("no audio input available")?;
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
config.sample_rate.0 as usize,
SAMPLE_RATE,
)?));
let ad = audio_data.clone();
let stream = device.build_input_stream(
&config,
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let mut ad = ad.lock().unwrap();
if let Err(err) = ad.push_samples(data) {
eprintln!("error processing audio input {err:?}")
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
pub(crate) fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
use rubato::Resampler;
let mut pcm_out =
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
let mut resampler =
rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)?;
let mut output_buffer = resampler.output_buffer_allocate(true);
let mut pos_in = 0;
while pos_in + resampler.input_frames_next() < pcm_in.len() {
let (in_len, out_len) =
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
pos_in += in_len;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
if pos_in < pcm_in.len() {
let (_in_len, out_len) = resampler.process_partial_into_buffer(
Some(&[&pcm_in[pos_in..]]),
&mut output_buffer,
None,
)?;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
Ok(pcm_out)
}

View File

@ -1,197 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::snac::{Config, Model};
use clap::{Parser, ValueEnum};
use hf_hub::api::sync::Api;
mod audio_io;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Action {
AudioToAudio,
AudioToCode,
CodeToAudio,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "24khz")]
S24khz,
#[value(name = "32khz")]
S32khz,
#[value(name = "44khz")]
S44khz,
}
impl Which {
fn sample_rate(&self) -> u32 {
match self {
Which::S24khz => 24000,
Which::S32khz => 32000,
Which::S44khz => 44000,
}
}
fn config_repo(&self) -> &'static str {
match self {
Which::S24khz => "hubertsiuzdak/snac_24khz",
Which::S32khz => "hubertsiuzdak/snac_32khz",
Which::S44khz => "hubertsiuzdak/snac_44khz",
}
}
fn model_file(&self) -> &'static str {
match self {
Which::S24khz => "snac_24khz.safetensors",
Which::S32khz => "snac_32khz.safetensors",
Which::S44khz => "snac_44khz.safetensors",
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The action to be performed, specifies the format for the input and output data.
action: Action,
/// The input file, either an audio file or some snac tokens stored as safetensors.
in_file: String,
/// The output file, either a wave audio file or some snac tokens stored as safetensors.
out_file: String,
/// The model size to use.
#[arg(long, default_value = "24khz")]
which: Which,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The model weight file, in safetensor format.
#[arg(long)]
model: Option<String>,
/// The config file, in safetensor format.
#[arg(long)]
config: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let model_sample_rate = args.which.sample_rate();
let config = match args.config {
Some(c) => std::path::PathBuf::from(c),
None => Api::new()?
.model(args.which.config_repo().to_string())
.get("config.json")?,
};
let config: Config = serde_json::from_slice(&std::fs::read(config)?)?;
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("lmz/candle-snac".to_string())
.get(args.which.model_file())?,
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let model = Model::new(&config, vb)?;
let codes = match args.action {
Action::CodeToAudio => {
let codes = candle::safetensors::load(args.in_file, &device)?;
let num_codebooks = model.num_codebooks();
(0..num_codebooks)
.map(|i| {
codes
.get(&format!("codes-{i}"))
.expect("no codes in input file")
.clone()
})
.collect::<Vec<_>>()
}
Action::AudioToCode | Action::AudioToAudio => {
let pcm = if args.in_file == "-" {
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
let (stream, input_audio) = audio_io::setup_input_stream()?;
let mut pcms = vec![];
let stdin = std::thread::spawn(|| {
let mut s = String::new();
std::io::stdin().read_line(&mut s)
});
while !stdin.is_finished() {
let input = input_audio.lock().unwrap().take_all();
if input.is_empty() {
std::thread::sleep(std::time::Duration::from_millis(100));
continue;
}
pcms.push(input)
}
drop(stream);
pcms.concat()
} else {
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
if sample_rate != model_sample_rate {
println!("WARNING: snac uses a {model_sample_rate} sample rate, input uses {sample_rate}, resampling...");
audio_io::resample(&pcm, sample_rate, model_sample_rate)?
} else {
pcm
}
};
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
model.encode(&pcm)?
}
};
for codes in codes.iter() {
println!("codes shape: {:?}", codes.shape());
}
match args.action {
Action::AudioToCode => {
let mut tensors = std::collections::HashMap::new();
for (i, codes) in codes.iter().enumerate() {
tensors.insert(format!("codes-{i}"), codes.clone());
}
candle::safetensors::save(&tensors, "codes.safetensors")?;
}
Action::AudioToAudio | Action::CodeToAudio => {
let codes = codes.iter().collect::<Vec<_>>();
let pcm = model.decode(&codes)?;
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, model_sample_rate, true)?;
let pcm = pcm.to_vec1::<f32>()?;
if args.out_file == "-" {
let (stream, ad) = audio_io::setup_output_stream()?;
{
let mut ad = ad.lock().unwrap();
ad.push_samples(&pcm)?;
}
loop {
let ad = ad.lock().unwrap();
if ad.is_empty() {
break;
}
// That's very weird, calling thread::sleep here triggers the stream to stop
// playing (the callback doesn't seem to be called anymore).
// std::thread::sleep(std::time::Duration::from_millis(100));
}
drop(stream)
} else {
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, model_sample_rate)?;
}
}
}
Ok(())
}

View File

@ -1,15 +0,0 @@
# candle-starcoder2
Candle implementation of Star Coder 2 family of code generation model from [StarCoder 2 and The Stack v2: The Next Generation](https://arxiv.org/pdf/2402.19173).
## Running an example
```bash
$ cargo run --example starcoder2 -- --prompt "write a recursive fibonacci function in python "
> # that returns the nth number in the sequence.
>
> def fib(n):
> if n
```

View File

@ -10,7 +10,7 @@ Stella_en_1.5B_v5 is used to generate text embeddings embeddings for a prompt. T
are downloaded from the hub on the first run.
```bash
$ cargo run --example stella-en-v5 --release -- --query "What are safetensors?" --which 1.5b
$ cargo run --example stella-en-v5 --release -- --query "What are safetensors?"
> [[ 0.3905, -0.0130, 0.2072, ..., -0.1100, -0.0086, 0.6002]]
> Tensor[[1, 1024], f32]

View File

@ -1,7 +1,5 @@
# candle-t5
Candle implementations of the T5 family of translation models.
## Encoder-decoder example:
```bash

View File

@ -7,7 +7,7 @@ The VGG models are defined in `candle-transformers/src/models/vgg.rs`. The main
You can run the example with the following command:
```bash
cargo run --example vgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which vgg13
cargo run --example vgg --release -- --image ../yolo-v8/assets/bike.jpg --which vgg13
```
In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19).

View File

@ -7,8 +7,8 @@ probabilities for the top-5 classes.
## Running an example
```bash
$ cargo run --example vit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
```
$ cargo run --example vit --release -- --image tiger.jpg
loaded image Tensor[dims 3, 224, 224; f32]
model built

View File

@ -1,15 +0,0 @@
# candle-whisper-microphone
Whisper implementation using microphone as input.
## Running an example
```bash
$ cargo run --example whisper-microphone --features microphone
> transcribing audio...
> 480256 160083
> language_token: None
> 0.0s -- 30.0s: Hello, hello, I don't know if this is working, but You know, how long did I make this?
> 480256 160085
```

View File

@ -1,13 +0,0 @@
# candle-yi
Candle implentations of the Yi family of bilingual (English, Chinese) LLMs.
## Running an example
```bash
$ cargo run --example yi -- --prompt "Here is a test sentence"
> python
> print("Hello World")
>
```

View File

@ -1,32 +0,0 @@
# candle-yolo-v3:
Candle implementation of Yolo-V3 for object detection.
## Running an example
```bash
$ cargo run --example yolo-v3 --release -- candle-examples/examples/yolo-v8/assets/bike.jpg
> generated predictions Tensor[dims 10647, 85; f32]
> person: Bbox { xmin: 46.362198, ymin: 72.177, xmax: 135.92522, ymax: 339.8356, confidence: 0.99705493, data: () }
> person: Bbox { xmin: 137.25645, ymin: 67.58148, xmax: 216.90437, ymax: 333.80756, confidence: 0.9898516, data: () }
> person: Bbox { xmin: 245.7842, ymin: 82.76726, xmax: 316.79053, ymax: 337.21613, confidence: 0.9884322, data: () }
> person: Bbox { xmin: 207.52783, ymin: 61.815224, xmax: 266.77884, ymax: 307.92606, confidence: 0.9860648, data: () }
> person: Bbox { xmin: 11.457404, ymin: 60.335564, xmax: 34.39357, ymax: 187.7714, confidence: 0.9545012, data: () }
> person: Bbox { xmin: 251.88353, ymin: 11.235481, xmax: 286.56607, ymax: 92.54697, confidence: 0.8439807, data: () }
> person: Bbox { xmin: -0.44309902, ymin: 55.486923, xmax: 13.160354, ymax: 184.09705, confidence: 0.8266243, data: () }
> person: Bbox { xmin: 317.40826, ymin: 55.39501, xmax: 370.6704, ymax: 153.74887, confidence: 0.7327442, data: () }
> person: Bbox { xmin: 370.02835, ymin: 66.120224, xmax: 404.22824, ymax: 142.09691, confidence: 0.7265741, data: () }
> person: Bbox { xmin: 250.36511, ymin: 57.349842, xmax: 280.06335, ymax: 116.29384, confidence: 0.709422, data: () }
> person: Bbox { xmin: 32.573215, ymin: 66.66239, xmax: 50.49056, ymax: 173.42068, confidence: 0.6998766, data: () }
> person: Bbox { xmin: 131.72215, ymin: 63.946213, xmax: 166.66151, ymax: 241.52773, confidence: 0.64457536, data: () }
> person: Bbox { xmin: 407.42416, ymin: 49.106407, xmax: 415.24307, ymax: 84.7134, confidence: 0.5955802, data: () }
> person: Bbox { xmin: 51.650482, ymin: 64.4985, xmax: 67.40904, ymax: 106.952385, confidence: 0.5196007, data: () }
> bicycle: Bbox { xmin: 160.10031, ymin: 183.90837, xmax: 200.86832, ymax: 398.609, confidence: 0.9623588, data: () }
> bicycle: Bbox { xmin: 66.570915, ymin: 192.56966, xmax: 112.06765, ymax: 369.28497, confidence: 0.9174347, data: () }
> bicycle: Bbox { xmin: 258.2856, ymin: 197.04532, xmax: 298.43106, ymax: 364.8627, confidence: 0.6851388, data: () }
> bicycle: Bbox { xmin: 214.0034, ymin: 175.76498, xmax: 252.45158, ymax: 356.53818, confidence: 0.67071193, data: () }
> motorbike: Bbox { xmin: 318.23938, ymin: 95.22487, xmax: 369.9743, ymax: 213.46263, confidence: 0.96691036, data: () }
> motorbike: Bbox { xmin: 367.46417, ymin: 100.07982, xmax: 394.9981, ymax: 174.6545, confidence: 0.9185384, data: () }
> writing "candle-examples/examples/yolo-v8/assets/bike.pp.jpg"
```

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.9.0-alpha.2"
version = "0.8.4"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,17 +11,14 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.2" }
candle = { path = "../candle-core", features = ["_cuda"], package = "candle-core", version = "0.8.4" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
bindgen_cuda = "0.1.1"
anyhow = { version = "1", features = ["backtrace"] }
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", features = ["cuda"] }
[features]
default = []
cudnn = ["candle/cudnn"]
candle-nn = { path = "../candle-nn", features = ["_cuda"] }

View File

@ -2,6 +2,7 @@ mod ffi;
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
use half::{bf16, f16};
@ -87,7 +88,6 @@ impl FlashAttn {
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
}
let stream = dev.cuda_stream();
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
if alibi_slopes.dtype() != DType::F32 {
candle::bail!(
@ -114,9 +114,7 @@ impl FlashAttn {
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
// Dropping the guard here doesn't seem very safe.
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
ptr as *const core::ffi::c_void
*alibi_slopes.device_ptr() as *const core::ffi::c_void
} else {
std::ptr::null()
};
@ -141,8 +139,10 @@ impl FlashAttn {
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<T>(elem_count)? };
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)?;
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
.w()?;
let is_bf16 = if is_bf16 { 1 } else { 0 };
@ -161,17 +161,17 @@ impl FlashAttn {
}
unsafe {
let (q_ptr, _guard) = q.device_ptr(&stream);
let (k_ptr, _guard) = k.device_ptr(&stream);
let (v_ptr, _guard) = v.device_ptr(&stream);
let (dst_ptr, _guard) = dst.device_ptr(&stream);
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
ffi::run_mha(
q_ptr as *const core::ffi::c_void,
k_ptr as *const core::ffi::c_void,
v_ptr as *const core::ffi::c_void,
dst_ptr as *const core::ffi::c_void,
softmax_lse_ptr as *const core::ffi::c_void,
q_ptr,
k_ptr,
v_ptr,
dst_ptr,
softmax_lse_ptr,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ std::ptr::null(),
/* cu_seqlens_k_ptr */ std::ptr::null(),
@ -550,7 +550,6 @@ impl FlashAttnVarLen {
let batch_size = nseqlens_q - 1;
let stream = dev.cuda_stream();
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
if alibi_slopes.dtype() != DType::F32 {
candle::bail!(
@ -577,9 +576,7 @@ impl FlashAttnVarLen {
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
// Dropping the guard here doesn't seem very safe.
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
ptr as *const core::ffi::c_void
*alibi_slopes.device_ptr() as *const core::ffi::c_void
} else {
std::ptr::null()
};
@ -604,8 +601,8 @@ impl FlashAttnVarLen {
let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128);
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<f16>(elem_count)? };
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q)?;
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q).w()?;
let is_bf16 = if is_bf16 { 1 } else { 0 };
@ -624,22 +621,22 @@ impl FlashAttnVarLen {
}
unsafe {
let (q_ptr, _guard) = q.device_ptr(&stream);
let (k_ptr, _guard) = k.device_ptr(&stream);
let (v_ptr, _guard) = v.device_ptr(&stream);
let (dst_ptr, _guard) = dst.device_ptr(&stream);
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
let (seqlens_q_ptr, _guard) = seqlens_q.device_ptr(&stream);
let (seqlens_k_ptr, _guard) = seqlens_k.device_ptr(&stream);
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int;
let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int;
ffi::run_mha(
q_ptr as *const core::ffi::c_void,
k_ptr as *const core::ffi::c_void,
v_ptr as *const core::ffi::c_void,
dst_ptr as *const core::ffi::c_void,
softmax_lse_ptr as *const core::ffi::c_void,
/* alibi_slopes_ptr */ alibi_slopes_ptr as *const core::ffi::c_void,
/* cu_seqlens_q_ptr */ seqlens_q_ptr as *const i32,
/* cu_seqlens_k_ptr */ seqlens_k_ptr as *const i32,
q_ptr,
k_ptr,
v_ptr,
dst_ptr,
softmax_lse_ptr,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
/* q_batch_stride */ 0,
/* k_batch_stride */ 0,
/* v_batch_stride */ 0,

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.9.0-alpha.2"
version = "0.8.4"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -7,5 +7,5 @@ fn main() {
let builder = bindgen_cuda::Builder::default();
println!("cargo:info={builder:?}");
let bindings = builder.build_ptx().unwrap();
bindings.write("src/ptx.rs").unwrap();
bindings.write("src/lib.rs").unwrap();
}

View File

@ -53,7 +53,7 @@ __device__ void conv1d(
template <typename T>
__device__ void im2col1d(
const size_t numel,
const size_t dst_numel,
const size_t l_out,
const size_t l_k,
const size_t stride,
@ -63,10 +63,10 @@ __device__ void im2col1d(
const T *src,
T *dst
) {
const size_t thread_i = blockIdx.x * blockDim.x + threadIdx.x;
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
// dst: (b_size, l_out, c_in, l_k)
// src: (b_size, c_in, l_in)
if (thread_i >= numel) {
if (dst_i >= dst_numel) {
return;
}
const size_t *src_dims = info;
@ -74,26 +74,26 @@ __device__ void im2col1d(
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
const size_t dst_s1 = c_in;
const size_t dst_s2 = l_k;
const size_t dst_s1 = c_in * dst_s2;
const size_t dst_s0 = l_out * dst_s1;
size_t tmp_dst_i = thread_i;
size_t tmp_dst_i = dst_i;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t l_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= l_idx * dst_s1;
const size_t c_idx = tmp_dst_i;
for (size_t l_k_idx = 0; l_k_idx < l_k; ++l_k_idx) {
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
size_t dst_i = thread_i * l_k + l_k_idx;
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
dst[dst_i] = static_cast<T>(0);
}
else {
src_l_idx -= padding;
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
dst[dst_i] = src[src_i];
}
const size_t c_idx = tmp_dst_i / dst_s2;
tmp_dst_i -= c_idx * dst_s2;
const size_t l_k_idx = tmp_dst_i;
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
dst[dst_i] = static_cast<T>(0);
}
else {
src_l_idx -= padding;
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
dst[dst_i] = src[src_i];
}
}

View File

@ -1,78 +1,11 @@
mod ptx;
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Id {
Affine,
Binary,
Cast,
Conv,
Fill,
Indexing,
Quantized,
Reduce,
Sort,
Ternary,
Unary,
}
pub const ALL_IDS: [Id; 11] = [
Id::Affine,
Id::Binary,
Id::Cast,
Id::Conv,
Id::Fill,
Id::Indexing,
Id::Quantized,
Id::Reduce,
Id::Sort,
Id::Ternary,
Id::Unary,
];
pub struct Module {
index: usize,
ptx: &'static str,
}
impl Module {
pub fn index(&self) -> usize {
self.index
}
pub fn ptx(&self) -> &'static str {
self.ptx
}
}
const fn module_index(id: Id) -> usize {
let mut i = 0;
while i < ALL_IDS.len() {
if ALL_IDS[i] as u32 == id as u32 {
return i;
}
i += 1;
}
panic!("id not found")
}
macro_rules! mdl {
($cst:ident, $id:ident) => {
pub const $cst: Module = Module {
index: module_index(Id::$id),
ptx: ptx::$cst,
};
};
}
mdl!(AFFINE, Affine);
mdl!(BINARY, Binary);
mdl!(CAST, Cast);
mdl!(CONV, Conv);
mdl!(FILL, Fill);
mdl!(INDEXING, Indexing);
mdl!(QUANTIZED, Quantized);
mdl!(REDUCE, Reduce);
mdl!(SORT, Sort);
mdl!(TERNARY, Ternary);
mdl!(UNARY, Unary);
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));

View File

@ -1,11 +0,0 @@
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.9.0-alpha.2"
version = "0.8.4"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -32,9 +32,10 @@ criterion = { workspace = true }
[features]
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
_cuda = ["candle/_cuda"]
cuda = ["candle/cuda"]
cudnn = ["candle/cudnn"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
_mkl = ["dep:intel-mkl-src", "candle/_mkl"]
mkl = ["candle/mkl"]
metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]
[[bench]]

View File

@ -15,9 +15,9 @@ impl BenchDevice for Device {
match self {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
return Ok(device.synchronize()?);
#[cfg(not(feature = "cuda"))]
#[cfg(not(feature = "_cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
}
Device::Metal(device) => {
@ -34,7 +34,7 @@ impl BenchDevice for Device {
Device::Cpu => {
let cpu_type = if cfg!(feature = "accelerate") {
"accelerate"
} else if cfg!(feature = "mkl") {
} else if cfg!(feature = "_mkl") {
"mkl"
} else {
"cpu"
@ -56,7 +56,7 @@ impl BenchDeviceHandler {
let mut devices = Vec::new();
if cfg!(feature = "metal") {
devices.push(Device::new_metal(0)?);
} else if cfg!(feature = "cuda") {
} else if cfg!(feature = "_cuda") {
devices.push(Device::new_cuda(0)?);
}
devices.push(Device::Cpu);

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,5 +1,5 @@
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -41,36 +41,12 @@ impl Linear {
impl super::Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
// When possible, we avoid using a broadcasted matmul as it is much slower
// than the standard matmul for the cuda and cpu backends.
let x = match *x.dims() {
[b1, b2, m, k] => {
if x.is_contiguous() {
let w = self.weight.t()?;
x.reshape((b1 * b2 * m, k))?
.matmul(&w)?
.reshape((b1, b2, m, ()))?
} else {
let w = self.weight.broadcast_left((b1, b2))?.t()?;
x.matmul(&w)?
}
}
[bsize, m, k] => {
if x.is_contiguous() {
let w = self.weight.t()?;
x.reshape((bsize * m, k))?
.matmul(&w)?
.reshape((bsize, m, ()))?
} else {
let w = self.weight.broadcast_left(bsize)?.t()?;
x.matmul(&w)?
}
}
_ => {
let w = self.weight.t()?;
x.matmul(&w)?
}
let w = match *x.dims() {
[b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),

View File

@ -7,7 +7,7 @@ use candle::{Result, Tensor};
/// Arguments
///
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
/// of categories. This is expected to contain log probabilities.
/// of categories. This is expected to contain log probabilities.
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
///
/// The resulting tensor is a scalar containing the average value over the batch.
@ -34,7 +34,7 @@ pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
/// Arguments
///
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
/// of categories. This is expected to raw logits.
/// of categories. This is expected to raw logits.
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
///
/// The resulting tensor is a scalar containing the average value over the batch.
@ -56,9 +56,9 @@ pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
/// Arguments
///
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
/// of categories. This is expected to raw logits.
/// of categories. This is expected to raw logits.
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
/// of categories.
/// of categories.
///
/// The resulting tensor is a scalar containing the average value over the batch.
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {

View File

@ -82,7 +82,7 @@ impl candle::CustomOp1 for Sigmoid {
Ok((storage, layout.shape().clone()))
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
storage: &candle::CudaStorage,
@ -90,7 +90,7 @@ impl candle::CustomOp1 for Sigmoid {
) -> Result<(candle::CudaStorage, Shape)> {
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
};
use candle::cuda_backend::SlicePtrOrNull;
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
@ -110,17 +110,13 @@ impl candle::CustomOp1 for Sigmoid {
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el_count)? };
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
let mut builder = func.builder();
candle::builder_arg!(builder, el_count, dims.len());
ds.builder_arg(&mut builder);
builder.arg(src);
builder.arg(&out);
let params = (el_count, dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
@ -337,14 +333,14 @@ impl candle::CustomOp1 for SoftmaxLastDim {
}
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
storage: &candle::CudaStorage,
layout: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
use candle::{CudaDevice, WithDType};
@ -371,15 +367,12 @@ impl candle::CustomOp1 for SoftmaxLastDim {
block_dim: (1, 32, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&dst);
candle::builder_arg!(builder, n_cols as i32);
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &dst, n_cols as i32);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
@ -514,7 +507,7 @@ impl candle::CustomOp2 for RmsNorm {
}
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
@ -523,7 +516,7 @@ impl candle::CustomOp2 for RmsNorm {
l2: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
use candle::{CudaDevice, WithDType};
@ -559,16 +552,19 @@ impl candle::CustomOp2 for RmsNorm {
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&dst);
builder.arg(&alpha);
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src,
&dst,
&alpha,
n_cols as i32,
block_size as i32,
self.eps,
);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
@ -744,7 +740,7 @@ impl candle::CustomOp3 for LayerNorm {
}
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
@ -755,7 +751,7 @@ impl candle::CustomOp3 for LayerNorm {
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
use candle::{CudaDevice, WithDType};
@ -797,18 +793,20 @@ impl candle::CustomOp3 for LayerNorm {
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let func =
dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&dst);
builder.arg(&alpha);
builder.arg(&beta);
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src,
&dst,
&alpha,
&beta,
n_cols as i32,
block_size as i32,
self.eps,
);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}

View File

@ -77,7 +77,7 @@ impl candle::CustomOp3 for RotaryEmbI {
}
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
@ -88,7 +88,7 @@ impl candle::CustomOp3 for RotaryEmbI {
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};
@ -117,17 +117,12 @@ impl candle::CustomOp3 for RotaryEmbI {
let (b, h, t, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
@ -327,7 +322,7 @@ impl candle::CustomOp3 for RotaryEmb {
}
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
@ -338,7 +333,7 @@ impl candle::CustomOp3 for RotaryEmb {
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};
@ -367,17 +362,20 @@ impl candle::CustomOp3 for RotaryEmb {
let (b, h, t, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src,
&cos,
&sin,
&dst,
(b * h) as u32,
(t * d) as u32,
d as u32,
);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
@ -578,7 +576,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
}
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
@ -589,7 +587,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};
@ -618,17 +616,14 @@ impl candle::CustomOp3 for RotaryEmbThd {
let (b, t, h, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el)? };
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32,
);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -18,7 +18,7 @@ t = torch.tensor(
print(group_norm(t, num_groups=2))
print(group_norm(t, num_groups=3))
*/
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

Some files were not shown because too many files have changed in this diff Show More