mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add some unary ops.
This commit is contained in:
@ -4,5 +4,7 @@ use crate::Tensor;
|
|||||||
pub(crate) enum Op {
|
pub(crate) enum Op {
|
||||||
Add(Tensor, Tensor),
|
Add(Tensor, Tensor),
|
||||||
Mul(Tensor, Tensor),
|
Mul(Tensor, Tensor),
|
||||||
|
Sqr(Tensor),
|
||||||
|
Sqrt(Tensor),
|
||||||
// TODO: Support for custom ops.
|
// TODO: Support for custom ops.
|
||||||
}
|
}
|
||||||
|
@ -91,18 +91,18 @@ impl Shape {
|
|||||||
}
|
}
|
||||||
|
|
||||||
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
|
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
|
||||||
extract_dims!(r1, 1, |d: &Vec<usize>| d[0], usize);
|
extract_dims!(r1, 1, |d: &[usize]| d[0], usize);
|
||||||
extract_dims!(r2, 2, |d: &Vec<usize>| (d[0], d[1]), (usize, usize));
|
extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||||
extract_dims!(
|
extract_dims!(
|
||||||
r3,
|
r3,
|
||||||
3,
|
3,
|
||||||
|d: &Vec<usize>| (d[0], d[1], d[2]),
|
|d: &[usize]| (d[0], d[1], d[2]),
|
||||||
(usize, usize, usize)
|
(usize, usize, usize)
|
||||||
);
|
);
|
||||||
extract_dims!(
|
extract_dims!(
|
||||||
r4,
|
r4,
|
||||||
4,
|
4,
|
||||||
|d: &Vec<usize>| (d[0], d[1], d[2], d[3]),
|
|d: &[usize]| (d[0], d[1], d[2], d[3]),
|
||||||
(usize, usize, usize, usize)
|
(usize, usize, usize, usize)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
140
src/storage.rs
140
src/storage.rs
@ -81,6 +81,63 @@ pub enum Storage {
|
|||||||
Cpu(CpuStorage),
|
Cpu(CpuStorage),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trait UnaryOp {
|
||||||
|
const NAME: &'static str;
|
||||||
|
fn f32(v1: f32) -> f32;
|
||||||
|
fn f64(v1: f64) -> f64;
|
||||||
|
}
|
||||||
|
|
||||||
|
trait BinaryOp {
|
||||||
|
const NAME: &'static str;
|
||||||
|
fn f32(v1: f32, v2: f32) -> f32;
|
||||||
|
fn f64(v1: f64, v2: f64) -> f64;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Add;
|
||||||
|
struct Mul;
|
||||||
|
struct Sqr;
|
||||||
|
struct Sqrt;
|
||||||
|
|
||||||
|
impl BinaryOp for Add {
|
||||||
|
const NAME: &'static str = "add";
|
||||||
|
fn f32(v1: f32, v2: f32) -> f32 {
|
||||||
|
v1 + v2
|
||||||
|
}
|
||||||
|
fn f64(v1: f64, v2: f64) -> f64 {
|
||||||
|
v1 + v2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BinaryOp for Mul {
|
||||||
|
const NAME: &'static str = "mul";
|
||||||
|
fn f32(v1: f32, v2: f32) -> f32 {
|
||||||
|
v1 * v2
|
||||||
|
}
|
||||||
|
fn f64(v1: f64, v2: f64) -> f64 {
|
||||||
|
v1 * v2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UnaryOp for Sqr {
|
||||||
|
const NAME: &'static str = "sqr";
|
||||||
|
fn f32(v1: f32) -> f32 {
|
||||||
|
v1 * v1
|
||||||
|
}
|
||||||
|
fn f64(v1: f64) -> f64 {
|
||||||
|
v1 * v1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UnaryOp for Sqrt {
|
||||||
|
const NAME: &'static str = "sqrt";
|
||||||
|
fn f32(v1: f32) -> f32 {
|
||||||
|
v1.sqrt()
|
||||||
|
}
|
||||||
|
fn f64(v1: f64) -> f64 {
|
||||||
|
v1.sqrt()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Storage {
|
impl Storage {
|
||||||
pub fn device(&self) -> Device {
|
pub fn device(&self) -> Device {
|
||||||
match self {
|
match self {
|
||||||
@ -114,16 +171,34 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||||
|
// TODO: Different code path for the contiguous case?
|
||||||
|
match self {
|
||||||
|
Storage::Cpu(storage) => match storage {
|
||||||
|
CpuStorage::F32(storage) => {
|
||||||
|
let index = StridedIndex::new(shape.dims(), stride);
|
||||||
|
let data = index.map(|i| B::f32(storage[i])).collect();
|
||||||
|
Ok(Storage::Cpu(CpuStorage::F32(data)))
|
||||||
|
}
|
||||||
|
CpuStorage::F64(storage) => {
|
||||||
|
let index = StridedIndex::new(shape.dims(), stride);
|
||||||
|
let data = index.map(|i| B::f64(storage[i])).collect();
|
||||||
|
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Support broadcasting?
|
// TODO: Support broadcasting?
|
||||||
pub(crate) fn add_impl(
|
fn binary_impl<B: BinaryOp>(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
lhs_stride: &[usize],
|
lhs_stride: &[usize],
|
||||||
rhs_stride: &[usize],
|
rhs_stride: &[usize],
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
self.same_device(rhs, "add")?;
|
self.same_device(rhs, B::NAME)?;
|
||||||
self.same_dtype(rhs, "add")?;
|
self.same_dtype(rhs, B::NAME)?;
|
||||||
// The ggml implementation has different paths based on whether the rhs is contiguous
|
// The ggml implementation has different paths based on whether the rhs is contiguous
|
||||||
// or not, for now we only consider the general case but we should benchmark and do the
|
// or not, for now we only consider the general case but we should benchmark and do the
|
||||||
// same if it helps.
|
// same if it helps.
|
||||||
@ -135,7 +210,7 @@ impl Storage {
|
|||||||
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
||||||
let data = lhs_index
|
let data = lhs_index
|
||||||
.zip(rhs_index)
|
.zip(rhs_index)
|
||||||
.map(|(lhs_i, rhs_i)| lhs[lhs_i] + rhs[rhs_i])
|
.map(|(lhs_i, rhs_i)| B::f32(lhs[lhs_i], rhs[rhs_i]))
|
||||||
.collect();
|
.collect();
|
||||||
Ok(Storage::Cpu(CpuStorage::F32(data)))
|
Ok(Storage::Cpu(CpuStorage::F32(data)))
|
||||||
}
|
}
|
||||||
@ -144,7 +219,7 @@ impl Storage {
|
|||||||
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
||||||
let data = lhs_index
|
let data = lhs_index
|
||||||
.zip(rhs_index)
|
.zip(rhs_index)
|
||||||
.map(|(lhs_i, rhs_i)| lhs[lhs_i] + rhs[rhs_i])
|
.map(|(lhs_i, rhs_i)| B::f64(lhs[lhs_i], rhs[rhs_i]))
|
||||||
.collect();
|
.collect();
|
||||||
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
||||||
}
|
}
|
||||||
@ -153,14 +228,23 @@ impl Storage {
|
|||||||
Err(Error::DTypeMismatchBinaryOp {
|
Err(Error::DTypeMismatchBinaryOp {
|
||||||
lhs: lhs.dtype(),
|
lhs: lhs.dtype(),
|
||||||
rhs: rhs.dtype(),
|
rhs: rhs.dtype(),
|
||||||
op: "add",
|
op: B::NAME,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Support broadcasting?
|
pub(crate) fn add_impl(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
shape: &Shape,
|
||||||
|
lhs_stride: &[usize],
|
||||||
|
rhs_stride: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.binary_impl::<Add>(rhs, shape, lhs_stride, rhs_stride)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn mul_impl(
|
pub(crate) fn mul_impl(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
@ -168,38 +252,14 @@ impl Storage {
|
|||||||
lhs_stride: &[usize],
|
lhs_stride: &[usize],
|
||||||
rhs_stride: &[usize],
|
rhs_stride: &[usize],
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
self.same_device(rhs, "mul")?;
|
self.binary_impl::<Mul>(rhs, shape, lhs_stride, rhs_stride)
|
||||||
self.same_dtype(rhs, "mul")?;
|
}
|
||||||
// TODO: share this code with the add implementation, using a macro or a trait?
|
|
||||||
match (self, rhs) {
|
pub(crate) fn sqr_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => match (lhs, rhs) {
|
self.unary_impl::<Sqr>(shape, stride)
|
||||||
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
}
|
||||||
let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
|
|
||||||
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||||
let data = lhs_index
|
self.unary_impl::<Sqrt>(shape, stride)
|
||||||
.zip(rhs_index)
|
|
||||||
.map(|(lhs_i, rhs_i)| lhs[lhs_i] * rhs[rhs_i])
|
|
||||||
.collect();
|
|
||||||
Ok(Storage::Cpu(CpuStorage::F32(data)))
|
|
||||||
}
|
|
||||||
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
|
||||||
let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
|
|
||||||
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
|
||||||
let data = lhs_index
|
|
||||||
.zip(rhs_index)
|
|
||||||
.map(|(lhs_i, rhs_i)| lhs[lhs_i] * rhs[rhs_i])
|
|
||||||
.collect();
|
|
||||||
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
// This should be covered by the dtype check above.
|
|
||||||
Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: lhs.dtype(),
|
|
||||||
rhs: rhs.dtype(),
|
|
||||||
op: "add",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,6 +27,40 @@ impl std::fmt::Debug for Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
macro_rules! unary_op {
|
||||||
|
($fn_name:ident, $op_name:ident, $impl_name:ident) => {
|
||||||
|
pub fn $fn_name(&self) -> Result<Self> {
|
||||||
|
let shape = self.shape();
|
||||||
|
let storage = self.storage.$impl_name(self.shape(), self.stride())?;
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
storage,
|
||||||
|
shape: shape.clone(),
|
||||||
|
stride: shape.stride_contiguous(),
|
||||||
|
op: Some(Op::$op_name(self.clone())),
|
||||||
|
};
|
||||||
|
Ok(Self(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! binary_op {
|
||||||
|
($fn_name:ident, $op_name:ident, $impl_name:ident) => {
|
||||||
|
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||||
|
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||||
|
let storage =
|
||||||
|
self.storage
|
||||||
|
.$impl_name(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
storage,
|
||||||
|
shape: shape.clone(),
|
||||||
|
stride: shape.stride_contiguous(),
|
||||||
|
op: Some(Op::$op_name(self.clone(), rhs.clone())),
|
||||||
|
};
|
||||||
|
Ok(Self(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
@ -70,34 +104,11 @@ impl Tensor {
|
|||||||
|
|
||||||
// TODO: Also make an inplace version or a pre-allocated? This could be tricky
|
// TODO: Also make an inplace version or a pre-allocated? This could be tricky
|
||||||
// if this can create cycles in the compute graph.
|
// if this can create cycles in the compute graph.
|
||||||
pub fn add(&self, rhs: &Self) -> Result<Self> {
|
binary_op!(add, Add, add_impl);
|
||||||
let shape = self.same_shape_binary_op(rhs, "add")?;
|
binary_op!(mul, Mul, mul_impl);
|
||||||
let storage = self
|
|
||||||
.storage
|
|
||||||
.add_impl(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
|
||||||
let tensor_ = Tensor_ {
|
|
||||||
storage,
|
|
||||||
shape: shape.clone(),
|
|
||||||
stride: shape.stride_contiguous(),
|
|
||||||
op: Some(Op::Add(self.clone(), rhs.clone())),
|
|
||||||
};
|
|
||||||
Ok(Self(Arc::new(tensor_)))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn mul(&self, rhs: &Self) -> Result<Self> {
|
|
||||||
let shape = self.same_shape_binary_op(rhs, "mul")?;
|
|
||||||
let storage = self
|
|
||||||
.storage
|
|
||||||
.mul_impl(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
|
||||||
let tensor_ = Tensor_ {
|
|
||||||
storage,
|
|
||||||
shape: shape.clone(),
|
|
||||||
stride: shape.stride_contiguous(),
|
|
||||||
op: Some(Op::Mul(self.clone(), rhs.clone())),
|
|
||||||
};
|
|
||||||
Ok(Self(Arc::new(tensor_)))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
unary_op!(sqr, Sqr, sqr_impl);
|
||||||
|
unary_op!(sqrt, Sqrt, sqrt_impl);
|
||||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||||
if self.rank() != 0 {
|
if self.rank() != 0 {
|
||||||
return Err(Error::UnexpectedNumberOfDims {
|
return Err(Error::UnexpectedNumberOfDims {
|
||||||
@ -135,8 +146,20 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
|
pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
|
||||||
// TODO: Similar to to_vec1 then reshape the resulting vec?
|
let (dim1, dim2) = self.shape().r2()?;
|
||||||
todo!()
|
match &self.storage {
|
||||||
|
Storage::Cpu(cpu_storage) => {
|
||||||
|
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||||
|
let mut rows = vec![];
|
||||||
|
let mut src_index = self.strided_index();
|
||||||
|
for _idx_row in 0..dim1 {
|
||||||
|
let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
|
||||||
|
rows.push(row)
|
||||||
|
}
|
||||||
|
assert!(src_index.next().is_none());
|
||||||
|
Ok(rows)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
|
Reference in New Issue
Block a user