mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Implement the module trait directly for QMatMul. (#1372)
This commit is contained in:
@ -8,11 +8,10 @@ use anyhow::Result;
|
|||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
|
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
|
||||||
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
|
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
|
||||||
let start = std::time::Instant::now();
|
let new_a = a.slice_scatter(&b, 1, 2)?;
|
||||||
let res = inp.conv2d(&w, 0, 1, 1, 1)?;
|
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||||
println!("{:?}", start.elapsed());
|
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||||
println!("{res:?}");
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -123,12 +123,6 @@ pub trait Module {
|
|||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for quantized::QMatMul {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
self.forward(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
self(xs)
|
self(xs)
|
||||||
|
@ -307,8 +307,8 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl QMatMul {
|
impl crate::Module for QMatMul {
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
|
||||||
Self::Tensor(w) => {
|
Self::Tensor(w) => {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use candle_core::{
|
use candle_core::{
|
||||||
quantized::{self, GgmlDType},
|
quantized::{self, GgmlDType},
|
||||||
test_utils::to_vec2_round,
|
test_utils::to_vec2_round,
|
||||||
Device, Result, Tensor,
|
Device, Module, Result, Tensor,
|
||||||
};
|
};
|
||||||
use quantized::{k_quants, GgmlType};
|
use quantized::{k_quants, GgmlType};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
@ -6,7 +6,7 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle::quantized::GgmlType;
|
use candle::quantized::GgmlType;
|
||||||
use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D};
|
use candle::{CpuStorage, Device, Layout, Module, Result, Shape, Tensor, D};
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
const CHECK_CONV2D: bool = false;
|
const CHECK_CONV2D: bool = false;
|
||||||
|
@ -17,7 +17,7 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
use ::candle::{quantized::QTensor, DType, Device, Module, Tensor, WithDType};
|
||||||
|
|
||||||
mod utils;
|
mod utils;
|
||||||
use utils::wrap_err;
|
use utils::wrap_err;
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use candle::{
|
use candle::{
|
||||||
quantized::{self, k_quants, GgmlDType, GgmlType},
|
quantized::{self, k_quants, GgmlDType, GgmlType},
|
||||||
test_utils::to_vec2_round,
|
test_utils::to_vec2_round,
|
||||||
Device, Result, Tensor,
|
Device, Module, Result, Tensor,
|
||||||
};
|
};
|
||||||
|
|
||||||
use wasm_bindgen_test::*;
|
use wasm_bindgen_test::*;
|
||||||
|
Reference in New Issue
Block a user