mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Rename as_slice to storage_data and implement the cuda version.
This commit is contained in:
11
src/dtype.rs
11
src/dtype.rs
@ -26,6 +26,7 @@ pub trait WithDType: Sized + Copy {
|
||||
|
||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
|
||||
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>;
|
||||
fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>>;
|
||||
}
|
||||
|
||||
macro_rules! with_dtype {
|
||||
@ -37,6 +38,16 @@ macro_rules! with_dtype {
|
||||
CpuStorage::$dtype(data)
|
||||
}
|
||||
|
||||
fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>> {
|
||||
match s {
|
||||
CpuStorage::$dtype(data) => Ok(data),
|
||||
_ => Err(Error::UnexpectedDType {
|
||||
expected: DType::$dtype,
|
||||
got: s.dtype(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
|
||||
match s {
|
||||
CpuStorage::$dtype(data) => Ok(data),
|
||||
|
@ -313,10 +313,20 @@ impl Tensor {
|
||||
crate::StridedIndex::new(self.dims(), self.stride())
|
||||
}
|
||||
|
||||
pub fn as_slice<S: crate::WithDType>(&self) -> Result<&[S]> {
|
||||
/// Returns data from the underlying storage, this does not take the strides
|
||||
/// into account so the size of the resulting buffer might be larger than the
|
||||
/// tensor number of elements.
|
||||
pub fn storage_data<S: crate::WithDType>(&self) -> Result<std::borrow::Cow<[S]>> {
|
||||
match &self.storage {
|
||||
Storage::Cpu(cpu_storage) => S::cpu_storage_as_slice(cpu_storage),
|
||||
Storage::Cuda { .. } => todo!(),
|
||||
Storage::Cpu(cpu_storage) => {
|
||||
let slice = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
Ok(std::borrow::Cow::Borrowed(slice))
|
||||
}
|
||||
Storage::Cuda(slice) => {
|
||||
let cpu_storage = slice.to_cpu_storage()?;
|
||||
let storage_data = S::cpu_storage_data(cpu_storage)?;
|
||||
Ok(std::borrow::Cow::Owned(storage_data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user