mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
@ -79,6 +79,10 @@ impl CudaDevice {
|
||||
pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
DType::U32 => {
|
||||
let data = self.alloc_zeros::<u32>(elem_count)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = self.alloc_zeros::<f32>(elem_count)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
@ -98,6 +102,14 @@ impl CudaDevice {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let slice = match dtype {
|
||||
DType::U32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count) }?;
|
||||
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
||||
let params = (&data, v as u32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count) }?;
|
||||
@ -127,6 +139,10 @@ impl CudaDevice {
|
||||
|
||||
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_sync_copy(storage)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_sync_copy(storage)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
@ -159,6 +175,7 @@ impl CudaDevice {
|
||||
|
||||
#[derive(Debug)]
|
||||
enum CudaStorageSlice {
|
||||
U32(CudaSlice<u32>),
|
||||
F32(CudaSlice<f32>),
|
||||
F64(CudaSlice<f64>),
|
||||
}
|
||||
@ -205,6 +222,7 @@ fn gemm_config<T>(
|
||||
impl CudaStorage {
|
||||
pub fn try_clone(&self) -> Result<Self> {
|
||||
let slice = match &self.slice {
|
||||
CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?),
|
||||
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
|
||||
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
|
||||
};
|
||||
@ -214,6 +232,7 @@ impl CudaStorage {
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self.slice {
|
||||
CudaStorageSlice::U32(_) => DType::U32,
|
||||
CudaStorageSlice::F32(_) => DType::F32,
|
||||
CudaStorageSlice::F64(_) => DType::F64,
|
||||
}
|
||||
@ -236,6 +255,15 @@ impl CudaStorage {
|
||||
let dev = self.device();
|
||||
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||
let slice = match &self.slice {
|
||||
CudaStorageSlice::U32(arg) => {
|
||||
let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<u32>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out, mul as u32, add as u32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
CudaStorageSlice::F32(arg) => {
|
||||
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -270,6 +298,9 @@ impl CudaStorage {
|
||||
let dev = &self.device;
|
||||
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||
let slice = match &self.slice {
|
||||
CudaStorageSlice::U32(_arg) => {
|
||||
todo!("No unary kernels for u32");
|
||||
}
|
||||
CudaStorageSlice::F32(arg) => {
|
||||
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
@ -333,6 +364,11 @@ impl CudaStorage {
|
||||
|
||||
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
match &self.slice {
|
||||
CudaStorageSlice::U32(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||
Ok(CpuStorage::U32(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::F32(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||
@ -348,9 +384,9 @@ impl CudaStorage {
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
hidden_size: usize,
|
||||
vocab_size: usize,
|
||||
_rhs: &Self,
|
||||
_hidden_size: usize,
|
||||
_vocab_size: usize,
|
||||
) -> Result<Self> {
|
||||
todo!("Implement embedding for gpu");
|
||||
}
|
||||
|
Reference in New Issue
Block a user