Starcoder fix (#264)

* Bugfix for starcoder.

* Get some proper code generation.

* Slightly simpler softmax.
This commit is contained in:
Laurent Mazare
2023-07-28 11:17:49 +01:00
committed by GitHub
parent 6a54ca115e
commit 3e89df938c
3 changed files with 20 additions and 7 deletions

View File

@ -705,7 +705,8 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
expected: 1,
got: d.len(),
shape: self.ids_l.shape().clone(),
})?,
}
.bt())?,
};
let stride_ids = self.ids_l.stride()[0];
let mut dst_dims = layout.dims().to_vec();

View File

@ -127,7 +127,7 @@ fn main() -> Result<()> {
));
let tokenizer_filename = repo.get("tokenizer.json")?;
let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file.clone())],
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => {
let repo_filenames: Vec<String> = vec![];
repo_filenames

View File

@ -24,12 +24,22 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
fn make_causal_mask(t: usize) -> Result<Tensor> {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
.flat_map(|i| (0..t).map(move |j| u32::from(j <= i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
Ok(mask)
}
// TODO: Use a numerically stable implementation by default.
fn softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
let d = d.to_index(xs.shape(), "log-softmax")?;
let max = xs.max_keepdim(d)?;
let diff = xs.broadcast_sub(&max)?;
let num = diff.exp()?;
let den = num.sum_keepdim(d)?;
num.broadcast_div(&den)
}
#[derive(Debug)]
pub struct Config {
pub vocab_size: usize,
@ -176,19 +186,21 @@ impl Attention {
(query, key, attn_shape, attn_view)
};
let attn_weights = (query.matmul(&key)? * scale_factor)?.reshape(attn_shape)?;
let attn_weights =
(query.matmul(&key.contiguous()?)? * scale_factor)?.reshape(attn_shape)?;
let attention_mask = attention_mask.broadcast_as(attn_shape)?;
let mask_value =
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
let attn_weights = attn_weights.softmax(D::Minus1)?;
let attn_weights = softmax(&attn_weights, D::Minus1)?;
let value = value.contiguous()?;
let attn_output = if self.multi_query {
attn_weights
.reshape(attn_view)?
.matmul(value)?
.matmul(&value)?
.reshape(initial_query_shape)?
} else {
attn_weights.matmul(value)?
attn_weights.matmul(&value)?
};
Ok(attn_output)
}