Avoid copying the data on squeeze and unsqueeze. (#1884)

* Avoid copying the data on squeeze and unsqueeze.

* Fix the quantized llama example.

* Unrelated fix for the quantized stable-lm example on cuda.

* Fix for mamba on cuda (unrelated to the PR).
This commit is contained in:
Laurent Mazare
2024-03-20 13:04:36 +01:00
committed by GitHub
parent 2a8679509e
commit 455c42aa72
5 changed files with 47 additions and 8 deletions

View File

@ -2093,8 +2093,19 @@ impl Tensor {
let dim = dim.to_index(self.shape(), "squeeze")?;
if dims[dim] == 1 {
let mut dims = dims.to_vec();
let mut strides = self.stride().to_vec();
dims.remove(dim);
self.reshape(dims)
strides.remove(dim);
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
op: BackpropOp::new1(self, Op::Reshape),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
} else {
Ok(self.clone())
}
@ -2115,10 +2126,24 @@ impl Tensor {
/// ```
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
let mut dims = self.dims().to_vec();
let mut strides = self.stride().to_vec();
let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
// Cannot panic because to_index_plus_one already checks dimensions
dims.insert(dim, 1);
self.reshape(dims)
// Any stride would work here, but we pick one so as to maximize the probability to remain
// C contiguous.
let stride = if dim < strides.len() { strides[dim] } else { 1 };
strides.insert(dim, stride);
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
op: BackpropOp::new1(self, Op::Reshape),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
/// Stacks two or more tensors along a particular dimension.

View File

@ -88,7 +88,7 @@ fn strided_blocks() -> Result<()> {
}
};
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
let tensor = tensor.i((.., 1))?;
let tensor = tensor.i((.., 1))?.contiguous()?;
match tensor.strided_blocks() {
candle::StridedBlocks::SingleBlock { start_offset, len } => {
assert_eq!(start_offset, 0);
@ -100,6 +100,20 @@ fn strided_blocks() -> Result<()> {
}
};
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
let tensor = tensor.i((.., 1))?;
match tensor.strided_blocks() {
candle::StridedBlocks::SingleBlock { .. } => {
panic!("unexpected block structure")
}
candle::StridedBlocks::MultipleBlocks {
block_len,
block_start_index,
} => {
assert_eq!(block_len, 4);
assert_eq!(block_start_index.collect::<Vec<_>>(), &[4, 16])
}
};
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
match tensor.t()?.strided_blocks() {
candle::StridedBlocks::SingleBlock { .. } => {
panic!("unexpected block structure")

View File

@ -288,12 +288,12 @@ fn main() -> Result<()> {
};
let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized {
let model = if args.quantized {
let filename = &filenames[0];
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QStableLM::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
Model::Quantized(model)
} else {
let dtype = if device.is_cuda() {
DType::BF16
@ -302,7 +302,7 @@ fn main() -> Result<()> {
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = StableLM::new(&config, vb)?;
(Model::StableLM(model), device)
Model::StableLM(model)
};
println!("loaded the model in {:?}", start.elapsed());

View File

@ -121,7 +121,7 @@ impl MambaBlock {
// Algorithm 3.2 on page 6, https://arxiv.org/pdf/2312.00752.pdf
let x_proj = self.x_proj.forward(&proj_for_conv)?;
let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?;
let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?.contiguous()?;
let b = x_proj.narrow(D::Minus1, self.dt_rank, D_STATE)?;
let c = x_proj.narrow(D::Minus1, self.dt_rank + D_STATE, D_STATE)?;

View File

@ -512,7 +512,7 @@ impl ModelWeights {
layer_in = x
}
let x = self.norm.forward(&layer_in)?;
let x = x.i((.., seq_len - 1, ..))?;
let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
let _enter = self.span_output.enter();
self.output.forward(&x)
}