mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Start adding support for cuda.
This commit is contained in:
@ -3,7 +3,7 @@ use crate::{CpuStorage, DType, Device, Error, Result, Shape};
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Storage {
|
||||
Cpu(CpuStorage),
|
||||
Cuda { gpu_id: usize }, // TODO: Actually add the storage.
|
||||
Cuda(cudarc::driver::CudaSlice<f32>),
|
||||
}
|
||||
|
||||
pub(crate) trait UnaryOp {
|
||||
@ -100,7 +100,7 @@ impl Storage {
|
||||
pub fn device(&self) -> Device {
|
||||
match self {
|
||||
Self::Cpu(_) => Device::Cpu,
|
||||
Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id },
|
||||
Self::Cuda(slice) => Device::Cuda(slice.device()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,8 +112,8 @@ impl Storage {
|
||||
}
|
||||
|
||||
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||
let lhs = self.device();
|
||||
let rhs = rhs.device();
|
||||
let lhs = self.device().location();
|
||||
let rhs = rhs.device().location();
|
||||
if lhs != rhs {
|
||||
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op })
|
||||
} else {
|
||||
@ -179,8 +179,8 @@ impl Storage {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device(),
|
||||
rhs: rhs.device(),
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: B::NAME,
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user