mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add the source offset when copying the data around.
This commit is contained in:
@ -317,9 +317,10 @@ impl Tensor {
|
||||
let mut dims = dims.to_vec();
|
||||
dims[dim] = length;
|
||||
let shape = Shape::from(dims);
|
||||
let storage = self.device().zeros(&shape, self.dtype())?;
|
||||
// TODO: Actually copy the data, compared to copy_strided_src this requires a src start
|
||||
// offset as well as a way to specify the number of elements to be copied.
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
let src_offset = 0; // TODO
|
||||
self.storage
|
||||
.copy_strided_src(&mut storage, 0, &self.shape, &self.stride, src_offset)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Narrow(self.clone(), dim, start, length))
|
||||
} else {
|
||||
@ -666,7 +667,7 @@ impl Tensor {
|
||||
let shape = self.shape();
|
||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||
self.storage
|
||||
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
|
||||
.copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?;
|
||||
Ok(from_storage(
|
||||
storage,
|
||||
shape.clone(),
|
||||
@ -709,7 +710,7 @@ impl Tensor {
|
||||
} else {
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
self.storage
|
||||
.copy_strided_src(&mut storage, &self.shape, &self.stride, 0)?;
|
||||
.copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?;
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
}
|
||||
@ -786,7 +787,7 @@ impl Tensor {
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
arg.storage
|
||||
.copy_strided_src(&mut storage, &arg.shape, &arg.stride, offset)?
|
||||
.copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?
|
||||
}
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
Reference in New Issue
Block a user