mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Abstract the gradient storage.
This commit is contained in:
@ -54,27 +54,36 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N];
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Storage {
|
||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => Storage::Cpu(CpuStorage::ones_impl(shape, dtype)),
|
||||
Device::Cpu => {
|
||||
let storage = Storage::Cpu(CpuStorage::ones_impl(shape, dtype));
|
||||
Ok(storage)
|
||||
}
|
||||
Device::Cuda { gpu_id: _ } => {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage {
|
||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => Storage::Cpu(CpuStorage::zeros_impl(shape, dtype)),
|
||||
Device::Cpu => {
|
||||
let storage = Storage::Cpu(CpuStorage::zeros_impl(shape, dtype));
|
||||
Ok(storage)
|
||||
}
|
||||
Device::Cuda { gpu_id: _ } => {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Storage {
|
||||
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => Storage::Cpu(array.to_cpu_storage()),
|
||||
Device::Cpu => {
|
||||
let storage = Storage::Cpu(array.to_cpu_storage());
|
||||
Ok(storage)
|
||||
}
|
||||
Device::Cuda { gpu_id: _ } => {
|
||||
todo!()
|
||||
}
|
||||
|
Reference in New Issue
Block a user