mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the continuous method.
This commit is contained in:
@ -12,6 +12,9 @@ pub enum Error {
|
|||||||
#[error("{op} expects at least one tensor")]
|
#[error("{op} expects at least one tensor")]
|
||||||
OpRequiresAtLeastOneTensor { op: &'static str },
|
OpRequiresAtLeastOneTensor { op: &'static str },
|
||||||
|
|
||||||
|
#[error("backward is not supported for {op}")]
|
||||||
|
BackwardNotSupported { op: &'static str },
|
||||||
|
|
||||||
#[error("the candle crate has not been built with cuda support")]
|
#[error("the candle crate has not been built with cuda support")]
|
||||||
NotCompiledWithCudaSupport,
|
NotCompiledWithCudaSupport,
|
||||||
|
|
||||||
|
@ -17,6 +17,8 @@ pub(crate) enum Op {
|
|||||||
add: f64,
|
add: f64,
|
||||||
},
|
},
|
||||||
Neg(Tensor),
|
Neg(Tensor),
|
||||||
|
#[allow(dead_code)]
|
||||||
|
Reshape(Tensor),
|
||||||
Sqr(Tensor),
|
Sqr(Tensor),
|
||||||
Sqrt(Tensor),
|
Sqrt(Tensor),
|
||||||
ToDevice(Tensor),
|
ToDevice(Tensor),
|
||||||
|
@ -575,6 +575,26 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn contiguous(&self) -> Result<Tensor> {
|
||||||
|
if self.is_contiguous() {
|
||||||
|
Ok(self.clone())
|
||||||
|
} else {
|
||||||
|
let shape = self.shape();
|
||||||
|
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||||
|
self.storage
|
||||||
|
.copy_strided_src(&mut storage, shape, &self.stride, 0)?;
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
id: TensorId::new(),
|
||||||
|
storage,
|
||||||
|
shape: shape.clone(),
|
||||||
|
stride: shape.stride_contiguous(),
|
||||||
|
op: self.op.clone(),
|
||||||
|
is_variable: self.is_variable,
|
||||||
|
};
|
||||||
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn cat(args: &[Self], dim: usize) -> Result<Self> {
|
pub fn cat(args: &[Self], dim: usize) -> Result<Self> {
|
||||||
if args.is_empty() {
|
if args.is_empty() {
|
||||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||||
@ -708,7 +728,8 @@ impl Tensor {
|
|||||||
nodes
|
nodes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Op::ToDevice(node)
|
Op::Reshape(node)
|
||||||
|
| Op::ToDevice(node)
|
||||||
| Op::Transpose(node, _, _)
|
| Op::Transpose(node, _, _)
|
||||||
| Op::Sqr(node)
|
| Op::Sqr(node)
|
||||||
| Op::Sqrt(node)
|
| Op::Sqrt(node)
|
||||||
@ -788,9 +809,7 @@ impl Tensor {
|
|||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
Op::Cat(_args, _dim) => {
|
Op::Cat(_args, _dim) => return Err(Error::BackwardNotSupported { op: "cat" }),
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
Op::Affine { arg, mul, .. } => {
|
Op::Affine { arg, mul, .. } => {
|
||||||
let arg_grad = grad.affine(*mul, 0.)?;
|
let arg_grad = grad.affine(*mul, 0.)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
@ -801,6 +820,7 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
|
Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }),
|
||||||
Op::Sqr(arg) => {
|
Op::Sqr(arg) => {
|
||||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
Reference in New Issue
Block a user