Add the copy op. (#227)

* Add the copy op.

* Tweak some cat error messages.

* Handle the contiguous case in to_vec1.

* Fast variant for to_vec2.

* Add add a faster to_vec3 variant.
This commit is contained in:
Laurent Mazare
2023-07-23 19:06:47 +02:00
committed by GitHub
parent 23827c49cd
commit fe87778223
3 changed files with 72 additions and 40 deletions

View File

@ -82,6 +82,7 @@ impl Tensor {
}
}
Op::Reshape(node)
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
| Op::Reduce(node, _, _)
@ -246,6 +247,10 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
}
Op::Copy(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad)?
}
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;
let sum_grad = grads.or_insert(arg)?;