Cuda support for dtype conversions.

This commit is contained in:
laurent
2023-06-27 09:15:46 +01:00
parent 51640ba7e6
commit ee3d290f8b
6 changed files with 125 additions and 18 deletions

View File

@ -14,7 +14,7 @@
use anyhow::{Error as E, Result};
use clap::Parser;
use candle::{Device, Tensor};
use candle::{DType, Device, Tensor};
mod var_store;
use var_store::VarBuilder;
@ -135,7 +135,10 @@ impl Embedding {
}
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
Ok(Tensor::embedding(indexes, &self.embeddings)?)
Ok(Tensor::embedding(
indexes,
&self.embeddings.to_dtype(DType::F32)?,
)?)
}
}
@ -158,10 +161,10 @@ impl Linear {
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.matmul(&self.ws)?;
let x = x.matmul(&self.ws.to_dtype(DType::F32)?)?;
let y = match &self.bs {
None => x,
Some(bs) => x.broadcast_add(bs)?,
Some(bs) => x.broadcast_add(&bs.to_dtype(DType::F32)?)?,
};
Ok(y)
}
@ -183,7 +186,10 @@ impl RmsNorm {
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let scale = self.scale.broadcast_as((seq_len, self.size))?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((seq_len, self.size))?;
Ok((scale * x_normed)?)
}
}
@ -431,7 +437,7 @@ fn main() -> Result<()> {
.get_ids()
.to_vec();
let weight_path = std::path::Path::new("llama-f32.npz");
let weight_path = std::path::Path::new("llama.npz");
let weights = if weight_path.exists() {
println!("loading weights from {weight_path:?}");
let start_load = std::time::Instant::now();