diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 400351f3..c311d4c4 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -390,7 +390,7 @@ impl Llama { x = block.forward(&x, index_pos, block_idx, cache)?; } let x = self.ln_f.forward(&x)?; - let x = x.i((.., seq_len - 1, ..))?; + let x = x.i((.., seq_len - 1, ..))?.contiguous()?; let logits = self.lm_head.forward(&x)?; logits.to_dtype(DType::F32) }