mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the copy op. (#227)
* Add the copy op. * Tweak some cat error messages. * Handle the contiguous case in to_vec1. * Fast variant for to_vec2. * Add add a faster to_vec3 variant.
This commit is contained in:
@ -82,6 +82,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Op::Reshape(node)
|
Op::Reshape(node)
|
||||||
|
| Op::Copy(node)
|
||||||
| Op::Broadcast(node)
|
| Op::Broadcast(node)
|
||||||
| Op::Cmp(node, _)
|
| Op::Cmp(node, _)
|
||||||
| Op::Reduce(node, _, _)
|
| Op::Reduce(node, _, _)
|
||||||
@ -246,6 +247,10 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||||
}
|
}
|
||||||
|
Op::Copy(arg) => {
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = sum_grad.add(&grad)?
|
||||||
|
}
|
||||||
Op::Affine { arg, mul, .. } => {
|
Op::Affine { arg, mul, .. } => {
|
||||||
let arg_grad = grad.affine(*mul, 0.)?;
|
let arg_grad = grad.affine(*mul, 0.)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
@ -89,6 +89,7 @@ pub(crate) enum Op {
|
|||||||
add: f64,
|
add: f64,
|
||||||
},
|
},
|
||||||
ToDType(Tensor),
|
ToDType(Tensor),
|
||||||
|
Copy(Tensor),
|
||||||
Broadcast(Tensor),
|
Broadcast(Tensor),
|
||||||
Narrow(Tensor, usize, usize, usize),
|
Narrow(Tensor, usize, usize, usize),
|
||||||
Reshape(Tensor),
|
Reshape(Tensor),
|
||||||
|
@ -1128,17 +1128,17 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
.bt())?
|
.bt())?
|
||||||
}
|
}
|
||||||
|
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
|
||||||
|
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||||
|
let data = match self.layout.contiguous_offsets() {
|
||||||
|
Some((o1, o2)) => data[o1..o2].to_vec(),
|
||||||
|
None => self.strided_index().map(|i| data[i]).collect(),
|
||||||
|
};
|
||||||
|
Ok::<Vec<_>, Error>(data)
|
||||||
|
};
|
||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
Storage::Cpu(cpu_storage) => {
|
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||||
Ok(self.strided_index().map(|i| data[i]).collect())
|
|
||||||
}
|
|
||||||
Storage::Cuda(slice) => {
|
|
||||||
// TODO: Would it be possible to only fetch the necessary data?
|
|
||||||
let cpu_storage = slice.to_cpu_storage()?;
|
|
||||||
let data = S::cpu_storage_as_slice(&cpu_storage)?;
|
|
||||||
Ok(self.strided_index().map(|i| data[i]).collect())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1148,12 +1148,22 @@ impl Tensor {
|
|||||||
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
|
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
|
||||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||||
let mut rows = vec![];
|
let mut rows = vec![];
|
||||||
let mut src_index = self.strided_index();
|
match self.layout.contiguous_offsets() {
|
||||||
for _idx_row in 0..dim1 {
|
Some((o1, o2)) => {
|
||||||
let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
|
let data = &data[o1..o2];
|
||||||
rows.push(row)
|
for idx_row in 0..dim1 {
|
||||||
|
rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let mut src_index = self.strided_index();
|
||||||
|
for _idx_row in 0..dim1 {
|
||||||
|
let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
|
||||||
|
rows.push(row)
|
||||||
|
}
|
||||||
|
assert!(src_index.next().is_none());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
assert!(src_index.next().is_none());
|
|
||||||
Ok(rows)
|
Ok(rows)
|
||||||
};
|
};
|
||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
@ -1168,16 +1178,32 @@ impl Tensor {
|
|||||||
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
|
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
|
||||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||||
let mut top_rows = vec![];
|
let mut top_rows = vec![];
|
||||||
let mut src_index = self.strided_index();
|
match self.layout.contiguous_offsets() {
|
||||||
for _idx in 0..dim1 {
|
Some((o1, o2)) => {
|
||||||
let mut rows = vec![];
|
let data = &data[o1..o2];
|
||||||
for _jdx in 0..dim2 {
|
let dim23 = dim2 * dim3;
|
||||||
let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
|
for idx1 in 0..dim1 {
|
||||||
rows.push(row)
|
let data = &data[idx1 * dim23..(idx1 + 1) * dim23];
|
||||||
|
let mut rows = vec![];
|
||||||
|
for idx2 in 0..dim2 {
|
||||||
|
rows.push(data[idx2 * dim3..(idx2 + 1) * dim3].to_vec())
|
||||||
|
}
|
||||||
|
top_rows.push(rows);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let mut src_index = self.strided_index();
|
||||||
|
for _idx in 0..dim1 {
|
||||||
|
let mut rows = vec![];
|
||||||
|
for _jdx in 0..dim2 {
|
||||||
|
let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
|
||||||
|
rows.push(row)
|
||||||
|
}
|
||||||
|
top_rows.push(rows);
|
||||||
|
}
|
||||||
|
assert!(src_index.next().is_none());
|
||||||
}
|
}
|
||||||
top_rows.push(rows);
|
|
||||||
}
|
}
|
||||||
assert!(src_index.next().is_none());
|
|
||||||
Ok(top_rows)
|
Ok(top_rows)
|
||||||
};
|
};
|
||||||
match &*self.storage() {
|
match &*self.storage() {
|
||||||
@ -1404,7 +1430,7 @@ impl Tensor {
|
|||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
|
storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
|
||||||
layout: self.layout.clone(),
|
layout: self.layout.clone(),
|
||||||
op: None, // TODO
|
op: Some(Op::Copy(self.clone())),
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
dtype: self.dtype,
|
dtype: self.dtype,
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
@ -1540,7 +1566,7 @@ impl Tensor {
|
|||||||
Ok(from_storage(
|
Ok(from_storage(
|
||||||
storage,
|
storage,
|
||||||
shape.clone(),
|
shape.clone(),
|
||||||
None, // TODO
|
Some(Op::Copy(self.clone())),
|
||||||
false,
|
false,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
@ -1734,7 +1760,6 @@ impl Tensor {
|
|||||||
for (arg_idx, arg) in args.iter().enumerate() {
|
for (arg_idx, arg) in args.iter().enumerate() {
|
||||||
let arg = arg.as_ref();
|
let arg = arg.as_ref();
|
||||||
if arg.dtype() != dtype {
|
if arg.dtype() != dtype {
|
||||||
// TODO: Improve the error message.
|
|
||||||
Err(Error::DTypeMismatchBinaryOp {
|
Err(Error::DTypeMismatchBinaryOp {
|
||||||
lhs: dtype,
|
lhs: dtype,
|
||||||
rhs: arg.dtype(),
|
rhs: arg.dtype(),
|
||||||
@ -1743,7 +1768,6 @@ impl Tensor {
|
|||||||
.bt())?
|
.bt())?
|
||||||
}
|
}
|
||||||
if arg.device().location() != device.location() {
|
if arg.device().location() != device.location() {
|
||||||
// TODO: Improve the error message.
|
|
||||||
Err(Error::DeviceMismatchBinaryOp {
|
Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: device.location(),
|
lhs: device.location(),
|
||||||
rhs: arg.device().location(),
|
rhs: arg.device().location(),
|
||||||
@ -1751,7 +1775,14 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
.bt())?
|
.bt())?
|
||||||
}
|
}
|
||||||
let mut mismatch = arg.rank() != rank;
|
if rank != arg.rank() {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: rank,
|
||||||
|
got: arg.rank(),
|
||||||
|
shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
for (dim_idx, (v1, v2)) in arg0
|
for (dim_idx, (v1, v2)) in arg0
|
||||||
.shape()
|
.shape()
|
||||||
.dims()
|
.dims()
|
||||||
@ -1763,20 +1794,15 @@ impl Tensor {
|
|||||||
cat_dims[0] += v2;
|
cat_dims[0] += v2;
|
||||||
}
|
}
|
||||||
if dim_idx != 0 && v1 != v2 {
|
if dim_idx != 0 && v1 != v2 {
|
||||||
// TODO: It would probably be good to have a nicer error message here, i.e.
|
Err(Error::ShapeMismatchCat {
|
||||||
// mention the problematic dimension and the values.
|
dim: dim_idx,
|
||||||
mismatch = true;
|
first_shape: arg0.shape().clone(),
|
||||||
|
n: arg_idx + 1,
|
||||||
|
nth_shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if mismatch {
|
|
||||||
Err(Error::ShapeMismatchCat {
|
|
||||||
dim: 0, // TODO: not the appropriate error message
|
|
||||||
first_shape: arg0.shape().clone(),
|
|
||||||
n: arg_idx + 1,
|
|
||||||
nth_shape: arg.shape().clone(),
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
||||||
offsets.push(next_offset);
|
offsets.push(next_offset);
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user