mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Backport.
This commit is contained in:
@ -79,6 +79,10 @@ impl CudaDevice {
|
|||||||
pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let slice = match dtype {
|
let slice = match dtype {
|
||||||
|
DType::U32 => {
|
||||||
|
let data = self.alloc_zeros::<u32>(elem_count)?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let data = self.alloc_zeros::<f32>(elem_count)?;
|
let data = self.alloc_zeros::<f32>(elem_count)?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
@ -98,6 +102,14 @@ impl CudaDevice {
|
|||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
let slice = match dtype {
|
let slice = match dtype {
|
||||||
|
DType::U32 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<u32>(elem_count) }?;
|
||||||
|
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
||||||
|
let params = (&data, v as u32, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<f32>(elem_count) }?;
|
let data = unsafe { self.alloc::<f32>(elem_count) }?;
|
||||||
@ -127,6 +139,10 @@ impl CudaDevice {
|
|||||||
|
|
||||||
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||||
let slice = match storage {
|
let slice = match storage {
|
||||||
|
CpuStorage::U32(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage)?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
CpuStorage::F32(storage) => {
|
CpuStorage::F32(storage) => {
|
||||||
let data = self.htod_sync_copy(storage)?;
|
let data = self.htod_sync_copy(storage)?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
@ -159,6 +175,7 @@ impl CudaDevice {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum CudaStorageSlice {
|
enum CudaStorageSlice {
|
||||||
|
U32(CudaSlice<u32>),
|
||||||
F32(CudaSlice<f32>),
|
F32(CudaSlice<f32>),
|
||||||
F64(CudaSlice<f64>),
|
F64(CudaSlice<f64>),
|
||||||
}
|
}
|
||||||
@ -205,6 +222,7 @@ fn gemm_config<T>(
|
|||||||
impl CudaStorage {
|
impl CudaStorage {
|
||||||
pub fn try_clone(&self) -> Result<Self> {
|
pub fn try_clone(&self) -> Result<Self> {
|
||||||
let slice = match &self.slice {
|
let slice = match &self.slice {
|
||||||
|
CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?),
|
||||||
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
|
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
|
||||||
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
|
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
|
||||||
};
|
};
|
||||||
@ -214,6 +232,7 @@ impl CudaStorage {
|
|||||||
|
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
match self.slice {
|
match self.slice {
|
||||||
|
CudaStorageSlice::U32(_) => DType::U32,
|
||||||
CudaStorageSlice::F32(_) => DType::F32,
|
CudaStorageSlice::F32(_) => DType::F32,
|
||||||
CudaStorageSlice::F64(_) => DType::F64,
|
CudaStorageSlice::F64(_) => DType::F64,
|
||||||
}
|
}
|
||||||
@ -236,6 +255,15 @@ impl CudaStorage {
|
|||||||
let dev = self.device();
|
let dev = self.device();
|
||||||
let ds = dev.htod_copy([dims, stride].concat())?;
|
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||||
let slice = match &self.slice {
|
let slice = match &self.slice {
|
||||||
|
CudaStorageSlice::U32(arg) => {
|
||||||
|
let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<u32>(el_count) }?;
|
||||||
|
let params = (el_count, dims.len(), &ds, arg, &out, mul as u32, add as u32);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::U32(out)
|
||||||
|
}
|
||||||
CudaStorageSlice::F32(arg) => {
|
CudaStorageSlice::F32(arg) => {
|
||||||
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
|
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
@ -270,6 +298,9 @@ impl CudaStorage {
|
|||||||
let dev = &self.device;
|
let dev = &self.device;
|
||||||
let ds = dev.htod_copy([dims, stride].concat())?;
|
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||||
let slice = match &self.slice {
|
let slice = match &self.slice {
|
||||||
|
CudaStorageSlice::U32(_arg) => {
|
||||||
|
todo!("No unary kernels for u32");
|
||||||
|
}
|
||||||
CudaStorageSlice::F32(arg) => {
|
CudaStorageSlice::F32(arg) => {
|
||||||
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
|
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
@ -333,6 +364,11 @@ impl CudaStorage {
|
|||||||
|
|
||||||
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
match &self.slice {
|
match &self.slice {
|
||||||
|
CudaStorageSlice::U32(slice) => {
|
||||||
|
let dev = slice.device();
|
||||||
|
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||||
|
Ok(CpuStorage::U32(cpu_storage))
|
||||||
|
}
|
||||||
CudaStorageSlice::F32(slice) => {
|
CudaStorageSlice::F32(slice) => {
|
||||||
let dev = slice.device();
|
let dev = slice.device();
|
||||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||||
@ -348,9 +384,9 @@ impl CudaStorage {
|
|||||||
|
|
||||||
pub(crate) fn embedding_impl(
|
pub(crate) fn embedding_impl(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
_rhs: &Self,
|
||||||
hidden_size: usize,
|
_hidden_size: usize,
|
||||||
vocab_size: usize,
|
_vocab_size: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
todo!("Implement embedding for gpu");
|
todo!("Implement embedding for gpu");
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user