mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Merge pull request #33 from LaurentMazare/cuda-map
Simplify the dtype matchings in the cuda backend
This commit is contained in:
@ -487,6 +487,7 @@ fn main() -> Result<()> {
|
|||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||||
let input = Tensor::new(ctxt, &device)?;
|
let input = Tensor::new(ctxt, &device)?;
|
||||||
let logits = llama.forward(&input, &freqs_cis)?;
|
let logits = llama.forward(&input, &freqs_cis)?;
|
||||||
@ -496,6 +497,7 @@ fn main() -> Result<()> {
|
|||||||
let next_token = distr.sample(&mut rng) as u32;
|
let next_token = distr.sample(&mut rng) as u32;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
new_tokens.push(next_token);
|
new_tokens.push(next_token);
|
||||||
|
println!("> {:?}", start_gen.elapsed());
|
||||||
println!(
|
println!(
|
||||||
"{} token: {} '{}'",
|
"{} token: {} '{}'",
|
||||||
index + 1,
|
index + 1,
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
use crate::{CpuStorage, DType, Layout, Shape};
|
use crate::{CpuStorage, DType, Layout, Shape, WithDType};
|
||||||
use candle_kernels as kernels;
|
use candle_kernels as kernels;
|
||||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||||
use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig};
|
use cudarc::driver::{
|
||||||
|
CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||||
|
};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@ -242,6 +244,260 @@ enum CudaStorageSlice {
|
|||||||
F32(CudaSlice<f32>),
|
F32(CudaSlice<f32>),
|
||||||
F64(CudaSlice<f64>),
|
F64(CudaSlice<f64>),
|
||||||
}
|
}
|
||||||
|
type S = CudaStorageSlice;
|
||||||
|
|
||||||
|
trait Map1 {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
|
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||||
|
let out = match s {
|
||||||
|
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||||
|
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||||
|
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||||
|
S::F32(s) => S::F32(self.f(s, d, l)?),
|
||||||
|
S::F64(s) => S::F64(self.f(s, d, l)?),
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
trait Map2 {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src1: &CudaSlice<T>,
|
||||||
|
layout1: &Layout,
|
||||||
|
src2: &CudaSlice<T>,
|
||||||
|
layout2: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
|
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||||
|
let out = match (s1, s2) {
|
||||||
|
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Clone;
|
||||||
|
impl Map1 for Clone {
|
||||||
|
fn f<T: DeviceRepr>(
|
||||||
|
&self,
|
||||||
|
s: &CudaSlice<T>,
|
||||||
|
_: &CudaDevice,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
Ok(s.try_clone()?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn kernel_name<T: WithDType>(root: &str) -> String {
|
||||||
|
let dtype = T::DTYPE.as_str();
|
||||||
|
format!("{root}_{dtype}")
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Affine(f64, f64);
|
||||||
|
impl Map1 for Affine {
|
||||||
|
fn f<T: DeviceRepr + WithDType>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let shape = layout.shape();
|
||||||
|
let dims = shape.dims();
|
||||||
|
let el = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
|
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
||||||
|
let src = &src.slice(layout.start_offset()..);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<T>(el) }?;
|
||||||
|
let params = (
|
||||||
|
el,
|
||||||
|
dims.len(),
|
||||||
|
&ds,
|
||||||
|
src,
|
||||||
|
&out,
|
||||||
|
T::from_f64(self.0),
|
||||||
|
T::from_f64(self.1),
|
||||||
|
);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Sum<'a>(&'a [usize]);
|
||||||
|
impl<'a> Map1 for Sum<'a> {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let shape = layout.shape();
|
||||||
|
let src_dims = shape.dims();
|
||||||
|
let el = shape.elem_count();
|
||||||
|
let mut dst_el = el;
|
||||||
|
for &sum_dim in self.0.iter() {
|
||||||
|
dst_el /= src_dims[sum_dim];
|
||||||
|
}
|
||||||
|
let mut sum_dims = self.0.to_vec();
|
||||||
|
// Sort the sum_dims as they have to be processed from left to right when converting the
|
||||||
|
// indexes.
|
||||||
|
sum_dims.sort();
|
||||||
|
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
|
||||||
|
let sum_dims_s: Vec<usize> = sum_dims
|
||||||
|
.iter()
|
||||||
|
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
|
||||||
|
.collect();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
|
let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
|
||||||
|
let src = &src.slice(layout.start_offset()..);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("sum"), kernels::REDUCE)?;
|
||||||
|
let out = dev.alloc_zeros::<T>(dst_el)?;
|
||||||
|
let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<U: crate::op::UnaryOp> Map1 for U {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let shape = layout.shape();
|
||||||
|
let dims = shape.dims();
|
||||||
|
let el_count = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||||
|
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
||||||
|
let src = &src.slice(layout.start_offset()..);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<T>(el_count) }?;
|
||||||
|
let params = (el_count, dims.len(), &ds, src, &out);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Embedding<'a>(&'a CudaStorage, &'a Layout);
|
||||||
|
impl<'a> Map1 for Embedding<'a> {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
rhs: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let ids_l = &self.1;
|
||||||
|
let ids = match &self.0.slice {
|
||||||
|
CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
|
||||||
|
_ => Err(CudaError::UnexpectedDType {
|
||||||
|
msg: "embedding ids should be u32",
|
||||||
|
expected: DType::U32,
|
||||||
|
got: self.0.dtype(),
|
||||||
|
})?,
|
||||||
|
};
|
||||||
|
let ids = &ids;
|
||||||
|
let shape = ids_l.shape();
|
||||||
|
let (v_size, h_size) = rhs_l
|
||||||
|
.shape()
|
||||||
|
.r2()
|
||||||
|
.map_err(|e| CudaError::WrappedError(Box::new(e)))?;
|
||||||
|
let dims = shape.dims();
|
||||||
|
let el = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
|
let ds = dev.htod_copy([dims, ids_l.stride()].concat())?;
|
||||||
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("emb"), kernels::EMBEDDINGS)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<T>(el * h_size) }?;
|
||||||
|
let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct WhereCond<'a>(&'a CudaStorage, &'a Layout);
|
||||||
|
impl<'a> Map2 for WhereCond<'a> {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
t: &CudaSlice<T>,
|
||||||
|
layout_t: &Layout,
|
||||||
|
f: &CudaSlice<T>,
|
||||||
|
layout_f: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let ids_l = &self.1;
|
||||||
|
let ids = match &self.0.slice {
|
||||||
|
CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
|
||||||
|
_ => Err(CudaError::UnexpectedDType {
|
||||||
|
msg: "where conditions should be u32",
|
||||||
|
expected: DType::U32,
|
||||||
|
got: self.0.dtype(),
|
||||||
|
})?,
|
||||||
|
};
|
||||||
|
let ids = &ids;
|
||||||
|
let shape = ids_l.shape();
|
||||||
|
let dims = shape.dims();
|
||||||
|
let el = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
|
let ds =
|
||||||
|
dev.htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?;
|
||||||
|
let t = &t.slice(layout_t.start_offset()..);
|
||||||
|
let f = &f.slice(layout_f.start_offset()..);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("where"), kernels::TERNARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<T>(el) }?;
|
||||||
|
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||||
|
// SAFETY: ffi
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<U: crate::op::BinaryOp> Map2 for U {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
lhs: &CudaSlice<T>,
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs: &CudaSlice<T>,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let shape = lhs_l.shape();
|
||||||
|
let dims = shape.dims();
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
|
let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?;
|
||||||
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||||
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<T>(elem_count) }?;
|
||||||
|
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||||
|
// SAFETY: ffi
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn slice_src_and_dst<'a, T>(
|
fn slice_src_and_dst<'a, T>(
|
||||||
src: &'a CudaSlice<T>,
|
src: &'a CudaSlice<T>,
|
||||||
@ -332,14 +588,8 @@ fn gemm_config<T>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl CudaStorage {
|
impl CudaStorage {
|
||||||
pub fn try_clone(&self) -> Result<Self> {
|
pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
|
||||||
let slice = match &self.slice {
|
let slice = Clone.map(&self.slice, self.device(), layout)?;
|
||||||
CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?),
|
|
||||||
CudaStorageSlice::BF16(slice) => CudaStorageSlice::BF16(slice.try_clone()?),
|
|
||||||
CudaStorageSlice::F16(slice) => CudaStorageSlice::F16(slice.try_clone()?),
|
|
||||||
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
|
|
||||||
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
|
|
||||||
};
|
|
||||||
let device = self.device.clone();
|
let device = self.device.clone();
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
@ -420,152 +670,14 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||||
let shape = layout.shape();
|
let device = self.device().clone();
|
||||||
let dims = shape.dims();
|
let slice = Affine(mul, add).map(&self.slice, &device, layout)?;
|
||||||
let el_count = shape.elem_count();
|
|
||||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
|
||||||
let dev = self.device();
|
|
||||||
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
|
||||||
let slice = match &self.slice {
|
|
||||||
CudaStorageSlice::U32(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<u32>(el_count) }?;
|
|
||||||
let params = (el_count, dims.len(), &ds, arg, &out, mul as u32, add as u32);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::U32(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::BF16(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
|
|
||||||
let params = (
|
|
||||||
el_count,
|
|
||||||
dims.len(),
|
|
||||||
&ds,
|
|
||||||
arg,
|
|
||||||
&out,
|
|
||||||
bf16::from_f64(mul),
|
|
||||||
bf16::from_f64(add),
|
|
||||||
);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::BF16(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F16(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f16>(el_count) }?;
|
|
||||||
let params = (
|
|
||||||
el_count,
|
|
||||||
dims.len(),
|
|
||||||
&ds,
|
|
||||||
arg,
|
|
||||||
&out,
|
|
||||||
f16::from_f64(mul),
|
|
||||||
f16::from_f64(add),
|
|
||||||
);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F16(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F32(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f32>(el_count) }?;
|
|
||||||
let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F32(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F64(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f64>(el_count) }?;
|
|
||||||
let params = (el_count, dims.len(), &ds, arg, &out, mul, add);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F64(out)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let device = dev.clone();
|
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||||
let shape = layout.shape();
|
let device = self.device().clone();
|
||||||
let src_dims = shape.dims();
|
let slice = Sum(sum_dims).map(&self.slice, &device, layout)?;
|
||||||
let el = shape.elem_count();
|
|
||||||
let mut dst_el = el;
|
|
||||||
for &sum_dim in sum_dims.iter() {
|
|
||||||
dst_el /= src_dims[sum_dim];
|
|
||||||
}
|
|
||||||
let mut sum_dims = sum_dims.to_vec();
|
|
||||||
// Sort the sum_dims as they have to be processed from left to right when converting the
|
|
||||||
// indexes.
|
|
||||||
sum_dims.sort();
|
|
||||||
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
|
|
||||||
let sum_dims_s: Vec<usize> = sum_dims
|
|
||||||
.iter()
|
|
||||||
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
|
|
||||||
.collect();
|
|
||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
||||||
let dev = self.device();
|
|
||||||
let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
|
|
||||||
let slice = match &self.slice {
|
|
||||||
CudaStorageSlice::U32(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?;
|
|
||||||
let out = dev.alloc_zeros::<u32>(dst_el)?;
|
|
||||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::U32(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::BF16(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?;
|
|
||||||
let out = dev.alloc_zeros::<bf16>(dst_el)?;
|
|
||||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::BF16(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F16(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?;
|
|
||||||
let out = dev.alloc_zeros::<f16>(dst_el)?;
|
|
||||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F16(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F32(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?;
|
|
||||||
let out = dev.alloc_zeros::<f32>(dst_el)?;
|
|
||||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F32(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F64(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?;
|
|
||||||
let out = dev.alloc_zeros::<f64>(dst_el)?;
|
|
||||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F64(out)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let device = dev.clone();
|
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -576,58 +688,8 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
pub(crate) fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||||
let shape = layout.shape();
|
let device = self.device().clone();
|
||||||
let dims = shape.dims();
|
let slice = U::V.map(&self.slice, &device, layout)?;
|
||||||
let el_count = shape.elem_count();
|
|
||||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
|
||||||
let dev = &self.device;
|
|
||||||
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
|
||||||
let slice = match &self.slice {
|
|
||||||
CudaStorageSlice::U32(_arg) => {
|
|
||||||
todo!("No unary kernels for u32");
|
|
||||||
}
|
|
||||||
CudaStorageSlice::BF16(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
|
|
||||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::BF16(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F16(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f16>(el_count) }?;
|
|
||||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F16(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F32(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f32>(el_count) }?;
|
|
||||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F32(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F64(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f64>(el_count) }?;
|
|
||||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F64(out)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let device = dev.clone();
|
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -637,72 +699,8 @@ impl CudaStorage {
|
|||||||
lhs_l: &Layout,
|
lhs_l: &Layout,
|
||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let shape = lhs_l.shape();
|
let device = self.device().clone();
|
||||||
let dims = shape.dims();
|
let slice = B::V.map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?;
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
|
||||||
let dev = self.device();
|
|
||||||
let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?;
|
|
||||||
let slice = match (&self.slice, &rhs.slice) {
|
|
||||||
(CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
|
|
||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<bf16>(elem_count) }?;
|
|
||||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
|
||||||
// SAFETY: ffi
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::BF16(out)
|
|
||||||
}
|
|
||||||
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
|
||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f16>(elem_count) }?;
|
|
||||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
|
||||||
// SAFETY: ffi
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F16(out)
|
|
||||||
}
|
|
||||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
|
||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
|
||||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
|
||||||
// SAFETY: ffi
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F32(out)
|
|
||||||
}
|
|
||||||
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
|
||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?;
|
|
||||||
let out = unsafe { dev.alloc::<f64>(elem_count) }?;
|
|
||||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
|
||||||
// SAFETY: ffi
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F64(out)
|
|
||||||
}
|
|
||||||
(CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => {
|
|
||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?;
|
|
||||||
let out = unsafe { dev.alloc::<u32>(elem_count) }?;
|
|
||||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
|
||||||
// SAFETY: ffi
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::U32(out)
|
|
||||||
}
|
|
||||||
// The dtypes should have been checked at this point so this is an internal error.
|
|
||||||
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
|
|
||||||
};
|
|
||||||
let device = dev.clone();
|
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -740,163 +738,18 @@ impl CudaStorage {
|
|||||||
&self,
|
&self,
|
||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
t: &Self,
|
t: &Self,
|
||||||
layout_t: &Layout,
|
t_l: &Layout,
|
||||||
f: &Self,
|
f: &Self,
|
||||||
layout_f: &Layout,
|
f_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let ids = match &self.slice {
|
let device = self.device().clone();
|
||||||
CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
|
let slice = WhereCond(self, layout).map(&t.slice, t_l, &f.slice, f_l, &device)?;
|
||||||
_ => Err(CudaError::UnexpectedDType {
|
|
||||||
msg: "where conditions should be u32",
|
|
||||||
expected: DType::U32,
|
|
||||||
got: self.dtype(),
|
|
||||||
})?,
|
|
||||||
};
|
|
||||||
let ids = &ids;
|
|
||||||
let shape = layout.shape();
|
|
||||||
let dims = shape.dims();
|
|
||||||
let el = shape.elem_count();
|
|
||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
||||||
let dev = self.device();
|
|
||||||
let ds =
|
|
||||||
dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?;
|
|
||||||
let slice = match (&t.slice, &f.slice) {
|
|
||||||
(CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => {
|
|
||||||
let t = &t.slice(layout_t.start_offset()..);
|
|
||||||
let f = &f.slice(layout_f.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<bf16>(el) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
|
||||||
// SAFETY: ffi
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::BF16(out)
|
|
||||||
}
|
|
||||||
(CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => {
|
|
||||||
let t = &t.slice(layout_t.start_offset()..);
|
|
||||||
let f = &f.slice(layout_f.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f16>(el) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
|
||||||
// SAFETY: ffi
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F16(out)
|
|
||||||
}
|
|
||||||
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
|
|
||||||
let t = &t.slice(layout_t.start_offset()..);
|
|
||||||
let f = &f.slice(layout_f.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f32>(el) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
|
||||||
// SAFETY: ffi
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F32(out)
|
|
||||||
}
|
|
||||||
(CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => {
|
|
||||||
let t = &t.slice(layout_t.start_offset()..);
|
|
||||||
let f = &f.slice(layout_f.start_offset()..);
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?;
|
|
||||||
let out = unsafe { dev.alloc::<f64>(el) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
|
||||||
// SAFETY: ffi
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F64(out)
|
|
||||||
}
|
|
||||||
(CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => {
|
|
||||||
let t = &t.slice(layout_t.start_offset()..);
|
|
||||||
let f = &f.slice(layout_f.start_offset()..);
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?;
|
|
||||||
let out = unsafe { dev.alloc::<u32>(el) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
|
||||||
// SAFETY: ffi
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::U32(out)
|
|
||||||
}
|
|
||||||
// The dtypes should have been checked at this point so this is an internal error.
|
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
|
||||||
};
|
|
||||||
let device = dev.clone();
|
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||||
let ids = match &self.slice {
|
let device = self.device().clone();
|
||||||
CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
|
let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
|
||||||
_ => Err(CudaError::UnexpectedDType {
|
|
||||||
msg: "embedding ids should be u32",
|
|
||||||
expected: DType::U32,
|
|
||||||
got: self.dtype(),
|
|
||||||
})?,
|
|
||||||
};
|
|
||||||
let ids = &ids;
|
|
||||||
let shape = layout.shape();
|
|
||||||
let (v_size, h_size) = rhs_l
|
|
||||||
.shape()
|
|
||||||
.r2()
|
|
||||||
.map_err(|e| CudaError::WrappedError(Box::new(e)))?;
|
|
||||||
let dims = shape.dims();
|
|
||||||
let el = shape.elem_count();
|
|
||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
||||||
let dev = self.device();
|
|
||||||
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
|
||||||
let slice = match &rhs.slice {
|
|
||||||
// The kernels below assume that rhs is contiguous.
|
|
||||||
CudaStorageSlice::U32(arg) => {
|
|
||||||
let arg = &arg.slice(rhs_l.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::U32(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::BF16(arg) => {
|
|
||||||
let arg = &arg.slice(rhs_l.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<bf16>(el * h_size) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::BF16(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F16(arg) => {
|
|
||||||
let arg = &arg.slice(rhs_l.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f16>(el * h_size) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F16(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F32(arg) => {
|
|
||||||
let arg = &arg.slice(rhs_l.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f32>(el * h_size) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F32(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F64(arg) => {
|
|
||||||
let arg = &arg.slice(rhs_l.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f64>(el * h_size) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F64(out)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let device = dev.clone();
|
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ impl CudaDevice {
|
|||||||
pub struct CudaStorage;
|
pub struct CudaStorage;
|
||||||
|
|
||||||
impl CudaStorage {
|
impl CudaStorage {
|
||||||
pub fn try_clone(&self) -> Result<Self> {
|
pub fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,11 +43,8 @@ pub(crate) enum Op {
|
|||||||
|
|
||||||
pub(crate) trait UnaryOp {
|
pub(crate) trait UnaryOp {
|
||||||
const NAME: &'static str;
|
const NAME: &'static str;
|
||||||
const KERNEL_BF16: &'static str;
|
const KERNEL: &'static str;
|
||||||
const KERNEL_F16: &'static str;
|
const V: Self;
|
||||||
const KERNEL_F32: &'static str;
|
|
||||||
const KERNEL_F64: &'static str;
|
|
||||||
const KERNEL_U32: &'static str;
|
|
||||||
fn bf16(v1: bf16) -> bf16;
|
fn bf16(v1: bf16) -> bf16;
|
||||||
fn f16(v1: f16) -> f16;
|
fn f16(v1: f16) -> f16;
|
||||||
fn f32(v1: f32) -> f32;
|
fn f32(v1: f32) -> f32;
|
||||||
@ -57,11 +54,8 @@ pub(crate) trait UnaryOp {
|
|||||||
|
|
||||||
pub(crate) trait BinaryOp {
|
pub(crate) trait BinaryOp {
|
||||||
const NAME: &'static str;
|
const NAME: &'static str;
|
||||||
const KERNEL_BF16: &'static str;
|
const KERNEL: &'static str;
|
||||||
const KERNEL_F16: &'static str;
|
const V: Self;
|
||||||
const KERNEL_F32: &'static str;
|
|
||||||
const KERNEL_F64: &'static str;
|
|
||||||
const KERNEL_U32: &'static str;
|
|
||||||
fn bf16(v1: bf16, v2: bf16) -> bf16;
|
fn bf16(v1: bf16, v2: bf16) -> bf16;
|
||||||
fn f16(v1: f16, v2: f16) -> f16;
|
fn f16(v1: f16, v2: f16) -> f16;
|
||||||
fn f32(v1: f32, v2: f32) -> f32;
|
fn f32(v1: f32, v2: f32) -> f32;
|
||||||
@ -88,11 +82,8 @@ macro_rules! bin_op {
|
|||||||
($op:ident, $name: literal, $e: expr) => {
|
($op:ident, $name: literal, $e: expr) => {
|
||||||
impl BinaryOp for $op {
|
impl BinaryOp for $op {
|
||||||
const NAME: &'static str = $name;
|
const NAME: &'static str = $name;
|
||||||
const KERNEL_BF16: &'static str = concat!("b", $name, "_bf16");
|
const KERNEL: &'static str = concat!("b", $name);
|
||||||
const KERNEL_F16: &'static str = concat!("b", $name, "_f16");
|
const V: Self = $op;
|
||||||
const KERNEL_F32: &'static str = concat!("b", $name, "_f32");
|
|
||||||
const KERNEL_F64: &'static str = concat!("b", $name, "_f64");
|
|
||||||
const KERNEL_U32: &'static str = concat!("b", $name, "_u32");
|
|
||||||
fn bf16(v1: bf16, v2: bf16) -> bf16 {
|
fn bf16(v1: bf16, v2: bf16) -> bf16 {
|
||||||
$e(v1, v2)
|
$e(v1, v2)
|
||||||
}
|
}
|
||||||
@ -121,11 +112,8 @@ macro_rules! unary_op {
|
|||||||
($op: ident, $name: literal, $a: ident, $e: expr) => {
|
($op: ident, $name: literal, $a: ident, $e: expr) => {
|
||||||
impl UnaryOp for $op {
|
impl UnaryOp for $op {
|
||||||
const NAME: &'static str = $name;
|
const NAME: &'static str = $name;
|
||||||
const KERNEL_BF16: &'static str = concat!("u", $name, "_bf16");
|
const KERNEL: &'static str = concat!("u", $name);
|
||||||
const KERNEL_F16: &'static str = concat!("u", $name, "_f16");
|
const V: Self = $op;
|
||||||
const KERNEL_F32: &'static str = concat!("u", $name, "_f32");
|
|
||||||
const KERNEL_F64: &'static str = concat!("u", $name, "_f64");
|
|
||||||
const KERNEL_U32: &'static str = concat!("u", $name, "_u32");
|
|
||||||
fn bf16($a: bf16) -> bf16 {
|
fn bf16($a: bf16) -> bf16 {
|
||||||
$e
|
$e
|
||||||
}
|
}
|
||||||
@ -158,6 +146,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt());
|
|||||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||||
impl UnaryOp for Gelu {
|
impl UnaryOp for Gelu {
|
||||||
const NAME: &'static str = "gelu";
|
const NAME: &'static str = "gelu";
|
||||||
|
const V: Self = Gelu;
|
||||||
fn bf16(v: bf16) -> bf16 {
|
fn bf16(v: bf16) -> bf16 {
|
||||||
bf16::from_f32_const(0.5)
|
bf16::from_f32_const(0.5)
|
||||||
* v
|
* v
|
||||||
@ -191,20 +180,13 @@ impl UnaryOp for Gelu {
|
|||||||
fn u32(_: u32) -> u32 {
|
fn u32(_: u32) -> u32 {
|
||||||
0
|
0
|
||||||
}
|
}
|
||||||
const KERNEL_BF16: &'static str = "ugelu_bf16";
|
const KERNEL: &'static str = "ugelu";
|
||||||
const KERNEL_F16: &'static str = "ugelu_f16";
|
|
||||||
const KERNEL_F32: &'static str = "ugelu_f32";
|
|
||||||
const KERNEL_F64: &'static str = "ugelu_f64";
|
|
||||||
const KERNEL_U32: &'static str = "ugelu_u32";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UnaryOp for Relu {
|
impl UnaryOp for Relu {
|
||||||
const NAME: &'static str = "relu";
|
const NAME: &'static str = "relu";
|
||||||
const KERNEL_BF16: &'static str = "urelu_bf16";
|
const KERNEL: &'static str = "urelu";
|
||||||
const KERNEL_F16: &'static str = "urelu_f16";
|
const V: Self = Relu;
|
||||||
const KERNEL_F32: &'static str = "urelu_f32";
|
|
||||||
const KERNEL_F64: &'static str = "urelu_f64";
|
|
||||||
const KERNEL_U32: &'static str = "urelu_u32";
|
|
||||||
fn bf16(v: bf16) -> bf16 {
|
fn bf16(v: bf16) -> bf16 {
|
||||||
v.max(bf16::ZERO)
|
v.max(bf16::ZERO)
|
||||||
}
|
}
|
||||||
|
@ -9,11 +9,11 @@ pub enum Storage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Storage {
|
impl Storage {
|
||||||
pub fn try_clone(&self) -> Result<Self> {
|
pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
|
Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
|
||||||
Self::Cuda(storage) => {
|
Self::Cuda(storage) => {
|
||||||
let storage = storage.try_clone()?;
|
let storage = storage.try_clone(layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -709,7 +709,7 @@ impl Tensor {
|
|||||||
pub fn copy(&self) -> Result<Tensor> {
|
pub fn copy(&self) -> Result<Tensor> {
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: Arc::new(self.storage.try_clone()?),
|
storage: Arc::new(self.storage.try_clone(self.layout())?),
|
||||||
layout: self.layout.clone(),
|
layout: self.layout.clone(),
|
||||||
op: None, // TODO
|
op: None, // TODO
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
|
Reference in New Issue
Block a user