mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
@ -112,7 +112,7 @@ impl candle::CustomOp1 for Sigmoid {
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
||||
let out = unsafe { dev.alloc::<T>(el_count)? };
|
||||
|
||||
let mut builder = func.builder();
|
||||
candle::builder_arg!(builder, el_count, dims.len());
|
||||
@ -373,7 +373,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let dst = unsafe { dev.alloc::<T>(el)? };
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
@ -561,7 +561,7 @@ impl candle::CustomOp2 for RmsNorm {
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let dst = unsafe { dev.alloc::<T>(el)? };
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
@ -800,7 +800,7 @@ impl candle::CustomOp3 for LayerNorm {
|
||||
let func =
|
||||
dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let dst = unsafe { dev.alloc::<T>(el)? };
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
|
@ -119,7 +119,7 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let dst = unsafe { dev.alloc::<T>(el)? };
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&cos);
|
||||
@ -369,7 +369,7 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let dst = unsafe { dev.alloc::<T>(el)? };
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&cos);
|
||||
@ -620,7 +620,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let dst = unsafe { dev.alloc::<T>(el)? };
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&cos);
|
||||
|
Reference in New Issue
Block a user