mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add the copy op. (#227)
* Add the copy op. * Tweak some cat error messages. * Handle the contiguous case in to_vec1. * Fast variant for to_vec2. * Add add a faster to_vec3 variant.
This commit is contained in:
@ -82,6 +82,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
Op::Reshape(node)
|
||||
| Op::Copy(node)
|
||||
| Op::Broadcast(node)
|
||||
| Op::Cmp(node, _)
|
||||
| Op::Reduce(node, _, _)
|
||||
@ -246,6 +247,10 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||
}
|
||||
Op::Copy(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad)?
|
||||
}
|
||||
Op::Affine { arg, mul, .. } => {
|
||||
let arg_grad = grad.affine(*mul, 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -89,6 +89,7 @@ pub(crate) enum Op {
|
||||
add: f64,
|
||||
},
|
||||
ToDType(Tensor),
|
||||
Copy(Tensor),
|
||||
Broadcast(Tensor),
|
||||
Narrow(Tensor, usize, usize, usize),
|
||||
Reshape(Tensor),
|
||||
|
@ -1128,17 +1128,17 @@ impl Tensor {
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
let data = match self.layout.contiguous_offsets() {
|
||||
Some((o1, o2)) => data[o1..o2].to_vec(),
|
||||
None => self.strided_index().map(|i| data[i]).collect(),
|
||||
};
|
||||
Ok::<Vec<_>, Error>(data)
|
||||
};
|
||||
match &*self.storage() {
|
||||
Storage::Cpu(cpu_storage) => {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
Ok(self.strided_index().map(|i| data[i]).collect())
|
||||
}
|
||||
Storage::Cuda(slice) => {
|
||||
// TODO: Would it be possible to only fetch the necessary data?
|
||||
let cpu_storage = slice.to_cpu_storage()?;
|
||||
let data = S::cpu_storage_as_slice(&cpu_storage)?;
|
||||
Ok(self.strided_index().map(|i| data[i]).collect())
|
||||
}
|
||||
Storage::Cpu(storage) => from_cpu_storage(storage),
|
||||
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1148,12 +1148,22 @@ impl Tensor {
|
||||
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
let mut rows = vec![];
|
||||
let mut src_index = self.strided_index();
|
||||
for _idx_row in 0..dim1 {
|
||||
let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
|
||||
rows.push(row)
|
||||
match self.layout.contiguous_offsets() {
|
||||
Some((o1, o2)) => {
|
||||
let data = &data[o1..o2];
|
||||
for idx_row in 0..dim1 {
|
||||
rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec())
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let mut src_index = self.strided_index();
|
||||
for _idx_row in 0..dim1 {
|
||||
let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
|
||||
rows.push(row)
|
||||
}
|
||||
assert!(src_index.next().is_none());
|
||||
}
|
||||
}
|
||||
assert!(src_index.next().is_none());
|
||||
Ok(rows)
|
||||
};
|
||||
match &*self.storage() {
|
||||
@ -1168,16 +1178,32 @@ impl Tensor {
|
||||
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
let mut top_rows = vec![];
|
||||
let mut src_index = self.strided_index();
|
||||
for _idx in 0..dim1 {
|
||||
let mut rows = vec![];
|
||||
for _jdx in 0..dim2 {
|
||||
let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
|
||||
rows.push(row)
|
||||
match self.layout.contiguous_offsets() {
|
||||
Some((o1, o2)) => {
|
||||
let data = &data[o1..o2];
|
||||
let dim23 = dim2 * dim3;
|
||||
for idx1 in 0..dim1 {
|
||||
let data = &data[idx1 * dim23..(idx1 + 1) * dim23];
|
||||
let mut rows = vec![];
|
||||
for idx2 in 0..dim2 {
|
||||
rows.push(data[idx2 * dim3..(idx2 + 1) * dim3].to_vec())
|
||||
}
|
||||
top_rows.push(rows);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let mut src_index = self.strided_index();
|
||||
for _idx in 0..dim1 {
|
||||
let mut rows = vec![];
|
||||
for _jdx in 0..dim2 {
|
||||
let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
|
||||
rows.push(row)
|
||||
}
|
||||
top_rows.push(rows);
|
||||
}
|
||||
assert!(src_index.next().is_none());
|
||||
}
|
||||
top_rows.push(rows);
|
||||
}
|
||||
assert!(src_index.next().is_none());
|
||||
Ok(top_rows)
|
||||
};
|
||||
match &*self.storage() {
|
||||
@ -1404,7 +1430,7 @@ impl Tensor {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
|
||||
layout: self.layout.clone(),
|
||||
op: None, // TODO
|
||||
op: Some(Op::Copy(self.clone())),
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
@ -1540,7 +1566,7 @@ impl Tensor {
|
||||
Ok(from_storage(
|
||||
storage,
|
||||
shape.clone(),
|
||||
None, // TODO
|
||||
Some(Op::Copy(self.clone())),
|
||||
false,
|
||||
))
|
||||
}
|
||||
@ -1734,7 +1760,6 @@ impl Tensor {
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg.dtype() != dtype {
|
||||
// TODO: Improve the error message.
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
@ -1743,7 +1768,6 @@ impl Tensor {
|
||||
.bt())?
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
// TODO: Improve the error message.
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
@ -1751,7 +1775,14 @@ impl Tensor {
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let mut mismatch = arg.rank() != rank;
|
||||
if rank != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: rank,
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
@ -1763,20 +1794,15 @@ impl Tensor {
|
||||
cat_dims[0] += v2;
|
||||
}
|
||||
if dim_idx != 0 && v1 != v2 {
|
||||
// TODO: It would probably be good to have a nicer error message here, i.e.
|
||||
// mention the problematic dimension and the values.
|
||||
mismatch = true;
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
if mismatch {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: 0, // TODO: not the appropriate error message
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
||||
offsets.push(next_offset);
|
||||
}
|
||||
|
Reference in New Issue
Block a user