Cuda fix for starcoder. (#266)

* Cuda fix for starcoder.

* Nicer output.
This commit is contained in:
Laurent Mazare
2023-07-28 12:13:41 +01:00
committed by GitHub
parent 54ccf94472
commit 68eab38de6
2 changed files with 13 additions and 15 deletions

View File

@ -22,11 +22,11 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
Ok(LayerNorm::new(weight, bias, eps))
}
fn make_causal_mask(t: usize) -> Result<Tensor> {
fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u32::from(j <= i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
let mask = Tensor::from_slice(&mask, (t, t), device)?;
Ok(mask)
}
@ -327,7 +327,7 @@ impl GPTBigCode {
.collect::<Result<Vec<_>>>()?;
let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
let lm_head = linear(hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?;
let bias = make_causal_mask(cfg.max_position_embeddings)?;
let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?;
Ok(Self {
wte,
wpe,