mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Cuda support for dtype conversions.
This commit is contained in:
@ -7,6 +7,8 @@ fn main() -> Result<()> {
|
|||||||
println!("> {:?}", x.sum(&[0])?.to_vec2::<f32>()?);
|
println!("> {:?}", x.sum(&[0])?.to_vec2::<f32>()?);
|
||||||
println!("> {:?}", x.sum(&[1])?.to_vec2::<f32>()?);
|
println!("> {:?}", x.sum(&[1])?.to_vec2::<f32>()?);
|
||||||
println!("> {:?}", x.sum(&[0, 1])?.to_vec2::<f32>()?);
|
println!("> {:?}", x.sum(&[0, 1])?.to_vec2::<f32>()?);
|
||||||
|
let x = x.to_dtype(candle::DType::F16)?;
|
||||||
|
println!("> {:?}", x.sum(&[0])?.to_vec2::<half::f16>()?);
|
||||||
|
|
||||||
let x = Tensor::new(&[3f32, 1., 4., 1., 5.], &device)?;
|
let x = Tensor::new(&[3f32, 1., 4., 1., 5.], &device)?;
|
||||||
println!("{:?}", x.to_vec1::<f32>()?);
|
println!("{:?}", x.to_vec1::<f32>()?);
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
use candle::{Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
|
|
||||||
mod var_store;
|
mod var_store;
|
||||||
use var_store::VarBuilder;
|
use var_store::VarBuilder;
|
||||||
@ -135,7 +135,10 @@ impl Embedding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||||
Ok(Tensor::embedding(indexes, &self.embeddings)?)
|
Ok(Tensor::embedding(
|
||||||
|
indexes,
|
||||||
|
&self.embeddings.to_dtype(DType::F32)?,
|
||||||
|
)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,10 +161,10 @@ impl Linear {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = x.matmul(&self.ws)?;
|
let x = x.matmul(&self.ws.to_dtype(DType::F32)?)?;
|
||||||
let y = match &self.bs {
|
let y = match &self.bs {
|
||||||
None => x,
|
None => x,
|
||||||
Some(bs) => x.broadcast_add(bs)?,
|
Some(bs) => x.broadcast_add(&bs.to_dtype(DType::F32)?)?,
|
||||||
};
|
};
|
||||||
Ok(y)
|
Ok(y)
|
||||||
}
|
}
|
||||||
@ -183,7 +186,10 @@ impl RmsNorm {
|
|||||||
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
|
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
|
||||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
||||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||||
let scale = self.scale.broadcast_as((seq_len, self.size))?;
|
let scale = self
|
||||||
|
.scale
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.broadcast_as((seq_len, self.size))?;
|
||||||
Ok((scale * x_normed)?)
|
Ok((scale * x_normed)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -431,7 +437,7 @@ fn main() -> Result<()> {
|
|||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
|
||||||
let weight_path = std::path::Path::new("llama-f32.npz");
|
let weight_path = std::path::Path::new("llama.npz");
|
||||||
let weights = if weight_path.exists() {
|
let weights = if weight_path.exists() {
|
||||||
println!("loading weights from {weight_path:?}");
|
println!("loading weights from {weight_path:?}");
|
||||||
let start_load = std::time::Instant::now();
|
let start_load = std::time::Instant::now();
|
||||||
|
34
kernels/src/cast.cu
Normal file
34
kernels/src/cast.cu
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
#include "cuda_utils.cuh"
|
||||||
|
|
||||||
|
#define CAST_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const size_t numel, \
|
||||||
|
const size_t num_dims, \
|
||||||
|
const size_t *info, \
|
||||||
|
const SRC_TYPENAME *inp, \
|
||||||
|
DST_TYPENAME *out \
|
||||||
|
) { \
|
||||||
|
const size_t *dims = info; \
|
||||||
|
const size_t *strides = info + num_dims; \
|
||||||
|
if (is_contiguous(num_dims, dims, strides)) { \
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
|
out[i] = inp[i]; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
else { \
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
|
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||||
|
out[i] = inp[strided_i]; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
|
CAST_OP(__half, __half, cast_f16_f16)
|
||||||
|
CAST_OP(__half, float, cast_f16_f32)
|
||||||
|
CAST_OP(float, __half, cast_f32_f16)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
CAST_OP(float, float, cast_f32_f32)
|
||||||
|
CAST_OP(float, double, cast_f32_f64)
|
||||||
|
CAST_OP(double, float, cast_f64_f32)
|
@ -1,5 +1,6 @@
|
|||||||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||||
|
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||||
pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx"));
|
pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx"));
|
||||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||||
|
@ -21,7 +21,7 @@ pub enum CudaError {
|
|||||||
RequiresContiguous { op: &'static str },
|
RequiresContiguous { op: &'static str },
|
||||||
|
|
||||||
#[error("missing kernel '{module_name}'")]
|
#[error("missing kernel '{module_name}'")]
|
||||||
MissingKernel { module_name: &'static str },
|
MissingKernel { module_name: String },
|
||||||
|
|
||||||
#[error("internal error '{0}'")]
|
#[error("internal error '{0}'")]
|
||||||
InternalError(&'static str),
|
InternalError(&'static str),
|
||||||
@ -43,7 +43,7 @@ pub enum CudaError {
|
|||||||
#[error("{cuda} when loading {module_name}")]
|
#[error("{cuda} when loading {module_name}")]
|
||||||
Load {
|
Load {
|
||||||
cuda: cudarc::driver::DriverError,
|
cuda: cudarc::driver::DriverError,
|
||||||
module_name: &'static str,
|
module_name: String,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -211,19 +211,23 @@ impl CudaDevice {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_or_load_func(
|
fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||||
&self,
|
|
||||||
module_name: &'static str,
|
|
||||||
ptx: &'static str,
|
|
||||||
) -> Result<CudaFunction> {
|
|
||||||
if !self.has_func(module_name, module_name) {
|
if !self.has_func(module_name, module_name) {
|
||||||
self.load_ptx(ptx.into(), module_name, &[module_name])
|
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
||||||
.map_err(|cuda| CudaError::Load { cuda, module_name })?;
|
// done once per kernel name.
|
||||||
|
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
||||||
|
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
||||||
|
.map_err(|cuda| CudaError::Load {
|
||||||
|
cuda,
|
||||||
|
module_name: module_name.to_string(),
|
||||||
|
})?;
|
||||||
}
|
}
|
||||||
self.get_func(module_name, module_name)
|
self.get_func(module_name, module_name)
|
||||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
||||||
// able to only build the error value if needed.
|
// able to only build the error value if needed.
|
||||||
.ok_or(CudaError::MissingKernel { module_name })
|
.ok_or(CudaError::MissingKernel {
|
||||||
|
module_name: module_name.to_string(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -330,8 +334,58 @@ impl CudaStorage {
|
|||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result<Self> {
|
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||||
Err(CudaError::InternalError("TODO: implement to_dtype"))
|
use cudarc::driver::DevicePtr;
|
||||||
|
let dims = shape.dims();
|
||||||
|
let el = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
|
let dev = self.device();
|
||||||
|
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||||
|
let inp = match &self.slice {
|
||||||
|
CudaStorageSlice::U32(inp) => inp.device_ptr(),
|
||||||
|
CudaStorageSlice::BF16(inp) => inp.device_ptr(),
|
||||||
|
CudaStorageSlice::F16(inp) => inp.device_ptr(),
|
||||||
|
CudaStorageSlice::F32(inp) => inp.device_ptr(),
|
||||||
|
CudaStorageSlice::F64(inp) => inp.device_ptr(),
|
||||||
|
};
|
||||||
|
let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str());
|
||||||
|
let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?;
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U32 => {
|
||||||
|
let out = unsafe { dev.alloc::<u32>(el) }?;
|
||||||
|
let params = (el, dims.len(), &ds, *inp, &out);
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::U32(out)
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
let out = unsafe { dev.alloc::<bf16>(el) }?;
|
||||||
|
let params = (el, dims.len(), &ds, *inp, &out);
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::BF16(out)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let out = unsafe { dev.alloc::<f16>(el) }?;
|
||||||
|
let params = (el, dims.len(), &ds, *inp, &out);
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::F16(out)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let out = unsafe { dev.alloc::<f32>(el) }?;
|
||||||
|
let params = (el, dims.len(), &ds, *inp, &out);
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::F32(out)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let out = unsafe { dev.alloc::<f64>(el) }?;
|
||||||
|
let params = (el, dims.len(), &ds, *inp, &out);
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
CudaStorageSlice::F64(out)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
slice,
|
||||||
|
device: dev.clone(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn affine_impl(
|
pub(crate) fn affine_impl(
|
||||||
|
10
src/dtype.rs
10
src/dtype.rs
@ -10,6 +10,16 @@ pub enum DType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl DType {
|
impl DType {
|
||||||
|
pub fn as_str(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::U32 => "u32",
|
||||||
|
Self::BF16 => "bf16",
|
||||||
|
Self::F16 => "f16",
|
||||||
|
Self::F32 => "f32",
|
||||||
|
Self::F64 => "f64",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn size_in_bytes(&self) -> usize {
|
pub fn size_in_bytes(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
Self::U32 => 4,
|
Self::U32 => 4,
|
||||||
|
Reference in New Issue
Block a user