Add the cuda mode to llama.

This commit is contained in:
laurent
2023-06-26 10:06:44 +01:00
parent 512d12e38d
commit 59a59f41a6
2 changed files with 17 additions and 9 deletions

View File

@ -377,7 +377,7 @@ impl Llama {
}
}
fn precompute_freqs_cis(config: &Config) -> Result<Tensor> {
fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
let seq_len = CONTEXT_SIZE;
let n_elem = config.n_embd / config.n_head;
let theta: Vec<_> = (0..n_elem)
@ -385,8 +385,8 @@ fn precompute_freqs_cis(config: &Config) -> Result<Tensor> {
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect();
let arange: Vec<_> = (0..seq_len).map(|c| c as f32).collect();
let theta = Tensor::new(theta.as_slice(), &candle::Device::Cpu)?;
let arange = Tensor::new(arange.as_slice(), &candle::Device::Cpu)?;
let theta = Tensor::new(theta.as_slice(), device)?;
let arange = Tensor::new(arange.as_slice(), device)?;
let idx_theta = arange
.reshape((arange.elem_count(), 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
@ -418,6 +418,11 @@ fn main() -> Result<()> {
use tokenizers::Tokenizer;
let args = Args::parse();
let device = if args.cpu {
Device::Cpu
} else {
Device::new_cuda(0)?
};
println!("loading tokenizer config");
let tokenizer = Tokenizer::from_file("llama-tokenizer.json").map_err(E::msg)?;
let mut tokens = tokenizer
@ -438,20 +443,20 @@ fn main() -> Result<()> {
println!("cannot find {weight_path:?}, using zero weights");
None
};
let vb = VarBuilder::new::<f32>(weights);
let vb = VarBuilder::new::<f32>(&device, weights);
println!("building the model");
let config = Config::config_7b();
let llama = Llama::new(vb, &config)?;
println!("pre-computing the positional embeddings");
let freqs_cis = precompute_freqs_cis(&config)?;
let freqs_cis = precompute_freqs_cis(&config, &device)?;
println!("starting the inference loop");
let mut new_tokens = vec![];
let mut rng = thread_rng();
for index in 0..args.sample_len {
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
let input = Tensor::new(ctxt, &Device::Cpu)?;
let input = Tensor::new(ctxt, &device)?;
let logits = llama.forward(&input, &freqs_cis)?;
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
let logits_v: Vec<f32> = prs.to_vec1()?;

View File

@ -15,6 +15,7 @@ pub struct VarBuilder {
path: Vec<String>,
vars: std::rc::Rc<std::cell::RefCell<Vec<NamedVar>>>,
default_dtype: DType,
default_device: Device,
tensors: Arc<Option<HashMap<String, Tensor>>>,
}
@ -24,13 +25,14 @@ pub struct VarStore {
}
impl VarBuilder {
pub fn new<B: WithDType>(tensors: Option<HashMap<String, Tensor>>) -> Self {
pub fn new<B: WithDType>(device: &Device, tensors: Option<HashMap<String, Tensor>>) -> Self {
let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![]));
Self {
path: vec![],
vars,
default_dtype: B::DTYPE,
tensors: Arc::new(tensors),
default_device: device.clone(),
}
}
@ -43,9 +45,9 @@ impl VarBuilder {
let path = format!("{}.{s}", self.path.join("."));
let mut vars = self.vars.borrow_mut();
let parameter = match self.tensors.as_ref() {
None => Tensor::zeros(&shape, self.default_dtype, &Device::Cpu)?,
None => Tensor::zeros(&shape, self.default_dtype, &self.default_device)?,
Some(tensors) => match tensors.get(&path) {
Some(tensor) => tensor.clone(),
Some(tensor) => tensor.to_device(&self.default_device)?,
None => panic!("cannot find tensor for {path}"),
},
};
@ -75,6 +77,7 @@ impl<S: ToString> std::ops::Div<S> for &VarBuilder {
path,
vars: self.vars.clone(),
default_dtype: self.default_dtype,
default_device: self.default_device.clone(),
tensors: self.tensors.clone(),
}
}