mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add a currently wrong test for narrow.
This commit is contained in:
@ -300,7 +300,7 @@ impl Tensor {
|
||||
|
||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||
/// ranges from `start` to `start + length`.
|
||||
// TODO: Once we've refactor the shape and strides, make this return a view of the same data
|
||||
// TODO: Once we've refactored the shape and strides, make this return a view of the same data
|
||||
// rather than copying.
|
||||
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
|
||||
let dims = self.shape().dims();
|
||||
@ -318,7 +318,8 @@ impl Tensor {
|
||||
dims[dim] = length;
|
||||
let shape = Shape::from(dims);
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
let src_offset = 0; // TODO
|
||||
let src_offset = self.stride[dim] * start;
|
||||
// TODO: This is incorrect, see the currently wrong test in tensor_tests.rs
|
||||
self.storage
|
||||
.copy_strided_src(&mut storage, 0, &self.shape, &self.stride, src_offset)?;
|
||||
let op = if self.track_op() {
|
||||
|
Reference in New Issue
Block a user