mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Use F16 for moondream on cuda. (#2013)
This commit is contained in:
@ -283,6 +283,11 @@ async fn main() -> anyhow::Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let config = moondream::Config::v2();
|
||||
let dtype = if device.is_cuda() && !args.quantized {
|
||||
DType::F16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let model = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&model_file,
|
||||
@ -291,15 +296,16 @@ async fn main() -> anyhow::Result<()> {
|
||||
let model = quantized_moondream::Model::new(&config, vb)?;
|
||||
Model::Quantized(model)
|
||||
} else {
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
let model = moondream::Model::new(&config, vb)?;
|
||||
Model::Moondream(model)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
let image = load_image(args.image)?
|
||||
.to_device(&device)?
|
||||
.to_dtype(dtype)?;
|
||||
let image_embeds = image.unsqueeze(0)?;
|
||||
let image_embeds = match model {
|
||||
Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?,
|
||||
|
@ -135,7 +135,9 @@ fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let on_true = Tensor::new(on_true, on_false.device())?
|
||||
.to_dtype(on_false.dtype())?
|
||||
.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
@ -147,7 +149,7 @@ struct RotaryEmbedding {
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dim: usize, max_seq_len: usize, dev: &Device) -> Result<Self> {
|
||||
fn new(dim: usize, max_seq_len: usize, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||
@ -159,8 +161,8 @@ impl RotaryEmbedding {
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||
})
|
||||
}
|
||||
|
||||
@ -274,7 +276,8 @@ impl MHA {
|
||||
let op_size = cfg.n_embd;
|
||||
let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
|
||||
let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
|
||||
let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?;
|
||||
let rotary_emb =
|
||||
RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.dtype(), vb.device())?;
|
||||
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
||||
Ok(Self {
|
||||
wqkv,
|
||||
|
Reference in New Issue
Block a user