Kernel build example (#224)

* Build example kernels.

* Add some sample custom kernel.

* Get the example kernel to compile.

* Add some cuda code.

* More cuda custom op.

* More cuda custom ops.
This commit is contained in:
Laurent Mazare
2023-07-23 08:15:37 +02:00
committed by GitHub
parent 5f20acf080
commit b8a10425ad
7 changed files with 427 additions and 0 deletions

View File

@ -771,6 +771,50 @@ pub struct CudaStorage {
device: CudaDevice,
}
pub trait CudaDType: Sized {
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>>;
fn wrap_cuda_slice(s: CudaSlice<Self>, dev: CudaDevice) -> CudaStorage;
}
macro_rules! cuda_dtype {
($ty:ty, $dtype:ident) => {
impl CudaDType for $ty {
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>> {
match &s.slice {
CudaStorageSlice::$dtype(data) => Ok(&data),
_ => Err(crate::Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
msg: "unexpected dtype",
}
.bt()),
}
}
fn wrap_cuda_slice(slice: CudaSlice<Self>, device: CudaDevice) -> CudaStorage {
let slice = CudaStorageSlice::$dtype(slice);
CudaStorage { slice, device }
}
}
};
}
cuda_dtype!(u8, U8);
cuda_dtype!(u32, U32);
cuda_dtype!(f16, F16);
cuda_dtype!(bf16, BF16);
cuda_dtype!(f32, F32);
cuda_dtype!(f64, F64);
impl CudaStorage {
pub fn wrap_cuda_slice<T: CudaDType>(slice: CudaSlice<T>, device: CudaDevice) -> CudaStorage {
T::wrap_cuda_slice(slice, device)
}
pub fn as_cuda_slice<T: CudaDType>(&self) -> Result<&CudaSlice<T>> {
T::as_cuda_slice(self)
}
}
fn gemm_config<T>(
alpha: T,
beta: T,