|
|
|
@ -1,4 +1,4 @@
|
|
|
|
|
use crate::{CpuStorage, DType, Shape};
|
|
|
|
|
use crate::{CpuStorage, DType, Layout, Shape};
|
|
|
|
|
use candle_kernels as kernels;
|
|
|
|
|
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
|
|
|
|
use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig};
|
|
|
|
@ -26,6 +26,9 @@ pub enum CudaError {
|
|
|
|
|
#[error("internal error '{0}'")]
|
|
|
|
|
InternalError(&'static str),
|
|
|
|
|
|
|
|
|
|
#[error("internal error '{0}'")]
|
|
|
|
|
WrappedError(Box<dyn std::error::Error>),
|
|
|
|
|
|
|
|
|
|
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
|
|
|
|
MatMulNonContiguous {
|
|
|
|
|
lhs_stride: Vec<usize>,
|
|
|
|
@ -268,12 +271,14 @@ fn gemm_config<T>(
|
|
|
|
|
alpha: T,
|
|
|
|
|
beta: T,
|
|
|
|
|
(b, m, n, k): (usize, usize, usize, usize),
|
|
|
|
|
lhs_stride: &[usize],
|
|
|
|
|
rhs_stride: &[usize],
|
|
|
|
|
lhs_l: &Layout,
|
|
|
|
|
rhs_l: &Layout,
|
|
|
|
|
) -> Result<StridedBatchedConfig<T>> {
|
|
|
|
|
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm
|
|
|
|
|
use cudarc::cublas::sys::cublasOperation_t;
|
|
|
|
|
|
|
|
|
|
let lhs_stride = lhs_l.stride();
|
|
|
|
|
let rhs_stride = rhs_l.stride();
|
|
|
|
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
|
|
|
|
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
|
|
|
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
|
|
|
@ -352,19 +357,21 @@ impl CudaStorage {
|
|
|
|
|
&self.device
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
|
|
|
|
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
|
|
|
|
use cudarc::driver::DevicePtr;
|
|
|
|
|
let shape = layout.shape();
|
|
|
|
|
let dims = shape.dims();
|
|
|
|
|
let el = shape.elem_count();
|
|
|
|
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
|
|
|
let dev = self.device();
|
|
|
|
|
let ds = dev.htod_copy([dims, stride].concat())?;
|
|
|
|
|
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
|
|
|
|
let start_o = layout.start_offset();
|
|
|
|
|
let inp = match &self.slice {
|
|
|
|
|
CudaStorageSlice::U32(inp) => inp.device_ptr(),
|
|
|
|
|
CudaStorageSlice::BF16(inp) => inp.device_ptr(),
|
|
|
|
|
CudaStorageSlice::F16(inp) => inp.device_ptr(),
|
|
|
|
|
CudaStorageSlice::F32(inp) => inp.device_ptr(),
|
|
|
|
|
CudaStorageSlice::F64(inp) => inp.device_ptr(),
|
|
|
|
|
CudaStorageSlice::U32(inp) => inp.slice(start_o..).device_ptr(),
|
|
|
|
|
CudaStorageSlice::BF16(inp) => inp.slice(start_o..).device_ptr(),
|
|
|
|
|
CudaStorageSlice::F16(inp) => inp.slice(start_o..).device_ptr(),
|
|
|
|
|
CudaStorageSlice::F32(inp) => inp.slice(start_o..).device_ptr(),
|
|
|
|
|
CudaStorageSlice::F64(inp) => inp.slice(start_o..).device_ptr(),
|
|
|
|
|
};
|
|
|
|
|
let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str());
|
|
|
|
|
let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?;
|
|
|
|
@ -406,20 +413,16 @@ impl CudaStorage {
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub(crate) fn affine_impl(
|
|
|
|
|
&self,
|
|
|
|
|
shape: &Shape,
|
|
|
|
|
stride: &[usize],
|
|
|
|
|
mul: f64,
|
|
|
|
|
add: f64,
|
|
|
|
|
) -> Result<Self> {
|
|
|
|
|
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
|
|
|
|
let shape = layout.shape();
|
|
|
|
|
let dims = shape.dims();
|
|
|
|
|
let el_count = shape.elem_count();
|
|
|
|
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
|
|
|
|
let dev = self.device();
|
|
|
|
|
let ds = dev.htod_copy([dims, stride].concat())?;
|
|
|
|
|
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
|
|
|
|
let slice = match &self.slice {
|
|
|
|
|
CudaStorageSlice::U32(arg) => {
|
|
|
|
|
let arg = &arg.slice(layout.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<u32>(el_count) }?;
|
|
|
|
@ -429,6 +432,7 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::U32(out)
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::BF16(arg) => {
|
|
|
|
|
let arg = &arg.slice(layout.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
|
|
|
|
@ -446,6 +450,7 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::BF16(out)
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::F16(arg) => {
|
|
|
|
|
let arg = &arg.slice(layout.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<f16>(el_count) }?;
|
|
|
|
@ -463,6 +468,7 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F16(out)
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::F32(arg) => {
|
|
|
|
|
let arg = &arg.slice(layout.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<f32>(el_count) }?;
|
|
|
|
@ -472,6 +478,7 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F32(out)
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::F64(arg) => {
|
|
|
|
|
let arg = &arg.slice(layout.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<f64>(el_count) }?;
|
|
|
|
@ -485,7 +492,8 @@ impl CudaStorage {
|
|
|
|
|
Ok(Self { slice, device })
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result<Self> {
|
|
|
|
|
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
|
|
|
|
let shape = layout.shape();
|
|
|
|
|
let src_dims = shape.dims();
|
|
|
|
|
let el = shape.elem_count();
|
|
|
|
|
let mut dst_el = el;
|
|
|
|
@ -503,9 +511,10 @@ impl CudaStorage {
|
|
|
|
|
.collect();
|
|
|
|
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
|
|
|
let dev = self.device();
|
|
|
|
|
let ds = dev.htod_copy([src_dims, stride, &sum_dims_l, &sum_dims_s].concat())?;
|
|
|
|
|
let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
|
|
|
|
|
let slice = match &self.slice {
|
|
|
|
|
CudaStorageSlice::U32(arg) => {
|
|
|
|
|
let arg = &arg.slice(layout.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?;
|
|
|
|
|
let out = dev.alloc_zeros::<u32>(dst_el)?;
|
|
|
|
|
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
|
|
|
@ -514,6 +523,7 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::U32(out)
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::BF16(arg) => {
|
|
|
|
|
let arg = &arg.slice(layout.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?;
|
|
|
|
|
let out = dev.alloc_zeros::<bf16>(dst_el)?;
|
|
|
|
|
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
|
|
|
@ -522,6 +532,7 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::BF16(out)
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::F16(arg) => {
|
|
|
|
|
let arg = &arg.slice(layout.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?;
|
|
|
|
|
let out = dev.alloc_zeros::<f16>(dst_el)?;
|
|
|
|
|
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
|
|
|
@ -530,6 +541,7 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F16(out)
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::F32(arg) => {
|
|
|
|
|
let arg = &arg.slice(layout.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?;
|
|
|
|
|
let out = dev.alloc_zeros::<f32>(dst_el)?;
|
|
|
|
|
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
|
|
|
@ -538,6 +550,7 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F32(out)
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::F64(arg) => {
|
|
|
|
|
let arg = &arg.slice(layout.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?;
|
|
|
|
|
let out = dev.alloc_zeros::<f64>(dst_el)?;
|
|
|
|
|
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
|
|
|
@ -556,21 +569,19 @@ impl CudaStorage {
|
|
|
|
|
))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub(crate) fn unary_impl<U: crate::op::UnaryOp>(
|
|
|
|
|
&self,
|
|
|
|
|
shape: &Shape,
|
|
|
|
|
stride: &[usize],
|
|
|
|
|
) -> Result<Self> {
|
|
|
|
|
pub(crate) fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
|
|
|
|
let shape = layout.shape();
|
|
|
|
|
let dims = shape.dims();
|
|
|
|
|
let el_count = shape.elem_count();
|
|
|
|
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
|
|
|
|
let dev = &self.device;
|
|
|
|
|
let ds = dev.htod_copy([dims, stride].concat())?;
|
|
|
|
|
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
|
|
|
|
let slice = match &self.slice {
|
|
|
|
|
CudaStorageSlice::U32(_arg) => {
|
|
|
|
|
todo!("No unary kernels for u32");
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::BF16(arg) => {
|
|
|
|
|
let arg = &arg.slice(layout.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
|
|
|
|
@ -580,6 +591,7 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::BF16(out)
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::F16(arg) => {
|
|
|
|
|
let arg = &arg.slice(layout.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<f16>(el_count) }?;
|
|
|
|
@ -589,6 +601,7 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F16(out)
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::F32(arg) => {
|
|
|
|
|
let arg = &arg.slice(layout.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<f32>(el_count) }?;
|
|
|
|
@ -598,6 +611,7 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F32(out)
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::F64(arg) => {
|
|
|
|
|
let arg = &arg.slice(layout.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<f64>(el_count) }?;
|
|
|
|
@ -614,17 +628,19 @@ impl CudaStorage {
|
|
|
|
|
pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
|
|
|
|
|
&self,
|
|
|
|
|
rhs: &Self,
|
|
|
|
|
shape: &Shape,
|
|
|
|
|
lhs_stride: &[usize],
|
|
|
|
|
rhs_stride: &[usize],
|
|
|
|
|
lhs_l: &Layout,
|
|
|
|
|
rhs_l: &Layout,
|
|
|
|
|
) -> Result<Self> {
|
|
|
|
|
let shape = lhs_l.shape();
|
|
|
|
|
let dims = shape.dims();
|
|
|
|
|
let elem_count = shape.elem_count();
|
|
|
|
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
|
|
|
|
let dev = self.device();
|
|
|
|
|
let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
|
|
|
|
|
let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?;
|
|
|
|
|
let slice = match (&self.slice, &rhs.slice) {
|
|
|
|
|
(CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
|
|
|
|
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
|
|
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<bf16>(elem_count) }?;
|
|
|
|
@ -634,6 +650,8 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::BF16(out)
|
|
|
|
|
}
|
|
|
|
|
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
|
|
|
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
|
|
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<f16>(elem_count) }?;
|
|
|
|
@ -643,6 +661,8 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F16(out)
|
|
|
|
|
}
|
|
|
|
|
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
|
|
|
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
|
|
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
|
|
|
@ -652,6 +672,8 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F32(out)
|
|
|
|
|
}
|
|
|
|
|
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
|
|
|
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
|
|
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?;
|
|
|
|
|
let out = unsafe { dev.alloc::<f64>(elem_count) }?;
|
|
|
|
@ -661,6 +683,8 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F64(out)
|
|
|
|
|
}
|
|
|
|
|
(CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => {
|
|
|
|
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
|
|
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?;
|
|
|
|
|
let out = unsafe { dev.alloc::<u32>(elem_count) }?;
|
|
|
|
@ -708,28 +732,31 @@ impl CudaStorage {
|
|
|
|
|
|
|
|
|
|
pub(crate) fn where_cond(
|
|
|
|
|
&self,
|
|
|
|
|
shape: &Shape,
|
|
|
|
|
stride: &[usize],
|
|
|
|
|
layout: &Layout,
|
|
|
|
|
t: &Self,
|
|
|
|
|
stride_t: &[usize],
|
|
|
|
|
layout_t: &Layout,
|
|
|
|
|
f: &Self,
|
|
|
|
|
stride_f: &[usize],
|
|
|
|
|
layout_f: &Layout,
|
|
|
|
|
) -> Result<Self> {
|
|
|
|
|
let ids = match &self.slice {
|
|
|
|
|
CudaStorageSlice::U32(slice) => slice,
|
|
|
|
|
CudaStorageSlice::U32(slice) => &slice.slice(layout.start_offset()..),
|
|
|
|
|
_ => Err(CudaError::UnexpectedDType {
|
|
|
|
|
msg: "where conditions should be u32",
|
|
|
|
|
expected: DType::U32,
|
|
|
|
|
got: self.dtype(),
|
|
|
|
|
})?,
|
|
|
|
|
};
|
|
|
|
|
let shape = layout.shape();
|
|
|
|
|
let dims = shape.dims();
|
|
|
|
|
let el = shape.elem_count();
|
|
|
|
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
|
|
|
let dev = self.device();
|
|
|
|
|
let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?;
|
|
|
|
|
let ds =
|
|
|
|
|
dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?;
|
|
|
|
|
let slice = match (&t.slice, &f.slice) {
|
|
|
|
|
(CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => {
|
|
|
|
|
let t = &t.slice(layout_t.start_offset()..);
|
|
|
|
|
let f = &f.slice(layout_f.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<bf16>(el) }?;
|
|
|
|
@ -739,6 +766,8 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::BF16(out)
|
|
|
|
|
}
|
|
|
|
|
(CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => {
|
|
|
|
|
let t = &t.slice(layout_t.start_offset()..);
|
|
|
|
|
let f = &f.slice(layout_f.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<f16>(el) }?;
|
|
|
|
@ -748,6 +777,8 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F16(out)
|
|
|
|
|
}
|
|
|
|
|
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
|
|
|
|
|
let t = &t.slice(layout_t.start_offset()..);
|
|
|
|
|
let f = &f.slice(layout_f.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<f32>(el) }?;
|
|
|
|
@ -757,6 +788,8 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F32(out)
|
|
|
|
|
}
|
|
|
|
|
(CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => {
|
|
|
|
|
let t = &t.slice(layout_t.start_offset()..);
|
|
|
|
|
let f = &f.slice(layout_f.start_offset()..);
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?;
|
|
|
|
|
let out = unsafe { dev.alloc::<f64>(el) }?;
|
|
|
|
@ -766,6 +799,8 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F64(out)
|
|
|
|
|
}
|
|
|
|
|
(CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => {
|
|
|
|
|
let t = &t.slice(layout_t.start_offset()..);
|
|
|
|
|
let f = &f.slice(layout_f.start_offset()..);
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?;
|
|
|
|
|
let out = unsafe { dev.alloc::<u32>(el) }?;
|
|
|
|
@ -775,36 +810,35 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::U32(out)
|
|
|
|
|
}
|
|
|
|
|
// The dtypes should have been checked at this point so this is an internal error.
|
|
|
|
|
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
|
|
|
|
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
|
|
|
|
};
|
|
|
|
|
let device = dev.clone();
|
|
|
|
|
Ok(Self { slice, device })
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub(crate) fn embedding_impl(
|
|
|
|
|
&self,
|
|
|
|
|
shape: &Shape,
|
|
|
|
|
stride: &[usize],
|
|
|
|
|
rhs: &Self,
|
|
|
|
|
h_size: usize, // hidden size
|
|
|
|
|
v_size: usize, // vocab size
|
|
|
|
|
) -> Result<Self> {
|
|
|
|
|
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
|
|
|
|
let ids = match &self.slice {
|
|
|
|
|
CudaStorageSlice::U32(slice) => slice,
|
|
|
|
|
CudaStorageSlice::U32(slice) => &slice.slice(layout.start_offset()..),
|
|
|
|
|
_ => Err(CudaError::UnexpectedDType {
|
|
|
|
|
msg: "embedding ids should be u32",
|
|
|
|
|
expected: DType::U32,
|
|
|
|
|
got: self.dtype(),
|
|
|
|
|
})?,
|
|
|
|
|
};
|
|
|
|
|
let shape = layout.shape();
|
|
|
|
|
let (v_size, h_size) = rhs_l
|
|
|
|
|
.shape()
|
|
|
|
|
.r2()
|
|
|
|
|
.map_err(|e| CudaError::WrappedError(Box::new(e)))?;
|
|
|
|
|
let dims = shape.dims();
|
|
|
|
|
let el = shape.elem_count();
|
|
|
|
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
|
|
|
let dev = self.device();
|
|
|
|
|
let ds = dev.htod_copy([dims, stride].concat())?;
|
|
|
|
|
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
|
|
|
|
let slice = match &rhs.slice {
|
|
|
|
|
// The kernels below assume that rhs is contiguous.
|
|
|
|
|
CudaStorageSlice::U32(arg) => {
|
|
|
|
|
let arg = &arg.slice(rhs_l.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
|
|
|
|
@ -814,6 +848,7 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::U32(out)
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::BF16(arg) => {
|
|
|
|
|
let arg = &arg.slice(rhs_l.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<bf16>(el * h_size) }?;
|
|
|
|
@ -823,6 +858,7 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::BF16(out)
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::F16(arg) => {
|
|
|
|
|
let arg = &arg.slice(rhs_l.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<f16>(el * h_size) }?;
|
|
|
|
@ -832,6 +868,7 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F16(out)
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::F32(arg) => {
|
|
|
|
|
let arg = &arg.slice(rhs_l.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<f32>(el * h_size) }?;
|
|
|
|
@ -841,6 +878,7 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F32(out)
|
|
|
|
|
}
|
|
|
|
|
CudaStorageSlice::F64(arg) => {
|
|
|
|
|
let arg = &arg.slice(rhs_l.start_offset()..);
|
|
|
|
|
let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?;
|
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
|
|
|
let out = unsafe { dev.alloc::<f64>(el * h_size) }?;
|
|
|
|
@ -854,12 +892,12 @@ impl CudaStorage {
|
|
|
|
|
Ok(Self { slice, device })
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub(crate) fn matmul_impl(
|
|
|
|
|
pub(crate) fn matmul(
|
|
|
|
|
&self,
|
|
|
|
|
rhs: &Self,
|
|
|
|
|
(b, m, n, k): (usize, usize, usize, usize),
|
|
|
|
|
lhs_stride: &[usize],
|
|
|
|
|
rhs_stride: &[usize],
|
|
|
|
|
lhs_l: &Layout,
|
|
|
|
|
rhs_l: &Layout,
|
|
|
|
|
) -> Result<Self> {
|
|
|
|
|
let elem_count = b * m * n;
|
|
|
|
|
let dev = &self.device;
|
|
|
|
@ -868,7 +906,9 @@ impl CudaStorage {
|
|
|
|
|
todo!("bf16")
|
|
|
|
|
}
|
|
|
|
|
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
|
|
|
|
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_stride, rhs_stride)?;
|
|
|
|
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
|
|
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
|
|
|
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
|
|
|
|
let mut out = unsafe { dev.alloc::<f16>(elem_count) }?;
|
|
|
|
|
unsafe {
|
|
|
|
|
self.device
|
|
|
|
@ -878,7 +918,9 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F16(out)
|
|
|
|
|
}
|
|
|
|
|
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
|
|
|
|
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
|
|
|
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
|
|
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
|
|
|
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
|
|
|
|
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
|
|
|
|
unsafe {
|
|
|
|
|
self.device
|
|
|
|
@ -888,7 +930,9 @@ impl CudaStorage {
|
|
|
|
|
CudaStorageSlice::F32(out)
|
|
|
|
|
}
|
|
|
|
|
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
|
|
|
|
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
|
|
|
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
|
|
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
|
|
|
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
|
|
|
|
let mut out = unsafe { dev.alloc::<f64>(elem_count) }?;
|
|
|
|
|
unsafe {
|
|
|
|
|
self.device
|
|
|
|
@ -907,13 +951,8 @@ impl CudaStorage {
|
|
|
|
|
&self,
|
|
|
|
|
dst: &mut Self,
|
|
|
|
|
dst_offset: usize,
|
|
|
|
|
src_shape: &Shape,
|
|
|
|
|
src_stride: &[usize],
|
|
|
|
|
src_offset: usize,
|
|
|
|
|
src_l: &Layout,
|
|
|
|
|
) -> Result<()> {
|
|
|
|
|
if src_shape.rank() != src_stride.len() {
|
|
|
|
|
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
|
|
|
|
}
|
|
|
|
|
let dims = src_shape.dims();
|
|
|
|
|
let el_count = src_shape.elem_count();
|
|
|
|
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
|
|
|
|