Compare commits

..

3 Commits

Author SHA1 Message Date
33c9b66554 Add the new gemma models. (#2023)
* Add the new gemma models.

* Revert the lightning changes.

* Support for the 1.1 models.
2024-04-06 21:25:38 +02:00
9fd52b3b71 Handle the batch dimension in quantized MMV on metal. (#2022) 2024-04-06 20:02:24 +02:00
e662431acf Fix the final rmsnorm for quantized-metavoice. (#2021) 2024-04-06 19:35:01 +02:00
6 changed files with 35 additions and 13 deletions

View File

@ -149,8 +149,11 @@ impl QMetalStorage {
let (n, k) = self_shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec();
// We always use a single batch dimension and stack all the tensors in the batch on the
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
// properly.
let (b, m) = match dst_shape.len() {
3 => (dst_shape[0], dst_shape[1]),
3 => (1, dst_shape[0] * dst_shape[1]),
2 => (1, dst_shape[0]),
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};

View File

@ -16,6 +16,22 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "2b")]
Base2B,
#[value(name = "7b")]
Base7B,
#[value(name = "2b-it")]
Instruct2B,
#[value(name = "7b-it")]
Instruct7B,
#[value(name = "1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
}
struct TextGeneration {
model: Model,
device: Device,
@ -165,6 +181,10 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "2b")]
which: Which,
}
fn main() -> Result<()> {
@ -196,14 +216,15 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => match model_id.as_str() {
"7b-it" => "google/gemma-7b-it".to_string(),
"7b" => "google/gemma-7b".to_string(),
"2b-it" => "google/gemma-2b-it".to_string(),
"2b" => "google/gemma-2b".to_string(),
_ => model_id.to_string(),
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
Which::Base2B => "google/gemma-2b".to_string(),
Which::Base7B => "google/gemma-7b".to_string(),
Which::Instruct2B => "google/gemma-2b-it".to_string(),
Which::Instruct7B => "google/gemma-7b-it".to_string(),
},
None => "google/gemma-2b".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,

View File

@ -406,7 +406,7 @@ pub fn call_copy2d(
);
let width: usize = d1 * d2;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width / 4);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);

View File

@ -112,7 +112,6 @@ kernel void FN_NAME( \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
tid *= 4; \
if (tid >= d1 * d2) { \
return; \
} \
@ -121,9 +120,6 @@ kernel void FN_NAME( \
size_t src_idx = idx1 * src_s + idx2; \
size_t dst_idx = idx1 * dst_s + idx2; \
output[dst_idx] = input[src_idx]; \
output[dst_idx+1] = input[src_idx+1]; \
output[dst_idx+2] = input[src_idx+2]; \
output[dst_idx+3] = input[src_idx+3]; \
}
COPY2D(copy2d_f32, float)

View File

@ -11,6 +11,7 @@ fn default_max_position_embeddings() -> usize {
pub struct Config {
pub attention_bias: bool,
pub head_dim: usize,
#[serde(alias = "hidden_activation")]
pub hidden_act: candle_nn::Activation,
pub hidden_size: usize,
pub intermediate_size: usize,

View File

@ -235,6 +235,7 @@ pub mod transformer {
xs = layer.forward(&xs, pos, &mask)?
}
xs.narrow(1, seqlen - 1, 1)?
.contiguous()?
.apply(&self.norm)?
.apply(&self.output)
}