Compare commits

...

4 Commits

32 changed files with 78 additions and 73 deletions

View File

@ -43,7 +43,7 @@ candle-onnx = { path = "./candle-onnx", version = "0.8.4" }
candle-transformers = { path = "./candle-transformers", version = "0.8.4" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.13.5", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
cudarc = { version = "0.13.5", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.4.1"
@ -51,7 +51,7 @@ half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_di
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
intel-mkl-src = { version = "0.8.1" }
libc = { version = "0.2.147" }
log = "0.4"
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }

View File

@ -43,9 +43,11 @@ criterion = { workspace = true }
[features]
default = []
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
_cuda = ["dep:cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cuda = ["_cuda"]
cudnn = ["_cuda", "cudarc?/cudnn"]
_mkl = ["dep:libc", "dep:intel-mkl-src"]
mkl = ["_mkl", "intel-mkl-src?/mkl-static-lp64-iomp"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"]

View File

@ -20,9 +20,9 @@ impl BenchDevice for Device {
match self {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
return Ok(device.synchronize()?);
#[cfg(not(feature = "cuda"))]
#[cfg(not(feature = "_cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
}
Device::Metal(device) => {
@ -39,7 +39,7 @@ impl BenchDevice for Device {
Device::Cpu => {
let cpu_type = if cfg!(feature = "accelerate") {
"accelerate"
} else if cfg!(feature = "mkl") {
} else if cfg!(feature = "_mkl") {
"mkl"
} else {
"cpu"
@ -61,7 +61,7 @@ impl BenchDeviceHandler {
let mut devices = Vec::new();
if cfg!(feature = "metal") {
devices.push(Device::new_metal(0)?);
} else if cfg!(feature = "cuda") {
} else if cfg!(feature = "_cuda") {
devices.push(Device::new_cuda(0)?);
}
devices.push(Device::Cpu);

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,7 +1,7 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
use anyhow::Result;

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,7 +1,7 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
use anyhow::Result;

View File

@ -1246,7 +1246,7 @@ impl MatMul {
impl Map2 for MatMul {
const OP: &'static str = "mat_mul";
#[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
#[cfg(all(not(feature = "_mkl"), not(feature = "accelerate")))]
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,
lhs: &[T],
@ -1411,7 +1411,7 @@ impl Map2 for MatMul {
Ok(dst)
}
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,
lhs: &[T],

View File

@ -378,7 +378,7 @@ impl Tensor {
pub struct UgIOp1 {
name: &'static str,
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
func: cudarc::driver::CudaFunction,
#[cfg(feature = "metal")]
func: metal::ComputePipelineState,
@ -392,7 +392,7 @@ impl UgIOp1 {
kernel: ug::lang::ssa::Kernel,
device: &crate::Device,
) -> Result<Self> {
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
{
let device = device.as_cuda_device()?;
let func = device.compile(name, kernel)?;
@ -404,7 +404,7 @@ impl UgIOp1 {
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
#[cfg(not(any(feature = "cuda", feature = "metal")))]
#[cfg(not(any(feature = "_cuda", feature = "metal")))]
{
Ok(Self { name })
}
@ -456,7 +456,7 @@ impl InplaceOp1 for UgIOp1 {
Ok(())
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
use crate::cuda_backend::WrapErr;
use cudarc::driver::LaunchAsync;

View File

@ -55,7 +55,7 @@ pub mod conv;
mod convert;
pub mod cpu;
pub mod cpu_backend;
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
pub mod cuda_backend;
mod custom_op;
mod device;
@ -68,7 +68,7 @@ mod indexer;
pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
mod mkl;
pub mod npy;
pub mod op;
@ -104,10 +104,10 @@ pub use strided_index::{StridedBlocks, StridedIndex};
pub use tensor::{Tensor, TensorId};
pub use variable::Var;
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
pub use cuda_backend as cuda;
#[cfg(not(feature = "cuda"))]
#[cfg(not(feature = "_cuda"))]
pub use dummy_cuda_backend as cuda;
pub use cuda::{CudaDevice, CudaStorage};
@ -118,7 +118,7 @@ pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(not(feature = "metal"))]
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -294,16 +294,16 @@ macro_rules! bin_op {
$e(v1, v2)
}
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
crate::mkl::$f32_vec(xs1, xs2, ys)
}
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
crate::mkl::$f64_vec(xs1, xs2, ys)
@ -418,16 +418,16 @@ macro_rules! unary_op {
todo!("no unary function for i64")
}
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::$f32_vec(xs, ys)
}
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::$f64_vec(xs, ys)
@ -518,19 +518,19 @@ impl UnaryOpT for Gelu {
}
const KERNEL: &'static str = "ugelu";
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::vs_gelu(xs, ys)
}
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_gelu(xs, ys)
@ -625,19 +625,19 @@ impl UnaryOpT for Silu {
}
const KERNEL: &'static str = "usilu";
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::vs_silu(xs, ys)
}
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_silu(xs, ys)

View File

@ -16,9 +16,9 @@ pub mod metal;
mod metal {
pub use super::dummy_metal::*;
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
pub mod cuda;
#[cfg(not(feature = "cuda"))]
#[cfg(not(feature = "_cuda"))]
mod cuda {
pub use super::dummy_cuda::*;
}

View File

@ -52,7 +52,7 @@ impl ArgSort {
}
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
mod cuda {
use super::*;
use crate::cuda_backend::cudarc::driver::{
@ -118,7 +118,7 @@ impl crate::CustomOp1 for ArgSort {
Ok((sort_indexes, layout.shape().into()))
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
storage: &crate::CudaStorage,

View File

@ -10,7 +10,7 @@ macro_rules! test_device {
$fn_name(&Device::Cpu)
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
#[test]
fn $test_cuda() -> Result<()> {
$fn_name(&Device::new_cuda(0)?)

View File

@ -17,11 +17,11 @@ pub fn has_accelerate() -> bool {
}
pub fn has_mkl() -> bool {
cfg!(feature = "mkl")
cfg!(feature = "_mkl")
}
pub fn cuda_is_available() -> bool {
cfg!(feature = "cuda")
cfg!(feature = "_cuda")
}
pub fn metal_is_available() -> bool {

View File

@ -144,7 +144,7 @@ fn inplace_op1() -> Result<()> {
Ok(())
}
#[cfg(any(feature = "cuda", feature = "metal"))]
#[cfg(any(feature = "_cuda", feature = "metal"))]
#[allow(clippy::approx_constant)]
#[test]
fn ug_op() -> Result<()> {

View File

@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.4" }
candle = { path = "../candle-core", features = ["_cuda"], package = "candle-core", version = "0.8.4" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
@ -21,4 +21,4 @@ anyhow = { version = "1", features = ["backtrace"] }
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", features = ["cuda"] }
candle-nn = { path = "../candle-nn", features = ["_cuda"] }

View File

@ -32,8 +32,10 @@ criterion = { workspace = true }
[features]
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
_cuda = ["candle/_cuda"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
_mkl = ["dep:intel-mkl-src", "candle/_mkl"]
mkl = ["candle/mkl"]
metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]
[[bench]]

View File

@ -15,9 +15,9 @@ impl BenchDevice for Device {
match self {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
return Ok(device.synchronize()?);
#[cfg(not(feature = "cuda"))]
#[cfg(not(feature = "_cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
}
Device::Metal(device) => {
@ -34,7 +34,7 @@ impl BenchDevice for Device {
Device::Cpu => {
let cpu_type = if cfg!(feature = "accelerate") {
"accelerate"
} else if cfg!(feature = "mkl") {
} else if cfg!(feature = "_mkl") {
"mkl"
} else {
"cpu"
@ -56,7 +56,7 @@ impl BenchDeviceHandler {
let mut devices = Vec::new();
if cfg!(feature = "metal") {
devices.push(Device::new_metal(0)?);
} else if cfg!(feature = "cuda") {
} else if cfg!(feature = "_cuda") {
devices.push(Device::new_cuda(0)?);
}
devices.push(Device::Cpu);

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,5 +1,5 @@
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -82,7 +82,7 @@ impl candle::CustomOp1 for Sigmoid {
Ok((storage, layout.shape().clone()))
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
storage: &candle::CudaStorage,
@ -333,7 +333,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
}
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
storage: &candle::CudaStorage,
@ -507,7 +507,7 @@ impl candle::CustomOp2 for RmsNorm {
}
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
@ -740,7 +740,7 @@ impl candle::CustomOp3 for LayerNorm {
}
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,

View File

@ -77,7 +77,7 @@ impl candle::CustomOp3 for RotaryEmbI {
}
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
@ -322,7 +322,7 @@ impl candle::CustomOp3 for RotaryEmb {
}
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
@ -576,7 +576,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
}
}
#[cfg(feature = "cuda")]
#[cfg(feature = "_cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -18,7 +18,7 @@ t = torch.tensor(
print(group_norm(t, num_groups=2))
print(group_norm(t, num_groups=3))
*/
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -1,4 +1,4 @@
#[cfg(feature = "mkl")]
#[cfg(feature = "_mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]

View File

@ -28,6 +28,7 @@ tracing = { workspace = true }
[features]
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
_cuda = ["candle/_cuda", "candle-nn/_cuda"]
cuda = ["candle/cuda", "candle-nn/cuda"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]