mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Update to cudarc 0.14 (breaking change). (#2858)
* Start updating to cudarc 0.14. * Adapt a couple more things. * And a couple more fixes. * More tweaks. * And a couple more fixes. * Bump the major version number. * Proper module system for the cuda kernels. * Proper ptx loading. * Launch the sort kernel. * Custom op. * Start using the builder pattern. * More builder. * More builder. * Get candle-core to compile. * Get the tests to pass. * Get candle-nn to work too. * Support for custom cuda functions. * cudnn fixes. * Get flash attn to run. * Switch the crate versions to be alpha. * Bump the ug dependency.
This commit is contained in:
@ -56,7 +56,7 @@ impl CustomOp1 for LayerNorm {
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig};
|
||||
use candle::cuda_backend::cudarc::driver::{LaunchConfig, PushKernelArg};
|
||||
use candle::cuda_backend::WrapErr;
|
||||
let (d1, d2) = layout.shape().dims2()?;
|
||||
let d1 = d1 as u32;
|
||||
@ -69,14 +69,18 @@ impl CustomOp1 for LayerNorm {
|
||||
};
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
|
||||
let params = (&dst, &slice, self.eps, d1, d2);
|
||||
let func =
|
||||
dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?;
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (d1, 1, 1),
|
||||
block_dim: (d2, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&dst);
|
||||
builder.arg(&slice);
|
||||
candle::builder_arg!(builder, self.eps, d1, d2);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, layout.shape().clone()))
|
||||
|
Reference in New Issue
Block a user