mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Dummy broadcast placeholder functions.
This commit is contained in:
@ -86,6 +86,10 @@ impl CpuStorage {
|
|||||||
lhs_stride: &[usize],
|
lhs_stride: &[usize],
|
||||||
rhs_stride: &[usize],
|
rhs_stride: &[usize],
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
let dims = shape.dims();
|
||||||
|
if dims.len() != lhs_stride.len() || dims.len() != rhs_stride.len() {
|
||||||
|
todo!("implement broadcast");
|
||||||
|
}
|
||||||
// The ggml implementation has different paths based on whether the rhs is contiguous
|
// The ggml implementation has different paths based on whether the rhs is contiguous
|
||||||
// or not, for now we only consider the general case but we should benchmark and do the
|
// or not, for now we only consider the general case but we should benchmark and do the
|
||||||
// same if it helps.
|
// same if it helps.
|
||||||
|
@ -331,8 +331,11 @@ impl CudaStorage {
|
|||||||
lhs_stride: &[usize],
|
lhs_stride: &[usize],
|
||||||
rhs_stride: &[usize],
|
rhs_stride: &[usize],
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
|
if dims.len() != lhs_stride.len() || dims.len() != rhs_stride.len() {
|
||||||
|
return Err(CudaError::InternalError("TODO: implement broadcast"));
|
||||||
|
}
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
let dev = self.device();
|
let dev = self.device();
|
||||||
let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
|
let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
|
||||||
|
@ -6,6 +6,10 @@ pub(crate) enum Op {
|
|||||||
Mul(Tensor, Tensor),
|
Mul(Tensor, Tensor),
|
||||||
Sub(Tensor, Tensor),
|
Sub(Tensor, Tensor),
|
||||||
Div(Tensor, Tensor),
|
Div(Tensor, Tensor),
|
||||||
|
BroadcastAdd(Tensor, Tensor),
|
||||||
|
BroadcastMul(Tensor, Tensor),
|
||||||
|
BroadcastSub(Tensor, Tensor),
|
||||||
|
BroadcastDiv(Tensor, Tensor),
|
||||||
Matmul(Tensor, Tensor),
|
Matmul(Tensor, Tensor),
|
||||||
Embedding(Tensor, Tensor),
|
Embedding(Tensor, Tensor),
|
||||||
|
|
||||||
|
@ -91,7 +91,6 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Support broadcasting?
|
|
||||||
pub(crate) fn binary_impl<B: op::BinaryOp>(
|
pub(crate) fn binary_impl<B: op::BinaryOp>(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
|
@ -95,6 +95,34 @@ macro_rules! binary_op {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
macro_rules! broadcast_binary_op {
|
||||||
|
($fn_name:ident, $impl_name:ident, $op_name:ident) => {
|
||||||
|
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||||
|
let shape = self.broadcast_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||||
|
let storage = self.storage.binary_impl::<crate::op::$impl_name>(
|
||||||
|
&rhs.storage,
|
||||||
|
shape,
|
||||||
|
self.stride(),
|
||||||
|
rhs.stride(),
|
||||||
|
)?;
|
||||||
|
let op = if self.track_op() || rhs.track_op() {
|
||||||
|
Some(Op::$op_name(self.clone(), rhs.clone()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
id: TensorId::new(),
|
||||||
|
storage,
|
||||||
|
shape: shape.clone(),
|
||||||
|
stride: shape.stride_contiguous(),
|
||||||
|
op,
|
||||||
|
is_variable: false,
|
||||||
|
};
|
||||||
|
Ok(Self(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
fn ones_impl<S: Into<Shape>>(
|
fn ones_impl<S: Into<Shape>>(
|
||||||
shape: S,
|
shape: S,
|
||||||
@ -210,6 +238,27 @@ impl Tensor {
|
|||||||
Self::new_impl(array, shape.into(), device, true)
|
Self::new_impl(array, shape.into(), device, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn broadcast_shape_binary_op<'a>(
|
||||||
|
&'a self,
|
||||||
|
rhs: &'a Self,
|
||||||
|
op: &'static str,
|
||||||
|
) -> Result<&'a Shape> {
|
||||||
|
let lhs = self;
|
||||||
|
let lhs_dims = lhs.shape().dims();
|
||||||
|
let rhs_dims = rhs.shape().dims();
|
||||||
|
if lhs_dims.strip_suffix(rhs_dims).is_some() {
|
||||||
|
Ok(self.shape())
|
||||||
|
} else if rhs_dims.strip_suffix(lhs_dims).is_some() {
|
||||||
|
Ok(rhs.shape())
|
||||||
|
} else {
|
||||||
|
Err(Error::ShapeMismatchBinaryOp {
|
||||||
|
lhs: self.shape().clone(),
|
||||||
|
rhs: rhs.shape().clone(),
|
||||||
|
op,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||||
let lhs = self.shape();
|
let lhs = self.shape();
|
||||||
let rhs = rhs.shape();
|
let rhs = rhs.shape();
|
||||||
@ -236,6 +285,10 @@ impl Tensor {
|
|||||||
binary_op!(mul, Mul);
|
binary_op!(mul, Mul);
|
||||||
binary_op!(sub, Sub);
|
binary_op!(sub, Sub);
|
||||||
binary_op!(div, Div);
|
binary_op!(div, Div);
|
||||||
|
broadcast_binary_op!(broadcast_add, Add, BroadcastAdd);
|
||||||
|
broadcast_binary_op!(broadcast_mul, Mul, BroadcastMul);
|
||||||
|
broadcast_binary_op!(broadcast_sub, Sub, BroadcastSub);
|
||||||
|
broadcast_binary_op!(broadcast_div, Div, BroadcastDiv);
|
||||||
|
|
||||||
unary_op!(neg, Neg);
|
unary_op!(neg, Neg);
|
||||||
unary_op!(sqr, Sqr);
|
unary_op!(sqr, Sqr);
|
||||||
@ -773,6 +826,10 @@ impl Tensor {
|
|||||||
| Op::Mul(lhs, rhs)
|
| Op::Mul(lhs, rhs)
|
||||||
| Op::Sub(lhs, rhs)
|
| Op::Sub(lhs, rhs)
|
||||||
| Op::Div(lhs, rhs)
|
| Op::Div(lhs, rhs)
|
||||||
|
| Op::BroadcastAdd(lhs, rhs)
|
||||||
|
| Op::BroadcastMul(lhs, rhs)
|
||||||
|
| Op::BroadcastSub(lhs, rhs)
|
||||||
|
| Op::BroadcastDiv(lhs, rhs)
|
||||||
| Op::Embedding(lhs, rhs)
|
| Op::Embedding(lhs, rhs)
|
||||||
| Op::Matmul(lhs, rhs) => {
|
| Op::Matmul(lhs, rhs) => {
|
||||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||||
@ -865,6 +922,26 @@ impl Tensor {
|
|||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
|
Op::BroadcastAdd(_lhs, _rhs) => {
|
||||||
|
return Err(Error::BackwardNotSupported {
|
||||||
|
op: "broadcast_add",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Op::BroadcastSub(_lhs, _rhs) => {
|
||||||
|
return Err(Error::BackwardNotSupported {
|
||||||
|
op: "broadcast_sub",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Op::BroadcastMul(_lhs, _rhs) => {
|
||||||
|
return Err(Error::BackwardNotSupported {
|
||||||
|
op: "broadcast_mul",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Op::BroadcastDiv(_lhs, _rhs) => {
|
||||||
|
return Err(Error::BackwardNotSupported {
|
||||||
|
op: "broadcast_div",
|
||||||
|
})
|
||||||
|
}
|
||||||
Op::Embedding(_lhs, _rhs) => {
|
Op::Embedding(_lhs, _rhs) => {
|
||||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user