Add a specific example for cuda.

This commit is contained in:
laurent
2023-06-21 18:56:04 +01:00
parent 2bfe8f18ab
commit c654ecdb16
3 changed files with 34 additions and 11 deletions

9
examples/cuda_basics.rs Normal file
View File

@ -0,0 +1,9 @@
use anyhow::Result;
use candle::{DType, Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let x = Tensor::zeros(4, DType::F32, device)?;
println!("{:?}", x.to_vec1::<f32>()?);
Ok(())
}

View File

@ -62,6 +62,11 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N];
} }
impl Device { impl Device {
pub fn new_cuda(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal)?;
Ok(Self::Cuda(device))
}
pub fn location(&self) -> DeviceLocation { pub fn location(&self) -> DeviceLocation {
match self { match self {
Self::Cpu => DeviceLocation::Cpu, Self::Cpu => DeviceLocation::Cpu,
@ -74,11 +79,14 @@ impl Device {
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> { pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self { match self {
Device::Cpu => { Device::Cpu => {
let storage = Storage::Cpu(CpuStorage::ones_impl(shape, dtype)); let storage = CpuStorage::ones_impl(shape, dtype);
Ok(storage) Ok(Storage::Cpu(storage))
} }
Device::Cuda(_) => { Device::Cuda(device) => {
todo!() // TODO: Instead of allocating memory on the host and transfering it,
// allocate some zeros on the device and use a shader to set them to 1.
let storage = device.htod_copy(vec![1f32; shape.elem_count()])?;
Ok(Storage::Cuda(storage))
} }
} }
} }
@ -98,13 +106,19 @@ impl Device {
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Result<Storage> { pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Result<Storage> {
match self { match self {
Device::Cpu => { Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
let storage = Storage::Cpu(array.to_cpu_storage()); Device::Cuda(device) => {
Ok(storage) // TODO: Avoid making a copy through the cpu.
} match array.to_cpu_storage() {
Device::Cuda(_) => { CpuStorage::F64(_) => {
todo!() todo!()
} }
CpuStorage::F32(data) => {
let storage = device.htod_copy(data)?;
Ok(Storage::Cuda(storage))
}
}
}
} }
} }
} }

View File

@ -250,7 +250,7 @@ impl Tensor {
let data = S::cpu_storage_as_slice(cpu_storage)?; let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok(self.strided_index().map(|i| data[i]).collect()) Ok(self.strided_index().map(|i| data[i]).collect())
} }
Storage::Cuda { .. } => todo!(), Storage::Cuda(_) => todo!(),
} }
} }