mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Cuda support for dtype conversions.
This commit is contained in:
@ -14,7 +14,7 @@
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{Device, Tensor};
|
||||
use candle::{DType, Device, Tensor};
|
||||
|
||||
mod var_store;
|
||||
use var_store::VarBuilder;
|
||||
@ -135,7 +135,10 @@ impl Embedding {
|
||||
}
|
||||
|
||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
Ok(Tensor::embedding(indexes, &self.embeddings)?)
|
||||
Ok(Tensor::embedding(
|
||||
indexes,
|
||||
&self.embeddings.to_dtype(DType::F32)?,
|
||||
)?)
|
||||
}
|
||||
}
|
||||
|
||||
@ -158,10 +161,10 @@ impl Linear {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.matmul(&self.ws)?;
|
||||
let x = x.matmul(&self.ws.to_dtype(DType::F32)?)?;
|
||||
let y = match &self.bs {
|
||||
None => x,
|
||||
Some(bs) => x.broadcast_add(bs)?,
|
||||
Some(bs) => x.broadcast_add(&bs.to_dtype(DType::F32)?)?,
|
||||
};
|
||||
Ok(y)
|
||||
}
|
||||
@ -183,7 +186,10 @@ impl RmsNorm {
|
||||
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||
let scale = self.scale.broadcast_as((seq_len, self.size))?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_as((seq_len, self.size))?;
|
||||
Ok((scale * x_normed)?)
|
||||
}
|
||||
}
|
||||
@ -431,7 +437,7 @@ fn main() -> Result<()> {
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let weight_path = std::path::Path::new("llama-f32.npz");
|
||||
let weight_path = std::path::Path::new("llama.npz");
|
||||
let weights = if weight_path.exists() {
|
||||
println!("loading weights from {weight_path:?}");
|
||||
let start_load = std::time::Instant::now();
|
||||
|
Reference in New Issue
Block a user