diff --git a/candle-examples/examples/moondream/main.rs b/candle-examples/examples/moondream/main.rs index 008346f0..bcc21337 100644 --- a/candle-examples/examples/moondream/main.rs +++ b/candle-examples/examples/moondream/main.rs @@ -72,20 +72,16 @@ impl TextGeneration { let mut tokens = tokens.get_ids().to_vec(); let mut generated_tokens = 0usize; - // Moondream tokenizer bos_token is "<|endoftext|>" + // Moondream tokenizer bos_token and eos_token is "<|endoftext|>" // https://huggingface.co/vikhyatk/moondream2/blob/main/special_tokens_map.json - let bos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { + let special_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { Some(token) => *token, - None => anyhow::bail!("cannot find the BOS token"), - }; - // eos_token is "END" - // https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L100 - let eos_token = match self.tokenizer.get_vocab(true).get("END") { - Some(token) => *token, - None => anyhow::bail!("cannot find the EOS token"), + None => anyhow::bail!("cannot find the special token"), }; + let (bos_token, eos_token) = (special_token, special_token); let start_gen = std::time::Instant::now(); + let mut load_t = std::time::Duration::from_secs_f64(0f64); for index in 0..sample_len { let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; @@ -97,7 +93,7 @@ impl TextGeneration { } } else { let bos_token = Tensor::new(&[bos_token], &self.device)?.unsqueeze(0)?; - match self.model { + let logits = match self.model { Model::Moondream(ref mut model) => { model .text_model @@ -108,7 +104,10 @@ impl TextGeneration { .text_model .forward_with_img(&bos_token, &input, image_embeds)? } - } + }; + load_t = start_gen.elapsed(); + println!("load_t: {:?}", load_t); + logits }; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; let logits = if self.repeat_penalty == 1. { @@ -132,10 +131,11 @@ impl TextGeneration { std::io::stdout().flush()?; } - let dt = start_gen.elapsed(); + let dt = start_gen.elapsed() - load_t; println!( - "\n{generated_tokens} tokens generated ({:.2} token/s)", - generated_tokens as f64 / dt.as_secs_f64() + "\ngenerated in {} seconds\n{generated_tokens} tokens generated ({:.2} token/s)", + dt.as_secs_f64(), + (generated_tokens - 1) as f64 / dt.as_secs_f64() ); Ok(()) diff --git a/candle-transformers/src/models/moondream.rs b/candle-transformers/src/models/moondream.rs index 42b24fb8..717f3bb4 100644 --- a/candle-transformers/src/models/moondream.rs +++ b/candle-transformers/src/models/moondream.rs @@ -47,7 +47,7 @@ impl VisionConfig { embed_dim: 1152, num_blocks: 27, num_heads: 16, - act: candle_nn::Activation::Gelu, + act: candle_nn::Activation::GeluPytorchTanh, } } }