diff --git a/examples/cuda_basics.rs b/examples/cuda_basics.rs new file mode 100644 index 00000000..e1dca6a9 --- /dev/null +++ b/examples/cuda_basics.rs @@ -0,0 +1,9 @@ +use anyhow::Result; +use candle::{DType, Device, Tensor}; + +fn main() -> Result<()> { + let device = Device::new_cuda(0)?; + let x = Tensor::zeros(4, DType::F32, device)?; + println!("{:?}", x.to_vec1::()?); + Ok(()) +} diff --git a/src/device.rs b/src/device.rs index bb3d8870..19e1a302 100644 --- a/src/device.rs +++ b/src/device.rs @@ -62,6 +62,11 @@ impl NdArray for &[[S; N]; } impl Device { + pub fn new_cuda(ordinal: usize) -> Result { + let device = cudarc::driver::CudaDevice::new(ordinal)?; + Ok(Self::Cuda(device)) + } + pub fn location(&self) -> DeviceLocation { match self { Self::Cpu => DeviceLocation::Cpu, @@ -74,11 +79,14 @@ impl Device { pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result { match self { Device::Cpu => { - let storage = Storage::Cpu(CpuStorage::ones_impl(shape, dtype)); - Ok(storage) + let storage = CpuStorage::ones_impl(shape, dtype); + Ok(Storage::Cpu(storage)) } - Device::Cuda(_) => { - todo!() + Device::Cuda(device) => { + // TODO: Instead of allocating memory on the host and transfering it, + // allocate some zeros on the device and use a shader to set them to 1. + let storage = device.htod_copy(vec![1f32; shape.elem_count()])?; + Ok(Storage::Cuda(storage)) } } } @@ -98,12 +106,18 @@ impl Device { pub(crate) fn tensor(&self, array: A) -> Result { match self { - Device::Cpu => { - let storage = Storage::Cpu(array.to_cpu_storage()); - Ok(storage) - } - Device::Cuda(_) => { - todo!() + Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), + Device::Cuda(device) => { + // TODO: Avoid making a copy through the cpu. + match array.to_cpu_storage() { + CpuStorage::F64(_) => { + todo!() + } + CpuStorage::F32(data) => { + let storage = device.htod_copy(data)?; + Ok(Storage::Cuda(storage)) + } + } } } } diff --git a/src/tensor.rs b/src/tensor.rs index 9ba412f9..5faf886f 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -250,7 +250,7 @@ impl Tensor { let data = S::cpu_storage_as_slice(cpu_storage)?; Ok(self.strided_index().map(|i| data[i]).collect()) } - Storage::Cuda { .. } => todo!(), + Storage::Cuda(_) => todo!(), } }