mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Update to cudarc 0.14 (breaking change). (#2858)
* Start updating to cudarc 0.14. * Adapt a couple more things. * And a couple more fixes. * More tweaks. * And a couple more fixes. * Bump the major version number. * Proper module system for the cuda kernels. * Proper ptx loading. * Launch the sort kernel. * Custom op. * Start using the builder pattern. * More builder. * More builder. * Get candle-core to compile. * Get the tests to pass. * Get candle-nn to work too. * Support for custom cuda functions. * cudnn fixes. * Get flash attn to run. * Switch the crate versions to be alpha. * Bump the ug dependency.
This commit is contained in:
@ -90,7 +90,7 @@ impl candle::CustomOp1 for Sigmoid {
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
|
||||
};
|
||||
use candle::cuda_backend::SlicePtrOrNull;
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||
@ -110,13 +110,17 @@ impl candle::CustomOp1 for Sigmoid {
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
||||
|
||||
let params = (el_count, dims.len(), &ds, src, &out);
|
||||
let mut builder = func.builder();
|
||||
candle::builder_arg!(builder, el_count, dims.len());
|
||||
ds.builder_arg(&mut builder);
|
||||
builder.arg(src);
|
||||
builder.arg(&out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
@ -340,7 +344,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
@ -367,12 +371,15 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
block_dim: (1, 32, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (&src, &dst, n_cols as i32);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, n_cols as i32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
@ -516,7 +523,7 @@ impl candle::CustomOp2 for RmsNorm {
|
||||
l2: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
@ -552,19 +559,16 @@ impl candle::CustomOp2 for RmsNorm {
|
||||
block_dim: (block_size, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
&src,
|
||||
&dst,
|
||||
&alpha,
|
||||
n_cols as i32,
|
||||
block_size as i32,
|
||||
self.eps,
|
||||
);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
builder.arg(&alpha);
|
||||
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
@ -751,7 +755,7 @@ impl candle::CustomOp3 for LayerNorm {
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
@ -793,20 +797,18 @@ impl candle::CustomOp3 for LayerNorm {
|
||||
block_dim: (block_size, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
|
||||
let func =
|
||||
dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
&src,
|
||||
&dst,
|
||||
&alpha,
|
||||
&beta,
|
||||
n_cols as i32,
|
||||
block_size as i32,
|
||||
self.eps,
|
||||
);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
builder.arg(&alpha);
|
||||
builder.arg(&beta);
|
||||
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
@ -88,7 +88,7 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
@ -117,12 +117,17 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&cos);
|
||||
builder.arg(&sin);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
@ -333,7 +338,7 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
@ -362,20 +367,17 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
&src,
|
||||
&cos,
|
||||
&sin,
|
||||
&dst,
|
||||
(b * h) as u32,
|
||||
(t * d) as u32,
|
||||
d as u32,
|
||||
);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&cos);
|
||||
builder.arg(&sin);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
@ -587,7 +589,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
@ -616,14 +618,17 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
&src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32,
|
||||
);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&cos);
|
||||
builder.arg(&sin);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user