mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Merge pull request #38 from LaurentMazare/llama_f16
Moving llama to f16.
This commit is contained in:
@ -152,7 +152,7 @@ impl Linear {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = x.matmul(&self.weight.to_dtype(DType::F32)?.t()?)?;
|
let x = x.matmul(&self.weight.t()?)?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -167,8 +167,9 @@ impl RmsNorm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let x = x.to_dtype(DType::F32)?;
|
||||||
let (seq_len, hidden_size) = x.shape().r2()?;
|
let (seq_len, hidden_size) = x.shape().r2()?;
|
||||||
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
|
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
|
||||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
||||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||||
let size = self.scale.shape().r1()?;
|
let size = self.scale.shape().r1()?;
|
||||||
@ -176,7 +177,9 @@ impl RmsNorm {
|
|||||||
.scale
|
.scale
|
||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.broadcast_as((seq_len, size))?;
|
.broadcast_as((seq_len, size))?;
|
||||||
Ok((scale * x_normed)?)
|
let x = (scale * x_normed)?;
|
||||||
|
let x = x.to_dtype(DType::F16)?;
|
||||||
|
Ok(x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -285,6 +288,7 @@ impl CausalSelfAttention {
|
|||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
let (t, c) = x.shape().r2()?;
|
let (t, c) = x.shape().r2()?;
|
||||||
let qkv = self.c_attn.forward(x)?;
|
let qkv = self.c_attn.forward(x)?;
|
||||||
|
let qkv = qkv.to_dtype(DType::F32)?;
|
||||||
let n_embd = c;
|
let n_embd = c;
|
||||||
let q = qkv.narrow(1, 0, n_embd)?;
|
let q = qkv.narrow(1, 0, n_embd)?;
|
||||||
let k = qkv.narrow(1, n_embd, n_embd)?;
|
let k = qkv.narrow(1, n_embd, n_embd)?;
|
||||||
@ -303,6 +307,7 @@ impl CausalSelfAttention {
|
|||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
let y = att.matmul(&v.contiguous()?)?;
|
let y = att.matmul(&v.contiguous()?)?;
|
||||||
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
|
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
|
||||||
|
let y = y.to_dtype(DType::F16)?;
|
||||||
let y = self.c_proj.forward(&y)?;
|
let y = self.c_proj.forward(&y)?;
|
||||||
Ok(y)
|
Ok(y)
|
||||||
}
|
}
|
||||||
@ -352,14 +357,14 @@ impl Llama {
|
|||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
// TODO: Support for mini-batches? (i.e. r2)
|
// TODO: Support for mini-batches? (i.e. r2)
|
||||||
let t = x.shape().r1()?;
|
let t = x.shape().r1()?;
|
||||||
let x = self.wte.forward(x)?;
|
let mut x = self.wte.forward(x)?;
|
||||||
let mut x = x.to_dtype(DType::F32)?;
|
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter() {
|
||||||
x = block.forward(&x, freqs_cis)?;
|
x = block.forward(&x, freqs_cis)?;
|
||||||
}
|
}
|
||||||
let x = self.ln_f.forward(&x)?;
|
let x = self.ln_f.forward(&x)?;
|
||||||
let x = x.narrow(0, t - 1, 1)?;
|
let x = x.narrow(0, t - 1, 1)?;
|
||||||
let logits = self.lm_head.forward(&x)?;
|
let logits = self.lm_head.forward(&x)?;
|
||||||
|
let logits = logits.to_dtype(DType::F32)?;
|
||||||
let (b, vocab_size) = logits.shape().r2()?;
|
let (b, vocab_size) = logits.shape().r2()?;
|
||||||
assert_eq!(b, 1);
|
assert_eq!(b, 1);
|
||||||
Ok(logits.reshape(vocab_size)?)
|
Ok(logits.reshape(vocab_size)?)
|
||||||
@ -420,7 +425,6 @@ async fn main() -> Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
Device::new_cuda(0)?
|
Device::new_cuda(0)?
|
||||||
};
|
};
|
||||||
let api = Api::new()?;
|
|
||||||
let config = Config::config_7b();
|
let config = Config::config_7b();
|
||||||
let cache = Cache::new(&device);
|
let cache = Cache::new(&device);
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
@ -431,7 +435,9 @@ async fn main() -> Result<()> {
|
|||||||
std::path::Path::new("llama-tokenizer.json").to_path_buf(),
|
std::path::Path::new("llama-tokenizer.json").to_path_buf(),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
|
let api = Api::new()?;
|
||||||
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||||
|
println!("building the model");
|
||||||
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||||
let mut filenames = vec![];
|
let mut filenames = vec![];
|
||||||
for rfilename in [
|
for rfilename in [
|
||||||
|
Reference in New Issue
Block a user