mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add a first kernel.
This commit is contained in:
@ -4,9 +4,9 @@ use candle::{Device, Tensor};
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let x = Tensor::new(&[3f32, 1., 4., 1., 5.], &device)?;
|
||||
let y = Tensor::new(&[2f32, 7., 1., 8., 2.], &device)?;
|
||||
println!("{:?}", x.to_vec1::<f32>()?);
|
||||
let z = (x + y)?;
|
||||
let y = Tensor::new(&[2f32, 7., 1., 8., 2.], &device)?;
|
||||
let z = (y * 3.)?;
|
||||
println!("{:?}", z.to_vec1::<f32>()?);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,11 +1,28 @@
|
||||
use crate::{CpuStorage, DType, Result, Shape};
|
||||
use cudarc::driver::CudaSlice;
|
||||
use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig};
|
||||
|
||||
pub(crate) type Error = cudarc::driver::DriverError;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CudaDevice(std::sync::Arc<cudarc::driver::CudaDevice>);
|
||||
|
||||
// TODO: Switch to pre-compiled PTX kernels rather than compiling on the fly.
|
||||
const AFFINE_CU: &str = r#"
|
||||
extern "C" __global__ void affine_f32(
|
||||
const size_t numel,
|
||||
const float *x,
|
||||
float *y,
|
||||
const float mul,
|
||||
const float add
|
||||
) {
|
||||
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= numel) {
|
||||
return;
|
||||
}
|
||||
y[i] = x[i] * mul + add;
|
||||
}
|
||||
"#;
|
||||
|
||||
impl CudaDevice {
|
||||
pub(crate) fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = cudarc::driver::CudaDevice::new(ordinal)?;
|
||||
@ -65,6 +82,39 @@ impl CudaStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn affine_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
_stride: &[usize],
|
||||
mul: f64,
|
||||
add: f64,
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
Self::F32(arg) => {
|
||||
// TODO: Handle the stride.
|
||||
let dev = arg.device();
|
||||
let module_name = "affine_f32";
|
||||
if !dev.has_func(module_name, module_name) {
|
||||
let ptx = cudarc::nvrtc::compile_ptx(AFFINE_CU).unwrap();
|
||||
dev.load_ptx(ptx, module_name, &[module_name])?;
|
||||
}
|
||||
let elem_count = shape.elem_count();
|
||||
let fwd_fn = dev.get_func(module_name, module_name).unwrap();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
// SAFETY: if this function returns Ok(..), the kernel has been applied
|
||||
// and has set the initially unset memory.
|
||||
let out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||
let params = (elem_count, arg, &out, mul as f32, add as f32);
|
||||
// SAFETY: well, well, well...
|
||||
unsafe { fwd_fn.launch(cfg, params) }?;
|
||||
Ok(Self::F32(out))
|
||||
}
|
||||
Self::F64(_) => {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
match self {
|
||||
Self::F32(slice) => {
|
||||
|
@ -144,7 +144,10 @@ impl Storage {
|
||||
let storage = storage.affine_impl(shape, stride, mul, add)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda { .. } => todo!(),
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.affine_impl(shape, stride, mul, add)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user