mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Kernel build example (#224)
* Build example kernels. * Add some sample custom kernel. * Get the example kernel to compile. * Add some cuda code. * More cuda custom op. * More cuda custom ops.
This commit is contained in:
@ -771,6 +771,50 @@ pub struct CudaStorage {
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
pub trait CudaDType: Sized {
|
||||
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>>;
|
||||
fn wrap_cuda_slice(s: CudaSlice<Self>, dev: CudaDevice) -> CudaStorage;
|
||||
}
|
||||
|
||||
macro_rules! cuda_dtype {
|
||||
($ty:ty, $dtype:ident) => {
|
||||
impl CudaDType for $ty {
|
||||
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>> {
|
||||
match &s.slice {
|
||||
CudaStorageSlice::$dtype(data) => Ok(&data),
|
||||
_ => Err(crate::Error::UnexpectedDType {
|
||||
expected: DType::$dtype,
|
||||
got: s.dtype(),
|
||||
msg: "unexpected dtype",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
fn wrap_cuda_slice(slice: CudaSlice<Self>, device: CudaDevice) -> CudaStorage {
|
||||
let slice = CudaStorageSlice::$dtype(slice);
|
||||
CudaStorage { slice, device }
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
cuda_dtype!(u8, U8);
|
||||
cuda_dtype!(u32, U32);
|
||||
cuda_dtype!(f16, F16);
|
||||
cuda_dtype!(bf16, BF16);
|
||||
cuda_dtype!(f32, F32);
|
||||
cuda_dtype!(f64, F64);
|
||||
|
||||
impl CudaStorage {
|
||||
pub fn wrap_cuda_slice<T: CudaDType>(slice: CudaSlice<T>, device: CudaDevice) -> CudaStorage {
|
||||
T::wrap_cuda_slice(slice, device)
|
||||
}
|
||||
|
||||
pub fn as_cuda_slice<T: CudaDType>(&self) -> Result<&CudaSlice<T>> {
|
||||
T::as_cuda_slice(self)
|
||||
}
|
||||
}
|
||||
|
||||
fn gemm_config<T>(
|
||||
alpha: T,
|
||||
beta: T,
|
||||
|
Reference in New Issue
Block a user