Avoid copying the data on squeeze and unsqueeze. (#1884)

* Avoid copying the data on squeeze and unsqueeze.

* Fix the quantized llama example.

* Unrelated fix for the quantized stable-lm example on cuda.

* Fix for mamba on cuda (unrelated to the PR).
This commit is contained in:
Laurent Mazare
2024-03-20 13:04:36 +01:00
committed by GitHub
parent 2a8679509e
commit 455c42aa72
5 changed files with 47 additions and 8 deletions

View File

@ -121,7 +121,7 @@ impl MambaBlock {
// Algorithm 3.2 on page 6, https://arxiv.org/pdf/2312.00752.pdf
let x_proj = self.x_proj.forward(&proj_for_conv)?;
let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?;
let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?.contiguous()?;
let b = x_proj.narrow(D::Minus1, self.dt_rank, D_STATE)?;
let c = x_proj.narrow(D::Minus1, self.dt_rank + D_STATE, D_STATE)?;

View File

@ -512,7 +512,7 @@ impl ModelWeights {
layer_in = x
}
let x = self.norm.forward(&layer_in)?;
let x = x.i((.., seq_len - 1, ..))?;
let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
let _enter = self.span_output.enter();
self.output.forward(&x)
}