mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use F16 for moondream on cuda. (#2013)
This commit is contained in:
@ -283,6 +283,11 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let config = moondream::Config::v2();
|
let config = moondream::Config::v2();
|
||||||
|
let dtype = if device.is_cuda() && !args.quantized {
|
||||||
|
DType::F16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
let model = if args.quantized {
|
let model = if args.quantized {
|
||||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||||
&model_file,
|
&model_file,
|
||||||
@ -291,15 +296,16 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let model = quantized_moondream::Model::new(&config, vb)?;
|
let model = quantized_moondream::Model::new(&config, vb)?;
|
||||||
Model::Quantized(model)
|
Model::Quantized(model)
|
||||||
} else {
|
} else {
|
||||||
let vb =
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
|
||||||
let model = moondream::Model::new(&config, vb)?;
|
let model = moondream::Model::new(&config, vb)?;
|
||||||
Model::Moondream(model)
|
Model::Moondream(model)
|
||||||
};
|
};
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let image = load_image(args.image)?.to_device(&device)?;
|
let image = load_image(args.image)?
|
||||||
|
.to_device(&device)?
|
||||||
|
.to_dtype(dtype)?;
|
||||||
let image_embeds = image.unsqueeze(0)?;
|
let image_embeds = image.unsqueeze(0)?;
|
||||||
let image_embeds = match model {
|
let image_embeds = match model {
|
||||||
Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?,
|
Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?,
|
||||||
|
@ -135,7 +135,9 @@ fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
|||||||
|
|
||||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||||
let shape = mask.shape();
|
let shape = mask.shape();
|
||||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
let on_true = Tensor::new(on_true, on_false.device())?
|
||||||
|
.to_dtype(on_false.dtype())?
|
||||||
|
.broadcast_as(shape.dims())?;
|
||||||
let m = mask.where_cond(&on_true, on_false)?;
|
let m = mask.where_cond(&on_true, on_false)?;
|
||||||
Ok(m)
|
Ok(m)
|
||||||
}
|
}
|
||||||
@ -147,7 +149,7 @@ struct RotaryEmbedding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl RotaryEmbedding {
|
impl RotaryEmbedding {
|
||||||
fn new(dim: usize, max_seq_len: usize, dev: &Device) -> Result<Self> {
|
fn new(dim: usize, max_seq_len: usize, dtype: DType, dev: &Device) -> Result<Self> {
|
||||||
let inv_freq: Vec<_> = (0..dim)
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||||
@ -159,8 +161,8 @@ impl RotaryEmbedding {
|
|||||||
.reshape((max_seq_len, 1))?;
|
.reshape((max_seq_len, 1))?;
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
sin: freqs.sin()?,
|
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||||
cos: freqs.cos()?,
|
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -274,7 +276,8 @@ impl MHA {
|
|||||||
let op_size = cfg.n_embd;
|
let op_size = cfg.n_embd;
|
||||||
let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
|
let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
|
||||||
let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
|
let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
|
||||||
let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?;
|
let rotary_emb =
|
||||||
|
RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.dtype(), vb.device())?;
|
||||||
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
wqkv,
|
wqkv,
|
||||||
|
Reference in New Issue
Block a user