mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Start adding some ops.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Result, Shape};
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[allow(dead_code)]
|
||||
@ -45,11 +45,46 @@ impl Tensor {
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||
// TODO: properly use the strides here.
|
||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||
let lhs = self.shape();
|
||||
let rhs = rhs.shape();
|
||||
if lhs != rhs {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: lhs.clone(),
|
||||
rhs: rhs.clone(),
|
||||
op,
|
||||
})
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add(&self, rhs: &Self) -> Result<Self> {
|
||||
self.same_shape_binary_op(rhs, "add")?;
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn mul(&self, rhs: &Self) -> Result<Self> {
|
||||
self.same_shape_binary_op(rhs, "mul")?;
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||
if self.rank() != 0 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 0,
|
||||
got: self.rank(),
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
match &self.0.storage {
|
||||
Storage::Cpu(cpu_storage) => {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
Ok(data[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
||||
// TODO: properly use the strides here.
|
||||
todo!()
|
||||
|
Reference in New Issue
Block a user