mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Dummy broadcast placeholder functions.
This commit is contained in:
@ -86,6 +86,10 @@ impl CpuStorage {
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
let dims = shape.dims();
|
||||
if dims.len() != lhs_stride.len() || dims.len() != rhs_stride.len() {
|
||||
todo!("implement broadcast");
|
||||
}
|
||||
// The ggml implementation has different paths based on whether the rhs is contiguous
|
||||
// or not, for now we only consider the general case but we should benchmark and do the
|
||||
// same if it helps.
|
||||
|
@ -331,8 +331,11 @@ impl CudaStorage {
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
let elem_count = shape.elem_count();
|
||||
let dims = shape.dims();
|
||||
if dims.len() != lhs_stride.len() || dims.len() != rhs_stride.len() {
|
||||
return Err(CudaError::InternalError("TODO: implement broadcast"));
|
||||
}
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dev = self.device();
|
||||
let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
|
||||
|
@ -6,6 +6,10 @@ pub(crate) enum Op {
|
||||
Mul(Tensor, Tensor),
|
||||
Sub(Tensor, Tensor),
|
||||
Div(Tensor, Tensor),
|
||||
BroadcastAdd(Tensor, Tensor),
|
||||
BroadcastMul(Tensor, Tensor),
|
||||
BroadcastSub(Tensor, Tensor),
|
||||
BroadcastDiv(Tensor, Tensor),
|
||||
Matmul(Tensor, Tensor),
|
||||
Embedding(Tensor, Tensor),
|
||||
|
||||
|
@ -91,7 +91,6 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Support broadcasting?
|
||||
pub(crate) fn binary_impl<B: op::BinaryOp>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
|
@ -95,6 +95,34 @@ macro_rules! binary_op {
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! broadcast_binary_op {
|
||||
($fn_name:ident, $impl_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
let shape = self.broadcast_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
let storage = self.storage.binary_impl::<crate::op::$impl_name>(
|
||||
&rhs.storage,
|
||||
shape,
|
||||
self.stride(),
|
||||
rhs.stride(),
|
||||
)?;
|
||||
let op = if self.track_op() || rhs.track_op() {
|
||||
Some(Op::$op_name(self.clone(), rhs.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
fn ones_impl<S: Into<Shape>>(
|
||||
shape: S,
|
||||
@ -210,6 +238,27 @@ impl Tensor {
|
||||
Self::new_impl(array, shape.into(), device, true)
|
||||
}
|
||||
|
||||
pub(crate) fn broadcast_shape_binary_op<'a>(
|
||||
&'a self,
|
||||
rhs: &'a Self,
|
||||
op: &'static str,
|
||||
) -> Result<&'a Shape> {
|
||||
let lhs = self;
|
||||
let lhs_dims = lhs.shape().dims();
|
||||
let rhs_dims = rhs.shape().dims();
|
||||
if lhs_dims.strip_suffix(rhs_dims).is_some() {
|
||||
Ok(self.shape())
|
||||
} else if rhs_dims.strip_suffix(lhs_dims).is_some() {
|
||||
Ok(rhs.shape())
|
||||
} else {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||
let lhs = self.shape();
|
||||
let rhs = rhs.shape();
|
||||
@ -236,6 +285,10 @@ impl Tensor {
|
||||
binary_op!(mul, Mul);
|
||||
binary_op!(sub, Sub);
|
||||
binary_op!(div, Div);
|
||||
broadcast_binary_op!(broadcast_add, Add, BroadcastAdd);
|
||||
broadcast_binary_op!(broadcast_mul, Mul, BroadcastMul);
|
||||
broadcast_binary_op!(broadcast_sub, Sub, BroadcastSub);
|
||||
broadcast_binary_op!(broadcast_div, Div, BroadcastDiv);
|
||||
|
||||
unary_op!(neg, Neg);
|
||||
unary_op!(sqr, Sqr);
|
||||
@ -773,6 +826,10 @@ impl Tensor {
|
||||
| Op::Mul(lhs, rhs)
|
||||
| Op::Sub(lhs, rhs)
|
||||
| Op::Div(lhs, rhs)
|
||||
| Op::BroadcastAdd(lhs, rhs)
|
||||
| Op::BroadcastMul(lhs, rhs)
|
||||
| Op::BroadcastSub(lhs, rhs)
|
||||
| Op::BroadcastDiv(lhs, rhs)
|
||||
| Op::Embedding(lhs, rhs)
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
@ -865,6 +922,26 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::BroadcastAdd(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported {
|
||||
op: "broadcast_add",
|
||||
})
|
||||
}
|
||||
Op::BroadcastSub(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported {
|
||||
op: "broadcast_sub",
|
||||
})
|
||||
}
|
||||
Op::BroadcastMul(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported {
|
||||
op: "broadcast_mul",
|
||||
})
|
||||
}
|
||||
Op::BroadcastDiv(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported {
|
||||
op: "broadcast_div",
|
||||
})
|
||||
}
|
||||
Op::Embedding(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||
}
|
||||
|
Reference in New Issue
Block a user