mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add an easy way to create tensor objects.
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
storage::{CpuStorage, Storage},
|
||||
DType,
|
||||
DType, Result, Shape,
|
||||
};
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
@ -8,11 +8,38 @@ pub enum Device {
|
||||
Cpu,
|
||||
}
|
||||
|
||||
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
|
||||
pub trait NdArray {
|
||||
fn shape(&self) -> Result<Shape>;
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage;
|
||||
}
|
||||
|
||||
impl<S: crate::WithDType> NdArray for S {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from(()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
S::to_cpu_storage(&[*self])
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: crate::WithDType> NdArray for &[S] {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from(self.len()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
S::to_cpu_storage(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub(crate) fn zeros(&self, shape: &[usize], dtype: DType) -> Storage {
|
||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let elem_count: usize = shape.iter().product();
|
||||
let elem_count = shape.elem_count();
|
||||
let storage = match dtype {
|
||||
DType::F32 => {
|
||||
let data = vec![0f32; elem_count];
|
||||
@ -27,4 +54,10 @@ impl Device {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Storage {
|
||||
match self {
|
||||
Device::Cpu => Storage::Cpu(array.to_cpu_storage()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user