mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add some skeleton code for GPU support.
This commit is contained in:
@ -20,6 +20,7 @@ impl CpuStorage {
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Storage {
|
||||
Cpu(CpuStorage),
|
||||
Cuda { gpu_id: usize }, // TODO: Actually add the storage.
|
||||
}
|
||||
|
||||
trait UnaryOp {
|
||||
@ -116,12 +117,14 @@ impl Storage {
|
||||
pub fn device(&self) -> Device {
|
||||
match self {
|
||||
Self::Cpu(_) => Device::Cpu,
|
||||
Self::Cuda { gpu_id } => Device::Cuda { gpu_id: *gpu_id },
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self {
|
||||
Self::Cpu(storage) => storage.dtype(),
|
||||
Self::Cuda { .. } => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -168,6 +171,7 @@ impl Storage {
|
||||
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
||||
}
|
||||
},
|
||||
Self::Cuda { .. } => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -186,6 +190,7 @@ impl Storage {
|
||||
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
||||
}
|
||||
},
|
||||
Self::Cuda { .. } => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -232,6 +237,16 @@ impl Storage {
|
||||
})
|
||||
}
|
||||
},
|
||||
(Self::Cuda { .. }, Self::Cuda { .. }) => todo!(),
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device(),
|
||||
rhs: rhs.device(),
|
||||
op: B::NAME,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user