mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Update to cudarc 0.14 (breaking change). (#2858)
* Start updating to cudarc 0.14. * Adapt a couple more things. * And a couple more fixes. * More tweaks. * And a couple more fixes. * Bump the major version number. * Proper module system for the cuda kernels. * Proper ptx loading. * Launch the sort kernel. * Custom op. * Start using the builder pattern. * More builder. * More builder. * Get candle-core to compile. * Get the tests to pass. * Get candle-nn to work too. * Support for custom cuda functions. * cudnn fixes. * Get flash attn to run. * Switch the crate versions to be alpha. * Bump the ug dependency.
This commit is contained in:
@ -396,7 +396,10 @@ impl UgIOp1 {
|
||||
{
|
||||
let device = device.as_cuda_device()?;
|
||||
let func = device.compile(name, kernel)?;
|
||||
Ok(Self { name, func })
|
||||
Ok(Self {
|
||||
name,
|
||||
func: func.into_cuda_function(),
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
{
|
||||
@ -459,16 +462,16 @@ impl InplaceOp1 for UgIOp1 {
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
|
||||
use crate::cuda_backend::WrapErr;
|
||||
use cudarc::driver::LaunchAsync;
|
||||
use cudarc::driver::PushKernelArg;
|
||||
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let stream = sto.device.cuda_stream();
|
||||
// TODO: support more dtypes.
|
||||
let sto = sto.as_cuda_slice::<f32>()?;
|
||||
let sto = match layout.contiguous_offsets() {
|
||||
None => crate::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => sto.slice(o1..o2),
|
||||
};
|
||||
let params = (&sto,);
|
||||
let (g, b) = if elem_count % 32 == 0 {
|
||||
(elem_count / 32, 32)
|
||||
} else {
|
||||
@ -479,7 +482,9 @@ impl InplaceOp1 for UgIOp1 {
|
||||
block_dim: (b as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe { self.func.clone().launch(cfg, params) }.w()?;
|
||||
let mut builder = stream.launch_builder(&self.func);
|
||||
builder.arg(&sto);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user