mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Cuda fix for starcoder. (#266)
* Cuda fix for starcoder. * Nicer output.
This commit is contained in:
@ -22,11 +22,11 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
Ok(LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
fn make_causal_mask(t: usize) -> Result<Tensor> {
|
||||
fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u32::from(j <= i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
|
||||
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
||||
Ok(mask)
|
||||
}
|
||||
|
||||
@ -327,7 +327,7 @@ impl GPTBigCode {
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
|
||||
let lm_head = linear(hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?;
|
||||
let bias = make_causal_mask(cfg.max_position_embeddings)?;
|
||||
let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?;
|
||||
Ok(Self {
|
||||
wte,
|
||||
wpe,
|
||||
|
Reference in New Issue
Block a user