mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the continuous method.
This commit is contained in:
@ -12,6 +12,9 @@ pub enum Error {
|
||||
#[error("{op} expects at least one tensor")]
|
||||
OpRequiresAtLeastOneTensor { op: &'static str },
|
||||
|
||||
#[error("backward is not supported for {op}")]
|
||||
BackwardNotSupported { op: &'static str },
|
||||
|
||||
#[error("the candle crate has not been built with cuda support")]
|
||||
NotCompiledWithCudaSupport,
|
||||
|
||||
|
@ -17,6 +17,8 @@ pub(crate) enum Op {
|
||||
add: f64,
|
||||
},
|
||||
Neg(Tensor),
|
||||
#[allow(dead_code)]
|
||||
Reshape(Tensor),
|
||||
Sqr(Tensor),
|
||||
Sqrt(Tensor),
|
||||
ToDevice(Tensor),
|
||||
|
@ -575,6 +575,26 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contiguous(&self) -> Result<Tensor> {
|
||||
if self.is_contiguous() {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||
self.storage
|
||||
.copy_strided_src(&mut storage, shape, &self.stride, 0)?;
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
shape: shape.clone(),
|
||||
stride: shape.stride_contiguous(),
|
||||
op: self.op.clone(),
|
||||
is_variable: self.is_variable,
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cat(args: &[Self], dim: usize) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||
@ -708,7 +728,8 @@ impl Tensor {
|
||||
nodes
|
||||
}
|
||||
}
|
||||
Op::ToDevice(node)
|
||||
Op::Reshape(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Sqr(node)
|
||||
| Op::Sqrt(node)
|
||||
@ -788,9 +809,7 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Cat(_args, _dim) => {
|
||||
todo!()
|
||||
}
|
||||
Op::Cat(_args, _dim) => return Err(Error::BackwardNotSupported { op: "cat" }),
|
||||
Op::Affine { arg, mul, .. } => {
|
||||
let arg_grad = grad.affine(*mul, 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
@ -801,6 +820,7 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }),
|
||||
Op::Sqr(arg) => {
|
||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
Reference in New Issue
Block a user