mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Avoid copying the data on squeeze and unsqueeze. (#1884)
* Avoid copying the data on squeeze and unsqueeze. * Fix the quantized llama example. * Unrelated fix for the quantized stable-lm example on cuda. * Fix for mamba on cuda (unrelated to the PR).
This commit is contained in:
@ -121,7 +121,7 @@ impl MambaBlock {
|
||||
// Algorithm 3.2 on page 6, https://arxiv.org/pdf/2312.00752.pdf
|
||||
|
||||
let x_proj = self.x_proj.forward(&proj_for_conv)?;
|
||||
let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?;
|
||||
let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?.contiguous()?;
|
||||
let b = x_proj.narrow(D::Minus1, self.dt_rank, D_STATE)?;
|
||||
let c = x_proj.narrow(D::Minus1, self.dt_rank + D_STATE, D_STATE)?;
|
||||
|
||||
|
@ -512,7 +512,7 @@ impl ModelWeights {
|
||||
layer_in = x
|
||||
}
|
||||
let x = self.norm.forward(&layer_in)?;
|
||||
let x = x.i((.., seq_len - 1, ..))?;
|
||||
let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
|
||||
let _enter = self.span_output.enter();
|
||||
self.output.forward(&x)
|
||||
}
|
||||
|
Reference in New Issue
Block a user