mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add more to the binary operators.
This commit is contained in:
@ -2,7 +2,7 @@ use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) struct Tensor_ {
|
||||
pub struct Tensor_ {
|
||||
storage: Storage,
|
||||
shape: Shape,
|
||||
// The strides are given in number of elements and not in bytes.
|
||||
@ -10,8 +10,17 @@ pub(crate) struct Tensor_ {
|
||||
op: Option<Op>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Tensor(Arc<Tensor_>);
|
||||
|
||||
impl std::ops::Deref for Tensor {
|
||||
type Target = Tensor_;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Tensor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "[{:?}, {:?}]", &self.shape().dims(), self.device())
|
||||
@ -45,7 +54,7 @@ impl Tensor {
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||
let lhs = self.shape();
|
||||
let rhs = rhs.shape();
|
||||
if lhs != rhs {
|
||||
@ -55,18 +64,38 @@ impl Tensor {
|
||||
op,
|
||||
})
|
||||
} else {
|
||||
Ok(())
|
||||
Ok(lhs)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Also make an inplace version or a pre-allocated? This could be tricky
|
||||
// if this can create cycles in the compute graph.
|
||||
pub fn add(&self, rhs: &Self) -> Result<Self> {
|
||||
self.same_shape_binary_op(rhs, "add")?;
|
||||
todo!()
|
||||
let shape = self.same_shape_binary_op(rhs, "add")?;
|
||||
let storage = self
|
||||
.storage
|
||||
.add_impl(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
||||
let tensor_ = Tensor_ {
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
op: Some(Op::Add(self.clone(), rhs.clone())),
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub fn mul(&self, rhs: &Self) -> Result<Self> {
|
||||
self.same_shape_binary_op(rhs, "mul")?;
|
||||
todo!()
|
||||
let shape = self.same_shape_binary_op(rhs, "mul")?;
|
||||
let storage = self
|
||||
.storage
|
||||
.mul_impl(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
||||
let tensor_ = Tensor_ {
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
op: Some(Op::Mul(self.clone(), rhs.clone())),
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||
@ -77,7 +106,7 @@ impl Tensor {
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
match &self.0.storage {
|
||||
match &self.storage {
|
||||
Storage::Cpu(cpu_storage) => {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
Ok(data[0])
|
||||
@ -96,15 +125,15 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.0.storage.dtype()
|
||||
self.storage.dtype()
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device {
|
||||
self.0.storage.device()
|
||||
self.storage.device()
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &Shape {
|
||||
&self.0.shape
|
||||
&self.shape
|
||||
}
|
||||
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
@ -112,7 +141,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn stride(&self) -> &[usize] {
|
||||
&self.0.stride
|
||||
&self.stride
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
|
Reference in New Issue
Block a user