mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add some unary ops.
This commit is contained in:
@ -27,6 +27,40 @@ impl std::fmt::Debug for Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! unary_op {
|
||||
($fn_name:ident, $op_name:ident, $impl_name:ident) => {
|
||||
pub fn $fn_name(&self) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage.$impl_name(self.shape(), self.stride())?;
|
||||
let tensor_ = Tensor_ {
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
op: Some(Op::$op_name(self.clone())),
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! binary_op {
|
||||
($fn_name:ident, $op_name:ident, $impl_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
let storage =
|
||||
self.storage
|
||||
.$impl_name(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
||||
let tensor_ = Tensor_ {
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
op: Some(Op::$op_name(self.clone(), rhs.clone())),
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
||||
let shape = shape.into();
|
||||
@ -70,34 +104,11 @@ impl Tensor {
|
||||
|
||||
// TODO: Also make an inplace version or a pre-allocated? This could be tricky
|
||||
// if this can create cycles in the compute graph.
|
||||
pub fn add(&self, rhs: &Self) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, "add")?;
|
||||
let storage = self
|
||||
.storage
|
||||
.add_impl(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
||||
let tensor_ = Tensor_ {
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
op: Some(Op::Add(self.clone(), rhs.clone())),
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub fn mul(&self, rhs: &Self) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, "mul")?;
|
||||
let storage = self
|
||||
.storage
|
||||
.mul_impl(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
||||
let tensor_ = Tensor_ {
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
op: Some(Op::Mul(self.clone(), rhs.clone())),
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
binary_op!(add, Add, add_impl);
|
||||
binary_op!(mul, Mul, mul_impl);
|
||||
|
||||
unary_op!(sqr, Sqr, sqr_impl);
|
||||
unary_op!(sqrt, Sqrt, sqrt_impl);
|
||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||
if self.rank() != 0 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
@ -135,8 +146,20 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
|
||||
// TODO: Similar to to_vec1 then reshape the resulting vec?
|
||||
todo!()
|
||||
let (dim1, dim2) = self.shape().r2()?;
|
||||
match &self.storage {
|
||||
Storage::Cpu(cpu_storage) => {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
let mut rows = vec![];
|
||||
let mut src_index = self.strided_index();
|
||||
for _idx_row in 0..dim1 {
|
||||
let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
|
||||
rows.push(row)
|
||||
}
|
||||
assert!(src_index.next().is_none());
|
||||
Ok(rows)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
|
Reference in New Issue
Block a user