mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Start adding some ops.
This commit is contained in:
24
src/dtype.rs
24
src/dtype.rs
@ -1,4 +1,4 @@
|
|||||||
use crate::CpuStorage;
|
use crate::{CpuStorage, Error, Result};
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum DType {
|
pub enum DType {
|
||||||
@ -19,6 +19,8 @@ pub trait WithDType: Sized + Copy {
|
|||||||
const DTYPE: DType;
|
const DTYPE: DType;
|
||||||
|
|
||||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage;
|
fn to_cpu_storage(data: &[Self]) -> CpuStorage;
|
||||||
|
|
||||||
|
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WithDType for f32 {
|
impl WithDType for f32 {
|
||||||
@ -27,6 +29,16 @@ impl WithDType for f32 {
|
|||||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||||
CpuStorage::F32(data.to_vec())
|
CpuStorage::F32(data.to_vec())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
|
||||||
|
match s {
|
||||||
|
CpuStorage::F32(data) => Ok(data),
|
||||||
|
_ => Err(Error::UnexpectedDType {
|
||||||
|
expected: DType::F32,
|
||||||
|
got: s.dtype(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WithDType for f64 {
|
impl WithDType for f64 {
|
||||||
@ -35,4 +47,14 @@ impl WithDType for f64 {
|
|||||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||||
CpuStorage::F64(data.to_vec())
|
CpuStorage::F64(data.to_vec())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
|
||||||
|
match s {
|
||||||
|
CpuStorage::F64(data) => Ok(data),
|
||||||
|
_ => Err(Error::UnexpectedDType {
|
||||||
|
expected: DType::F64,
|
||||||
|
got: s.dtype(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
15
src/error.rs
15
src/error.rs
@ -1,10 +1,15 @@
|
|||||||
|
use crate::{DType, Shape};
|
||||||
|
|
||||||
/// Main library error type.
|
/// Main library error type.
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
#[error("invalid shapes in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
#[error("unexpected dtype, expected: {expected:?}, got: {got:?}")]
|
||||||
BinaryInvalidShape {
|
UnexpectedDType { expected: DType, got: DType },
|
||||||
lhs: Vec<usize>,
|
|
||||||
rhs: Vec<usize>,
|
#[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||||
|
ShapeMismatchBinaryOp {
|
||||||
|
lhs: Shape,
|
||||||
|
rhs: Shape,
|
||||||
op: &'static str,
|
op: &'static str,
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -12,7 +17,7 @@ pub enum Error {
|
|||||||
UnexpectedNumberOfDims {
|
UnexpectedNumberOfDims {
|
||||||
expected: usize,
|
expected: usize,
|
||||||
got: usize,
|
got: usize,
|
||||||
shape: Vec<usize>,
|
shape: Shape,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
16
src/shape.rs
16
src/shape.rs
@ -1,4 +1,6 @@
|
|||||||
use crate::{Error, Result};
|
use crate::{Error, Result};
|
||||||
|
|
||||||
|
#[derive(Clone, PartialEq, Eq)]
|
||||||
pub struct Shape(pub(crate) Vec<usize>);
|
pub struct Shape(pub(crate) Vec<usize>);
|
||||||
|
|
||||||
impl std::fmt::Debug for Shape {
|
impl std::fmt::Debug for Shape {
|
||||||
@ -56,6 +58,10 @@ impl From<(usize, usize, usize)> for Shape {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Shape {
|
impl Shape {
|
||||||
|
pub fn from_dims(dims: &[usize]) -> Self {
|
||||||
|
Self(dims.to_vec())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn rank(&self) -> usize {
|
pub fn rank(&self) -> usize {
|
||||||
self.0.len()
|
self.0.len()
|
||||||
}
|
}
|
||||||
@ -76,7 +82,7 @@ impl Shape {
|
|||||||
Err(Error::UnexpectedNumberOfDims {
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
expected: 0,
|
expected: 0,
|
||||||
got: shape.len(),
|
got: shape.len(),
|
||||||
shape: shape.to_vec(),
|
shape: self.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -89,7 +95,7 @@ impl Shape {
|
|||||||
Err(Error::UnexpectedNumberOfDims {
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
expected: 1,
|
expected: 1,
|
||||||
got: shape.len(),
|
got: shape.len(),
|
||||||
shape: shape.to_vec(),
|
shape: self.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -102,7 +108,7 @@ impl Shape {
|
|||||||
Err(Error::UnexpectedNumberOfDims {
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
expected: 2,
|
expected: 2,
|
||||||
got: shape.len(),
|
got: shape.len(),
|
||||||
shape: shape.to_vec(),
|
shape: self.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -115,7 +121,7 @@ impl Shape {
|
|||||||
Err(Error::UnexpectedNumberOfDims {
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
expected: 3,
|
expected: 3,
|
||||||
got: shape.len(),
|
got: shape.len(),
|
||||||
shape: shape.to_vec(),
|
shape: self.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -128,7 +134,7 @@ impl Shape {
|
|||||||
Err(Error::UnexpectedNumberOfDims {
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
expected: 4,
|
expected: 4,
|
||||||
got: shape.len(),
|
got: shape.len(),
|
||||||
shape: shape.to_vec(),
|
shape: self.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{op::Op, storage::Storage, DType, Device, Result, Shape};
|
use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
@ -45,11 +45,46 @@ impl Tensor {
|
|||||||
Ok(Self(Arc::new(tensor_)))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||||
// TODO: properly use the strides here.
|
let lhs = self.shape();
|
||||||
|
let rhs = rhs.shape();
|
||||||
|
if lhs != rhs {
|
||||||
|
Err(Error::ShapeMismatchBinaryOp {
|
||||||
|
lhs: lhs.clone(),
|
||||||
|
rhs: rhs.clone(),
|
||||||
|
op,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add(&self, rhs: &Self) -> Result<Self> {
|
||||||
|
self.same_shape_binary_op(rhs, "add")?;
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn mul(&self, rhs: &Self) -> Result<Self> {
|
||||||
|
self.same_shape_binary_op(rhs, "mul")?;
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||||
|
if self.rank() != 0 {
|
||||||
|
return Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: 0,
|
||||||
|
got: self.rank(),
|
||||||
|
shape: self.shape().clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
match &self.0.storage {
|
||||||
|
Storage::Cpu(cpu_storage) => {
|
||||||
|
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||||
|
Ok(data[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
||||||
// TODO: properly use the strides here.
|
// TODO: properly use the strides here.
|
||||||
todo!()
|
todo!()
|
||||||
|
@ -6,7 +6,7 @@ fn add() -> Result<()> {
|
|||||||
let (dim1, dim2) = tensor.shape().r2()?;
|
let (dim1, dim2) = tensor.shape().r2()?;
|
||||||
assert_eq!(dim1, 5);
|
assert_eq!(dim1, 5);
|
||||||
assert_eq!(dim2, 2);
|
assert_eq!(dim2, 2);
|
||||||
let tensor = Tensor::new([3., 1., 4.].as_slice(), Device::Cpu)?;
|
let tensor = Tensor::new([3f32, 1., 4.].as_slice(), Device::Cpu)?;
|
||||||
let dim1 = tensor.shape().r1()?;
|
let dim1 = tensor.shape().r1()?;
|
||||||
assert_eq!(dim1, 3);
|
assert_eq!(dim1, 3);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
Reference in New Issue
Block a user