mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add some unary ops.
This commit is contained in:
140
src/storage.rs
140
src/storage.rs
@ -81,6 +81,63 @@ pub enum Storage {
|
||||
Cpu(CpuStorage),
|
||||
}
|
||||
|
||||
trait UnaryOp {
|
||||
const NAME: &'static str;
|
||||
fn f32(v1: f32) -> f32;
|
||||
fn f64(v1: f64) -> f64;
|
||||
}
|
||||
|
||||
trait BinaryOp {
|
||||
const NAME: &'static str;
|
||||
fn f32(v1: f32, v2: f32) -> f32;
|
||||
fn f64(v1: f64, v2: f64) -> f64;
|
||||
}
|
||||
|
||||
struct Add;
|
||||
struct Mul;
|
||||
struct Sqr;
|
||||
struct Sqrt;
|
||||
|
||||
impl BinaryOp for Add {
|
||||
const NAME: &'static str = "add";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 + v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 + v2
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Mul {
|
||||
const NAME: &'static str = "mul";
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
v1 * v2
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 * v2
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Sqr {
|
||||
const NAME: &'static str = "sqr";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1 * v1
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1 * v1
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Sqrt {
|
||||
const NAME: &'static str = "sqrt";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.sqrt()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
pub fn device(&self) -> Device {
|
||||
match self {
|
||||
@ -114,16 +171,34 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||
// TODO: Different code path for the contiguous case?
|
||||
match self {
|
||||
Storage::Cpu(storage) => match storage {
|
||||
CpuStorage::F32(storage) => {
|
||||
let index = StridedIndex::new(shape.dims(), stride);
|
||||
let data = index.map(|i| B::f32(storage[i])).collect();
|
||||
Ok(Storage::Cpu(CpuStorage::F32(data)))
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let index = StridedIndex::new(shape.dims(), stride);
|
||||
let data = index.map(|i| B::f64(storage[i])).collect();
|
||||
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Support broadcasting?
|
||||
pub(crate) fn add_impl(
|
||||
fn binary_impl<B: BinaryOp>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
self.same_device(rhs, "add")?;
|
||||
self.same_dtype(rhs, "add")?;
|
||||
self.same_device(rhs, B::NAME)?;
|
||||
self.same_dtype(rhs, B::NAME)?;
|
||||
// The ggml implementation has different paths based on whether the rhs is contiguous
|
||||
// or not, for now we only consider the general case but we should benchmark and do the
|
||||
// same if it helps.
|
||||
@ -135,7 +210,7 @@ impl Storage {
|
||||
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
||||
let data = lhs_index
|
||||
.zip(rhs_index)
|
||||
.map(|(lhs_i, rhs_i)| lhs[lhs_i] + rhs[rhs_i])
|
||||
.map(|(lhs_i, rhs_i)| B::f32(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect();
|
||||
Ok(Storage::Cpu(CpuStorage::F32(data)))
|
||||
}
|
||||
@ -144,7 +219,7 @@ impl Storage {
|
||||
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
||||
let data = lhs_index
|
||||
.zip(rhs_index)
|
||||
.map(|(lhs_i, rhs_i)| lhs[lhs_i] + rhs[rhs_i])
|
||||
.map(|(lhs_i, rhs_i)| B::f64(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect();
|
||||
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
||||
}
|
||||
@ -153,14 +228,23 @@ impl Storage {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: lhs.dtype(),
|
||||
rhs: rhs.dtype(),
|
||||
op: "add",
|
||||
op: B::NAME,
|
||||
})
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Support broadcasting?
|
||||
pub(crate) fn add_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
self.binary_impl::<Add>(rhs, shape, lhs_stride, rhs_stride)
|
||||
}
|
||||
|
||||
pub(crate) fn mul_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
@ -168,38 +252,14 @@ impl Storage {
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
self.same_device(rhs, "mul")?;
|
||||
self.same_dtype(rhs, "mul")?;
|
||||
// TODO: share this code with the add implementation, using a macro or a trait?
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => match (lhs, rhs) {
|
||||
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
||||
let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
|
||||
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
||||
let data = lhs_index
|
||||
.zip(rhs_index)
|
||||
.map(|(lhs_i, rhs_i)| lhs[lhs_i] * rhs[rhs_i])
|
||||
.collect();
|
||||
Ok(Storage::Cpu(CpuStorage::F32(data)))
|
||||
}
|
||||
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
||||
let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
|
||||
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
||||
let data = lhs_index
|
||||
.zip(rhs_index)
|
||||
.map(|(lhs_i, rhs_i)| lhs[lhs_i] * rhs[rhs_i])
|
||||
.collect();
|
||||
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
||||
}
|
||||
_ => {
|
||||
// This should be covered by the dtype check above.
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: lhs.dtype(),
|
||||
rhs: rhs.dtype(),
|
||||
op: "add",
|
||||
})
|
||||
}
|
||||
},
|
||||
}
|
||||
self.binary_impl::<Mul>(rhs, shape, lhs_stride, rhs_stride)
|
||||
}
|
||||
|
||||
pub(crate) fn sqr_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||
self.unary_impl::<Sqr>(shape, stride)
|
||||
}
|
||||
|
||||
pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||
self.unary_impl::<Sqrt>(shape, stride)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user