mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Starcoder fix (#264)
* Bugfix for starcoder. * Get some proper code generation. * Slightly simpler softmax.
This commit is contained in:
@ -705,7 +705,8 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
|
||||
expected: 1,
|
||||
got: d.len(),
|
||||
shape: self.ids_l.shape().clone(),
|
||||
})?,
|
||||
}
|
||||
.bt())?,
|
||||
};
|
||||
let stride_ids = self.ids_l.stride()[0];
|
||||
let mut dst_dims = layout.dims().to_vec();
|
||||
|
@ -127,7 +127,7 @@ fn main() -> Result<()> {
|
||||
));
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let filenames = match args.weight_file {
|
||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file.clone())],
|
||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||
None => {
|
||||
let repo_filenames: Vec<String> = vec![];
|
||||
repo_filenames
|
||||
|
@ -24,12 +24,22 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
|
||||
fn make_causal_mask(t: usize) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
||||
.flat_map(|i| (0..t).map(move |j| u32::from(j <= i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
|
||||
Ok(mask)
|
||||
}
|
||||
|
||||
// TODO: Use a numerically stable implementation by default.
|
||||
fn softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
|
||||
let d = d.to_index(xs.shape(), "log-softmax")?;
|
||||
let max = xs.max_keepdim(d)?;
|
||||
let diff = xs.broadcast_sub(&max)?;
|
||||
let num = diff.exp()?;
|
||||
let den = num.sum_keepdim(d)?;
|
||||
num.broadcast_div(&den)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
@ -176,19 +186,21 @@ impl Attention {
|
||||
(query, key, attn_shape, attn_view)
|
||||
};
|
||||
|
||||
let attn_weights = (query.matmul(&key)? * scale_factor)?.reshape(attn_shape)?;
|
||||
let attn_weights =
|
||||
(query.matmul(&key.contiguous()?)? * scale_factor)?.reshape(attn_shape)?;
|
||||
let attention_mask = attention_mask.broadcast_as(attn_shape)?;
|
||||
let mask_value =
|
||||
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
|
||||
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
|
||||
let attn_weights = attn_weights.softmax(D::Minus1)?;
|
||||
let attn_weights = softmax(&attn_weights, D::Minus1)?;
|
||||
let value = value.contiguous()?;
|
||||
let attn_output = if self.multi_query {
|
||||
attn_weights
|
||||
.reshape(attn_view)?
|
||||
.matmul(value)?
|
||||
.matmul(&value)?
|
||||
.reshape(initial_query_shape)?
|
||||
} else {
|
||||
attn_weights.matmul(value)?
|
||||
attn_weights.matmul(&value)?
|
||||
};
|
||||
Ok(attn_output)
|
||||
}
|
||||
|
Reference in New Issue
Block a user