Add the continuous method.

This commit is contained in:
laurent
2023-06-23 10:45:20 +01:00
parent 4712dcc2f6
commit c4c6167949
3 changed files with 29 additions and 4 deletions

View File

@ -12,6 +12,9 @@ pub enum Error {
#[error("{op} expects at least one tensor")] #[error("{op} expects at least one tensor")]
OpRequiresAtLeastOneTensor { op: &'static str }, OpRequiresAtLeastOneTensor { op: &'static str },
#[error("backward is not supported for {op}")]
BackwardNotSupported { op: &'static str },
#[error("the candle crate has not been built with cuda support")] #[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport, NotCompiledWithCudaSupport,

View File

@ -17,6 +17,8 @@ pub(crate) enum Op {
add: f64, add: f64,
}, },
Neg(Tensor), Neg(Tensor),
#[allow(dead_code)]
Reshape(Tensor),
Sqr(Tensor), Sqr(Tensor),
Sqrt(Tensor), Sqrt(Tensor),
ToDevice(Tensor), ToDevice(Tensor),

View File

@ -575,6 +575,26 @@ impl Tensor {
} }
} }
pub fn contiguous(&self) -> Result<Tensor> {
if self.is_contiguous() {
Ok(self.clone())
} else {
let shape = self.shape();
let mut storage = self.device().zeros(shape, self.dtype())?;
self.storage
.copy_strided_src(&mut storage, shape, &self.stride, 0)?;
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape: shape.clone(),
stride: shape.stride_contiguous(),
op: self.op.clone(),
is_variable: self.is_variable,
};
Ok(Tensor(Arc::new(tensor_)))
}
}
pub fn cat(args: &[Self], dim: usize) -> Result<Self> { pub fn cat(args: &[Self], dim: usize) -> Result<Self> {
if args.is_empty() { if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
@ -708,7 +728,8 @@ impl Tensor {
nodes nodes
} }
} }
Op::ToDevice(node) Op::Reshape(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _) | Op::Transpose(node, _, _)
| Op::Sqr(node) | Op::Sqr(node)
| Op::Sqrt(node) | Op::Sqrt(node)
@ -788,9 +809,7 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?; let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
} }
Op::Cat(_args, _dim) => { Op::Cat(_args, _dim) => return Err(Error::BackwardNotSupported { op: "cat" }),
todo!()
}
Op::Affine { arg, mul, .. } => { Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?; let arg_grad = grad.affine(*mul, 0.)?;
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
@ -801,6 +820,7 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)? *sum_grad = sum_grad.add(&arg_grad)?
} }
Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }),
Op::Sqr(arg) => { Op::Sqr(arg) => {
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?; let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;