diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 22cd4950..feab30c8 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2093,8 +2093,19 @@ impl Tensor { let dim = dim.to_index(self.shape(), "squeeze")?; if dims[dim] == 1 { let mut dims = dims.to_vec(); + let mut strides = self.stride().to_vec(); dims.remove(dim); - self.reshape(dims) + strides.remove(dim); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: Layout::new(dims.into(), strides, self.layout.start_offset()), + op: BackpropOp::new1(self, Op::Reshape), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) } else { Ok(self.clone()) } @@ -2115,10 +2126,24 @@ impl Tensor { /// ``` pub fn unsqueeze(&self, dim: D) -> Result { let mut dims = self.dims().to_vec(); + let mut strides = self.stride().to_vec(); let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?; // Cannot panic because to_index_plus_one already checks dimensions dims.insert(dim, 1); - self.reshape(dims) + // Any stride would work here, but we pick one so as to maximize the probability to remain + // C contiguous. + let stride = if dim < strides.len() { strides[dim] } else { 1 }; + strides.insert(dim, stride); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: Layout::new(dims.into(), strides, self.layout.start_offset()), + op: BackpropOp::new1(self, Op::Reshape), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) } /// Stacks two or more tensors along a particular dimension. diff --git a/candle-core/tests/layout_tests.rs b/candle-core/tests/layout_tests.rs index e0618850..bc67f7de 100644 --- a/candle-core/tests/layout_tests.rs +++ b/candle-core/tests/layout_tests.rs @@ -88,7 +88,7 @@ fn strided_blocks() -> Result<()> { } }; let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; - let tensor = tensor.i((.., 1))?; + let tensor = tensor.i((.., 1))?.contiguous()?; match tensor.strided_blocks() { candle::StridedBlocks::SingleBlock { start_offset, len } => { assert_eq!(start_offset, 0); @@ -100,6 +100,20 @@ fn strided_blocks() -> Result<()> { } }; let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; + let tensor = tensor.i((.., 1))?; + match tensor.strided_blocks() { + candle::StridedBlocks::SingleBlock { .. } => { + panic!("unexpected block structure") + } + candle::StridedBlocks::MultipleBlocks { + block_len, + block_start_index, + } => { + assert_eq!(block_len, 4); + assert_eq!(block_start_index.collect::>(), &[4, 16]) + } + }; + let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; match tensor.t()?.strided_blocks() { candle::StridedBlocks::SingleBlock { .. } => { panic!("unexpected block structure") diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs index f467903a..f0707010 100644 --- a/candle-examples/examples/stable-lm/main.rs +++ b/candle-examples/examples/stable-lm/main.rs @@ -288,12 +288,12 @@ fn main() -> Result<()> { }; let device = candle_examples::device(args.cpu)?; - let (model, device) = if args.quantized { + let model = if args.quantized { let filename = &filenames[0]; let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; let model = QStableLM::new(&config, vb)?; - (Model::Quantized(model), Device::Cpu) + Model::Quantized(model) } else { let dtype = if device.is_cuda() { DType::BF16 @@ -302,7 +302,7 @@ fn main() -> Result<()> { }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let model = StableLM::new(&config, vb)?; - (Model::StableLM(model), device) + Model::StableLM(model) }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs index 81828ad5..597dd2cd 100644 --- a/candle-transformers/src/models/mamba.rs +++ b/candle-transformers/src/models/mamba.rs @@ -121,7 +121,7 @@ impl MambaBlock { // Algorithm 3.2 on page 6, https://arxiv.org/pdf/2312.00752.pdf let x_proj = self.x_proj.forward(&proj_for_conv)?; - let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?; + let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?.contiguous()?; let b = x_proj.narrow(D::Minus1, self.dt_rank, D_STATE)?; let c = x_proj.narrow(D::Minus1, self.dt_rank + D_STATE, D_STATE)?; diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 94324149..5ce2de59 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -512,7 +512,7 @@ impl ModelWeights { layer_in = x } let x = self.norm.forward(&layer_in)?; - let x = x.i((.., seq_len - 1, ..))?; + let x = x.i((.., seq_len - 1, ..))?.contiguous()?; let _enter = self.span_output.enter(); self.output.forward(&x) }