mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Starcoder fix (#264)
* Bugfix for starcoder. * Get some proper code generation. * Slightly simpler softmax.
This commit is contained in:
@ -705,7 +705,8 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
|
|||||||
expected: 1,
|
expected: 1,
|
||||||
got: d.len(),
|
got: d.len(),
|
||||||
shape: self.ids_l.shape().clone(),
|
shape: self.ids_l.shape().clone(),
|
||||||
})?,
|
}
|
||||||
|
.bt())?,
|
||||||
};
|
};
|
||||||
let stride_ids = self.ids_l.stride()[0];
|
let stride_ids = self.ids_l.stride()[0];
|
||||||
let mut dst_dims = layout.dims().to_vec();
|
let mut dst_dims = layout.dims().to_vec();
|
||||||
|
@ -127,7 +127,7 @@ fn main() -> Result<()> {
|
|||||||
));
|
));
|
||||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||||
let filenames = match args.weight_file {
|
let filenames = match args.weight_file {
|
||||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file.clone())],
|
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||||
None => {
|
None => {
|
||||||
let repo_filenames: Vec<String> = vec![];
|
let repo_filenames: Vec<String> = vec![];
|
||||||
repo_filenames
|
repo_filenames
|
||||||
|
@ -24,12 +24,22 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
|||||||
|
|
||||||
fn make_causal_mask(t: usize) -> Result<Tensor> {
|
fn make_causal_mask(t: usize) -> Result<Tensor> {
|
||||||
let mask: Vec<_> = (0..t)
|
let mask: Vec<_> = (0..t)
|
||||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
.flat_map(|i| (0..t).map(move |j| u32::from(j <= i)))
|
||||||
.collect();
|
.collect();
|
||||||
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
|
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
|
||||||
Ok(mask)
|
Ok(mask)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Use a numerically stable implementation by default.
|
||||||
|
fn softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
|
||||||
|
let d = d.to_index(xs.shape(), "log-softmax")?;
|
||||||
|
let max = xs.max_keepdim(d)?;
|
||||||
|
let diff = xs.broadcast_sub(&max)?;
|
||||||
|
let num = diff.exp()?;
|
||||||
|
let den = num.sum_keepdim(d)?;
|
||||||
|
num.broadcast_div(&den)
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub vocab_size: usize,
|
pub vocab_size: usize,
|
||||||
@ -176,19 +186,21 @@ impl Attention {
|
|||||||
(query, key, attn_shape, attn_view)
|
(query, key, attn_shape, attn_view)
|
||||||
};
|
};
|
||||||
|
|
||||||
let attn_weights = (query.matmul(&key)? * scale_factor)?.reshape(attn_shape)?;
|
let attn_weights =
|
||||||
|
(query.matmul(&key.contiguous()?)? * scale_factor)?.reshape(attn_shape)?;
|
||||||
let attention_mask = attention_mask.broadcast_as(attn_shape)?;
|
let attention_mask = attention_mask.broadcast_as(attn_shape)?;
|
||||||
let mask_value =
|
let mask_value =
|
||||||
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
|
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
|
||||||
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
|
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
|
||||||
let attn_weights = attn_weights.softmax(D::Minus1)?;
|
let attn_weights = softmax(&attn_weights, D::Minus1)?;
|
||||||
|
let value = value.contiguous()?;
|
||||||
let attn_output = if self.multi_query {
|
let attn_output = if self.multi_query {
|
||||||
attn_weights
|
attn_weights
|
||||||
.reshape(attn_view)?
|
.reshape(attn_view)?
|
||||||
.matmul(value)?
|
.matmul(&value)?
|
||||||
.reshape(initial_query_shape)?
|
.reshape(initial_query_shape)?
|
||||||
} else {
|
} else {
|
||||||
attn_weights.matmul(value)?
|
attn_weights.matmul(&value)?
|
||||||
};
|
};
|
||||||
Ok(attn_output)
|
Ok(attn_output)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user