Add the index-select op. (#209)

* Add the index-select op.

* Cpu implementation of index-select.

* Add the cpu implementation for index-select.
This commit is contained in:
Laurent Mazare
2023-07-20 15:01:03 +02:00
committed by GitHub
parent 2a8f28d687
commit fa08fb3126
10 changed files with 168 additions and 20 deletions

View File

@ -40,6 +40,7 @@ impl Tensor {
..
}
| Op::Binary(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
| Op::Embedding(lhs, rhs)
| Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
@ -143,9 +144,12 @@ impl Tensor {
let f_grad = pred.where_cond(&zeros, &grad)?;
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }),
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::IndexSelect(_lhs, _rhs, _) => {
Err(Error::BackwardNotSupported { op: "index-select" })?
}
Op::Embedding(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported { op: "embedding" })
Err(Error::BackwardNotSupported { op: "embedding" })?
}
Op::Matmul(lhs, rhs) => {
// Skipping checks, the op went ok, we can skip
@ -195,10 +199,10 @@ impl Tensor {
}
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
Op::Reduce(_args, ReduceOp::Max, _) => {
return Err(Error::BackwardNotSupported { op: "max" })
Err(Error::BackwardNotSupported { op: "max" })?
}
Op::Reduce(_args, ReduceOp::Min, _) => {
return Err(Error::BackwardNotSupported { op: "min" })
Err(Error::BackwardNotSupported { op: "min" })?
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
@ -221,9 +225,7 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
}
Op::Unary(_, UnaryOp::Abs) => {
return Err(Error::BackwardNotSupported { op: "abs" })
}
Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?,
Op::Unary(arg, UnaryOp::Exp) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad / arg)?)?
@ -258,21 +260,15 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Softmax(_arg, _) => {
return Err(Error::BackwardNotSupported { op: "softmax" })
}
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Gelu) => {
return Err(Error::BackwardNotSupported { op: "gelu" })
}
Op::Unary(_, UnaryOp::Relu) => {
return Err(Error::BackwardNotSupported { op: "relu" })
}
Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }),
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?,
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::Unary(arg, UnaryOp::Sqr) => {
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
let sum_grad = grads.or_insert(arg)?;