mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Merge pull request #246 from LaurentMazare/rename_custom_op
Rename exposed ops.
This commit is contained in:
@ -109,11 +109,11 @@ pub trait CustomOp1: Send + Sync {
|
|||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)>;
|
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
fn cuda_fwd(&self, _: &CudaStorage, _: &Layout) -> Result<(CudaStorage, Shape)> {
|
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||||
Err(crate::Error::Cuda(
|
Err(crate::Error::Cuda(
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
))
|
))
|
||||||
|
@ -33,12 +33,12 @@ impl CustomOp1 for LayerNorm {
|
|||||||
"layer-norm"
|
"layer-norm"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
|
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||||
let (dim1, dim2) = l.shape().dims2()?;
|
let (dim1, dim2) = layout.shape().dims2()?;
|
||||||
let s = s.as_slice::<f32>()?;
|
let slice = storage.as_slice::<f32>()?;
|
||||||
let src = match l.contiguous_offsets() {
|
let src = match layout.contiguous_offsets() {
|
||||||
None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||||
Some((o1, o2)) => &s[o1..o2],
|
Some((o1, o2)) => &slice[o1..o2],
|
||||||
};
|
};
|
||||||
let mut dst = Vec::with_capacity(dim1 * dim2);
|
let mut dst = Vec::with_capacity(dim1 * dim2);
|
||||||
for idx1 in 0..dim1 {
|
for idx1 in 0..dim1 {
|
||||||
@ -48,30 +48,30 @@ impl CustomOp1 for LayerNorm {
|
|||||||
dst.extend(src.iter().map(|x| x * s_variance))
|
dst.extend(src.iter().map(|x| x * s_variance))
|
||||||
}
|
}
|
||||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||||
Ok((storage, l.shape().clone()))
|
Ok((storage, layout.shape().clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
fn cuda_fwd(
|
fn cuda_fwd(
|
||||||
&self,
|
&self,
|
||||||
s: &candle::CudaStorage,
|
storage: &candle::CudaStorage,
|
||||||
l: &Layout,
|
layout: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::{cudarc, WrapErr};
|
use candle::cuda_backend::{cudarc, WrapErr};
|
||||||
use cudarc::driver::{LaunchAsync, LaunchConfig};
|
use cudarc::driver::{LaunchAsync, LaunchConfig};
|
||||||
let (d1, d2) = l.shape().dims2()?;
|
let (d1, d2) = layout.shape().dims2()?;
|
||||||
let d1 = d1 as u32;
|
let d1 = d1 as u32;
|
||||||
let d2 = d2 as u32;
|
let d2 = d2 as u32;
|
||||||
let dev = s.device().clone();
|
let dev = storage.device().clone();
|
||||||
let s = s.as_cuda_slice::<f32>()?;
|
let slice = storage.as_cuda_slice::<f32>()?;
|
||||||
let s = match l.contiguous_offsets() {
|
let slice = match layout.contiguous_offsets() {
|
||||||
None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||||
Some((o1, o2)) => s.slice(o1..o2),
|
Some((o1, o2)) => slice.slice(o1..o2),
|
||||||
};
|
};
|
||||||
let elem_count = l.shape().elem_count();
|
let elem_count = layout.shape().elem_count();
|
||||||
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||||
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
|
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
|
||||||
let params = (&dst, &s, self.eps, d1, d2);
|
let params = (&dst, &slice, self.eps, d1, d2);
|
||||||
let cfg = LaunchConfig {
|
let cfg = LaunchConfig {
|
||||||
grid_dim: (d1, 1, 1),
|
grid_dim: (d1, 1, 1),
|
||||||
block_dim: (d2, 1, 1),
|
block_dim: (d2, 1, 1),
|
||||||
@ -80,7 +80,7 @@ impl CustomOp1 for LayerNorm {
|
|||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
|
||||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||||
Ok((dst, l.shape().clone()))
|
Ok((dst, layout.shape().clone()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user