mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add more to the binary operators.
This commit is contained in:
16
src/error.rs
16
src/error.rs
@ -1,4 +1,4 @@
|
|||||||
use crate::{DType, Shape};
|
use crate::{DType, Device, Shape};
|
||||||
|
|
||||||
/// Main library error type.
|
/// Main library error type.
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -13,6 +13,20 @@ pub enum Error {
|
|||||||
op: &'static str,
|
op: &'static str,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||||
|
DeviceMismatchBinaryOp {
|
||||||
|
lhs: Device,
|
||||||
|
rhs: Device,
|
||||||
|
op: &'static str,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||||
|
DTypeMismatchBinaryOp {
|
||||||
|
lhs: DType,
|
||||||
|
rhs: DType,
|
||||||
|
op: &'static str,
|
||||||
|
},
|
||||||
|
|
||||||
#[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
|
#[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
|
||||||
UnexpectedNumberOfDims {
|
UnexpectedNumberOfDims {
|
||||||
expected: usize,
|
expected: usize,
|
||||||
|
@ -144,6 +144,7 @@ impl Shape {
|
|||||||
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
|
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
|
||||||
self.0
|
self.0
|
||||||
.iter()
|
.iter()
|
||||||
|
.rev()
|
||||||
.scan(1, |prod, u| {
|
.scan(1, |prod, u| {
|
||||||
let prod_pre_mult = *prod;
|
let prod_pre_mult = *prod;
|
||||||
*prod *= u;
|
*prod *= u;
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{DType, Device};
|
use crate::{DType, Device, Error, Result, Shape};
|
||||||
|
|
||||||
// TODO: Think about whether we would be better off with a dtype and
|
// TODO: Think about whether we would be better off with a dtype and
|
||||||
// a buffer as an owned slice of bytes.
|
// a buffer as an owned slice of bytes.
|
||||||
@ -35,4 +35,75 @@ impl Storage {
|
|||||||
Self::Cpu(storage) => storage.dtype(),
|
Self::Cpu(storage) => storage.dtype(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||||
|
let lhs = self.device();
|
||||||
|
let rhs = rhs.device();
|
||||||
|
if lhs != rhs {
|
||||||
|
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op })
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn same_dtype(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||||
|
let lhs = self.dtype();
|
||||||
|
let rhs = rhs.dtype();
|
||||||
|
if lhs != rhs {
|
||||||
|
Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op })
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn add_impl(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
shape: &Shape,
|
||||||
|
_lhs_stride: &[usize],
|
||||||
|
_rhs_stride: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.same_device(rhs, "add")?;
|
||||||
|
self.same_dtype(rhs, "add")?;
|
||||||
|
// The ggml implementation has different paths based on whether the rhs is contiguous
|
||||||
|
// or not, for now we only consider the general case but we should benchmark and do the
|
||||||
|
// same if it helps.
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/aacdbd40562684665b6f7b8ba6695b7a2088bbb0/ggml.c#L7895
|
||||||
|
match (self, rhs) {
|
||||||
|
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => match (lhs, rhs) {
|
||||||
|
(CpuStorage::F32(_), CpuStorage::F32(_)) => {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let data = vec![0f32; elem_count];
|
||||||
|
// TODO: properly fill data with the sum
|
||||||
|
Ok(Storage::Cpu(CpuStorage::F32(data)))
|
||||||
|
}
|
||||||
|
(CpuStorage::F64(_), CpuStorage::F64(_)) => {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let data = vec![0f64; elem_count];
|
||||||
|
// TODO: properly fill data with the sum
|
||||||
|
Ok(Storage::Cpu(CpuStorage::F64(data)))
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// This should be covered by the dtype check above.
|
||||||
|
Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: lhs.dtype(),
|
||||||
|
rhs: rhs.dtype(),
|
||||||
|
op: "add",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn mul_impl(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
_shape: &Shape,
|
||||||
|
_lhs_stride: &[usize],
|
||||||
|
_rhs_stride: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.same_device(rhs, "mul")?;
|
||||||
|
self.same_dtype(rhs, "mul")?;
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub(crate) struct Tensor_ {
|
pub struct Tensor_ {
|
||||||
storage: Storage,
|
storage: Storage,
|
||||||
shape: Shape,
|
shape: Shape,
|
||||||
// The strides are given in number of elements and not in bytes.
|
// The strides are given in number of elements and not in bytes.
|
||||||
@ -10,8 +10,17 @@ pub(crate) struct Tensor_ {
|
|||||||
op: Option<Op>,
|
op: Option<Op>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct Tensor(Arc<Tensor_>);
|
pub struct Tensor(Arc<Tensor_>);
|
||||||
|
|
||||||
|
impl std::ops::Deref for Tensor {
|
||||||
|
type Target = Tensor_;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
self.0.as_ref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for Tensor {
|
impl std::fmt::Debug for Tensor {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(f, "[{:?}, {:?}]", &self.shape().dims(), self.device())
|
write!(f, "[{:?}, {:?}]", &self.shape().dims(), self.device())
|
||||||
@ -45,7 +54,7 @@ impl Tensor {
|
|||||||
Ok(Self(Arc::new(tensor_)))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||||
let lhs = self.shape();
|
let lhs = self.shape();
|
||||||
let rhs = rhs.shape();
|
let rhs = rhs.shape();
|
||||||
if lhs != rhs {
|
if lhs != rhs {
|
||||||
@ -55,18 +64,38 @@ impl Tensor {
|
|||||||
op,
|
op,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
Ok(())
|
Ok(lhs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Also make an inplace version or a pre-allocated? This could be tricky
|
||||||
|
// if this can create cycles in the compute graph.
|
||||||
pub fn add(&self, rhs: &Self) -> Result<Self> {
|
pub fn add(&self, rhs: &Self) -> Result<Self> {
|
||||||
self.same_shape_binary_op(rhs, "add")?;
|
let shape = self.same_shape_binary_op(rhs, "add")?;
|
||||||
todo!()
|
let storage = self
|
||||||
|
.storage
|
||||||
|
.add_impl(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
storage,
|
||||||
|
shape: shape.clone(),
|
||||||
|
stride: shape.stride_contiguous(),
|
||||||
|
op: Some(Op::Add(self.clone(), rhs.clone())),
|
||||||
|
};
|
||||||
|
Ok(Self(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn mul(&self, rhs: &Self) -> Result<Self> {
|
pub fn mul(&self, rhs: &Self) -> Result<Self> {
|
||||||
self.same_shape_binary_op(rhs, "mul")?;
|
let shape = self.same_shape_binary_op(rhs, "mul")?;
|
||||||
todo!()
|
let storage = self
|
||||||
|
.storage
|
||||||
|
.mul_impl(&rhs.storage, shape, self.stride(), rhs.stride())?;
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
storage,
|
||||||
|
shape: shape.clone(),
|
||||||
|
stride: shape.stride_contiguous(),
|
||||||
|
op: Some(Op::Mul(self.clone(), rhs.clone())),
|
||||||
|
};
|
||||||
|
Ok(Self(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||||
@ -77,7 +106,7 @@ impl Tensor {
|
|||||||
shape: self.shape().clone(),
|
shape: self.shape().clone(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
match &self.0.storage {
|
match &self.storage {
|
||||||
Storage::Cpu(cpu_storage) => {
|
Storage::Cpu(cpu_storage) => {
|
||||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||||
Ok(data[0])
|
Ok(data[0])
|
||||||
@ -96,15 +125,15 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
self.0.storage.dtype()
|
self.storage.dtype()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn device(&self) -> Device {
|
pub fn device(&self) -> Device {
|
||||||
self.0.storage.device()
|
self.storage.device()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn shape(&self) -> &Shape {
|
pub fn shape(&self) -> &Shape {
|
||||||
&self.0.shape
|
&self.shape
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dims(&self) -> &[usize] {
|
pub fn dims(&self) -> &[usize] {
|
||||||
@ -112,7 +141,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn stride(&self) -> &[usize] {
|
pub fn stride(&self) -> &[usize] {
|
||||||
&self.0.stride
|
&self.stride
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rank(&self) -> usize {
|
pub fn rank(&self) -> usize {
|
||||||
|
Reference in New Issue
Block a user