mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Compare commits
3 Commits
copy-multi
...
copy2d-met
Author | SHA1 | Date | |
---|---|---|---|
33c9b66554 | |||
9fd52b3b71 | |||
e662431acf |
@ -149,8 +149,11 @@ impl QMetalStorage {
|
|||||||
let (n, k) = self_shape.dims2()?;
|
let (n, k) = self_shape.dims2()?;
|
||||||
let mut dst_shape = src_shape.dims().to_vec();
|
let mut dst_shape = src_shape.dims().to_vec();
|
||||||
|
|
||||||
|
// We always use a single batch dimension and stack all the tensors in the batch on the
|
||||||
|
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
|
||||||
|
// properly.
|
||||||
let (b, m) = match dst_shape.len() {
|
let (b, m) = match dst_shape.len() {
|
||||||
3 => (dst_shape[0], dst_shape[1]),
|
3 => (1, dst_shape[0] * dst_shape[1]),
|
||||||
2 => (1, dst_shape[0]),
|
2 => (1, dst_shape[0]),
|
||||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||||
};
|
};
|
||||||
|
@ -16,6 +16,22 @@ use candle_transformers::generation::LogitsProcessor;
|
|||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "2b")]
|
||||||
|
Base2B,
|
||||||
|
#[value(name = "7b")]
|
||||||
|
Base7B,
|
||||||
|
#[value(name = "2b-it")]
|
||||||
|
Instruct2B,
|
||||||
|
#[value(name = "7b-it")]
|
||||||
|
Instruct7B,
|
||||||
|
#[value(name = "1.1-2b-it")]
|
||||||
|
InstructV1_1_2B,
|
||||||
|
#[value(name = "1.1-7b-it")]
|
||||||
|
InstructV1_1_7B,
|
||||||
|
}
|
||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
model: Model,
|
model: Model,
|
||||||
device: Device,
|
device: Device,
|
||||||
@ -165,6 +181,10 @@ struct Args {
|
|||||||
/// The context size to consider for the repeat penalty.
|
/// The context size to consider for the repeat penalty.
|
||||||
#[arg(long, default_value_t = 64)]
|
#[arg(long, default_value_t = 64)]
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model to use.
|
||||||
|
#[arg(long, default_value = "2b")]
|
||||||
|
which: Which,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -196,14 +216,15 @@ fn main() -> Result<()> {
|
|||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let model_id = match &args.model_id {
|
let model_id = match &args.model_id {
|
||||||
Some(model_id) => match model_id.as_str() {
|
Some(model_id) => model_id.to_string(),
|
||||||
"7b-it" => "google/gemma-7b-it".to_string(),
|
None => match args.which {
|
||||||
"7b" => "google/gemma-7b".to_string(),
|
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
||||||
"2b-it" => "google/gemma-2b-it".to_string(),
|
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
||||||
"2b" => "google/gemma-2b".to_string(),
|
Which::Base2B => "google/gemma-2b".to_string(),
|
||||||
_ => model_id.to_string(),
|
Which::Base7B => "google/gemma-7b".to_string(),
|
||||||
|
Which::Instruct2B => "google/gemma-2b-it".to_string(),
|
||||||
|
Which::Instruct7B => "google/gemma-7b-it".to_string(),
|
||||||
},
|
},
|
||||||
None => "google/gemma-2b".to_string(),
|
|
||||||
};
|
};
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -406,7 +406,7 @@ pub fn call_copy2d(
|
|||||||
);
|
);
|
||||||
|
|
||||||
let width: usize = d1 * d2;
|
let width: usize = d1 * d2;
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width / 4);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||||
|
|
||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
|
@ -112,7 +112,6 @@ kernel void FN_NAME( \
|
|||||||
device TYPENAME *output, \
|
device TYPENAME *output, \
|
||||||
uint tid [[ thread_position_in_grid ]] \
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
tid *= 4; \
|
|
||||||
if (tid >= d1 * d2) { \
|
if (tid >= d1 * d2) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
@ -121,9 +120,6 @@ kernel void FN_NAME( \
|
|||||||
size_t src_idx = idx1 * src_s + idx2; \
|
size_t src_idx = idx1 * src_s + idx2; \
|
||||||
size_t dst_idx = idx1 * dst_s + idx2; \
|
size_t dst_idx = idx1 * dst_s + idx2; \
|
||||||
output[dst_idx] = input[src_idx]; \
|
output[dst_idx] = input[src_idx]; \
|
||||||
output[dst_idx+1] = input[src_idx+1]; \
|
|
||||||
output[dst_idx+2] = input[src_idx+2]; \
|
|
||||||
output[dst_idx+3] = input[src_idx+3]; \
|
|
||||||
}
|
}
|
||||||
|
|
||||||
COPY2D(copy2d_f32, float)
|
COPY2D(copy2d_f32, float)
|
||||||
|
@ -11,6 +11,7 @@ fn default_max_position_embeddings() -> usize {
|
|||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub attention_bias: bool,
|
pub attention_bias: bool,
|
||||||
pub head_dim: usize,
|
pub head_dim: usize,
|
||||||
|
#[serde(alias = "hidden_activation")]
|
||||||
pub hidden_act: candle_nn::Activation,
|
pub hidden_act: candle_nn::Activation,
|
||||||
pub hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
pub intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
|
@ -235,6 +235,7 @@ pub mod transformer {
|
|||||||
xs = layer.forward(&xs, pos, &mask)?
|
xs = layer.forward(&xs, pos, &mask)?
|
||||||
}
|
}
|
||||||
xs.narrow(1, seqlen - 1, 1)?
|
xs.narrow(1, seqlen - 1, 1)?
|
||||||
|
.contiguous()?
|
||||||
.apply(&self.norm)?
|
.apply(&self.norm)?
|
||||||
.apply(&self.output)
|
.apply(&self.output)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user