From 59a59f41a63d8c91492f6c221a8118f288cd1819 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 10:06:44 +0100 Subject: [PATCH] Add the cuda mode to llama. --- examples/llama/main.rs | 17 +++++++++++------ examples/llama/var_store.rs | 9 ++++++--- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/examples/llama/main.rs b/examples/llama/main.rs index d0dd0d19..d2b16446 100644 --- a/examples/llama/main.rs +++ b/examples/llama/main.rs @@ -377,7 +377,7 @@ impl Llama { } } -fn precompute_freqs_cis(config: &Config) -> Result { +fn precompute_freqs_cis(config: &Config, device: &Device) -> Result { let seq_len = CONTEXT_SIZE; let n_elem = config.n_embd / config.n_head; let theta: Vec<_> = (0..n_elem) @@ -385,8 +385,8 @@ fn precompute_freqs_cis(config: &Config) -> Result { .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32)) .collect(); let arange: Vec<_> = (0..seq_len).map(|c| c as f32).collect(); - let theta = Tensor::new(theta.as_slice(), &candle::Device::Cpu)?; - let arange = Tensor::new(arange.as_slice(), &candle::Device::Cpu)?; + let theta = Tensor::new(theta.as_slice(), device)?; + let arange = Tensor::new(arange.as_slice(), device)?; let idx_theta = arange .reshape((arange.elem_count(), 1))? .matmul(&theta.reshape((1, theta.elem_count()))?)?; @@ -418,6 +418,11 @@ fn main() -> Result<()> { use tokenizers::Tokenizer; let args = Args::parse(); + let device = if args.cpu { + Device::Cpu + } else { + Device::new_cuda(0)? + }; println!("loading tokenizer config"); let tokenizer = Tokenizer::from_file("llama-tokenizer.json").map_err(E::msg)?; let mut tokens = tokenizer @@ -438,20 +443,20 @@ fn main() -> Result<()> { println!("cannot find {weight_path:?}, using zero weights"); None }; - let vb = VarBuilder::new::(weights); + let vb = VarBuilder::new::(&device, weights); println!("building the model"); let config = Config::config_7b(); let llama = Llama::new(vb, &config)?; println!("pre-computing the positional embeddings"); - let freqs_cis = precompute_freqs_cis(&config)?; + let freqs_cis = precompute_freqs_cis(&config, &device)?; println!("starting the inference loop"); let mut new_tokens = vec![]; let mut rng = thread_rng(); for index in 0..args.sample_len { let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; - let input = Tensor::new(ctxt, &Device::Cpu)?; + let input = Tensor::new(ctxt, &device)?; let logits = llama.forward(&input, &freqs_cis)?; let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?; let logits_v: Vec = prs.to_vec1()?; diff --git a/examples/llama/var_store.rs b/examples/llama/var_store.rs index fb3d0c61..1a400edc 100644 --- a/examples/llama/var_store.rs +++ b/examples/llama/var_store.rs @@ -15,6 +15,7 @@ pub struct VarBuilder { path: Vec, vars: std::rc::Rc>>, default_dtype: DType, + default_device: Device, tensors: Arc>>, } @@ -24,13 +25,14 @@ pub struct VarStore { } impl VarBuilder { - pub fn new(tensors: Option>) -> Self { + pub fn new(device: &Device, tensors: Option>) -> Self { let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![])); Self { path: vec![], vars, default_dtype: B::DTYPE, tensors: Arc::new(tensors), + default_device: device.clone(), } } @@ -43,9 +45,9 @@ impl VarBuilder { let path = format!("{}.{s}", self.path.join(".")); let mut vars = self.vars.borrow_mut(); let parameter = match self.tensors.as_ref() { - None => Tensor::zeros(&shape, self.default_dtype, &Device::Cpu)?, + None => Tensor::zeros(&shape, self.default_dtype, &self.default_device)?, Some(tensors) => match tensors.get(&path) { - Some(tensor) => tensor.clone(), + Some(tensor) => tensor.to_device(&self.default_device)?, None => panic!("cannot find tensor for {path}"), }, }; @@ -75,6 +77,7 @@ impl std::ops::Div for &VarBuilder { path, vars: self.vars.clone(), default_dtype: self.default_dtype, + default_device: self.default_device.clone(), tensors: self.tensors.clone(), } }