From ee3d290f8b75d527fffbc24eb25adc1d8cf642bc Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 27 Jun 2023 09:15:46 +0100 Subject: [PATCH] Cuda support for dtype conversions. --- examples/cuda_basics.rs | 2 ++ examples/llama/main.rs | 18 ++++++---- kernels/src/cast.cu | 34 ++++++++++++++++++ kernels/src/lib.rs | 1 + src/cuda_backend.rs | 78 ++++++++++++++++++++++++++++++++++------- src/dtype.rs | 10 ++++++ 6 files changed, 125 insertions(+), 18 deletions(-) create mode 100644 kernels/src/cast.cu diff --git a/examples/cuda_basics.rs b/examples/cuda_basics.rs index 6f95723d..f288cb40 100644 --- a/examples/cuda_basics.rs +++ b/examples/cuda_basics.rs @@ -7,6 +7,8 @@ fn main() -> Result<()> { println!("> {:?}", x.sum(&[0])?.to_vec2::()?); println!("> {:?}", x.sum(&[1])?.to_vec2::()?); println!("> {:?}", x.sum(&[0, 1])?.to_vec2::()?); + let x = x.to_dtype(candle::DType::F16)?; + println!("> {:?}", x.sum(&[0])?.to_vec2::()?); let x = Tensor::new(&[3f32, 1., 4., 1., 5.], &device)?; println!("{:?}", x.to_vec1::()?); diff --git a/examples/llama/main.rs b/examples/llama/main.rs index d2b16446..32e4b746 100644 --- a/examples/llama/main.rs +++ b/examples/llama/main.rs @@ -14,7 +14,7 @@ use anyhow::{Error as E, Result}; use clap::Parser; -use candle::{Device, Tensor}; +use candle::{DType, Device, Tensor}; mod var_store; use var_store::VarBuilder; @@ -135,7 +135,10 @@ impl Embedding { } fn forward(&self, indexes: &Tensor) -> Result { - Ok(Tensor::embedding(indexes, &self.embeddings)?) + Ok(Tensor::embedding( + indexes, + &self.embeddings.to_dtype(DType::F32)?, + )?) } } @@ -158,10 +161,10 @@ impl Linear { } fn forward(&self, x: &Tensor) -> Result { - let x = x.matmul(&self.ws)?; + let x = x.matmul(&self.ws.to_dtype(DType::F32)?)?; let y = match &self.bs { None => x, - Some(bs) => x.broadcast_add(bs)?, + Some(bs) => x.broadcast_add(&bs.to_dtype(DType::F32)?)?, }; Ok(y) } @@ -183,7 +186,10 @@ impl RmsNorm { let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; - let scale = self.scale.broadcast_as((seq_len, self.size))?; + let scale = self + .scale + .to_dtype(DType::F32)? + .broadcast_as((seq_len, self.size))?; Ok((scale * x_normed)?) } } @@ -431,7 +437,7 @@ fn main() -> Result<()> { .get_ids() .to_vec(); - let weight_path = std::path::Path::new("llama-f32.npz"); + let weight_path = std::path::Path::new("llama.npz"); let weights = if weight_path.exists() { println!("loading weights from {weight_path:?}"); let start_load = std::time::Instant::now(); diff --git a/kernels/src/cast.cu b/kernels/src/cast.cu new file mode 100644 index 00000000..0e47d297 --- /dev/null +++ b/kernels/src/cast.cu @@ -0,0 +1,34 @@ +#include "cuda_utils.cuh" + +#define CAST_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const SRC_TYPENAME *inp, \ + DST_TYPENAME *out \ +) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + if (is_contiguous(num_dims, dims, strides)) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + out[i] = inp[i]; \ + } \ + } \ + else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + out[i] = inp[strided_i]; \ + } \ + } \ +} \ + +#if __CUDA_ARCH__ >= 530 +CAST_OP(__half, __half, cast_f16_f16) +CAST_OP(__half, float, cast_f16_f32) +CAST_OP(float, __half, cast_f32_f16) +#endif + +CAST_OP(float, float, cast_f32_f32) +CAST_OP(float, double, cast_f32_f64) +CAST_OP(double, float, cast_f64_f32) diff --git a/kernels/src/lib.rs b/kernels/src/lib.rs index d29022da..c3a927ad 100644 --- a/kernels/src/lib.rs +++ b/kernels/src/lib.rs @@ -1,5 +1,6 @@ pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx")); pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx")); +pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx")); pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 0739b6b3..2410f1d7 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -21,7 +21,7 @@ pub enum CudaError { RequiresContiguous { op: &'static str }, #[error("missing kernel '{module_name}'")] - MissingKernel { module_name: &'static str }, + MissingKernel { module_name: String }, #[error("internal error '{0}'")] InternalError(&'static str), @@ -43,7 +43,7 @@ pub enum CudaError { #[error("{cuda} when loading {module_name}")] Load { cuda: cudarc::driver::DriverError, - module_name: &'static str, + module_name: String, }, } @@ -211,19 +211,23 @@ impl CudaDevice { }) } - fn get_or_load_func( - &self, - module_name: &'static str, - ptx: &'static str, - ) -> Result { + fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result { if !self.has_func(module_name, module_name) { - self.load_ptx(ptx.into(), module_name, &[module_name]) - .map_err(|cuda| CudaError::Load { cuda, module_name })?; + // Leaking the string here is a bit sad but we need a &'static str and this is only + // done once per kernel name. + let static_module_name = Box::leak(module_name.to_string().into_boxed_str()); + self.load_ptx(ptx.into(), module_name, &[static_module_name]) + .map_err(|cuda| CudaError::Load { + cuda, + module_name: module_name.to_string(), + })?; } self.get_func(module_name, module_name) // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is // able to only build the error value if needed. - .ok_or(CudaError::MissingKernel { module_name }) + .ok_or(CudaError::MissingKernel { + module_name: module_name.to_string(), + }) } } @@ -330,8 +334,58 @@ impl CudaStorage { &self.device } - pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result { - Err(CudaError::InternalError("TODO: implement to_dtype")) + pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result { + use cudarc::driver::DevicePtr; + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let dev = self.device(); + let ds = dev.htod_copy([dims, stride].concat())?; + let inp = match &self.slice { + CudaStorageSlice::U32(inp) => inp.device_ptr(), + CudaStorageSlice::BF16(inp) => inp.device_ptr(), + CudaStorageSlice::F16(inp) => inp.device_ptr(), + CudaStorageSlice::F32(inp) => inp.device_ptr(), + CudaStorageSlice::F64(inp) => inp.device_ptr(), + }; + let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str()); + let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?; + let slice = match dtype { + DType::U32 => { + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::U32(out) + } + DType::BF16 => { + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::BF16(out) + } + DType::F16 => { + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F16(out) + } + DType::F32 => { + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F32(out) + } + DType::F64 => { + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F64(out) + } + }; + Ok(Self { + slice, + device: dev.clone(), + }) } pub(crate) fn affine_impl( diff --git a/src/dtype.rs b/src/dtype.rs index 53b61ce0..1711c2b4 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -10,6 +10,16 @@ pub enum DType { } impl DType { + pub fn as_str(&self) -> &'static str { + match self { + Self::U32 => "u32", + Self::BF16 => "bf16", + Self::F16 => "f16", + Self::F32 => "f32", + Self::F64 => "f64", + } + } + pub fn size_in_bytes(&self) -> usize { match self { Self::U32 => 4,