mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
@ -73,22 +73,18 @@ impl TextGeneration {
|
|||||||
let img_break = get_token("[IMG_BREAK]")?;
|
let img_break = get_token("[IMG_BREAK]")?;
|
||||||
let img_end = get_token("[IMG_END]")?;
|
let img_end = get_token("[IMG_END]")?;
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
let mut pos = 0;
|
|
||||||
for index in 0..sample_len {
|
for index in 0..sample_len {
|
||||||
let logits = if index > 0 {
|
let logits = if index > 0 {
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
let ctxt = &tokens[start_pos..];
|
let ctxt = &tokens[start_pos..];
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
let logits = self.model.language_model.forward(&input, pos)?;
|
self.model.lm_forward(&input)?
|
||||||
pos += context_size;
|
|
||||||
logits
|
|
||||||
} else {
|
} else {
|
||||||
let (_b, _c, h, w) = self.image.dims4()?;
|
let (_b, _c, h, w) = self.image.dims4()?;
|
||||||
let h = h / self.model.patch_size;
|
let h = h / self.model.patch_size;
|
||||||
let w = w / self.model.patch_size;
|
let w = w / self.model.patch_size;
|
||||||
let image_embeds = self.model.vision_tower.forward(&self.image)?;
|
let image_embeds = self.model.encode_image(&self.image)?;
|
||||||
let image_embeds = self.model.multi_modal_projector.forward(&image_embeds)?;
|
|
||||||
println!("generated image embeddings {image_embeds:?}");
|
println!("generated image embeddings {image_embeds:?}");
|
||||||
let image_embeds = image_embeds.to_dtype(self.model.dtype)?;
|
let image_embeds = image_embeds.to_dtype(self.model.dtype)?;
|
||||||
for &t in tokens.iter() {
|
for &t in tokens.iter() {
|
||||||
@ -124,12 +120,7 @@ impl TextGeneration {
|
|||||||
input_embeds.push(end_embeds);
|
input_embeds.push(end_embeds);
|
||||||
|
|
||||||
let input_embeds = Tensor::cat(&input_embeds, 1)?;
|
let input_embeds = Tensor::cat(&input_embeds, 1)?;
|
||||||
let logits = self
|
self.model.lm_forward_embeds(&input_embeds)?
|
||||||
.model
|
|
||||||
.language_model
|
|
||||||
.forward_embeds(&input_embeds, None, pos)?;
|
|
||||||
pos += input_embeds.dim(1)?;
|
|
||||||
logits
|
|
||||||
};
|
};
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
let logits = if self.repeat_penalty == 1. {
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
@ -48,6 +48,7 @@ pub struct Model {
|
|||||||
pub vision_tower: vision_model::Model,
|
pub vision_tower: vision_model::Model,
|
||||||
pub patch_size: usize,
|
pub patch_size: usize,
|
||||||
pub dtype: candle::DType,
|
pub dtype: candle::DType,
|
||||||
|
pub pos: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
@ -67,6 +68,31 @@ impl Model {
|
|||||||
vision_tower,
|
vision_tower,
|
||||||
patch_size: cfg.vision_config.patch_size,
|
patch_size: cfg.vision_config.patch_size,
|
||||||
dtype: vb.dtype(),
|
dtype: vb.dtype(),
|
||||||
|
pos: 0,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.language_model.clear_kv_cache();
|
||||||
|
self.pos = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode_image(&self, image: &Tensor) -> Result<Tensor> {
|
||||||
|
let image_embeds = self.vision_tower.forward(image)?;
|
||||||
|
self.multi_modal_projector.forward(&image_embeds)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn lm_forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
|
let (_, seq_len) = input_ids.dims2()?;
|
||||||
|
let logits = self.language_model.forward(input_ids, self.pos)?;
|
||||||
|
self.pos += seq_len;
|
||||||
|
Ok(logits)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn lm_forward_embeds(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let (_, seq_len, _) = xs.dims3()?;
|
||||||
|
let logits = self.language_model.forward_embeds(xs, None, self.pos)?;
|
||||||
|
self.pos += seq_len;
|
||||||
|
Ok(logits)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user