Cuda cleanup. (#2880)

* Cuda cleanup.

* More fixes.
This commit is contained in:
Laurent Mazare
2025-04-11 21:43:35 +02:00
committed by GitHub
parent eb478ece92
commit acc5bd335f
8 changed files with 193 additions and 161 deletions

View File

@ -68,7 +68,7 @@ impl CustomOp1 for LayerNorm {
Some((o1, o2)) => slice.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
let dst = unsafe { dev.alloc::<f32>(elem_count) }?;
let func =
dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?;
let cfg = LaunchConfig {