mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add an easy way to create tensor objects.
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
storage::{CpuStorage, Storage},
|
||||
DType,
|
||||
DType, Result, Shape,
|
||||
};
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
@ -8,11 +8,38 @@ pub enum Device {
|
||||
Cpu,
|
||||
}
|
||||
|
||||
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
|
||||
pub trait NdArray {
|
||||
fn shape(&self) -> Result<Shape>;
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage;
|
||||
}
|
||||
|
||||
impl<S: crate::WithDType> NdArray for S {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from(()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
S::to_cpu_storage(&[*self])
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: crate::WithDType> NdArray for &[S] {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from(self.len()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
S::to_cpu_storage(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub(crate) fn zeros(&self, shape: &[usize], dtype: DType) -> Storage {
|
||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let elem_count: usize = shape.iter().product();
|
||||
let elem_count = shape.elem_count();
|
||||
let storage = match dtype {
|
||||
DType::F32 => {
|
||||
let data = vec![0f32; elem_count];
|
||||
@ -27,4 +54,10 @@ impl Device {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Storage {
|
||||
match self {
|
||||
Device::Cpu => Storage::Cpu(array.to_cpu_storage()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
14
src/dtype.rs
14
src/dtype.rs
@ -1,3 +1,5 @@
|
||||
use crate::CpuStorage;
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum DType {
|
||||
F32,
|
||||
@ -13,14 +15,24 @@ impl DType {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WithDType {
|
||||
pub trait WithDType: Sized + Copy {
|
||||
const DTYPE: DType;
|
||||
|
||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage;
|
||||
}
|
||||
|
||||
impl WithDType for f32 {
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||
CpuStorage::F32(data.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl WithDType for f64 {
|
||||
const DTYPE: DType = DType::F64;
|
||||
|
||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||
CpuStorage::F64(data.to_vec())
|
||||
}
|
||||
}
|
||||
|
@ -9,4 +9,6 @@ mod tensor;
|
||||
pub use device::Device;
|
||||
pub use dtype::{DType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use shape::Shape;
|
||||
pub use storage::{CpuStorage, Storage};
|
||||
pub use tensor::Tensor;
|
||||
|
@ -2,7 +2,7 @@ use crate::{DType, Device};
|
||||
|
||||
// TODO: Think about whether we would be better off with a dtype and
|
||||
// a buffer as an owned slice of bytes.
|
||||
pub(crate) enum CpuStorage {
|
||||
pub enum CpuStorage {
|
||||
F32(Vec<f32>),
|
||||
F64(Vec<f64>),
|
||||
}
|
||||
@ -17,18 +17,18 @@ impl CpuStorage {
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) enum Storage {
|
||||
pub enum Storage {
|
||||
Cpu(CpuStorage),
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
pub(crate) fn device(&self) -> Device {
|
||||
pub fn device(&self) -> Device {
|
||||
match self {
|
||||
Self::Cpu(_) => Device::Cpu,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn dtype(&self) -> DType {
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self {
|
||||
Self::Cpu(storage) => storage.dtype(),
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::{op::Op, shape, storage::Storage, DType, Device};
|
||||
use crate::{op::Op, shape, storage::Storage, DType, Device, Result};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[allow(dead_code)]
|
||||
@ -14,15 +14,28 @@ pub struct Tensor(Arc<Tensor_>);
|
||||
impl Tensor {
|
||||
pub fn zeros<S: Into<shape::Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape.0, dtype);
|
||||
let rank = shape.0.len();
|
||||
let storage = device.zeros(&shape, dtype);
|
||||
let rank = shape.rank();
|
||||
let tensor_ = Tensor_ {
|
||||
storage,
|
||||
shape,
|
||||
stride: vec![1; rank],
|
||||
op: None,
|
||||
};
|
||||
Tensor(Arc::new(tensor_))
|
||||
Self(Arc::new(tensor_))
|
||||
}
|
||||
|
||||
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
|
||||
let shape = array.shape()?;
|
||||
let storage = device.tensor(array);
|
||||
let rank = shape.rank();
|
||||
let tensor_ = Tensor_ {
|
||||
storage,
|
||||
shape,
|
||||
stride: vec![1; rank],
|
||||
op: None,
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
@ -38,7 +51,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
&self.shape().dims()
|
||||
self.shape().dims()
|
||||
}
|
||||
|
||||
pub fn stride(&self) -> &[usize] {
|
||||
|
@ -2,9 +2,12 @@ use candle::{DType, Device, Result, Tensor};
|
||||
|
||||
#[test]
|
||||
fn add() -> Result<()> {
|
||||
let tensor = Tensor::zeros(&[5, 2], DType::F32, Device::Cpu);
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, Device::Cpu);
|
||||
let (dim1, dim2) = tensor.shape().r2()?;
|
||||
assert_eq!(dim1, 5);
|
||||
assert_eq!(dim2, 2);
|
||||
let tensor = Tensor::new([3., 1., 4.].as_slice(), Device::Cpu)?;
|
||||
let dim1 = tensor.shape().r1()?;
|
||||
assert_eq!(dim1, 3);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user