mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add some skeleton code for GPU support.
This commit is contained in:
@ -6,6 +6,7 @@ use crate::{
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum Device {
|
||||
Cpu,
|
||||
Cuda { gpu_id: usize },
|
||||
}
|
||||
|
||||
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
|
||||
@ -72,6 +73,9 @@ impl Device {
|
||||
};
|
||||
Storage::Cpu(storage)
|
||||
}
|
||||
Device::Cuda { gpu_id: _ } => {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,12 +95,18 @@ impl Device {
|
||||
};
|
||||
Storage::Cpu(storage)
|
||||
}
|
||||
Device::Cuda { gpu_id: _ } => {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Storage {
|
||||
match self {
|
||||
Device::Cpu => Storage::Cpu(array.to_cpu_storage()),
|
||||
Device::Cuda { gpu_id: _ } => {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user