Use F16 for moondream on cuda. (#2013)

This commit is contained in:
Laurent Mazare
2024-04-04 23:30:10 +02:00
committed by GitHub
parent c5626b8271
commit c87381fc96
2 changed files with 17 additions and 8 deletions

View File

@ -135,7 +135,9 @@ fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let on_true = Tensor::new(on_true, on_false.device())?
.to_dtype(on_false.dtype())?
.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
@ -147,7 +149,7 @@ struct RotaryEmbedding {
}
impl RotaryEmbedding {
fn new(dim: usize, max_seq_len: usize, dev: &Device) -> Result<Self> {
fn new(dim: usize, max_seq_len: usize, dtype: DType, dev: &Device) -> Result<Self> {
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
@ -159,8 +161,8 @@ impl RotaryEmbedding {
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
sin: freqs.sin()?.to_dtype(dtype)?,
cos: freqs.cos()?.to_dtype(dtype)?,
})
}
@ -274,7 +276,8 @@ impl MHA {
let op_size = cfg.n_embd;
let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?;
let rotary_emb =
RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.dtype(), vb.device())?;
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
Ok(Self {
wqkv,