mirror of
https://github.com/huggingface/candle.git
synced 2025-06-23 04:46:15 +00:00
Compare commits
11 Commits
ivarflakst
...
metal5
Author | SHA1 | Date | |
---|---|---|---|
67d93b4f42 | |||
c35d7d50db | |||
9694671bbf | |||
3dbf65ef20 | |||
b2db5adf82 | |||
9ef040338d | |||
3aefc709c7 | |||
c8c603ce96 | |||
61ad8d91cc | |||
2cd1e59c9e | |||
9c4b4f0da0 |
@ -53,12 +53,12 @@ log = "0.4"
|
||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
parquet = { version = "50.0.0" }
|
||||
parquet = { version = "45.0.0" }
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
rayon = "1.7.0"
|
||||
rusttype = { version = "0.9", default-features = false }
|
||||
safetensors = "0.4.1"
|
||||
safetensors = "0.3.1"
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_plain = "1.0.2"
|
||||
serde_json = "1.0.99"
|
||||
|
@ -2,8 +2,7 @@ mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::where_cond::benches
|
||||
);
|
||||
|
@ -1,6 +1,5 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
use candle_core::{Device, Result};
|
||||
|
@ -1,63 +0,0 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn rand_uniform(a: &Tensor) {
|
||||
a.rand_like(-1.0, 123.0).unwrap();
|
||||
}
|
||||
|
||||
fn rand_normal(a: &Tensor) {
|
||||
a.randn_like(100.0, 15.0).unwrap();
|
||||
}
|
||||
|
||||
fn run_random_bench(c: &mut Criterion, device: &Device) {
|
||||
let b = 1;
|
||||
|
||||
let rows = 2048;
|
||||
let cols = 2048;
|
||||
|
||||
let dtype = DType::F32;
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let flops = b * rows * cols * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
rand_uniform(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
|
||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name("random_normal"));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |benches| {
|
||||
benches.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
rand_normal(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
run_random_bench(c, &device);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -7,9 +7,8 @@ use candle_metal_kernels::Kernels;
|
||||
use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
use std::sync::{Arc, RwLock, TryLockError};
|
||||
|
||||
/// Simple way to catch lock error without
|
||||
/// depending on T
|
||||
@ -102,8 +101,6 @@ pub struct MetalDevice {
|
||||
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
|
||||
/// (strong_count = 1).
|
||||
buffers: AllocatedBuffers,
|
||||
/// Seed for random number generation.
|
||||
seed: Arc<Mutex<Buffer>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
@ -228,7 +225,7 @@ impl MetalDevice {
|
||||
// The slice might not live long enough for metal
|
||||
// To actually fill the GPU buffer.
|
||||
// Putting this wait forces the GPU buffer to be filled
|
||||
// with the actual data allowing the CPU storage to do
|
||||
// with the actual data allowing the CPU storage todo
|
||||
// deallocate properly.
|
||||
self.wait_until_completed()?;
|
||||
Ok(real)
|
||||
@ -822,13 +819,10 @@ impl BackendStorage for MetalStorage {
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
),
|
||||
!layout.is_contiguous(),
|
||||
&t.buffer,
|
||||
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
!t_l.is_contiguous(),
|
||||
&f.buffer,
|
||||
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
!f_l.is_contiguous(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1560,11 +1554,6 @@ impl BackendDevice for MetalDevice {
|
||||
Ok(val) => val.parse()?,
|
||||
_ => 10,
|
||||
};
|
||||
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
||||
[299792458].as_ptr() as *const c_void,
|
||||
4,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
)));
|
||||
Ok(Self {
|
||||
device,
|
||||
command_queue,
|
||||
@ -1573,10 +1562,13 @@ impl BackendDevice for MetalDevice {
|
||||
compute_per_buffer,
|
||||
buffers,
|
||||
kernels,
|
||||
seed,
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||
crate::bail!("Metal set_seed not implemented")
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Metal {
|
||||
gpu_id: self.registry_id() as usize,
|
||||
@ -1616,31 +1608,12 @@ impl BackendDevice for MetalDevice {
|
||||
&self,
|
||||
shape: &Shape,
|
||||
dtype: DType,
|
||||
min: f64,
|
||||
max: f64,
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
let name = match dtype {
|
||||
DType::F32 => "rand_uniform_f32",
|
||||
DType::F16 => "rand_uniform_f16",
|
||||
DType::BF16 => "rand_uniform_bf16",
|
||||
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_uniform")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_random_uniform(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
min as f32,
|
||||
max as f32,
|
||||
shape.elem_count(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn rand_normal(
|
||||
@ -1650,43 +1623,9 @@ impl BackendDevice for MetalDevice {
|
||||
mean: f64,
|
||||
stddev: f64,
|
||||
) -> Result<Self::Storage> {
|
||||
let name = match dtype {
|
||||
DType::F32 => "rand_normal_f32",
|
||||
DType::F16 => "rand_normal_f16",
|
||||
DType::BF16 => "rand_normal_bf16",
|
||||
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_random_normal(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
mean as f32,
|
||||
stddev as f32,
|
||||
shape.elem_count(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
let seed: u32 = seed.try_into().map_err(|_| {
|
||||
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())
|
||||
})?;
|
||||
|
||||
let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
|
||||
let contents = seed_buffer.contents();
|
||||
unsafe {
|
||||
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4);
|
||||
}
|
||||
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));
|
||||
|
||||
Ok(())
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,22 +0,0 @@
|
||||
# candle-mobileone
|
||||
|
||||
[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040).
|
||||
|
||||
This candle implementation uses a pre-trained MobileOne network for inference. The
|
||||
classification head has been trained on the ImageNet dataset and returns the
|
||||
probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example mobileone --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which s2
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
mountain bike, all-terrain bike, off-roader: 79.33%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 15.32%
|
||||
crash helmet : 2.58%
|
||||
unicycle, monocycle : 1.70%
|
||||
alp : 0.21%
|
||||
|
||||
```
|
@ -1,96 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::mobileone;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
S0,
|
||||
S1,
|
||||
S2,
|
||||
S3,
|
||||
S4,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::S0 => "s0",
|
||||
Self::S1 => "s1",
|
||||
Self::S2 => "s2",
|
||||
Self::S3 => "s3",
|
||||
Self::S4 => "s4",
|
||||
};
|
||||
format!("timm/mobileone_{}.apple_in1k", name)
|
||||
}
|
||||
|
||||
fn config(&self) -> mobileone::Config {
|
||||
match self {
|
||||
Self::S0 => mobileone::Config::s0(),
|
||||
Self::S1 => mobileone::Config::s1(),
|
||||
Self::S2 => mobileone::Config::s2(),
|
||||
Self::S3 => mobileone::Config::s3(),
|
||||
Self::S4 => mobileone::Config::s4(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::S0)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = mobileone::mobileone(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -73,7 +73,7 @@ BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
|
||||
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
||||
|
||||
#define INT64_BINARY_OP_OUT(NAME, FN) \
|
||||
BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided);
|
||||
BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
|
||||
|
||||
BINARY_OP(x + y, add)
|
||||
BINARY_OP(x - y, sub)
|
||||
|
@ -12,11 +12,9 @@ const UNARY: &str = include_str!("unary.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||
@ -63,12 +61,10 @@ macro_rules! primitive {
|
||||
}
|
||||
};
|
||||
}
|
||||
primitive!(bool);
|
||||
primitive!(usize);
|
||||
primitive!(i32);
|
||||
primitive!(i64);
|
||||
primitive!(i32);
|
||||
primitive!(u32);
|
||||
primitive!(u64);
|
||||
primitive!(f32);
|
||||
|
||||
impl<T> EncoderParam for &[T] {
|
||||
@ -123,7 +119,6 @@ pub enum Source {
|
||||
Reduce,
|
||||
Mfa,
|
||||
Conv,
|
||||
Random,
|
||||
Quantized,
|
||||
}
|
||||
|
||||
@ -245,8 +240,6 @@ impl Kernels {
|
||||
Source::Cast => CAST,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Conv => CONV,
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
@ -909,22 +902,13 @@ pub fn call_where_cond_strided(
|
||||
shape: &[usize],
|
||||
cond: &Buffer,
|
||||
(cond_stride, cond_offset): (&[usize], usize),
|
||||
cond_is_strided: bool,
|
||||
left: &Buffer,
|
||||
(left_stride, left_offset): (&[usize], usize),
|
||||
left_is_strided: bool,
|
||||
right: &Buffer,
|
||||
(right_stride, right_offset): (&[usize], usize),
|
||||
right_is_strided: bool,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let constants = Some(ConstantValues::new(vec![
|
||||
(0, Value::Bool(cond_is_strided)),
|
||||
(1, Value::Bool(left_is_strided)),
|
||||
(2, Value::Bool(right_is_strided)),
|
||||
]));
|
||||
let pipeline =
|
||||
kernels.load_pipeline_with_constants(device, Source::Ternary, name, constants)?;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -1541,73 +1525,6 @@ pub fn call_upsample_nearest_2d(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_random_uniform(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
min: f32,
|
||||
max: f32,
|
||||
length: usize,
|
||||
seed: &Buffer,
|
||||
buffer: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
if min >= max {
|
||||
return Err(MetalKernelError::LoadLibraryError(
|
||||
"min must be less than max".to_string(),
|
||||
));
|
||||
}
|
||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let odd = (length % 2 != 0) as usize;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, min, max, seed, buffer));
|
||||
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_random_normal(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
mean: f32,
|
||||
stddev: f32,
|
||||
length: usize,
|
||||
seed: &Buffer,
|
||||
buffer: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let odd = (length % 2 != 0) as usize;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, mean, stddev, seed, buffer));
|
||||
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum GgmlDType {
|
||||
Q4_0,
|
||||
@ -1637,145 +1554,7 @@ pub fn call_quantized_matmul_t(
|
||||
rhs: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// Everything is in reverse
|
||||
let ne00 = k as i64;
|
||||
let ne01 = n as i64;
|
||||
let ne02 = b as i64;
|
||||
let ne03 = 1 as i64;
|
||||
|
||||
let nb00 = 0i64;
|
||||
let nb01 = 0 as i64;
|
||||
let nb02 = 0 as i64;
|
||||
|
||||
let ne10 = k as i64;
|
||||
let ne11 = m as i64;
|
||||
let ne12 = b as i64;
|
||||
let ne13 = 1 as i64;
|
||||
|
||||
let nb10 = 0i64;
|
||||
let nb11 = 0i64;
|
||||
let nb12 = 0i64;
|
||||
|
||||
let ne0 = n as i64;
|
||||
let ne1 = m as i64;
|
||||
let r2: u32 = (ne12 / ne02) as u32;
|
||||
let r3: u32 = (ne13 / ne03) as u32;
|
||||
|
||||
let (nth0, nth1, align) = match dtype {
|
||||
GgmlDType::Q4_0
|
||||
| GgmlDType::Q4_1
|
||||
| GgmlDType::Q5_0
|
||||
| GgmlDType::Q5_1
|
||||
| GgmlDType::Q8_0
|
||||
| GgmlDType::Q8_1 => {
|
||||
let nth0 = 8;
|
||||
let nth1 = 8;
|
||||
let align = 8;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
// Fixing a bug in Metal for GGML
|
||||
let nth0 = 4;
|
||||
let nth1 = 8;
|
||||
let align = 4;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let nth0 = 4;
|
||||
let nth1 = 8;
|
||||
let align = 4;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q3K | GgmlDType::Q5K => {
|
||||
let nth0 = 2;
|
||||
let nth1 = 32;
|
||||
let align = 4;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let nth0 = 2;
|
||||
let nth1 = 32;
|
||||
let align = 2;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::F16 | GgmlDType::Q8K => {
|
||||
// Original implem uses rows
|
||||
let nth0 = 32;
|
||||
let nth1 = 1;
|
||||
let align = 8;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
GgmlDType::F32 => {
|
||||
let nth0 = 32;
|
||||
let nth1 = 1;
|
||||
let align = 8;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
};
|
||||
let thread_groups_count = MTLSize {
|
||||
width: divide(ne01 as usize, align),
|
||||
height: ne11 as u64,
|
||||
depth: (ne12 * ne13) as u64,
|
||||
};
|
||||
let threads_per_threadgroup = MTLSize {
|
||||
width: nth0,
|
||||
height: nth1,
|
||||
depth: 1,
|
||||
};
|
||||
let name = match dtype {
|
||||
GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32",
|
||||
GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32",
|
||||
GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32",
|
||||
GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32",
|
||||
GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32",
|
||||
GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32",
|
||||
GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32",
|
||||
GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32",
|
||||
GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32",
|
||||
GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32",
|
||||
GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32",
|
||||
GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32",
|
||||
GgmlDType::F16 => "kernel_mul_mv_f16_f32",
|
||||
GgmlDType::F32 => "kernel_mul_mv_f32_f32",
|
||||
};
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
rhs,
|
||||
(lhs, lhs_offset),
|
||||
output,
|
||||
ne00,
|
||||
ne01,
|
||||
ne02,
|
||||
nb00,
|
||||
nb01,
|
||||
nb02,
|
||||
ne10,
|
||||
ne11,
|
||||
ne12,
|
||||
nb10,
|
||||
nb11,
|
||||
nb12,
|
||||
ne0,
|
||||
ne1,
|
||||
r2,
|
||||
r3
|
||||
)
|
||||
);
|
||||
encoder.set_threadgroup_memory_length(0, 8192);
|
||||
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
todo!("Not implemented yet");
|
||||
}
|
||||
|
||||
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,206 +0,0 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_integer>
|
||||
#include <metal_atomic>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// Constants
|
||||
// 2^32 and 1/2^32. Useful for converting between float and uint.
|
||||
static constexpr constant ulong UNIF01_NORM32 = 4294967296;
|
||||
static constexpr constant float UNIF01_INV32 = 2.328306436538696289e-10;
|
||||
// 2 * pi
|
||||
static constexpr constant float TWO_PI = 2.0 * M_PI_F;
|
||||
static constexpr constant int3 S1 = {13, 19, 12};
|
||||
static constexpr constant int3 S2 = {2, 25, 4};
|
||||
static constexpr constant int3 S3 = {3, 11, 17};
|
||||
|
||||
// Used to prevent bad seeds.
|
||||
static constexpr constant uint64_t PHI[16] = {
|
||||
0x9E3779B97F4A7C15,
|
||||
0xF39CC0605CEDC834,
|
||||
0x1082276BF3A27251,
|
||||
0xF86C6A11D0C18E95,
|
||||
0x2767F0B153D27B7F,
|
||||
0x0347045B5BF1827F,
|
||||
0x01886F0928403002,
|
||||
0xC1D64BA40F335E36,
|
||||
0xF06AD7AE9717877E,
|
||||
0x85839D6EFFBD7DC6,
|
||||
0x64D325D1C5371682,
|
||||
0xCADD0CCCFDFFBBE1,
|
||||
0x626E33B8D04B4331,
|
||||
0xBBF73C790D94F79D,
|
||||
0x471C4AB3ED3D82A5,
|
||||
0xFEC507705E4AE6E5,
|
||||
};
|
||||
|
||||
// Combined Tausworthe and LCG Random Number Generator.
|
||||
// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application
|
||||
// https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf
|
||||
struct HybridTaus {
|
||||
|
||||
float state;
|
||||
|
||||
HybridTaus() thread = default;
|
||||
HybridTaus() threadgroup = default;
|
||||
HybridTaus() device = default;
|
||||
HybridTaus() constant = default;
|
||||
|
||||
// Generate seeds for each thread.
|
||||
METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) {
|
||||
return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL));
|
||||
}
|
||||
|
||||
// Tausworthe generator.
|
||||
METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) {
|
||||
uint b = (((z << s.x) ^ z) >> s.y);
|
||||
return (((z & M) << s.z) ^ b);
|
||||
}
|
||||
|
||||
// LCG generator.
|
||||
METAL_FUNC static uint lcg(const uint z) {
|
||||
return (1664525 * z + 1013904223UL);
|
||||
}
|
||||
|
||||
// Initialize the RNG state.
|
||||
METAL_FUNC static HybridTaus init(const ulong4 seeds) {
|
||||
uint4 seed = seed_per_thread(seeds);
|
||||
|
||||
// Seed #1
|
||||
uint z1 = taus(seed.x, S1, 4294967294UL);
|
||||
uint z2 = taus(seed.y, S2, 4294967288UL);
|
||||
uint z3 = taus(seed.z, S3, 4294967280UL);
|
||||
uint z4 = lcg(seed.x);
|
||||
|
||||
// Seed #2
|
||||
uint r1 = (z1^z2^z3^z4^seed.y);
|
||||
z1 = taus(r1, S1, 429496729UL);
|
||||
z2 = taus(r1, S2, 4294967288UL);
|
||||
z3 = taus(r1, S3, 429496280UL);
|
||||
z4 = lcg(r1);
|
||||
|
||||
// Seed #3
|
||||
r1 = (z1^z2^z3^z4^seed.z);
|
||||
z1 = taus(r1, S1, 429496729UL);
|
||||
z2 = taus(r1, S2, 4294967288UL);
|
||||
z3 = taus(r1, S3, 429496280UL);
|
||||
z4 = lcg(r1);
|
||||
|
||||
// Seed #4
|
||||
r1 = (z1^z2^z3^z4^seed.w);
|
||||
z1 = taus(r1, S1, 429496729UL);
|
||||
z2 = taus(r1, S2, 4294967288UL);
|
||||
z3 = taus(r1, S3, 429496280UL);
|
||||
z4 = lcg(r1);
|
||||
|
||||
HybridTaus rng;
|
||||
rng.state = (z1^z2^z3^z4) * UNIF01_INV32;
|
||||
return rng;
|
||||
}
|
||||
|
||||
METAL_FUNC float rand() {
|
||||
uint seed = this->state * UNIF01_NORM32;
|
||||
uint z1 = taus(seed, S1, 429496729UL);
|
||||
uint z2 = taus(seed, S2, 4294967288UL);
|
||||
uint z3 = taus(seed, S3, 429496280UL);
|
||||
uint z4 = lcg(seed);
|
||||
|
||||
thread float result = this->state;
|
||||
this->state = (z1^z2^z3^z4) * UNIF01_INV32;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T> METAL_FUNC void rand_uniform(
|
||||
constant size_t &size,
|
||||
constant float &min,
|
||||
constant float &max,
|
||||
device atomic_uint *seed,
|
||||
device T *out,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
float diff = abs(min - max);
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
out[tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
// Return early if tid == 0, otherwise we will write to out[size].
|
||||
return;
|
||||
}
|
||||
// Use symmetry to fill the other half of the array.
|
||||
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
}
|
||||
|
||||
// Create Gaussian normal distribution using Box-Muller transform:
|
||||
// https://en.wikipedia.org/wiki/Box–Muller_transform
|
||||
template<typename T> METAL_FUNC void normal(
|
||||
constant size_t &size,
|
||||
constant float &mean,
|
||||
constant float &stddev,
|
||||
device atomic_uint *seed,
|
||||
device T *out,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= size) {
|
||||
return;
|
||||
}
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
float u1 = rng.rand();
|
||||
float u2 = rng.rand();
|
||||
|
||||
float cosval;
|
||||
float sinval = sincos(TWO_PI * u2, cosval);
|
||||
float mag = stddev * sqrt(-2.0 * log(u1));
|
||||
float z0 = mag * cosval + mean;
|
||||
float z1 = mag * sinval + mean;
|
||||
|
||||
out[tid] = static_cast<T>(z0);
|
||||
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
// Return early if tid == 0, otherwise we will write to out[size].
|
||||
return;
|
||||
}
|
||||
// Use symmetry to fill the other half of the array.
|
||||
out[size - tid] = static_cast<T>(z1);
|
||||
}
|
||||
|
||||
#define UNIFORM_OP(NAME, T) \
|
||||
kernel void rand_uniform_##NAME( \
|
||||
constant size_t &size, \
|
||||
constant float &min, \
|
||||
constant float &max, \
|
||||
device atomic_uint *seed, \
|
||||
device T *out, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
rand_uniform<T>(size, min, max, seed, out, tid); \
|
||||
} \
|
||||
|
||||
#define NORMAL_OP(NAME, T) \
|
||||
kernel void rand_normal_##NAME( \
|
||||
constant size_t &size, \
|
||||
constant float &mean, \
|
||||
constant float &stddev, \
|
||||
device atomic_uint *seed, \
|
||||
device T *out, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
normal<T>(size, mean, stddev, seed, out, tid); \
|
||||
} \
|
||||
|
||||
|
||||
#define RANDOM_OPS(NAME, T) \
|
||||
UNIFORM_OP(NAME, T) \
|
||||
NORMAL_OP(NAME, T) \
|
||||
|
||||
RANDOM_OPS(f32, float)
|
||||
RANDOM_OPS(f16, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
RANDOM_OPS(bf16, bfloat)
|
||||
#endif
|
@ -1,20 +1,14 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
#
|
||||
using namespace metal;
|
||||
|
||||
constant bool IDS_STRIDED [[function_constant(0)]];
|
||||
constant bool T_STRIDED [[function_constant(1)]];
|
||||
constant bool F_STRIDED [[function_constant(2)]];
|
||||
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant const size_t &num_dims,
|
||||
constant const size_t *dims,
|
||||
constant const size_t *strides
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
@ -23,7 +17,6 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
|
||||
template<typename T, typename ID>
|
||||
METAL_FUNC void where_cond(
|
||||
constant size_t &numel,
|
||||
@ -41,20 +34,10 @@ METAL_FUNC void where_cond(
|
||||
if (i >= numel){
|
||||
return;
|
||||
}
|
||||
uint strided_i = i;
|
||||
uint strided_i_t = i;
|
||||
uint strided_i_f = i;
|
||||
if (IDS_STRIDED) {
|
||||
strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||
}
|
||||
if (T_STRIDED) {
|
||||
strided_i_t = get_strided_index(i, num_dims, dims, strides_t);
|
||||
}
|
||||
if (F_STRIDED) {
|
||||
strided_i_f = get_strided_index(i, num_dims, dims, strides_f);
|
||||
}
|
||||
|
||||
out[i] = select(f[strided_i_t], t[strided_i_f], ids[strided_i]);
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t);
|
||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f);
|
||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f];
|
||||
}
|
||||
|
||||
#define WHERE_OP(T, ID, FN_NAME) \
|
||||
|
@ -11,7 +11,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
|
||||
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let ptr = data.as_ptr() as *const c_void;
|
||||
let ptr = data.as_ptr() as *const core::ffi::c_void;
|
||||
let size = (data.len() * std::mem::size_of::<T>()) as u64;
|
||||
device.new_buffer_with_data(ptr, size, options)
|
||||
}
|
||||
@ -713,6 +713,7 @@ fn softmax() {
|
||||
}
|
||||
let results = run_softmax(&v, last_dim, "softmax_f32");
|
||||
let results = approx(results, 4);
|
||||
println!("{results:?}");
|
||||
assert_eq!(
|
||||
results.iter().map(|&s| s.round() as usize).sum::<usize>(),
|
||||
n
|
||||
@ -803,13 +804,10 @@ fn run_where_cond<I: Clone, T: Clone>(
|
||||
shape,
|
||||
&cond,
|
||||
(&cond_stride, cond_offset),
|
||||
true,
|
||||
&left,
|
||||
(&left_stride, left_offset),
|
||||
true,
|
||||
&right,
|
||||
(&cond_stride, cond_offset),
|
||||
true,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -929,124 +927,3 @@ fn gemm() {
|
||||
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
|
||||
);
|
||||
}
|
||||
|
||||
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);
|
||||
|
||||
let seed = device.new_buffer_with_data(
|
||||
&seed as *const u32 as *const core::ffi::c_void,
|
||||
std::mem::size_of::<u32>() as NSUInteger,
|
||||
options,
|
||||
);
|
||||
|
||||
if name.starts_with("rand_uniform") {
|
||||
call_random_uniform(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
a,
|
||||
b,
|
||||
length,
|
||||
&seed,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
} else {
|
||||
call_random_normal(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
a,
|
||||
b,
|
||||
length,
|
||||
&seed,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
read_to_vec(&output, length)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn random() {
|
||||
fn calc_mean(data: &[f32]) -> f32 {
|
||||
let sum = data.iter().sum::<f32>() as f32;
|
||||
let count = data.len();
|
||||
assert!(count > 0);
|
||||
sum / count as f32
|
||||
}
|
||||
|
||||
fn calc_stddev(data: &[f32]) -> f32 {
|
||||
let mean = calc_mean(data);
|
||||
let count = data.len();
|
||||
assert!(count > 0);
|
||||
|
||||
let variance = data
|
||||
.iter()
|
||||
.map(|value| {
|
||||
let diff = mean - (*value as f32);
|
||||
diff * diff
|
||||
})
|
||||
.sum::<f32>()
|
||||
/ count as f32;
|
||||
|
||||
variance.sqrt()
|
||||
}
|
||||
|
||||
let shape = vec![1024, 10];
|
||||
|
||||
let length = shape.iter().product::<usize>();
|
||||
let seed = 299792458;
|
||||
|
||||
let min = -30.0;
|
||||
let max = 30.0;
|
||||
let mean = 100.0;
|
||||
let stddev = 50.0;
|
||||
|
||||
macro_rules! validate_random {
|
||||
($type:ty) => {
|
||||
let results: Vec<f32> = run_random::<$type>(
|
||||
concat!("rand_uniform_", stringify!($type)),
|
||||
seed,
|
||||
length,
|
||||
min,
|
||||
max,
|
||||
)
|
||||
.into_iter()
|
||||
.map(f32::from)
|
||||
.collect();
|
||||
results.iter().for_each(|v| {
|
||||
assert!(*v >= min && *v <= max);
|
||||
});
|
||||
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
|
||||
|
||||
let results: Vec<f32> = run_random::<$type>(
|
||||
concat!("rand_normal_", stringify!($type)),
|
||||
seed,
|
||||
length,
|
||||
mean,
|
||||
stddev,
|
||||
)
|
||||
.into_iter()
|
||||
.map(f32::from)
|
||||
.collect();
|
||||
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
|
||||
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
|
||||
};
|
||||
}
|
||||
|
||||
validate_random!(f32);
|
||||
validate_random!(f16);
|
||||
validate_random!(bf16);
|
||||
}
|
||||
|
@ -1,333 +0,0 @@
|
||||
//! MobileOne inference implementation based on timm and candle-repvgg
|
||||
//!
|
||||
//! See "MobileOne: An Improved One millisecond Mobile Backbone"
|
||||
//! https://arxiv.org/abs/2206.04040
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, BatchNorm, Conv2d, Conv2dConfig,
|
||||
Func, VarBuilder,
|
||||
};
|
||||
|
||||
struct StageConfig {
|
||||
blocks: usize,
|
||||
channels: usize,
|
||||
}
|
||||
|
||||
// The architecture in the paper has 6 stages. The timm implementation uses an equivalent form
|
||||
// by concatenating the 5th stage (starts with stride 1) to the previous one.
|
||||
const STAGES: [StageConfig; 5] = [
|
||||
StageConfig {
|
||||
blocks: 1,
|
||||
channels: 64,
|
||||
},
|
||||
StageConfig {
|
||||
blocks: 2,
|
||||
channels: 64,
|
||||
},
|
||||
StageConfig {
|
||||
blocks: 8,
|
||||
channels: 128,
|
||||
},
|
||||
StageConfig {
|
||||
blocks: 10,
|
||||
channels: 256,
|
||||
},
|
||||
StageConfig {
|
||||
blocks: 1,
|
||||
channels: 512,
|
||||
},
|
||||
];
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
/// overparameterization factor
|
||||
k: usize,
|
||||
/// per-stage channel number multipliers
|
||||
alphas: [f32; 5],
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn s0() -> Self {
|
||||
Self {
|
||||
k: 4,
|
||||
alphas: [0.75, 0.75, 1.0, 1.0, 2.0],
|
||||
}
|
||||
}
|
||||
pub fn s1() -> Self {
|
||||
Self {
|
||||
k: 1,
|
||||
alphas: [1.5, 1.5, 1.5, 2.0, 2.5],
|
||||
}
|
||||
}
|
||||
pub fn s2() -> Self {
|
||||
Self {
|
||||
k: 1,
|
||||
alphas: [1.5, 1.5, 2.0, 2.5, 4.0],
|
||||
}
|
||||
}
|
||||
pub fn s3() -> Self {
|
||||
Self {
|
||||
k: 1,
|
||||
alphas: [2.0, 2.0, 2.5, 3.0, 4.0],
|
||||
}
|
||||
}
|
||||
pub fn s4() -> Self {
|
||||
Self {
|
||||
k: 1,
|
||||
alphas: [3.0, 3.0, 3.5, 3.5, 4.0],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SE blocks are used in the last stages of the s4 variant.
|
||||
fn squeeze_and_excitation(
|
||||
in_channels: usize,
|
||||
squeeze_channels: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
..Default::default()
|
||||
};
|
||||
let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
|
||||
let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs;
|
||||
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
|
||||
let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
|
||||
|
||||
residual.broadcast_mul(&xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
|
||||
// based on the _fuse_bn_tensor method in timm
|
||||
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
|
||||
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
|
||||
let (gamma, beta) = bn.weight_and_bias().unwrap();
|
||||
let mu = bn.running_mean();
|
||||
let sigma = (bn.running_var() + bn.eps())?.sqrt();
|
||||
let gps = (gamma / sigma)?;
|
||||
let bias = (beta - mu * &gps)?;
|
||||
let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
|
||||
|
||||
Ok((weights, bias))
|
||||
}
|
||||
|
||||
// A mobileone block has a different training time and inference time architecture.
|
||||
// The latter is a simple and efficient equivalent transformation of the former
|
||||
// realized by a structural reparameterization technique, where convolutions
|
||||
// along with identity branches and batchnorm layers are fused into a single convolution.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn mobileone_block(
|
||||
has_identity: bool,
|
||||
k: usize,
|
||||
dim: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
groups: usize,
|
||||
kernel: usize,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride,
|
||||
padding,
|
||||
groups,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut w = Tensor::zeros(
|
||||
(out_channels, in_channels / groups, kernel, kernel),
|
||||
DType::F32,
|
||||
vb.device(),
|
||||
)?;
|
||||
let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
|
||||
|
||||
// k is the training-time overparameterization factor, larger than 1 only in the s0 variant
|
||||
for i in 0..k {
|
||||
let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp(format!("conv_kxk.{i}.bn")))?;
|
||||
let conv_kxk = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel,
|
||||
conv2d_cfg,
|
||||
vb.pp(format!("conv_kxk.{i}.conv")),
|
||||
)?;
|
||||
let (wk, bk) = fuse_conv_bn(conv_kxk.weight(), conv_kxk_bn)?;
|
||||
w = (w + wk)?;
|
||||
b = (b + bk)?;
|
||||
}
|
||||
|
||||
if kernel > 1 {
|
||||
let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"))?;
|
||||
let conv_scale = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv2d_cfg,
|
||||
vb.pp("conv_scale.conv"),
|
||||
)?;
|
||||
|
||||
let (mut ws, bs) = fuse_conv_bn(conv_scale.weight(), conv_scale_bn)?;
|
||||
// resize to 3x3
|
||||
ws = ws.pad_with_zeros(D::Minus1, 1, 1)?;
|
||||
ws = ws.pad_with_zeros(D::Minus2, 1, 1)?;
|
||||
|
||||
w = (w + ws)?;
|
||||
b = (b + bs)?;
|
||||
}
|
||||
|
||||
// Use SE blocks if present (last layers of the s4 variant)
|
||||
let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("attn"));
|
||||
|
||||
// read and reparameterize the identity bn into wi and bi
|
||||
if has_identity {
|
||||
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
|
||||
|
||||
let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
|
||||
|
||||
let id = in_channels / groups;
|
||||
// See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809
|
||||
for i in 0..in_channels {
|
||||
if kernel > 1 {
|
||||
weights[i * kernel * kernel + 4] = 1.0;
|
||||
} else {
|
||||
weights[i * (id + 1)] = 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
|
||||
let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
|
||||
|
||||
w = (w + wi)?;
|
||||
b = (b + bi)?;
|
||||
}
|
||||
|
||||
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.apply(&reparam_conv)?;
|
||||
if let Ok(f) = &se {
|
||||
xs = xs.apply(f)?;
|
||||
}
|
||||
xs = xs.relu()?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Get the number of output channels per stage taking into account the multipliers
|
||||
fn output_channels_per_stage(cfg: &Config, stage: usize) -> usize {
|
||||
let channels = STAGES[stage].channels as f32;
|
||||
let alpha = cfg.alphas[stage];
|
||||
|
||||
match stage {
|
||||
0 => std::cmp::min(64, (channels * alpha) as usize),
|
||||
_ => (channels * alpha) as usize,
|
||||
}
|
||||
}
|
||||
|
||||
// Each stage is made of blocks. The first layer always downsamples with stride 2.
|
||||
// All but the first block have a residual connection.
|
||||
fn mobileone_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let nblocks = STAGES[idx].blocks;
|
||||
let mut blocks = Vec::with_capacity(nblocks);
|
||||
|
||||
let mut in_channels = output_channels_per_stage(cfg, idx - 1);
|
||||
|
||||
for block_idx in 0..nblocks {
|
||||
let out_channels = output_channels_per_stage(cfg, idx);
|
||||
let (has_identity, stride) = if block_idx == 0 {
|
||||
(false, 2)
|
||||
} else {
|
||||
(true, 1)
|
||||
};
|
||||
|
||||
// depthwise convolution layer
|
||||
blocks.push(mobileone_block(
|
||||
has_identity,
|
||||
cfg.k,
|
||||
in_channels,
|
||||
stride,
|
||||
1,
|
||||
in_channels,
|
||||
3,
|
||||
in_channels,
|
||||
in_channels,
|
||||
vb.pp(block_idx * 2),
|
||||
)?);
|
||||
|
||||
// pointwise convolution layer
|
||||
blocks.push(mobileone_block(
|
||||
has_identity,
|
||||
cfg.k,
|
||||
out_channels,
|
||||
1, // stride
|
||||
0, // padding
|
||||
1, // groups
|
||||
1, // kernel
|
||||
in_channels,
|
||||
out_channels,
|
||||
vb.pp(block_idx * 2 + 1),
|
||||
)?);
|
||||
|
||||
in_channels = out_channels;
|
||||
}
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
for block in blocks.iter() {
|
||||
xs = xs.apply(block)?
|
||||
}
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Build a mobileone model for a given configuration.
|
||||
fn mobileone_model(
|
||||
config: &Config,
|
||||
nclasses: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let cls = match nclasses {
|
||||
None => None,
|
||||
Some(nclasses) => {
|
||||
let outputs = output_channels_per_stage(config, 4);
|
||||
let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
|
||||
Some(linear)
|
||||
}
|
||||
};
|
||||
|
||||
let stem_dim = output_channels_per_stage(config, 0);
|
||||
let stem = mobileone_block(false, 1, stem_dim, 2, 1, 1, 3, 3, stem_dim, vb.pp("stem"))?;
|
||||
let vb = vb.pp("stages");
|
||||
let stage1 = mobileone_stage(config, 1, vb.pp(0))?;
|
||||
let stage2 = mobileone_stage(config, 2, vb.pp(1))?;
|
||||
let stage3 = mobileone_stage(config, 3, vb.pp(2))?;
|
||||
let stage4 = mobileone_stage(config, 4, vb.pp(3))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&stem)?
|
||||
.apply(&stage1)?
|
||||
.apply(&stage2)?
|
||||
.apply(&stage3)?
|
||||
.apply(&stage4)?
|
||||
.mean(D::Minus2)?
|
||||
.mean(D::Minus1)?;
|
||||
match &cls {
|
||||
None => Ok(xs),
|
||||
Some(cls) => xs.apply(cls),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn mobileone(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
mobileone_model(cfg, Some(nclasses), vb)
|
||||
}
|
||||
|
||||
pub fn mobileone_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
mobileone_model(cfg, None, vb)
|
||||
}
|
@ -15,7 +15,6 @@ pub mod marian;
|
||||
pub mod mistral;
|
||||
pub mod mixformer;
|
||||
pub mod mixtral;
|
||||
pub mod mobileone;
|
||||
pub mod mpt;
|
||||
pub mod persimmon;
|
||||
pub mod phi;
|
||||
|
Reference in New Issue
Block a user