Skeleton implementation for the narrow method and op.

This commit is contained in:
laurent
2023-06-24 08:17:35 +01:00
parent 3deacba5f9
commit dd657397b2
2 changed files with 39 additions and 2 deletions

View File

@ -27,9 +27,9 @@ pub(crate) enum Op {
Sin(Tensor),
Cos(Tensor),
Abs(Tensor),
Narrow(Tensor, usize, usize, usize),
Neg(Tensor),
Reshape(Tensor),
#[allow(dead_code)]
Softmax(Tensor, usize),
Sqr(Tensor),
Sqrt(Tensor),

View File

@ -298,6 +298,36 @@ impl Tensor {
Ok(from_storage(storage, shape.clone(), op, false))
}
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
/// ranges from `start` to `start + length`.
// TODO: Once we've refactor the shape and strides, make this return a view of the same data
// rather than copying.
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
let dims = self.shape().dims();
if dim >= dims.len() {
return Err(Error::UnexpectedNumberOfDims {
expected: dim + 1,
got: dims.len(),
shape: self.shape().clone(),
});
}
if start + length > dims[dim] {
todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}")
}
let mut dims = dims.to_vec();
dims[dim] = length;
let shape = Shape::from(dims);
let storage = self.device().zeros(&shape, self.dtype())?;
// TODO: Actually copy the data, compared to copy_strided_src this requires a src start
// offset as well as a way to specify the number of elements to be copied.
let op = if self.track_op() {
Some(Op::Narrow(self.clone(), dim, start, length))
} else {
None
};
Ok(from_storage(storage, shape, op, false))
}
pub fn softmax(&self, dim: usize) -> Result<Self> {
let shape = self.shape();
let mut storage = self
@ -817,6 +847,7 @@ impl Tensor {
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Narrow(node, _, _, _)
| Op::Softmax(node, _)
| Op::Sqr(node)
| Op::Sqrt(node)
@ -933,7 +964,10 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Cat(_args, _dim) => return Err(Error::BackwardNotSupported { op: "cat" }),
Op::Cat(_args, _dim) => {
// TODO: Use narrow here.
return Err(Error::BackwardNotSupported { op: "cat" });
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
@ -964,6 +998,9 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&grad)?
}
Op::Narrow(_arg, _, _, _) => {
return Err(Error::BackwardNotSupported { op: "narrow" })
}
Op::Softmax(_arg, _) => {
return Err(Error::BackwardNotSupported { op: "softmax" })
}