mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Start adding some ops.
This commit is contained in:
24
src/dtype.rs
24
src/dtype.rs
@ -1,4 +1,4 @@
|
||||
use crate::CpuStorage;
|
||||
use crate::{CpuStorage, Error, Result};
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum DType {
|
||||
@ -19,6 +19,8 @@ pub trait WithDType: Sized + Copy {
|
||||
const DTYPE: DType;
|
||||
|
||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage;
|
||||
|
||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
|
||||
}
|
||||
|
||||
impl WithDType for f32 {
|
||||
@ -27,6 +29,16 @@ impl WithDType for f32 {
|
||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||
CpuStorage::F32(data.to_vec())
|
||||
}
|
||||
|
||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
|
||||
match s {
|
||||
CpuStorage::F32(data) => Ok(data),
|
||||
_ => Err(Error::UnexpectedDType {
|
||||
expected: DType::F32,
|
||||
got: s.dtype(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WithDType for f64 {
|
||||
@ -35,4 +47,14 @@ impl WithDType for f64 {
|
||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||
CpuStorage::F64(data.to_vec())
|
||||
}
|
||||
|
||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
|
||||
match s {
|
||||
CpuStorage::F64(data) => Ok(data),
|
||||
_ => Err(Error::UnexpectedDType {
|
||||
expected: DType::F64,
|
||||
got: s.dtype(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
15
src/error.rs
15
src/error.rs
@ -1,10 +1,15 @@
|
||||
use crate::{DType, Shape};
|
||||
|
||||
/// Main library error type.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("invalid shapes in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||
BinaryInvalidShape {
|
||||
lhs: Vec<usize>,
|
||||
rhs: Vec<usize>,
|
||||
#[error("unexpected dtype, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedDType { expected: DType, got: DType },
|
||||
|
||||
#[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||
ShapeMismatchBinaryOp {
|
||||
lhs: Shape,
|
||||
rhs: Shape,
|
||||
op: &'static str,
|
||||
},
|
||||
|
||||
@ -12,7 +17,7 @@ pub enum Error {
|
||||
UnexpectedNumberOfDims {
|
||||
expected: usize,
|
||||
got: usize,
|
||||
shape: Vec<usize>,
|
||||
shape: Shape,
|
||||
},
|
||||
}
|
||||
|
||||
|
16
src/shape.rs
16
src/shape.rs
@ -1,4 +1,6 @@
|
||||
use crate::{Error, Result};
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct Shape(pub(crate) Vec<usize>);
|
||||
|
||||
impl std::fmt::Debug for Shape {
|
||||
@ -56,6 +58,10 @@ impl From<(usize, usize, usize)> for Shape {
|
||||
}
|
||||
|
||||
impl Shape {
|
||||
pub fn from_dims(dims: &[usize]) -> Self {
|
||||
Self(dims.to_vec())
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
@ -76,7 +82,7 @@ impl Shape {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 0,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -89,7 +95,7 @@ impl Shape {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 1,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -102,7 +108,7 @@ impl Shape {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 2,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -115,7 +121,7 @@ impl Shape {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 3,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -128,7 +134,7 @@ impl Shape {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 4,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Result, Shape};
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[allow(dead_code)]
|
||||
@ -45,11 +45,46 @@ impl Tensor {
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||
// TODO: properly use the strides here.
|
||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||
let lhs = self.shape();
|
||||
let rhs = rhs.shape();
|
||||
if lhs != rhs {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: lhs.clone(),
|
||||
rhs: rhs.clone(),
|
||||
op,
|
||||
})
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add(&self, rhs: &Self) -> Result<Self> {
|
||||
self.same_shape_binary_op(rhs, "add")?;
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn mul(&self, rhs: &Self) -> Result<Self> {
|
||||
self.same_shape_binary_op(rhs, "mul")?;
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||
if self.rank() != 0 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 0,
|
||||
got: self.rank(),
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
match &self.0.storage {
|
||||
Storage::Cpu(cpu_storage) => {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
Ok(data[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
||||
// TODO: properly use the strides here.
|
||||
todo!()
|
||||
|
@ -6,7 +6,7 @@ fn add() -> Result<()> {
|
||||
let (dim1, dim2) = tensor.shape().r2()?;
|
||||
assert_eq!(dim1, 5);
|
||||
assert_eq!(dim2, 2);
|
||||
let tensor = Tensor::new([3., 1., 4.].as_slice(), Device::Cpu)?;
|
||||
let tensor = Tensor::new([3f32, 1., 4.].as_slice(), Device::Cpu)?;
|
||||
let dim1 = tensor.shape().r1()?;
|
||||
assert_eq!(dim1, 3);
|
||||
Ok(())
|
||||
|
Reference in New Issue
Block a user