Adding cast + binary kernels.

This commit is contained in:
Nicolas Patry
2023-11-07 23:45:53 +01:00
parent 0c24a885a6
commit 480a3e22e6
7 changed files with 601 additions and 84 deletions

View File

@ -2,7 +2,7 @@ use std::collections::HashMap;
use candle::quantized::QTensor;
use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, IndexOp, Result, Tensor, D};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 4096;
@ -196,15 +196,15 @@ fn precomput_freqs_cis(
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let range: Vec<f32> = (0..MAX_SEQ_LEN).map(|r| r as f32).collect();
let idx_theta = Tensor::new(range.as_slice(), device)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// TODO This change avoids allocating on Metal and then casting since allocating directly on
// CPU as f32 seems just as fast
// let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
// .to_dtype(DType::F32)?
// let idx_theta = Tensor::new(range.as_slice(), device)?
// .reshape((MAX_SEQ_LEN, 1))?
// .matmul(&theta.reshape((1, theta.elem_count()))?)?;
// TODO This change avoids allocating on Metal and then casting since allocating directly on
// CPU as f32 seems just as fast
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
Ok((cos, sin))