mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Quantized moondream implementation and BOS token (#1980)
* moondream implementation * add moondream example * change config default activation * Add assets and integrate phi mixformer with example * Make use of kv cache and fix seq_len bug; Clean up example code * Add README link to example * Remove pos_embed scaling; Remove assets; Add to README; Expand VisionConfig * Delete image * Use apply instead of forward * Pass bos token at the beginning of tensor. * Quantize moondream. * Forward with image bos token. * Clippy. * Use q4_0 quantization. * Add pointers for sequence and tokens; Remove seq_len conditional
This commit is contained in:
@ -337,6 +337,30 @@ impl MixFormerSequentialForCausalLM {
|
||||
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
|
||||
}
|
||||
|
||||
pub fn forward_with_img(
|
||||
&mut self,
|
||||
bos_token: &Tensor,
|
||||
xs: &Tensor,
|
||||
img_embeds: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = xs.apply(&self.embedding)?;
|
||||
let bos_token = bos_token.apply(&self.embedding)?;
|
||||
// Python implementation sequence order is <bos token embedding><img embedding><rest of text embedding>
|
||||
// https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L43-L56
|
||||
let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?;
|
||||
let (_b_size, seq_len, _embds) = xs.dims3()?;
|
||||
let mask = Some(get_mask(seq_len, xs.device())?);
|
||||
for block in self.blocks.iter_mut() {
|
||||
xs = block.forward(&xs, mask.as_ref())?
|
||||
}
|
||||
let xs = xs
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.head)?
|
||||
.squeeze(1)?;
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
|
||||
}
|
||||
|
Reference in New Issue
Block a user