Dummy broadcast placeholder functions.

This commit is contained in:
laurent
2023-06-23 14:07:05 +01:00
parent f8848db001
commit 92da45879c
5 changed files with 89 additions and 2 deletions

View File

@ -86,6 +86,10 @@ impl CpuStorage {
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
let dims = shape.dims();
if dims.len() != lhs_stride.len() || dims.len() != rhs_stride.len() {
todo!("implement broadcast");
}
// The ggml implementation has different paths based on whether the rhs is contiguous
// or not, for now we only consider the general case but we should benchmark and do the
// same if it helps.

View File

@ -331,8 +331,11 @@ impl CudaStorage {
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
let elem_count = shape.elem_count();
let dims = shape.dims();
if dims.len() != lhs_stride.len() || dims.len() != rhs_stride.len() {
return Err(CudaError::InternalError("TODO: implement broadcast"));
}
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dev = self.device();
let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?;

View File

@ -6,6 +6,10 @@ pub(crate) enum Op {
Mul(Tensor, Tensor),
Sub(Tensor, Tensor),
Div(Tensor, Tensor),
BroadcastAdd(Tensor, Tensor),
BroadcastMul(Tensor, Tensor),
BroadcastSub(Tensor, Tensor),
BroadcastDiv(Tensor, Tensor),
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),

View File

@ -91,7 +91,6 @@ impl Storage {
}
}
// TODO: Support broadcasting?
pub(crate) fn binary_impl<B: op::BinaryOp>(
&self,
rhs: &Self,

View File

@ -95,6 +95,34 @@ macro_rules! binary_op {
};
}
macro_rules! broadcast_binary_op {
($fn_name:ident, $impl_name:ident, $op_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
let shape = self.broadcast_shape_binary_op(rhs, stringify!($fn_name))?;
let storage = self.storage.binary_impl::<crate::op::$impl_name>(
&rhs.storage,
shape,
self.stride(),
rhs.stride(),
)?;
let op = if self.track_op() || rhs.track_op() {
Some(Op::$op_name(self.clone(), rhs.clone()))
} else {
None
};
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
op,
is_variable: false,
};
Ok(Self(Arc::new(tensor_)))
}
};
}
impl Tensor {
fn ones_impl<S: Into<Shape>>(
shape: S,
@ -210,6 +238,27 @@ impl Tensor {
Self::new_impl(array, shape.into(), device, true)
}
pub(crate) fn broadcast_shape_binary_op<'a>(
&'a self,
rhs: &'a Self,
op: &'static str,
) -> Result<&'a Shape> {
let lhs = self;
let lhs_dims = lhs.shape().dims();
let rhs_dims = rhs.shape().dims();
if lhs_dims.strip_suffix(rhs_dims).is_some() {
Ok(self.shape())
} else if rhs_dims.strip_suffix(lhs_dims).is_some() {
Ok(rhs.shape())
} else {
Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: rhs.shape().clone(),
op,
})
}
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
let lhs = self.shape();
let rhs = rhs.shape();
@ -236,6 +285,10 @@ impl Tensor {
binary_op!(mul, Mul);
binary_op!(sub, Sub);
binary_op!(div, Div);
broadcast_binary_op!(broadcast_add, Add, BroadcastAdd);
broadcast_binary_op!(broadcast_mul, Mul, BroadcastMul);
broadcast_binary_op!(broadcast_sub, Sub, BroadcastSub);
broadcast_binary_op!(broadcast_div, Div, BroadcastDiv);
unary_op!(neg, Neg);
unary_op!(sqr, Sqr);
@ -773,6 +826,10 @@ impl Tensor {
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
| Op::Div(lhs, rhs)
| Op::BroadcastAdd(lhs, rhs)
| Op::BroadcastMul(lhs, rhs)
| Op::BroadcastSub(lhs, rhs)
| Op::BroadcastDiv(lhs, rhs)
| Op::Embedding(lhs, rhs)
| Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
@ -865,6 +922,26 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::BroadcastAdd(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported {
op: "broadcast_add",
})
}
Op::BroadcastSub(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported {
op: "broadcast_sub",
})
}
Op::BroadcastMul(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported {
op: "broadcast_mul",
})
}
Op::BroadcastDiv(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported {
op: "broadcast_div",
})
}
Op::Embedding(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported { op: "embedding" })
}