Fix for the llama model. (#1906)

This commit is contained in:
Laurent Mazare
2024-03-21 19:36:10 +01:00
committed by GitHub
parent c0bdd9c7a6
commit c07e4057ab

View File

@ -390,7 +390,7 @@ impl Llama {
x = block.forward(&x, index_pos, block_idx, cache)?;
}
let x = self.ln_f.forward(&x)?;
let x = x.i((.., seq_len - 1, ..))?;
let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32)
}