mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Avoid copying the data on squeeze and unsqueeze. (#1884)
* Avoid copying the data on squeeze and unsqueeze. * Fix the quantized llama example. * Unrelated fix for the quantized stable-lm example on cuda. * Fix for mamba on cuda (unrelated to the PR).
This commit is contained in:
@ -2093,8 +2093,19 @@ impl Tensor {
|
||||
let dim = dim.to_index(self.shape(), "squeeze")?;
|
||||
if dims[dim] == 1 {
|
||||
let mut dims = dims.to_vec();
|
||||
let mut strides = self.stride().to_vec();
|
||||
dims.remove(dim);
|
||||
self.reshape(dims)
|
||||
strides.remove(dim);
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
|
||||
op: BackpropOp::new1(self, Op::Reshape),
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
} else {
|
||||
Ok(self.clone())
|
||||
}
|
||||
@ -2115,10 +2126,24 @@ impl Tensor {
|
||||
/// ```
|
||||
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let mut dims = self.dims().to_vec();
|
||||
let mut strides = self.stride().to_vec();
|
||||
let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
|
||||
// Cannot panic because to_index_plus_one already checks dimensions
|
||||
dims.insert(dim, 1);
|
||||
self.reshape(dims)
|
||||
// Any stride would work here, but we pick one so as to maximize the probability to remain
|
||||
// C contiguous.
|
||||
let stride = if dim < strides.len() { strides[dim] } else { 1 };
|
||||
strides.insert(dim, stride);
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
|
||||
op: BackpropOp::new1(self, Op::Reshape),
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// Stacks two or more tensors along a particular dimension.
|
||||
|
Reference in New Issue
Block a user