From d9904a3baf78d68ff2d773027a9245a4fec37bf9 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 3 Apr 2025 09:12:19 +0200 Subject: [PATCH] Update to cudarc 0.14 (breaking change). (#2858) * Start updating to cudarc 0.14. * Adapt a couple more things. * And a couple more fixes. * More tweaks. * And a couple more fixes. * Bump the major version number. * Proper module system for the cuda kernels. * Proper ptx loading. * Launch the sort kernel. * Custom op. * Start using the builder pattern. * More builder. * More builder. * Get candle-core to compile. * Get the tests to pass. * Get candle-nn to work too. * Support for custom cuda functions. * cudnn fixes. * Get flash attn to run. * Switch the crate versions to be alpha. * Bump the ug dependency. --- Cargo.toml | 26 +- candle-core/src/cuda_backend/cudnn.rs | 4 +- candle-core/src/cuda_backend/device.rs | 276 ++++--- candle-core/src/cuda_backend/mod.rs | 757 +++++++++++--------- candle-core/src/custom_op.rs | 13 +- candle-core/src/quantized/cuda.rs | 121 ++-- candle-core/src/sort.rs | 16 +- candle-examples/examples/custom-ops/main.rs | 12 +- candle-flash-attn/Cargo.toml | 4 +- candle-flash-attn/src/lib.rs | 60 +- candle-kernels/Cargo.toml | 2 +- candle-kernels/build.rs | 2 +- candle-kernels/src/lib.rs | 89 ++- candle-kernels/src/ptx.rs | 11 + candle-metal-kernels/Cargo.toml | 2 +- candle-nn/src/ops.rs | 64 +- candle-nn/src/rotary_emb.rs | 49 +- candle-onnx/Cargo.toml | 6 +- 18 files changed, 924 insertions(+), 590 deletions(-) create mode 100644 candle-kernels/src/ptx.rs diff --git a/Cargo.toml b/Cargo.toml index cd597eb4..aaefb02d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ exclude = [ resolver = "2" [workspace.package] -version = "0.8.4" +version = "0.9.0-alpha.1" edition = "2021" description = "Minimalist ML framework." repository = "https://github.com/huggingface/candle" @@ -33,17 +33,17 @@ ab_glyph = "0.2.23" accelerate-src = { version = "0.3.2" } anyhow = { version = "1", features = ["backtrace"] } byteorder = "1.4.3" -candle = { path = "./candle-core", package = "candle-core", version = "0.8.4" } -candle-datasets = { path = "./candle-datasets", version = "0.8.4" } -candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.4" } -candle-kernels = { path = "./candle-kernels", version = "0.8.4" } -candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.4" } -candle-nn = { path = "./candle-nn", version = "0.8.4" } -candle-onnx = { path = "./candle-onnx", version = "0.8.4" } -candle-transformers = { path = "./candle-transformers", version = "0.8.4" } +candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.1" } +candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.1" } +candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.1" } +candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.1" } +candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.1" } +candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.1" } +candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.1" } +candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.1" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features=false } -cudarc = { version = "0.13.5", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } +cudarc = { version = "0.14.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } fancy-regex = "0.13.0" gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } hf-hub = "0.4.1" @@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false } tracing = "0.1.37" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" -ug = "0.1.0" -ug-cuda = "0.1.0" -ug-metal = "0.1.0" +ug = "0.2.0" +ug-cuda = "0.2.0" +ug-metal = "0.2.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "1.1.1", default-features = false } metal = { version = "0.27.0", features = ["mps"]} diff --git a/candle-core/src/cuda_backend/cudnn.rs b/candle-core/src/cuda_backend/cudnn.rs index f5b4db90..318d6b56 100644 --- a/candle-core/src/cuda_backend/cudnn.rs +++ b/candle-core/src/cuda_backend/cudnn.rs @@ -43,7 +43,7 @@ pub(crate) fn launch_conv2d< if let Some(cudnn) = cudnn.borrow().get(&device_id) { return Ok(cudnn.clone()); } - let c = Cudnn::new(dev.cuda_device()); + let c = Cudnn::new(dev.cuda_stream()); if let Ok(c) = &c { cudnn.borrow_mut().insert(device_id, c.clone()); } @@ -109,7 +109,7 @@ pub(crate) fn launch_conv2d< Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT, }; let workspace_size = conv2d.get_workspace_size(alg)?; - let mut workspace = dev.cuda_device().alloc_zeros::(workspace_size)?; + let mut workspace = dev.cuda_stream().alloc_zeros::(workspace_size)?; unsafe { conv2d.launch::, _, _, _>( alg, diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index b9ab4349..8967eb98 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -2,8 +2,9 @@ use crate::backend::BackendDevice; use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; pub use candle_kernels as kernels; pub use cudarc; -use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig}; +use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg}; use half::{bf16, f16}; +use std::collections::HashMap; use std::sync::{Arc, Mutex}; use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr}; @@ -24,10 +25,17 @@ impl DeviceId { struct CudaRng(cudarc::curand::CudaRng); unsafe impl Send for CudaRng {} +pub struct ModuleStore { + mdls: [Option>; kernels::ALL_IDS.len()], +} + #[derive(Clone)] pub struct CudaDevice { id: DeviceId, - device: Arc, + context: Arc, + modules: Arc>, + custom_modules: Arc>>>, + stream: Arc, pub(crate) blas: Arc, curand: Arc>, } @@ -39,16 +47,51 @@ impl std::fmt::Debug for CudaDevice { } impl std::ops::Deref for CudaDevice { - type Target = Arc; + type Target = Arc; fn deref(&self) -> &Self::Target { - &self.device + &self.stream + } +} + +pub struct CudaFunc { + func: CudaFunction, + stream: Arc, +} + +impl std::ops::Deref for CudaFunc { + type Target = CudaFunction; + + fn deref(&self) -> &Self::Target { + &self.func + } +} + +impl CudaFunc { + pub fn into_cuda_function(self) -> CudaFunction { + self.func + } +} + +#[macro_export] +macro_rules! builder_arg { + ($b:ident, $($arg:expr),*) => { + $( + let __arg = $arg; + $b.arg(&__arg); + )* + }; +} + +impl CudaFunc { + pub fn builder(&self) -> cudarc::driver::LaunchArgs<'_> { + self.stream.launch_builder(&self.func) } } impl CudaDevice { - pub fn cuda_device(&self) -> Arc { - self.device.clone() + pub fn cuda_stream(&self) -> Arc { + self.stream.clone() } #[cfg(not(target_arch = "wasm32"))] @@ -56,7 +99,7 @@ impl CudaDevice { &self, func_name: &'static str, kernel: ug::lang::ssa::Kernel, - ) -> Result { + ) -> Result { let mut buf = vec![]; ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?; let cuda_code = String::from_utf8(buf)?; @@ -65,12 +108,12 @@ impl CudaDevice { ..Default::default() }; let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?; - self.device.load_ptx(ptx, "ug", &[func_name]).w()?; - let func = match self.device.get_func("ug", func_name) { - Some(func) => func, - None => crate::bail!("unknown function ug::{func_name}"), - }; - Ok(func) + let module = self.context.load_module(ptx).w()?; + let func = module.load_function(func_name).w()?; + Ok(CudaFunc { + func, + stream: self.stream.clone(), + }) } pub fn id(&self) -> DeviceId { @@ -84,57 +127,84 @@ impl CudaDevice { DType::U8 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_u8", kernels::FILL)?; - let params = (&data, v as u8, elem_count); - unsafe { func.launch(cfg, params) }.w()?; + let func = self.get_or_load_func("fill_u8", &kernels::FILL)?; + let mut builder = self.stream.launch_builder(&func); + let v = v as u8; + builder.arg(&data); + builder.arg(&v); + builder.arg(&elem_count); + unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::U8(data) } DType::U32 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_u32", kernels::FILL)?; - let params = (&data, v as u32, elem_count); - unsafe { func.launch(cfg, params) }.w()?; + let func = self.get_or_load_func("fill_u32", &kernels::FILL)?; + let mut builder = self.stream.launch_builder(&func); + let v = v as u32; + builder.arg(&data); + builder.arg(&v); + builder.arg(&elem_count); + unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::U32(data) } DType::I64 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_i64", kernels::FILL)?; - let params = (&data, v as i64, elem_count); - unsafe { func.launch(cfg, params) }.w()?; + let func = self.get_or_load_func("fill_i64", &kernels::FILL)?; + let mut builder = self.stream.launch_builder(&func); + let v = v as i64; + builder.arg(&data); + builder.arg(&v); + builder.arg(&elem_count); + unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::I64(data) } DType::BF16 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_bf16", kernels::FILL)?; - let params = (&data, bf16::from_f64(v), elem_count); - unsafe { func.launch(cfg, params) }.w()?; + let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?; + let mut builder = self.stream.launch_builder(&func); + let v = bf16::from_f64(v); + builder.arg(&data); + builder.arg(&v); + builder.arg(&elem_count); + unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::BF16(data) } DType::F16 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_f16", kernels::FILL)?; - let params = (&data, f16::from_f64(v), elem_count); - unsafe { func.launch(cfg, params) }.w()?; + let func = self.get_or_load_func("fill_f16", &kernels::FILL)?; + let mut builder = self.stream.launch_builder(&func); + let v = f16::from_f64(v); + builder.arg(&data); + builder.arg(&v); + builder.arg(&elem_count); + unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::F16(data) } DType::F32 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_f32", kernels::FILL)?; - let params = (&data, v as f32, elem_count); - unsafe { func.launch(cfg, params) }.w()?; + let func = self.get_or_load_func("fill_f32", &kernels::FILL)?; + let mut builder = self.stream.launch_builder(&func); + let v = v as f32; + builder.arg(&data); + builder.arg(&v); + builder.arg(&elem_count); + unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::F32(data) } DType::F64 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_f64", kernels::FILL)?; - let params = (&data, v, elem_count); - unsafe { func.launch(cfg, params) }.w()?; + let func = self.get_or_load_func("fill_f64", &kernels::FILL)?; + let mut builder = self.stream.launch_builder(&func); + builder.arg(&data); + builder.arg(&v); + builder.arg(&elem_count); + unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::F64(data) } }; @@ -144,38 +214,69 @@ impl CudaDevice { }) } - pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result { - if !self.has_func(module_name, module_name) { - // Leaking the string here is a bit sad but we need a &'static str and this is only - // done once per kernel name. - let static_module_name = Box::leak(module_name.to_string().into_boxed_str()); - self.load_ptx(ptx.into(), module_name, &[static_module_name]) - .map_err(|cuda| CudaError::Load { - cuda, - module_name: module_name.to_string(), - }) - .w()?; + pub fn get_or_load_custom_func( + &self, + fn_name: &str, + module_name: &str, + ptx: &str, + ) -> Result { + let ms = self.custom_modules.read().unwrap(); + if let Some(mdl) = ms.get(module_name).as_ref() { + let func = mdl.load_function(fn_name).w()?; + return Ok(CudaFunc { + func, + stream: self.stream.clone(), + }); } - self.get_func(module_name, module_name) - // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is - // able to only build the error value if needed. - .ok_or(CudaError::MissingKernel { - module_name: module_name.to_string(), - }) - .w() + drop(ms); + let mut ms = self.custom_modules.write().unwrap(); + let cuda_module = self.context.load_module(ptx.into()).w()?; + ms.insert(module_name.to_string(), cuda_module.clone()); + let func = cuda_module.load_function(fn_name).w()?; + Ok(CudaFunc { + func, + stream: self.stream.clone(), + }) + } + + pub fn get_or_load_func(&self, fn_name: &str, mdl: &kernels::Module) -> Result { + let ms = self.modules.read().unwrap(); + if let Some(mdl) = ms.mdls[mdl.index()].as_ref() { + let func = mdl.load_function(fn_name).w()?; + return Ok(CudaFunc { + func, + stream: self.stream.clone(), + }); + } + drop(ms); + let mut ms = self.modules.write().unwrap(); + let cuda_module = self.context.load_module(mdl.ptx().into()).w()?; + ms.mdls[mdl.index()] = Some(cuda_module.clone()); + let func = cuda_module.load_function(fn_name).w()?; + Ok(CudaFunc { + func, + stream: self.stream.clone(), + }) } } impl CudaDevice { pub fn new_with_stream(ordinal: usize) -> Result { - let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?; - let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?; - let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?; + let context = cudarc::driver::CudaContext::new(ordinal).w()?; + let stream = context.new_stream().w()?; + let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?; + let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?; + let module_store = ModuleStore { + mdls: [const { None }; kernels::ALL_IDS.len()], + }; Ok(Self { id: DeviceId::new(), - device, + context, + stream, blas: Arc::new(blas), curand: Arc::new(Mutex::new(CudaRng(curand))), + modules: Arc::new(std::sync::RwLock::new(module_store)), + custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())), }) } } @@ -184,14 +285,21 @@ impl BackendDevice for CudaDevice { type Storage = CudaStorage; fn new(ordinal: usize) -> Result { - let device = cudarc::driver::CudaDevice::new(ordinal).w()?; - let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?; - let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?; + let context = cudarc::driver::CudaContext::new(ordinal).w()?; + let stream = context.default_stream(); + let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?; + let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?; + let module_store = ModuleStore { + mdls: [const { None }; kernels::ALL_IDS.len()], + }; Ok(Self { id: DeviceId::new(), - device, + context, + stream, blas: Arc::new(blas), curand: Arc::new(Mutex::new(CudaRng(curand))), + modules: Arc::new(std::sync::RwLock::new(module_store)), + custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())), }) } @@ -199,13 +307,13 @@ impl BackendDevice for CudaDevice { // We do not call set_seed but instead create a new curand object. This ensures that the // state will be identical and the same random numbers will be generated. let mut curand = self.curand.lock().unwrap(); - curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?; + curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?; Ok(()) } fn location(&self) -> crate::DeviceLocation { crate::DeviceLocation::Cuda { - gpu_id: self.device.ordinal(), + gpu_id: self.context.ordinal(), } } @@ -373,31 +481,31 @@ impl BackendDevice for CudaDevice { fn storage_from_slice(&self, s: &[T]) -> Result { let slice = match T::cpu_storage_ref(s) { CpuStorageRef::U8(storage) => { - let data = self.htod_sync_copy(storage).w()?; + let data = self.memcpy_stod(storage).w()?; CudaStorageSlice::U8(data) } CpuStorageRef::U32(storage) => { - let data = self.htod_sync_copy(storage).w()?; + let data = self.memcpy_stod(storage).w()?; CudaStorageSlice::U32(data) } CpuStorageRef::I64(storage) => { - let data = self.htod_sync_copy(storage).w()?; + let data = self.memcpy_stod(storage).w()?; CudaStorageSlice::I64(data) } CpuStorageRef::BF16(storage) => { - let data = self.htod_sync_copy(storage).w()?; + let data = self.memcpy_stod(storage).w()?; CudaStorageSlice::BF16(data) } CpuStorageRef::F16(storage) => { - let data = self.htod_sync_copy(storage).w()?; + let data = self.memcpy_stod(storage).w()?; CudaStorageSlice::F16(data) } CpuStorageRef::F32(storage) => { - let data = self.htod_sync_copy(storage).w()?; + let data = self.memcpy_stod(storage).w()?; CudaStorageSlice::F32(data) } CpuStorageRef::F64(storage) => { - let data = self.htod_sync_copy(storage).w()?; + let data = self.memcpy_stod(storage).w()?; CudaStorageSlice::F64(data) } }; @@ -410,31 +518,31 @@ impl BackendDevice for CudaDevice { fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { let slice = match storage { CpuStorage::U8(storage) => { - let data = self.htod_sync_copy(storage).w()?; + let data = self.memcpy_stod(storage).w()?; CudaStorageSlice::U8(data) } CpuStorage::U32(storage) => { - let data = self.htod_sync_copy(storage).w()?; + let data = self.memcpy_stod(storage).w()?; CudaStorageSlice::U32(data) } CpuStorage::I64(storage) => { - let data = self.htod_sync_copy(storage).w()?; + let data = self.memcpy_stod(storage).w()?; CudaStorageSlice::I64(data) } CpuStorage::BF16(storage) => { - let data = self.htod_sync_copy(storage).w()?; + let data = self.memcpy_stod(storage).w()?; CudaStorageSlice::BF16(data) } CpuStorage::F16(storage) => { - let data = self.htod_sync_copy(storage).w()?; + let data = self.memcpy_stod(storage).w()?; CudaStorageSlice::F16(data) } CpuStorage::F32(storage) => { - let data = self.htod_sync_copy(storage).w()?; + let data = self.memcpy_stod(storage).w()?; CudaStorageSlice::F32(data) } CpuStorage::F64(storage) => { - let data = self.htod_sync_copy(storage).w()?; + let data = self.memcpy_stod(storage).w()?; CudaStorageSlice::F64(data) } }; @@ -447,31 +555,31 @@ impl BackendDevice for CudaDevice { fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result { let slice = match storage { CpuStorage::U8(storage) => { - let data = self.htod_copy(storage).w()?; + let data = self.memcpy_stod(&storage).w()?; CudaStorageSlice::U8(data) } CpuStorage::U32(storage) => { - let data = self.htod_copy(storage).w()?; + let data = self.memcpy_stod(&storage).w()?; CudaStorageSlice::U32(data) } CpuStorage::I64(storage) => { - let data = self.htod_copy(storage).w()?; + let data = self.memcpy_stod(&storage).w()?; CudaStorageSlice::I64(data) } CpuStorage::BF16(storage) => { - let data = self.htod_copy(storage).w()?; + let data = self.memcpy_stod(&storage).w()?; CudaStorageSlice::BF16(data) } CpuStorage::F16(storage) => { - let data = self.htod_copy(storage).w()?; + let data = self.memcpy_stod(&storage).w()?; CudaStorageSlice::F16(data) } CpuStorage::F32(storage) => { - let data = self.htod_copy(storage).w()?; + let data = self.memcpy_stod(&storage).w()?; CudaStorageSlice::F32(data) } CpuStorage::F64(storage) => { - let data = self.htod_copy(storage).w()?; + let data = self.memcpy_stod(&storage).w()?; CudaStorageSlice::F64(data) } }; @@ -482,7 +590,7 @@ impl BackendDevice for CudaDevice { } fn synchronize(&self) -> Result<()> { - self.device.synchronize().map_err(crate::Error::wrap)?; + self.stream.synchronize().map_err(crate::Error::wrap)?; Ok(()) } } diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index c71b9694..a509e97a 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -2,12 +2,12 @@ //! use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; -use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType}; +use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, Shape, WithDType}; pub use candle_kernels as kernels; pub use cudarc; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ - CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, + CudaSlice, DevicePtr, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits, }; use half::{bf16, f16}; @@ -25,12 +25,12 @@ pub enum SlicePtrOrNull { Null, } -unsafe impl DeviceRepr for &SlicePtrOrNull { - fn as_kernel_param(&self) -> *mut std::ffi::c_void { +impl SlicePtrOrNull { + pub fn builder_arg<'a, 'b: 'a>(&'b self, builder: &mut cudarc::driver::LaunchArgs<'a>) { match self { - SlicePtrOrNull::Ptr(slice) => slice.as_kernel_param(), - SlicePtrOrNull::Null => 0usize.as_kernel_param(), - } + SlicePtrOrNull::Ptr(slice) => builder.arg(slice), + SlicePtrOrNull::Null => builder.arg(&0usize), + }; } } @@ -39,7 +39,7 @@ impl SlicePtrOrNull { let ds = if l.is_contiguous() { SlicePtrOrNull::Null } else { - SlicePtrOrNull::Ptr(dev.htod_copy([l.dims(), l.stride()].concat()).w()?) + SlicePtrOrNull::Ptr(dev.memcpy_stod(&[l.dims(), l.stride()].concat()).w()?) }; Ok(ds) } @@ -87,20 +87,19 @@ impl Map1 for Affine { let cfg = LaunchConfig::for_num_elems(el as u32); let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); - let func = dev.get_or_load_func(&kernel_name::("affine"), kernels::AFFINE)?; + let func = dev.get_or_load_func(&kernel_name::("affine"), &kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }.w()?; - let params = ( - el, - dims.len(), - &ds, - src, - &out, - T::from_f64(self.0), - T::from_f64(self.1), - ); + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + builder.arg(src); + builder.arg(&out); + barg!(builder, T::from_f64(self.0)); + barg!(builder, T::from_f64(self.1)); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg).w() }?; Ok(out) } } @@ -119,12 +118,18 @@ impl Map1 for Elu { let cfg = LaunchConfig::for_num_elems(el as u32); let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); - let func = dev.get_or_load_func(&kernel_name::("uelu"), kernels::UNARY)?; + let func = dev.get_or_load_func(&kernel_name::("uelu"), &kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }.w()?; - let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out); + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + barg!(builder, T::from_f64(self.0)); + builder.arg(src); + builder.arg(&out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(out) } } @@ -154,24 +159,23 @@ impl Map1 for Im2Col1D { let l_out = self.l_out(dims[2]); let dst_el = dims[0] * l_out * dims[1] * self.l_k; let cfg = LaunchConfig::for_num_elems(dst_el as u32); - let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let ds = dev.memcpy_stod(&[dims, layout.stride()].concat()).w()?; let src = &src.slice(layout.start_offset()..); - let func = dev.get_or_load_func(&kernel_name::("im2col1d"), kernels::CONV)?; + let func = dev.get_or_load_func(&kernel_name::("im2col1d"), &kernels::CONV)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(dst_el) }.w()?; - let params = ( - dst_el, - l_out, - self.l_k, - self.stride, - self.padding, - self.dilation, - &ds, - src, - &dst, - ); + let mut builder = func.builder(); + barg!(builder, dst_el); + barg!(builder, l_out); + barg!(builder, self.l_k); + barg!(builder, self.stride); + barg!(builder, self.padding); + barg!(builder, self.dilation); + builder.arg(&ds); + builder.arg(src); + builder.arg(&dst); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(dst) } } @@ -206,26 +210,25 @@ impl Map1 for Im2Col { let (h_out, w_out) = self.hw_out(dims[2], dims[3]); let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k; let cfg = LaunchConfig::for_num_elems(dst_el as u32); - let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let ds = dev.memcpy_stod(&[dims, layout.stride()].concat()).w()?; let src = &src.slice(layout.start_offset()..); - let func = dev.get_or_load_func(&kernel_name::("im2col"), kernels::CONV)?; + let func = dev.get_or_load_func(&kernel_name::("im2col"), &kernels::CONV)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(dst_el) }.w()?; - let params = ( - dst_el, - h_out, - w_out, - self.h_k, - self.w_k, - self.stride, - self.padding, - self.dilation, - &ds, - src, - &dst, - ); + let mut builder = func.builder(); + barg!(builder, dst_el); + barg!(builder, h_out); + barg!(builder, w_out); + barg!(builder, self.h_k); + barg!(builder, self.w_k); + barg!(builder, self.stride); + barg!(builder, self.padding); + barg!(builder, self.dilation); + builder.arg(&ds); + builder.arg(src); + builder.arg(&dst); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(dst) } } @@ -244,12 +247,18 @@ impl Map1 for Powf { let cfg = LaunchConfig::for_num_elems(el as u32); let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); - let func = dev.get_or_load_func(&kernel_name::("upowf"), kernels::UNARY)?; + let func = dev.get_or_load_func(&kernel_name::("upowf"), &kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }.w()?; - let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out); + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + barg!(builder, T::from_f64(self.0)); + builder.arg(src); + builder.arg(&out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(out) } } @@ -294,7 +303,7 @@ impl Map1Any for FastReduce<'_> { shared_mem_bytes: 0, }; let ds = dev - .htod_copy([dims.as_slice(), stride.as_slice()].concat()) + .memcpy_stod(&[dims.as_slice(), stride.as_slice()].concat()) .w()?; let src = &src.slice(layout.start_offset()..); let (name, check_empty, return_index) = match self.1 { @@ -307,20 +316,32 @@ impl Map1Any for FastReduce<'_> { if check_empty && layout.shape().elem_count() == 0 { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? } - let func = dev.get_or_load_func(&kernel_name::(name), kernels::REDUCE)?; + let func = dev.get_or_load_func(&kernel_name::(name), &kernels::REDUCE)?; if return_index { // SAFETY: filled in by the follow up kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; - let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); + let mut builder = func.builder(); + barg!(builder, src_el); + barg!(builder, el_to_sum_per_block); + barg!(builder, src_dims.len()); + builder.arg(&ds); + builder.arg(src); + builder.arg(&out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(S::U32(out)) } else { // SAFETY: filled in by the follow up kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; - let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); + let mut builder = func.builder(); + barg!(builder, src_el); + barg!(builder, el_to_sum_per_block); + barg!(builder, src_dims.len()); + builder.arg(&ds); + builder.arg(src); + builder.arg(&out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(wrap(out)) } } @@ -339,16 +360,27 @@ impl Map1 for U { let cfg = LaunchConfig::for_num_elems(el_count as u32); let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); - let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::UNARY)?; + let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), &kernels::UNARY)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }.w()?; - let params = (el_count, dims.len(), &ds, src, &out); + let mut out = unsafe { dev.alloc::(el_count) }.w()?; + let mut builder = func.builder(); + barg!(builder, el_count); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + builder.arg(src); + builder.arg(&mut out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(out) } } +fn slice_ptr(v: &CudaSlice, lo: usize) -> (u64, cudarc::driver::SyncOnDrop<'_>) { + let (_, guard) = v.device_ptr(v.stream()); + let (ptr, _) = v.slice(lo..).device_ptr(v.stream()); + (ptr, guard) +} + struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize); impl Map1 for IndexSelect<'_> { fn f( @@ -358,16 +390,10 @@ impl Map1 for IndexSelect<'_> { src_l: &Layout, ) -> Result> { let ids_l = &self.1; - let (name, ids) = match &self.0.slice { - CudaStorageSlice::U32(slice) => { - ("is_u32", *slice.slice(ids_l.start_offset()..).device_ptr()) - } - CudaStorageSlice::U8(slice) => { - ("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr()) - } - CudaStorageSlice::I64(slice) => { - ("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr()) - } + let (name, (ids, _guard)) = match &self.0.slice { + CudaStorageSlice::U32(slice) => ("is_u32", slice_ptr(slice, ids_l.start_offset())), + CudaStorageSlice::U8(slice) => ("is_u8", slice_ptr(slice, ids_l.start_offset())), + CudaStorageSlice::I64(slice) => ("is_i64", slice_ptr(slice, ids_l.start_offset())), _ => Err(CudaError::UnexpectedDType { msg: "index_select ids should be u8 or u32", expected: DType::U32, @@ -377,7 +403,7 @@ impl Map1 for IndexSelect<'_> { }; let ids_shape = ids_l.shape(); let ids_dims = ids_shape.dims(); - let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?; + let ds = dev.memcpy_stod(&[ids_dims, ids_l.stride()].concat()).w()?; let src = match src_l.contiguous_offsets() { Some((o1, o2)) => src.slice(o1..o2), None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?, @@ -388,23 +414,22 @@ impl Map1 for IndexSelect<'_> { let ids_dim_size = ids_shape.elem_count(); let dst_el = ids_shape.elem_count() * left_size * right_size; let cfg = LaunchConfig::for_num_elems(dst_el as u32); - let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; + let func = dev.get_or_load_func(&kernel_name::(name), &kernels::INDEXING)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; - let params = ( - dst_el, - ids_dims.len(), - &ds, - ids, - &src, - &out, - left_size, - src_dim_size, - ids_dim_size, - right_size, - ); + let mut builder = func.builder(); + barg!(builder, dst_el); + barg!(builder, ids_dims.len()); + builder.arg(&ds); + barg!(builder, ids); + builder.arg(&src); + builder.arg(&out); + barg!(builder, left_size); + barg!(builder, src_dim_size); + barg!(builder, ids_dim_size); + barg!(builder, right_size); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(out) } } @@ -420,18 +445,14 @@ impl Map1 for Gather<'_> { let ids = &self.0; let ids_l = &self.1; let dim = self.2; - let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { + let (ids_o1, _) = match ids_l.contiguous_offsets() { Some(o12) => o12, None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?, }; - let (name, ids) = match &ids.slice { - CudaStorageSlice::U32(slice) => { - ("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr()) - } - CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), - CudaStorageSlice::I64(slice) => { - ("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr()) - } + let (name, (ids, _guard)) = match &ids.slice { + CudaStorageSlice::U32(slice) => ("gather_u32", slice_ptr(slice, ids_o1)), + CudaStorageSlice::U8(slice) => ("gather_u8", slice_ptr(slice, ids_o1)), + CudaStorageSlice::I64(slice) => ("gather_i64", slice_ptr(slice, ids_o1)), _ => Err(CudaError::UnexpectedDType { msg: "gather ids should be u8/u32/i64", expected: DType::U32, @@ -448,14 +469,20 @@ impl Map1 for Gather<'_> { let right_sz: usize = src_l.dims()[dim + 1..].iter().product(); let src_dim_sz = src_l.dims()[dim]; let ids_dim_sz = ids_l.dims()[dim]; - let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; + let func = dev.get_or_load_func(&kernel_name::(name), &kernels::INDEXING)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }.w()?; - let params = ( - el, ids, &src, &out, left_sz, src_dim_sz, ids_dim_sz, right_sz, - ); + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, ids); + builder.arg(&src); + builder.arg(&out); + barg!(builder, left_sz); + barg!(builder, src_dim_sz); + barg!(builder, ids_dim_sz); + barg!(builder, right_sz); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(out) } } @@ -473,14 +500,14 @@ impl Map2InPlace for IndexAdd<'_> { let ids = &self.0; let ids_l = &self.1; let dim = self.2; - let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { + let (ids_o1, _) = match ids_l.contiguous_offsets() { Some(o12) => o12, None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, }; - let (name, ids) = match &ids.slice { - CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), - CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), - CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), + let (name, (ids, _guard)) = match &ids.slice { + CudaStorageSlice::U32(slice) => ("ia_u32", slice_ptr(slice, ids_o1)), + CudaStorageSlice::I64(slice) => ("ia_i64", slice_ptr(slice, ids_o1)), + CudaStorageSlice::U8(slice) => ("ia_u8", slice_ptr(slice, ids_o1)), _ => Err(CudaError::UnexpectedDType { msg: "index-add ids should be u8/u32/i64", expected: DType::U32, @@ -497,13 +524,15 @@ impl Map2InPlace for IndexAdd<'_> { let dst_dim_sz = dst_shape.dims()[dim]; let ids_dim_sz = ids_l.dims()[0]; let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32); - let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; - // SAFETY: Set later by running the kernel. - let params = ( - ids, ids_dim_sz, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz, - ); + let func = dev.get_or_load_func(&kernel_name::(name), &kernels::INDEXING)?; + let mut builder = func.builder(); + barg!(builder, ids); + barg!(builder, ids_dim_sz); + builder.arg(&src); + builder.arg(dst); + barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(()) } } @@ -521,14 +550,14 @@ impl Map2InPlace for ScatterAdd<'_> { let ids = &self.0; let ids_l = &self.1; let dim = self.2; - let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { + let (ids_o1, _) = match ids_l.contiguous_offsets() { Some(o12) => o12, None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, }; - let (name, ids) = match &ids.slice { - CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), - CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), - CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), + let (name, (ids, _guard)) = match &ids.slice { + CudaStorageSlice::U32(slice) => ("sa_u32", slice_ptr(slice, ids_o1)), + CudaStorageSlice::I64(slice) => ("sa_i64", slice_ptr(slice, ids_o1)), + CudaStorageSlice::U8(slice) => ("sa_u8", slice_ptr(slice, ids_o1)), _ => Err(CudaError::UnexpectedDType { msg: "scatter-add ids should be u8/u32/i64", expected: DType::U32, @@ -544,11 +573,14 @@ impl Map2InPlace for ScatterAdd<'_> { let src_dim_sz = src_l.dims()[dim]; let dst_dim_sz = dst_shape.dims()[dim]; let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32); - let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; - // SAFETY: Set later by running the kernel. - let params = (ids, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz); + let func = dev.get_or_load_func(&kernel_name::(name), &kernels::INDEXING)?; + let mut builder = func.builder(); + barg!(builder, ids); + builder.arg(&src); + builder.arg(dst); + barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(()) } } @@ -574,7 +606,7 @@ impl Map2 for Conv1D<'_> { let l_out = p.l_out(); let dst_el = p.c_out * l_out * p.b_size; let cfg = LaunchConfig::for_num_elems(dst_el as u32); - let func = dev.get_or_load_func(&kernel_name::("conv1d"), kernels::CONV)?; + let func = dev.get_or_load_func(&kernel_name::("conv1d"), &kernels::CONV)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; let ds = if dims.len() == 3 { @@ -584,12 +616,15 @@ impl Map2 for Conv1D<'_> { } else { crate::bail!("unexpected input shape for conv1d {dims:?}") }; - let ds = dev.htod_copy(ds).w()?; - let params = ( - el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out, - ); + let ds = dev.memcpy_stod(&ds).w()?; + let mut builder = func.builder(); + barg!(builder, el, l_out, p.stride, p.padding, p.dilation); + builder.arg(&ds); + builder.arg(inp); + builder.arg(k); + builder.arg(&out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(out) } } @@ -618,18 +653,21 @@ impl Map2 for Conv2D<'_> { // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; let cfg = LaunchConfig::for_num_elems(dst_el as u32); - let func = dev.get_or_load_func(&kernel_name::("conv2d"), kernels::CONV)?; + let func = dev.get_or_load_func(&kernel_name::("conv2d"), &kernels::CONV)?; let ds = if dims.len() == 4 { [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() } else { crate::bail!("unexpected input shape for conv2d {dims:?}") }; - let ds = dev.htod_copy(ds).w()?; - let params = ( - el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out, - ); + let ds = dev.memcpy_stod(&ds).w()?; + let mut builder = func.builder(); + barg!(builder, el, out_w, out_h, p.stride, p.padding, p.dilation); + builder.arg(&ds); + builder.arg(inp); + builder.arg(k); + builder.arg(&out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(out) } } @@ -652,9 +690,12 @@ impl Map1 for Col2Im1D { let mut im = unsafe { dev.alloc::(dst_el) }.w()?; let cfg = LaunchConfig::for_num_elems(dst_el as u32); - let params = (dst_el, l_out, l_in, c_out, k_size, stride, col, &mut im); - let func = dev.get_or_load_func(&kernel_name::("col2im1d"), kernels::CONV)?; - unsafe { func.launch(cfg, params) }.w()?; + let func = dev.get_or_load_func(&kernel_name::("col2im1d"), &kernels::CONV)?; + let mut builder = func.builder(); + barg!(builder, dst_el, l_out, l_in, c_out, k_size, stride); + builder.arg(col); + builder.arg(&mut im); + unsafe { builder.launch(cfg) }.w()?; Ok(im) } } @@ -683,27 +724,26 @@ impl Map2 for ConvTranspose1D<'_> { // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; let cfg = LaunchConfig::for_num_elems(dst_el as u32); - let func = dev.get_or_load_func(&kernel_name::("conv_transpose1d"), kernels::CONV)?; + let func = dev.get_or_load_func(&kernel_name::("conv_transpose1d"), &kernels::CONV)?; let ds = if dims.len() == 3 { [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() } else { crate::bail!("unexpected input shape for conv_transpose1d {dims:?}") }; - let ds = dev.htod_copy(ds).w()?; - let params = ( - el, - l_out, - p.stride, - p.padding, - p.output_padding, - p.dilation, - &ds, - inp, - k, - &out, - ); + let ds = dev.memcpy_stod(&ds).w()?; + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, l_out); + barg!(builder, p.stride); + barg!(builder, p.padding); + barg!(builder, p.output_padding); + barg!(builder, p.dilation); + builder.arg(&ds); + builder.arg(inp); + builder.arg(k); + builder.arg(&out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(out) } } @@ -732,28 +772,27 @@ impl Map2 for ConvTranspose2D<'_> { // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; let cfg = LaunchConfig::for_num_elems(dst_el as u32); - let func = dev.get_or_load_func(&kernel_name::("conv_transpose2d"), kernels::CONV)?; + let func = dev.get_or_load_func(&kernel_name::("conv_transpose2d"), &kernels::CONV)?; let ds = if dims.len() == 4 { [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() } else { crate::bail!("unexpected input shape for conv_transpose2d {dims:?}") }; - let ds = dev.htod_copy(ds).w()?; - let params = ( - el, - out_w, - out_h, - p.stride, - p.padding, - p.output_padding, - p.dilation, - &ds, - inp, - k, - &out, - ); + let ds = dev.memcpy_stod(&ds).w()?; + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, out_w); + barg!(builder, out_h); + barg!(builder, p.stride); + barg!(builder, p.padding); + barg!(builder, p.output_padding); + barg!(builder, p.dilation); + builder.arg(&ds); + builder.arg(inp); + builder.arg(k); + builder.arg(&out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(out) } } @@ -796,22 +835,21 @@ impl Map1 for Pool2D { PoolOp::Max => "max_pool2d", PoolOp::Avg => "avg_pool2d", }; - let func = dev.get_or_load_func(&kernel_name::(kname), kernels::CONV)?; + let func = dev.get_or_load_func(&kernel_name::(kname), &kernels::CONV)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; - let ds = dev.htod_copy(ds).w()?; - let params = ( - el, - self.w_k, - self.h_k, - self.w_stride, - self.h_stride, - &ds, - inp, - &out, - ); + let ds = dev.memcpy_stod(&ds).w()?; + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, self.w_k); + barg!(builder, self.h_k); + barg!(builder, self.w_stride); + barg!(builder, self.h_stride); + builder.arg(&ds); + builder.arg(inp); + builder.arg(&out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(out) } } @@ -836,15 +874,22 @@ impl Map1 for UpsampleNearest2D { let (out_w, out_h) = (self.0, self.1); let dst_el = out_w * out_h * dims[0] * dims[1]; let cfg = LaunchConfig::for_num_elems(dst_el as u32); - let func = dev.get_or_load_func(&kernel_name::("upsample_nearest2d"), kernels::CONV)?; + let func = dev.get_or_load_func(&kernel_name::("upsample_nearest2d"), &kernels::CONV)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; - let ds = dev.htod_copy(ds).w()?; + let ds = dev.memcpy_stod(&ds).w()?; let scale_w = dims[2] as f64 / out_w as f64; let scale_h = dims[3] as f64 / out_h as f64; - let params = (out_w, out_h, scale_w, scale_h, &ds, inp, &out); + let mut builder = func.builder(); + barg!(builder, out_w); + barg!(builder, out_h); + barg!(builder, scale_w); + barg!(builder, scale_h); + builder.arg(&ds); + builder.arg(inp); + builder.arg(&out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(out) } } @@ -860,17 +905,17 @@ impl Map2 for WhereCond<'_> { dev: &CudaDevice, ) -> Result> { let ids_l = &self.1; - let (ids, name) = match &self.0.slice { + let ((ids, _guard), name) = match &self.0.slice { CudaStorageSlice::U8(slice) => { - let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + let ptr = slice_ptr(slice, ids_l.start_offset()); (ptr, "where_u8") } CudaStorageSlice::U32(slice) => { - let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + let ptr = slice_ptr(slice, ids_l.start_offset()); (ptr, "where_u32") } CudaStorageSlice::I64(slice) => { - let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + let ptr = slice_ptr(slice, ids_l.start_offset()); (ptr, "where_i64") } _ => Err(CudaError::UnexpectedDType { @@ -885,16 +930,23 @@ impl Map2 for WhereCond<'_> { let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let ds = dev - .htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat()) + .memcpy_stod(&[dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat()) .w()?; let t = &t.slice(layout_t.start_offset()..); let f = &f.slice(layout_f.start_offset()..); - let func = dev.get_or_load_func(&kernel_name::(name), kernels::TERNARY)?; + let func = dev.get_or_load_func(&kernel_name::(name), &kernels::TERNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }.w()?; - let params = (el, dims.len(), &ds, ids, t, f, &out); + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, dims.len()); + builder.arg(&ds); + barg!(builder, ids); + builder.arg(t); + builder.arg(f); + builder.arg(&out); // SAFETY: ffi - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(out) } } @@ -916,18 +968,24 @@ impl Map2 for U { SlicePtrOrNull::Null } else { SlicePtrOrNull::Ptr( - dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) + dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat()) .w()?, ) }; let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::BINARY)?; + let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), &kernels::BINARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(elem_count) }.w()?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); + let mut builder = func.builder(); + barg!(builder, elem_count); + barg!(builder, dims.len()); + dims_and_strides.builder_arg(&mut builder); + builder.arg(lhs); + builder.arg(rhs); + builder.arg(&out); // SAFETY: ffi - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(out) } } @@ -950,7 +1008,7 @@ impl Map2Any for Cmp { SlicePtrOrNull::Null } else { SlicePtrOrNull::Ptr( - dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) + dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat()) .w()?, ) }; @@ -964,12 +1022,18 @@ impl Map2Any for Cmp { CmpOp::Gt => "gt", CmpOp::Ge => "ge", }; - let func = dev.get_or_load_func(&kernel_name::(name), kernels::BINARY)?; + let func = dev.get_or_load_func(&kernel_name::(name), &kernels::BINARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(elem_count) }.w()?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); + let mut builder = func.builder(); + barg!(builder, elem_count); + barg!(builder, dims.len()); + dims_and_strides.builder_arg(&mut builder); + builder.arg(lhs); + builder.arg(rhs); + builder.arg(&out); // SAFETY: ffi - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(S::U8(out)) } } @@ -1190,60 +1254,95 @@ impl BackendStorage for CudaStorage { // This returns an i64 rather than a &i64, this is useful to get around some temporary // lifetime issue and is safe as long as self.slice does not go out of scope before inp // is used. - let inp = match &self.slice { - CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(), - CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(), - CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(), - CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(), - CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), - CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(), - CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(), + let (inp, _guard) = match &self.slice { + CudaStorageSlice::U8(inp) => slice_ptr(inp, start_o), + CudaStorageSlice::U32(inp) => slice_ptr(inp, start_o), + CudaStorageSlice::I64(inp) => slice_ptr(inp, start_o), + CudaStorageSlice::BF16(inp) => slice_ptr(inp, start_o), + CudaStorageSlice::F16(inp) => slice_ptr(inp, start_o), + CudaStorageSlice::F32(inp) => slice_ptr(inp, start_o), + CudaStorageSlice::F64(inp) => slice_ptr(inp, start_o), }; let inp = &inp; let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str()); - let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?; + let func = dev.get_or_load_func(&kernel_name, &kernels::CAST)?; let slice = match dtype { DType::U8 => { let out = unsafe { dev.alloc::(el) }.w()?; - let params = (el, dims.len(), &ds, *inp, &out); - unsafe { func.launch(cfg, params) }.w()?; + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + barg!(builder, *inp); + builder.arg(&out); + unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::U8(out) } DType::U32 => { let out = unsafe { dev.alloc::(el) }.w()?; - let params = (el, dims.len(), &ds, *inp, &out); - unsafe { func.launch(cfg, params) }.w()?; + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + barg!(builder, *inp); + builder.arg(&out); + unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::U32(out) } DType::I64 => { let out = unsafe { dev.alloc::(el) }.w()?; - let params = (el, dims.len(), &ds, *inp, &out); - unsafe { func.launch(cfg, params) }.w()?; + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + barg!(builder, *inp); + builder.arg(&out); + unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::I64(out) } DType::BF16 => { let out = unsafe { dev.alloc::(el) }.w()?; - let params = (el, dims.len(), &ds, *inp, &out); - unsafe { func.launch(cfg, params) }.w()?; + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + barg!(builder, *inp); + builder.arg(&out); + unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::BF16(out) } DType::F16 => { let out = unsafe { dev.alloc::(el) }.w()?; - let params = (el, dims.len(), &ds, *inp, &out); - unsafe { func.launch(cfg, params) }.w()?; + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + barg!(builder, *inp); + builder.arg(&out); + unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::F16(out) } DType::F32 => { let out = unsafe { dev.alloc::(el) }.w()?; - let params = (el, dims.len(), &ds, *inp, &out); - unsafe { func.launch(cfg, params) }.w()?; + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + barg!(builder, *inp); + builder.arg(&out); + unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::F32(out) } DType::F64 => { let out = unsafe { dev.alloc::(el) }.w()?; - let params = (el, dims.len(), &ds, *inp, &out); - unsafe { func.launch(cfg, params) }.w()?; + let mut builder = func.builder(); + barg!(builder, el); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + barg!(builder, *inp); + builder.arg(&out); + unsafe { builder.launch(cfg) }.w()?; CudaStorageSlice::F64(out) } }; @@ -1303,38 +1402,31 @@ impl BackendStorage for CudaStorage { fn to_cpu_storage(&self) -> Result { match &self.slice { CudaStorageSlice::U8(slice) => { - let dev = slice.device(); - let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; Ok(CpuStorage::U8(cpu_storage)) } CudaStorageSlice::U32(slice) => { - let dev = slice.device(); - let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; Ok(CpuStorage::U32(cpu_storage)) } CudaStorageSlice::I64(slice) => { - let dev = slice.device(); - let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; Ok(CpuStorage::I64(cpu_storage)) } CudaStorageSlice::BF16(slice) => { - let dev = slice.device(); - let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; Ok(CpuStorage::BF16(cpu_storage)) } CudaStorageSlice::F16(slice) => { - let dev = slice.device(); - let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; Ok(CpuStorage::F16(cpu_storage)) } CudaStorageSlice::F32(slice) => { - let dev = slice.device(); - let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; Ok(CpuStorage::F32(cpu_storage)) } CudaStorageSlice::F64(slice) => { - let dev = slice.device(); - let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + let cpu_storage = slice.stream().memcpy_dtov(slice).w()?; Ok(CpuStorage::F64(cpu_storage)) } } @@ -1753,49 +1845,27 @@ impl BackendStorage for CudaStorage { } let dst_s = dst_s as u32; let src_s = src_s as u32; - let (src, dst, kname) = match (&self.slice, &mut dst.slice) { - (S::U8(s), S::U8(d)) => ( - *s.slice(src_o..).device_ptr(), - *d.slice(dst_o..).device_ptr(), - "copy2d_u8", - ), - (S::U32(s), S::U32(d)) => ( - *s.slice(src_o..).device_ptr(), - *d.slice(dst_o..).device_ptr(), - "copy2d_u32", - ), - (S::I64(s), S::I64(d)) => ( - *s.slice(src_o..).device_ptr(), - *d.slice(dst_o..).device_ptr(), - "copy2d_i64", - ), - (S::BF16(s), S::BF16(d)) => ( - *s.slice(src_o..).device_ptr(), - *d.slice(dst_o..).device_ptr(), - "copy2d_bf16", - ), - (S::F16(s), S::F16(d)) => ( - *s.slice(src_o..).device_ptr(), - *d.slice(dst_o..).device_ptr(), - "copy2d_f16", - ), - (S::F32(s), S::F32(d)) => ( - *s.slice(src_o..).device_ptr(), - *d.slice(dst_o..).device_ptr(), - "copy2d_f32", - ), - (S::F64(s), S::F64(d)) => ( - *s.slice(src_o..).device_ptr(), - *d.slice(dst_o..).device_ptr(), - "copy2d_f64", - ), + let ((src, _guard_src), (dst, _guard_dst), kname) = match (&self.slice, &mut dst.slice) { + (S::U8(s), S::U8(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_u8"), + (S::U32(s), S::U32(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_u32"), + (S::I64(s), S::I64(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_i64"), + (S::BF16(s), S::BF16(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_bf16"), + (S::F16(s), S::F16(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_f16"), + (S::F32(s), S::F32(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_f32"), + (S::F64(s), S::F64(d)) => (slice_ptr(s, src_o), slice_ptr(d, dst_o), "copy2d_f64"), _ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?, }; - let func = dev.get_or_load_func(kname, kernels::FILL)?; + let func = dev.get_or_load_func(kname, &kernels::FILL)?; let cfg = LaunchConfig::for_num_elems(d1 * d2); - let params = (src, dst, d1, d2, src_s, dst_s); + let mut builder = func.builder(); + barg!(builder, src); + barg!(builder, dst); + barg!(builder, d1); + barg!(builder, d2); + builder.arg(&src_s); + builder.arg(&dst_s); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(()) } @@ -1813,85 +1883,113 @@ impl BackendStorage for CudaStorage { (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.dtod_copy(&src, &mut dst).w()? + dev.memcpy_dtod(&src, &mut dst).w()? } else { - let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let params = (el_count, dims.len(), &ds, &src, &mut dst); + let func = dev.get_or_load_func("ucopy_bf16", &kernels::UNARY)?; + let mut builder = func.builder(); + barg!(builder, el_count); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + builder.arg(&src); + builder.arg(&mut dst); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()? + unsafe { builder.launch(cfg) }.w()?; } } (CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.dtod_copy(&src, &mut dst).w()? + dev.memcpy_dtod(&src, &mut dst).w()? } else { - let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let params = (el_count, dims.len(), &ds, &src, &mut dst); + let func = dev.get_or_load_func("ucopy_f16", &kernels::UNARY)?; + let mut builder = func.builder(); + barg!(builder, el_count); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + builder.arg(&src); + builder.arg(&mut dst); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()? + unsafe { builder.launch(cfg) }.w()?; } } (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.dtod_copy(&src, &mut dst).w()? + dev.memcpy_dtod(&src, &mut dst).w()? } else { - let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let params = (el_count, dims.len(), &ds, &src, &mut dst); + let func = dev.get_or_load_func("ucopy_f32", &kernels::UNARY)?; + let mut builder = func.builder(); + barg!(builder, el_count); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + builder.arg(&src); + builder.arg(&mut dst); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()? + unsafe { builder.launch(cfg) }.w()?; } } (CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.dtod_copy(&src, &mut dst).w()? + dev.memcpy_dtod(&src, &mut dst).w()? } else { - let func = dev.get_or_load_func("ucopy_u8", kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let params = (el_count, dims.len(), &ds, &src, &mut dst); + let func = dev.get_or_load_func("ucopy_u8", &kernels::UNARY)?; + let mut builder = func.builder(); + barg!(builder, el_count); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + builder.arg(&src); + builder.arg(&mut dst); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()? + unsafe { builder.launch(cfg) }.w()?; } } (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.dtod_copy(&src, &mut dst).w()? + dev.memcpy_dtod(&src, &mut dst).w()? } else { - let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let params = (el_count, dims.len(), &ds, &src, &mut dst); + let func = dev.get_or_load_func("ucopy_u32", &kernels::UNARY)?; + let mut builder = func.builder(); + barg!(builder, el_count); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + builder.arg(&src); + builder.arg(&mut dst); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()? + unsafe { builder.launch(cfg) }.w()?; } } (CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.dtod_copy(&src, &mut dst).w()? + dev.memcpy_dtod(&src, &mut dst).w()? } else { - let func = dev.get_or_load_func("ucopy_i64", kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let params = (el_count, dims.len(), &ds, &src, &mut dst); + let func = dev.get_or_load_func("ucopy_i64", &kernels::UNARY)?; + let mut builder = func.builder(); + barg!(builder, el_count); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + builder.arg(&src); + builder.arg(&mut dst); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()? + unsafe { builder.launch(cfg) }.w()?; } } (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.dtod_copy(&src, &mut dst).w()? + dev.memcpy_dtod(&src, &mut dst).w()? } else { - let func = dev.get_or_load_func("ucopy_f64", kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let params = (el_count, dims.len(), &ds, &src, &mut dst); + let func = dev.get_or_load_func("ucopy_f64", &kernels::UNARY)?; + let mut builder = func.builder(); + barg!(builder, el_count); + barg!(builder, dims.len()); + ds.builder_arg(&mut builder); + builder.arg(&src); + builder.arg(&mut dst); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; } } _ => Err(CudaError::InternalError( @@ -1965,6 +2063,11 @@ unsafe fn gemm_strided_batched_f32( let alpha = &cfg.gemm.alpha as *const f32 as *const _; let beta = &cfg.gemm.beta as *const f32 as *const _; + let stream = c.stream().clone(); + let (a, _guard_a) = a.device_ptr(&stream); + let (b, _guard_b) = b.device_ptr(&stream); + let (c, _guard_c) = c.device_ptr_mut(&stream); + cudarc::cublas::result::gemm_strided_batched_ex( *cublas.handle(), cfg.gemm.transa, @@ -1973,16 +2076,16 @@ unsafe fn gemm_strided_batched_f32( cfg.gemm.n, cfg.gemm.k, alpha, - *a.device_ptr() as *const _, + a as *const _, sys::cudaDataType_t::CUDA_R_32F, cfg.gemm.lda, cfg.stride_a, - *b.device_ptr() as *const _, + b as *const _, sys::cudaDataType_t::CUDA_R_32F, cfg.gemm.ldb, cfg.stride_b, beta, - *c.device_ptr_mut() as *mut _, + c as *mut _, sys::cudaDataType_t::CUDA_R_32F, cfg.gemm.ldc, cfg.stride_c, @@ -2020,6 +2123,10 @@ unsafe fn gemm_strided_batched_f16( ) }; + let stream = c.stream().clone(); + let (a, _guard_a) = a.device_ptr(&stream); + let (b, _guard_b) = b.device_ptr(&stream); + let (c, _guard_c) = c.device_ptr_mut(&stream); cudarc::cublas::result::gemm_strided_batched_ex( *cublas.handle(), cfg.gemm.transa, @@ -2028,16 +2135,16 @@ unsafe fn gemm_strided_batched_f16( cfg.gemm.n, cfg.gemm.k, alpha, - *a.device_ptr() as *const _, + a as *const _, sys::cudaDataType_t::CUDA_R_16F, cfg.gemm.lda, cfg.stride_a, - *b.device_ptr() as *const _, + b as *const _, sys::cudaDataType_t::CUDA_R_16F, cfg.gemm.ldb, cfg.stride_b, beta, - *c.device_ptr_mut() as *mut _, + c as *mut _, sys::cudaDataType_t::CUDA_R_16F, cfg.gemm.ldc, cfg.stride_c, @@ -2075,6 +2182,10 @@ unsafe fn gemm_strided_batched_bf16( ) }; + let stream = c.stream().clone(); + let (a, _guard_a) = a.device_ptr(&stream); + let (b, _guard_b) = b.device_ptr(&stream); + let (c, _guard_c) = c.device_ptr_mut(&stream); cudarc::cublas::result::gemm_strided_batched_ex( *cublas.handle(), cfg.gemm.transa, @@ -2083,16 +2194,16 @@ unsafe fn gemm_strided_batched_bf16( cfg.gemm.n, cfg.gemm.k, alpha, - *a.device_ptr() as *const _, + a as *const _, sys::cudaDataType_t::CUDA_R_16BF, cfg.gemm.lda, cfg.stride_a, - *b.device_ptr() as *const _, + b as *const _, sys::cudaDataType_t::CUDA_R_16BF, cfg.gemm.ldb, cfg.stride_b, beta, - *c.device_ptr_mut() as *mut _, + c as *mut _, sys::cudaDataType_t::CUDA_R_16BF, cfg.gemm.ldc, cfg.stride_c, diff --git a/candle-core/src/custom_op.rs b/candle-core/src/custom_op.rs index 18d4786e..5d0fc9f8 100644 --- a/candle-core/src/custom_op.rs +++ b/candle-core/src/custom_op.rs @@ -396,7 +396,10 @@ impl UgIOp1 { { let device = device.as_cuda_device()?; let func = device.compile(name, kernel)?; - Ok(Self { name, func }) + Ok(Self { + name, + func: func.into_cuda_function(), + }) } #[cfg(feature = "metal")] { @@ -459,16 +462,16 @@ impl InplaceOp1 for UgIOp1 { #[cfg(feature = "cuda")] fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> { use crate::cuda_backend::WrapErr; - use cudarc::driver::LaunchAsync; + use cudarc::driver::PushKernelArg; let elem_count = layout.shape().elem_count(); + let stream = sto.device.cuda_stream(); // TODO: support more dtypes. let sto = sto.as_cuda_slice::()?; let sto = match layout.contiguous_offsets() { None => crate::bail!("input has to be contiguous"), Some((o1, o2)) => sto.slice(o1..o2), }; - let params = (&sto,); let (g, b) = if elem_count % 32 == 0 { (elem_count / 32, 32) } else { @@ -479,7 +482,9 @@ impl InplaceOp1 for UgIOp1 { block_dim: (b as u32, 1, 1), shared_mem_bytes: 0, }; - unsafe { self.func.clone().launch(cfg, params) }.w()?; + let mut builder = stream.launch_builder(&self.func); + builder.arg(&sto); + unsafe { builder.launch(cfg) }.w()?; Ok(()) } } diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 1a3d72c0..92dfe028 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -1,10 +1,10 @@ use super::{GgmlDType, QStorage}; use crate::quantized::k_quants::GgmlType; use crate::{backend::BackendDevice, cuda_backend::WrapErr}; -use crate::{CudaDevice, CudaStorage, Result}; +use crate::{builder_arg as barg, CudaDevice, CudaStorage, Result}; use half::f16; -use cudarc::driver::{CudaSlice, CudaView, DeviceSlice}; +use cudarc::driver::{CudaSlice, CudaView, PushKernelArg}; #[derive(Clone, Debug)] struct PaddedCudaSlice { @@ -50,19 +50,20 @@ fn quantize_q8_1( ky: usize, dev: &CudaDevice, ) -> Result<()> { - use cudarc::driver::LaunchAsync; - let kx = elem_count; let kx_padded = pad(kx, MATRIX_ROW_PADDING); let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE); - let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?; + let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?; let cfg = cudarc::driver::LaunchConfig { grid_dim: (num_blocks as u32, ky as u32, 1), block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1), shared_mem_bytes: 0, }; - let params = (src, dst, kx as i32, kx_padded as i32); - unsafe { func.launch(cfg, params) }.w()?; + let mut builder = func.builder(); + builder.arg(src); + builder.arg(dst); + barg!(builder, kx as i32, kx_padded as i32); + unsafe { builder.launch(cfg) }.w()?; Ok(()) } @@ -72,8 +73,6 @@ fn dequantize_f32( elem_count: usize, dev: &CudaDevice, ) -> Result { - use cudarc::driver::LaunchAsync; - let nb = (elem_count + 255) / 256; let (kernel_name, is_k, block_dim, num_blocks) = match dtype { GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb), @@ -99,7 +98,7 @@ fn dequantize_f32( GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb), _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), }; - let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; + let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?; let dst = unsafe { dev.alloc::(elem_count).w()? }; // See e.g. // https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270 @@ -110,15 +109,20 @@ fn dequantize_f32( }; if is_k { - let params = (&data.inner, &dst); - unsafe { func.launch(cfg, params) }.w()?; + let mut builder = func.builder(); + builder.arg(&data.inner); + builder.arg(&dst); + unsafe { builder.launch(cfg) }.w()?; } else { let nb32 = match dtype { GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count, _ => elem_count / 32, }; - let params = (&data.inner, &dst, nb32 as i32); - unsafe { func.launch(cfg, params) }.w()?; + let mut builder = func.builder(); + builder.arg(&data.inner); + builder.arg(&dst); + barg!(builder, nb32 as i32); + unsafe { builder.launch(cfg) }.w()?; } Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } @@ -129,8 +133,6 @@ fn dequantize_f16( elem_count: usize, dev: &CudaDevice, ) -> Result { - use cudarc::driver::LaunchAsync; - let nb = (elem_count + 255) / 256; let (kernel_name, is_k, block_dim, num_blocks) = match dtype { GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb), @@ -156,7 +158,7 @@ fn dequantize_f16( GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb), _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), }; - let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; + let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?; let dst = unsafe { dev.alloc::(elem_count).w()? }; // See e.g. // https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270 @@ -167,15 +169,20 @@ fn dequantize_f16( }; if is_k { - let params = (&data.inner, &dst); - unsafe { func.launch(cfg, params) }.w()?; + let mut builder = func.builder(); + builder.arg(&data.inner); + builder.arg(&dst); + unsafe { builder.launch(cfg) }.w()?; } else { let nb32 = match dtype { GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count, _ => elem_count / 32, }; - let params = (&data.inner, &dst, nb32 as i32); - unsafe { func.launch(cfg, params) }.w()?; + let mut builder = func.builder(); + builder.arg(&data.inner); + builder.arg(&dst); + barg!(builder, nb32 as i32); + unsafe { builder.launch(cfg) }.w()?; } Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } @@ -188,8 +195,6 @@ fn dequantize_mul_mat_vec( nrows: usize, dev: &CudaDevice, ) -> Result { - use cudarc::driver::LaunchAsync; - let data_elems = data.len / dtype.type_size() * dtype.block_size(); if data_elems < ncols * nrows { crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems) @@ -210,7 +215,7 @@ fn dequantize_mul_mat_vec( GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k", _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), }; - let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; + let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?; let dst = unsafe { dev.alloc::(nrows).w()? }; let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y); let cfg = cudarc::driver::LaunchConfig { @@ -219,8 +224,12 @@ fn dequantize_mul_mat_vec( shared_mem_bytes: 0, }; - let params = (&data.inner, y, &dst, ncols as i32, nrows as i32); - unsafe { func.launch(cfg, params) }.w()?; + let mut builder = func.builder(); + builder.arg(&data.inner); + builder.arg(y); + builder.arg(&dst); + barg!(builder, ncols as i32, nrows as i32); + unsafe { builder.launch(cfg) }.w()?; Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } @@ -233,8 +242,6 @@ fn mul_mat_vec_via_q8_1( b_size: usize, dev: &CudaDevice, ) -> Result { - use cudarc::driver::LaunchAsync; - let data_elems = data.len / dtype.type_size() * dtype.block_size(); if data_elems < ncols * nrows { crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems) @@ -266,7 +273,7 @@ fn mul_mat_vec_via_q8_1( _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), }; let kernel_name = format!("{kernel_name}{b_size}"); - let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?; + let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?; let dst = unsafe { dev.alloc::(nrows * b_size).w()? }; // https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98 let (nblocks, nwarps) = match b_size { @@ -281,16 +288,18 @@ fn mul_mat_vec_via_q8_1( shared_mem_bytes: 0, }; - let params = ( - &data.inner, - &y_q8_1, - &dst, + let mut builder = func.builder(); + builder.arg(&data.inner); + builder.arg(&y_q8_1); + builder.arg(&dst); + barg!( + builder, /* ncols_x */ ncols as i32, /* nrows_x */ nrows as i32, /* nrows_y */ ncols_padded as i32, - /* nrows_dst */ nrows as i32, + /* nrows_dst */ nrows as i32 ); - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } @@ -305,8 +314,6 @@ fn mul_mat_via_q8_1( y_cols: usize, dev: &CudaDevice, ) -> Result { - use cudarc::driver::LaunchAsync; - let data_elems = data.len / dtype.type_size() * dtype.block_size(); if data_elems < x_rows * x_cols { crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems) @@ -338,7 +345,7 @@ fn mul_mat_via_q8_1( GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64), _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), }; - let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; + let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?; let dst = unsafe { dev.alloc::(x_rows * y_cols).w()? }; let cfg = cudarc::driver::LaunchConfig { grid_dim: ( @@ -350,17 +357,19 @@ fn mul_mat_via_q8_1( shared_mem_bytes: 0, }; - let params = ( - /* vx */ &data.inner, - /* vy */ &y_q8_1, - /* dst */ &dst, + let mut builder = func.builder(); + builder.arg(/* vx */ &data.inner); + builder.arg(/* vy */ &y_q8_1); + builder.arg(/* dst */ &dst); + barg!( + builder, /* ncols_x */ x_cols as i32, /* nrows_x */ x_rows as i32, /* ncols_y */ y_cols as i32, /* nrows_y */ k_padded as i32, - /* nrows_dst */ x_rows as i32, + /* nrows_dst */ x_rows as i32 ); - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } @@ -416,7 +425,7 @@ impl QCudaStorage { let buffer = self .device - .dtoh_sync_copy(&self.data.inner.slice(..self.data.len)) + .memcpy_dtov(&self.data.inner.slice(..self.data.len)) .w()?; let mut out = vec![0.0; elem_count]; let block_len = elem_count / self.dtype.block_size(); @@ -449,7 +458,7 @@ impl QCudaStorage { // Run the quantization on cpu. let src = match &src.slice { crate::cuda_backend::CudaStorageSlice::F32(data) => { - self.device.dtoh_sync_copy(data).w()? + self.device.memcpy_dtov(data).w()? } _ => crate::bail!("only f32 can be quantized"), }; @@ -462,7 +471,7 @@ impl QCudaStorage { data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size(); let mut inner = unsafe { self.device.alloc::(padded_len).w()? }; self.device - .htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len())) + .memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len())) .w()?; self.data = PaddedCudaSlice { inner, @@ -599,7 +608,7 @@ pub fn load_quantized( let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size(); let mut inner = unsafe { device.alloc::(padded_len).w()? }; device - .htod_sync_copy_into(data, &mut inner.slice_mut(..data.len())) + .memcpy_htod(data, &mut inner.slice_mut(..data.len())) .w()?; Ok(QStorage::Cuda(QCudaStorage { data: PaddedCudaSlice { @@ -624,7 +633,7 @@ mod test { el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes).w()? }; let vs: Vec = (0..el).map(|v| v as f32).collect(); - let y = dev.htod_sync_copy(&vs).w()?; + let y = dev.memcpy_stod(&vs).w()?; quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?; Ok(()) } @@ -634,7 +643,7 @@ mod test { let dev = CudaDevice::new(0)?; let ncols = 256; let vs: Vec = (0..ncols).map(|v| v as f32).collect(); - let y = dev.htod_sync_copy(&vs).w()?; + let y = dev.memcpy_stod(&vs).w()?; let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?; xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; let cuda_storage = mul_mat_vec_via_q8_1( @@ -647,7 +656,7 @@ mod test { &dev, )?; let vs = cuda_storage.as_cuda_slice::()?; - let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap(); + let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap(); assert_eq!(vs.len(), 1); // for n = 255, n.(n+1).(2n+1) / 6 = 5559680 // Q8 means 1/256 precision. @@ -662,7 +671,7 @@ mod test { &dev, )?; let vs = cuda_storage.as_cuda_slice::()?; - let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap(); + let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap(); assert_eq!(vs.len(), 1); assert_eq!(vs[0], 5561851.0); Ok(()) @@ -673,7 +682,7 @@ mod test { let dev = CudaDevice::new(0)?; let ncols = 256; let vs: Vec = (0..ncols * 4).map(|v| v as f32 / 4.).collect(); - let y = dev.htod_sync_copy(&vs).w()?; + let y = dev.memcpy_stod(&vs).w()?; let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?; xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; let cuda_storage = mul_mat_via_q8_1( @@ -687,7 +696,7 @@ mod test { &dev, )?; let vs = cuda_storage.as_cuda_slice::()?; - let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap(); + let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap(); /* x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256) @@ -714,7 +723,7 @@ mod test { let dev = CudaDevice::new(0)?; let (x_rows, ncols, y_cols) = (4, 16, 2048); let vs: Vec = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect(); - let y = dev.htod_sync_copy(&vs).w()?; + let y = dev.memcpy_stod(&vs).w()?; let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?; xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; let cuda_storage = mul_mat_via_q8_1( @@ -728,7 +737,7 @@ mod test { &dev, )?; let vs = cuda_storage.as_cuda_slice::()?; - let _vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap(); + let _vs = dev.memcpy_dtov(&vs.slice(..)).unwrap(); Ok(()) } } diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 0ebb1835..9a8597d3 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -56,7 +56,7 @@ impl ArgSort { mod cuda { use super::*; use crate::cuda_backend::cudarc::driver::{ - CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits, + CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits, }; use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr}; use crate::{CudaDevice, WithDType}; @@ -69,6 +69,8 @@ mod cuda { layout: &crate::Layout, _wrap: W, ) -> Result { + use cudarc::driver::PushKernelArg; + let slice = match layout.contiguous_offsets() { None => crate::bail!("input has to be contiguous"), Some((o1, o2)) => src.slice(o1..o2), @@ -76,20 +78,24 @@ mod cuda { let elem_count = layout.shape().elem_count(); let dst = unsafe { dev.alloc::(elem_count) }.w()?; let func = if self.asc { - dev.get_or_load_func(&kernel_name::("asort_asc"), kernels::SORT)? + dev.get_or_load_func(&kernel_name::("asort_asc"), &kernels::SORT)? } else { - dev.get_or_load_func(&kernel_name::("asort_desc"), kernels::SORT)? + dev.get_or_load_func(&kernel_name::("asort_desc"), &kernels::SORT)? }; let ncols = self.last_dim; let nrows = elem_count / ncols; let ncols_pad = next_power_of_2(ncols); - let params = (&slice, &dst, ncols as i32, ncols_pad as i32); let cfg = LaunchConfig { grid_dim: (1, nrows as u32, 1), block_dim: (ncols_pad as u32, 1, 1), shared_mem_bytes: (ncols_pad * std::mem::size_of::()) as u32, }; - unsafe { func.launch(cfg, params) }.w()?; + let stream = dev.cuda_stream(); + let mut builder = stream.launch_builder(&func); + let ncols = ncols as i32; + let ncols_pad = ncols_pad as i32; + builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad); + unsafe { builder.launch(cfg) }.w()?; Ok(S::U32(dst)) } } diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs index 30e413c1..9a312cb2 100644 --- a/candle-examples/examples/custom-ops/main.rs +++ b/candle-examples/examples/custom-ops/main.rs @@ -56,7 +56,7 @@ impl CustomOp1 for LayerNorm { layout: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { use candle::backend::BackendStorage; - use candle::cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig}; + use candle::cuda_backend::cudarc::driver::{LaunchConfig, PushKernelArg}; use candle::cuda_backend::WrapErr; let (d1, d2) = layout.shape().dims2()?; let d1 = d1 as u32; @@ -69,14 +69,18 @@ impl CustomOp1 for LayerNorm { }; let elem_count = layout.shape().elem_count(); let dst = unsafe { dev.alloc::(elem_count) }.w()?; - let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?; - let params = (&dst, &slice, self.eps, d1, d2); + let func = + dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?; let cfg = LaunchConfig { grid_dim: (d1, 1, 1), block_dim: (d2, 1, 1), shared_mem_bytes: 0, }; - unsafe { func.launch(cfg, params) }.w()?; + let mut builder = func.builder(); + builder.arg(&dst); + builder.arg(&slice); + candle::builder_arg!(builder, self.eps, d1, d2); + unsafe { builder.launch(cfg) }.w()?; let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev); Ok((dst, layout.shape().clone())) diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index f9c65fe9..91f3cb88 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-flash-attn" -version = "0.8.4" +version = "0.9.0-alpha.1" edition = "2021" description = "Flash attention layer for the candle ML framework." @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0" readme = "README.md" [dependencies] -candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.4" } +candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.1" } half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 1b2e5e43..e84edd14 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -88,6 +88,7 @@ impl FlashAttn { candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") } + let stream = dev.cuda_stream(); let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { if alibi_slopes.dtype() != DType::F32 { candle::bail!( @@ -114,7 +115,9 @@ impl FlashAttn { let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); - *alibi_slopes.device_ptr() as *const core::ffi::c_void + // Dropping the guard here doesn't seem very safe. + let (ptr, _guard) = alibi_slopes.device_ptr(&stream); + ptr as *const core::ffi::c_void } else { std::ptr::null() }; @@ -161,17 +164,17 @@ impl FlashAttn { } unsafe { - let q_ptr = *q.device_ptr() as *const core::ffi::c_void; - let k_ptr = *k.device_ptr() as *const core::ffi::c_void; - let v_ptr = *v.device_ptr() as *const core::ffi::c_void; - let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void; - let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void; + let (q_ptr, _guard) = q.device_ptr(&stream); + let (k_ptr, _guard) = k.device_ptr(&stream); + let (v_ptr, _guard) = v.device_ptr(&stream); + let (dst_ptr, _guard) = dst.device_ptr(&stream); + let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream); ffi::run_mha( - q_ptr, - k_ptr, - v_ptr, - dst_ptr, - softmax_lse_ptr, + q_ptr as *const core::ffi::c_void, + k_ptr as *const core::ffi::c_void, + v_ptr as *const core::ffi::c_void, + dst_ptr as *const core::ffi::c_void, + softmax_lse_ptr as *const core::ffi::c_void, /* alibi_slopes_ptr */ alibi_slopes_ptr, /* cu_seqlens_q_ptr */ std::ptr::null(), /* cu_seqlens_k_ptr */ std::ptr::null(), @@ -550,6 +553,7 @@ impl FlashAttnVarLen { let batch_size = nseqlens_q - 1; + let stream = dev.cuda_stream(); let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { if alibi_slopes.dtype() != DType::F32 { candle::bail!( @@ -576,7 +580,9 @@ impl FlashAttnVarLen { let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); - *alibi_slopes.device_ptr() as *const core::ffi::c_void + // Dropping the guard here doesn't seem very safe. + let (ptr, _guard) = alibi_slopes.device_ptr(&stream); + ptr as *const core::ffi::c_void } else { std::ptr::null() }; @@ -621,22 +627,22 @@ impl FlashAttnVarLen { } unsafe { - let q_ptr = *q.device_ptr() as *const core::ffi::c_void; - let k_ptr = *k.device_ptr() as *const core::ffi::c_void; - let v_ptr = *v.device_ptr() as *const core::ffi::c_void; - let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void; - let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void; - let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int; - let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int; + let (q_ptr, _guard) = q.device_ptr(&stream); + let (k_ptr, _guard) = k.device_ptr(&stream); + let (v_ptr, _guard) = v.device_ptr(&stream); + let (dst_ptr, _guard) = dst.device_ptr(&stream); + let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream); + let (seqlens_q_ptr, _guard) = seqlens_q.device_ptr(&stream); + let (seqlens_k_ptr, _guard) = seqlens_k.device_ptr(&stream); ffi::run_mha( - q_ptr, - k_ptr, - v_ptr, - dst_ptr, - softmax_lse_ptr, - /* alibi_slopes_ptr */ alibi_slopes_ptr, - /* cu_seqlens_q_ptr */ seqlens_q_ptr, - /* cu_seqlens_k_ptr */ seqlens_k_ptr, + q_ptr as *const core::ffi::c_void, + k_ptr as *const core::ffi::c_void, + v_ptr as *const core::ffi::c_void, + dst_ptr as *const core::ffi::c_void, + softmax_lse_ptr as *const core::ffi::c_void, + /* alibi_slopes_ptr */ alibi_slopes_ptr as *const core::ffi::c_void, + /* cu_seqlens_q_ptr */ seqlens_q_ptr as *const i32, + /* cu_seqlens_k_ptr */ seqlens_k_ptr as *const i32, /* q_batch_stride */ 0, /* k_batch_stride */ 0, /* v_batch_stride */ 0, diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml index 381489b8..ed4ae6cb 100644 --- a/candle-kernels/Cargo.toml +++ b/candle-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-kernels" -version = "0.8.4" +version = "0.9.0-alpha.1" edition = "2021" description = "CUDA kernels for Candle" diff --git a/candle-kernels/build.rs b/candle-kernels/build.rs index c28abd97..1acbe51d 100644 --- a/candle-kernels/build.rs +++ b/candle-kernels/build.rs @@ -7,5 +7,5 @@ fn main() { let builder = bindgen_cuda::Builder::default(); println!("cargo:info={builder:?}"); let bindings = builder.build_ptx().unwrap(); - bindings.write("src/lib.rs").unwrap(); + bindings.write("src/ptx.rs").unwrap(); } diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs index 1c73d6b7..78cacfbf 100644 --- a/candle-kernels/src/lib.rs +++ b/candle-kernels/src/lib.rs @@ -1,11 +1,78 @@ -pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx")); -pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx")); -pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx")); -pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx")); -pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); -pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx")); -pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx")); -pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); -pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx")); -pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx")); -pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx")); +mod ptx; + +#[repr(u32)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Id { + Affine, + Binary, + Cast, + Conv, + Fill, + Indexing, + Quantized, + Reduce, + Sort, + Ternary, + Unary, +} + +pub const ALL_IDS: [Id; 11] = [ + Id::Affine, + Id::Binary, + Id::Cast, + Id::Conv, + Id::Fill, + Id::Indexing, + Id::Quantized, + Id::Reduce, + Id::Sort, + Id::Ternary, + Id::Unary, +]; + +pub struct Module { + index: usize, + ptx: &'static str, +} + +impl Module { + pub fn index(&self) -> usize { + self.index + } + + pub fn ptx(&self) -> &'static str { + self.ptx + } +} + +const fn module_index(id: Id) -> usize { + let mut i = 0; + while i < ALL_IDS.len() { + if ALL_IDS[i] as u32 == id as u32 { + return i; + } + i += 1; + } + panic!("id not found") +} + +macro_rules! mdl { + ($cst:ident, $id:ident) => { + pub const $cst: Module = Module { + index: module_index(Id::$id), + ptx: ptx::$cst, + }; + }; +} + +mdl!(AFFINE, Affine); +mdl!(BINARY, Binary); +mdl!(CAST, Cast); +mdl!(CONV, Conv); +mdl!(FILL, Fill); +mdl!(INDEXING, Indexing); +mdl!(QUANTIZED, Quantized); +mdl!(REDUCE, Reduce); +mdl!(SORT, Sort); +mdl!(TERNARY, Ternary); +mdl!(UNARY, Unary); diff --git a/candle-kernels/src/ptx.rs b/candle-kernels/src/ptx.rs new file mode 100644 index 00000000..1c73d6b7 --- /dev/null +++ b/candle-kernels/src/ptx.rs @@ -0,0 +1,11 @@ +pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx")); +pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx")); +pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx")); +pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx")); +pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); +pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx")); +pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx")); +pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); +pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx")); +pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx")); +pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx")); diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 5a8b2cea..156a1962 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-metal-kernels" -version = "0.8.4" +version = "0.9.0-alpha.1" edition = "2021" description = "Metal kernels for Candle" diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index d7f88a0b..74169190 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -90,7 +90,7 @@ impl candle::CustomOp1 for Sigmoid { ) -> Result<(candle::CudaStorage, Shape)> { use candle::backend::BackendStorage; use candle::cuda_backend::cudarc::driver::{ - CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits, + CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits, }; use candle::cuda_backend::SlicePtrOrNull; use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr}; @@ -110,13 +110,17 @@ impl candle::CustomOp1 for Sigmoid { let cfg = LaunchConfig::for_num_elems(el_count as u32); let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); - let func = dev.get_or_load_func(&kernel_name::("usigmoid"), kernels::UNARY)?; + let func = dev.get_or_load_func(&kernel_name::("usigmoid"), &kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }.w()?; - let params = (el_count, dims.len(), &ds, src, &out); + let mut builder = func.builder(); + candle::builder_arg!(builder, el_count, dims.len()); + ds.builder_arg(&mut builder); + builder.arg(src); + builder.arg(&out); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(out) } } @@ -340,7 +344,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { layout: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { use candle::cuda_backend::cudarc::driver::{ - CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, }; use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr}; use candle::{CudaDevice, WithDType}; @@ -367,12 +371,15 @@ impl candle::CustomOp1 for SoftmaxLastDim { block_dim: (1, 32, 1), shared_mem_bytes: 0, }; - let func = dev.get_or_load_func(&kernel_name::("softmax"), kernels::REDUCE)?; + let func = dev.get_or_load_func(&kernel_name::("softmax"), &kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(el) }.w()?; - let params = (&src, &dst, n_cols as i32); + let mut builder = func.builder(); + builder.arg(&src); + builder.arg(&dst); + candle::builder_arg!(builder, n_cols as i32); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(dst) } } @@ -516,7 +523,7 @@ impl candle::CustomOp2 for RmsNorm { l2: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { use candle::cuda_backend::cudarc::driver::{ - CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, }; use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr}; use candle::{CudaDevice, WithDType}; @@ -552,19 +559,16 @@ impl candle::CustomOp2 for RmsNorm { block_dim: (block_size, 1, 1), shared_mem_bytes: 0, }; - let func = dev.get_or_load_func(&kernel_name::("rmsnorm"), kernels::REDUCE)?; + let func = dev.get_or_load_func(&kernel_name::("rmsnorm"), &kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(el) }.w()?; - let params = ( - &src, - &dst, - &alpha, - n_cols as i32, - block_size as i32, - self.eps, - ); + let mut builder = func.builder(); + builder.arg(&src); + builder.arg(&dst); + builder.arg(&alpha); + candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(dst) } } @@ -751,7 +755,7 @@ impl candle::CustomOp3 for LayerNorm { l3: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { use candle::cuda_backend::cudarc::driver::{ - CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, }; use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr}; use candle::{CudaDevice, WithDType}; @@ -793,20 +797,18 @@ impl candle::CustomOp3 for LayerNorm { block_dim: (block_size, 1, 1), shared_mem_bytes: 0, }; - let func = dev.get_or_load_func(&kernel_name::("layernorm"), kernels::REDUCE)?; + let func = + dev.get_or_load_func(&kernel_name::("layernorm"), &kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(el) }.w()?; - let params = ( - &src, - &dst, - &alpha, - &beta, - n_cols as i32, - block_size as i32, - self.eps, - ); + let mut builder = func.builder(); + builder.arg(&src); + builder.arg(&dst); + builder.arg(&alpha); + builder.arg(&beta); + candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(dst) } } diff --git a/candle-nn/src/rotary_emb.rs b/candle-nn/src/rotary_emb.rs index 0191bd7e..a1d7cfae 100644 --- a/candle-nn/src/rotary_emb.rs +++ b/candle-nn/src/rotary_emb.rs @@ -88,7 +88,7 @@ impl candle::CustomOp3 for RotaryEmbI { l3: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { use candle::cuda_backend::cudarc::driver::{ - CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, }; use candle::cuda_backend::{kernel_name, kernels, WrapErr}; use candle::{CudaDevice, WithDType}; @@ -117,12 +117,17 @@ impl candle::CustomOp3 for RotaryEmbI { let (b, h, t, d) = l_src.shape().dims4()?; let el = b * h * t * d; let cfg = LaunchConfig::for_num_elems((el / 2) as u32); - let func = dev.get_or_load_func(&kernel_name::("rope_i"), kernels::REDUCE)?; + let func = dev.get_or_load_func(&kernel_name::("rope_i"), &kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(el) }.w()?; - let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32); + let mut builder = func.builder(); + builder.arg(&src); + builder.arg(&cos); + builder.arg(&sin); + builder.arg(&dst); + candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(dst) } @@ -333,7 +338,7 @@ impl candle::CustomOp3 for RotaryEmb { l3: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { use candle::cuda_backend::cudarc::driver::{ - CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, }; use candle::cuda_backend::{kernel_name, kernels, WrapErr}; use candle::{CudaDevice, WithDType}; @@ -362,20 +367,17 @@ impl candle::CustomOp3 for RotaryEmb { let (b, h, t, d) = l_src.shape().dims4()?; let el = b * h * t * d; let cfg = LaunchConfig::for_num_elems((el / 2) as u32); - let func = dev.get_or_load_func(&kernel_name::("rope"), kernels::REDUCE)?; + let func = dev.get_or_load_func(&kernel_name::("rope"), &kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(el) }.w()?; - let params = ( - &src, - &cos, - &sin, - &dst, - (b * h) as u32, - (t * d) as u32, - d as u32, - ); + let mut builder = func.builder(); + builder.arg(&src); + builder.arg(&cos); + builder.arg(&sin); + builder.arg(&dst); + candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(dst) } @@ -587,7 +589,7 @@ impl candle::CustomOp3 for RotaryEmbThd { l3: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { use candle::cuda_backend::cudarc::driver::{ - CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, }; use candle::cuda_backend::{kernel_name, kernels, WrapErr}; use candle::{CudaDevice, WithDType}; @@ -616,14 +618,17 @@ impl candle::CustomOp3 for RotaryEmbThd { let (b, t, h, d) = l_src.shape().dims4()?; let el = b * h * t * d; let cfg = LaunchConfig::for_num_elems((el / 2) as u32); - let func = dev.get_or_load_func(&kernel_name::("rope_thd"), kernels::REDUCE)?; + let func = dev.get_or_load_func(&kernel_name::("rope_thd"), &kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(el) }.w()?; - let params = ( - &src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32, - ); + let mut builder = func.builder(); + builder.arg(&src); + builder.arg(&cos); + builder.arg(&sin); + builder.arg(&dst); + candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32); // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; + unsafe { builder.launch(cfg) }.w()?; Ok(dst) } diff --git a/candle-onnx/Cargo.toml b/candle-onnx/Cargo.toml index b80c7df3..b36de583 100644 --- a/candle-onnx/Cargo.toml +++ b/candle-onnx/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-onnx" -version = "0.8.4" +version = "0.9.0-alpha.1" edition = "2021" description = "ONNX support for Candle" @@ -10,8 +10,8 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -candle = { path = "../candle-core", package = "candle-core", version = "0.8.4" } -candle-nn = { path = "../candle-nn", version = "0.8.4" } +candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.1" } +candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.1" } prost = "0.12.1" [build-dependencies]