mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the cuda mode to llama.
This commit is contained in:
@ -377,7 +377,7 @@ impl Llama {
|
||||
}
|
||||
}
|
||||
|
||||
fn precompute_freqs_cis(config: &Config) -> Result<Tensor> {
|
||||
fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
|
||||
let seq_len = CONTEXT_SIZE;
|
||||
let n_elem = config.n_embd / config.n_head;
|
||||
let theta: Vec<_> = (0..n_elem)
|
||||
@ -385,8 +385,8 @@ fn precompute_freqs_cis(config: &Config) -> Result<Tensor> {
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
||||
.collect();
|
||||
let arange: Vec<_> = (0..seq_len).map(|c| c as f32).collect();
|
||||
let theta = Tensor::new(theta.as_slice(), &candle::Device::Cpu)?;
|
||||
let arange = Tensor::new(arange.as_slice(), &candle::Device::Cpu)?;
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let arange = Tensor::new(arange.as_slice(), device)?;
|
||||
let idx_theta = arange
|
||||
.reshape((arange.elem_count(), 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
@ -418,6 +418,11 @@ fn main() -> Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
let device = if args.cpu {
|
||||
Device::Cpu
|
||||
} else {
|
||||
Device::new_cuda(0)?
|
||||
};
|
||||
println!("loading tokenizer config");
|
||||
let tokenizer = Tokenizer::from_file("llama-tokenizer.json").map_err(E::msg)?;
|
||||
let mut tokens = tokenizer
|
||||
@ -438,20 +443,20 @@ fn main() -> Result<()> {
|
||||
println!("cannot find {weight_path:?}, using zero weights");
|
||||
None
|
||||
};
|
||||
let vb = VarBuilder::new::<f32>(weights);
|
||||
let vb = VarBuilder::new::<f32>(&device, weights);
|
||||
|
||||
println!("building the model");
|
||||
let config = Config::config_7b();
|
||||
let llama = Llama::new(vb, &config)?;
|
||||
|
||||
println!("pre-computing the positional embeddings");
|
||||
let freqs_cis = precompute_freqs_cis(&config)?;
|
||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||
println!("starting the inference loop");
|
||||
let mut new_tokens = vec![];
|
||||
let mut rng = thread_rng();
|
||||
for index in 0..args.sample_len {
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||
let input = Tensor::new(ctxt, &Device::Cpu)?;
|
||||
let input = Tensor::new(ctxt, &device)?;
|
||||
let logits = llama.forward(&input, &freqs_cis)?;
|
||||
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
|
@ -15,6 +15,7 @@ pub struct VarBuilder {
|
||||
path: Vec<String>,
|
||||
vars: std::rc::Rc<std::cell::RefCell<Vec<NamedVar>>>,
|
||||
default_dtype: DType,
|
||||
default_device: Device,
|
||||
tensors: Arc<Option<HashMap<String, Tensor>>>,
|
||||
}
|
||||
|
||||
@ -24,13 +25,14 @@ pub struct VarStore {
|
||||
}
|
||||
|
||||
impl VarBuilder {
|
||||
pub fn new<B: WithDType>(tensors: Option<HashMap<String, Tensor>>) -> Self {
|
||||
pub fn new<B: WithDType>(device: &Device, tensors: Option<HashMap<String, Tensor>>) -> Self {
|
||||
let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![]));
|
||||
Self {
|
||||
path: vec![],
|
||||
vars,
|
||||
default_dtype: B::DTYPE,
|
||||
tensors: Arc::new(tensors),
|
||||
default_device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -43,9 +45,9 @@ impl VarBuilder {
|
||||
let path = format!("{}.{s}", self.path.join("."));
|
||||
let mut vars = self.vars.borrow_mut();
|
||||
let parameter = match self.tensors.as_ref() {
|
||||
None => Tensor::zeros(&shape, self.default_dtype, &Device::Cpu)?,
|
||||
None => Tensor::zeros(&shape, self.default_dtype, &self.default_device)?,
|
||||
Some(tensors) => match tensors.get(&path) {
|
||||
Some(tensor) => tensor.clone(),
|
||||
Some(tensor) => tensor.to_device(&self.default_device)?,
|
||||
None => panic!("cannot find tensor for {path}"),
|
||||
},
|
||||
};
|
||||
@ -75,6 +77,7 @@ impl<S: ToString> std::ops::Div<S> for &VarBuilder {
|
||||
path,
|
||||
vars: self.vars.clone(),
|
||||
default_dtype: self.default_dtype,
|
||||
default_device: self.default_device.clone(),
|
||||
tensors: self.tensors.clone(),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user