mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add a simple Module trait and implement it for the various nn layers (#500)
* Start adding the module trait. * Use the module trait. * Implement module for qmatmul.
This commit is contained in:
@ -184,6 +184,7 @@ impl QTensor {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct QMatMul(QTensor);
|
||||
|
||||
impl QMatMul {
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{Embedding, VarBuilder};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
|
@ -1,6 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, VarBuilder};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
@ -1,6 +1,6 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::linear_no_bias as linear;
|
||||
use candle_nn::{embedding, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
||||
use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
||||
use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||
use half::f16;
|
||||
use std::rc::Rc;
|
||||
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{loss, ops, Linear, VarBuilder, VarMap};
|
||||
use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap};
|
||||
|
||||
const IMAGE_DIM: usize = 784;
|
||||
const LABELS: usize = 10;
|
||||
|
@ -1,6 +1,7 @@
|
||||
use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder};
|
||||
use anyhow::Result;
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::Module;
|
||||
|
||||
// Encodec Model
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||
|
@ -4,6 +4,7 @@ use crate::nn::{
|
||||
use crate::{encodec_model, t5_model};
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
use candle_nn::Module;
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
|
@ -4,6 +4,7 @@
|
||||
use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Tensor, D};
|
||||
use candle_nn::Module;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
|
@ -7,7 +7,7 @@ use tokenizers::Tokenizer;
|
||||
use candle::quantized::ggml_file::Content;
|
||||
use candle::quantized::QTensor;
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::Embedding;
|
||||
use candle_nn::{Embedding, Module};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
|
@ -1,6 +1,7 @@
|
||||
//! Attention Based Building Blocks
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
use candle_nn::Module;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct GeGlu {
|
||||
|
@ -7,6 +7,7 @@
|
||||
//! https://github.com/openai/CLIP
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
use candle_nn::Module;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Activation {
|
||||
|
@ -1,6 +1,7 @@
|
||||
#![allow(dead_code)]
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
use candle_nn::Module;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TimestepEmbedding {
|
||||
|
@ -8,6 +8,7 @@
|
||||
use crate::utils::{conv2d, Conv2d};
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
use candle_nn::Module;
|
||||
|
||||
/// Configuration for a ResNet block.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
|
@ -7,6 +7,7 @@ use crate::unet_2d_blocks::*;
|
||||
use crate::utils::{conv2d, Conv2d};
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn as nn;
|
||||
use candle_nn::Module;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct BlockConfig {
|
||||
|
@ -1,4 +1,5 @@
|
||||
use candle::{Device, Result, Tensor};
|
||||
use candle_nn::Module;
|
||||
|
||||
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
|
||||
if steps < 1 {
|
||||
|
@ -10,6 +10,7 @@ use crate::unet_2d_blocks::{
|
||||
};
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn as nn;
|
||||
use candle_nn::Module;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct EncoderConfig {
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
|
||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
// The names in comments correspond to the original implementation:
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{linear, AdamW, Linear, ParamsAdamW, VarBuilder, VarMap};
|
||||
use candle_nn::{linear, AdamW, Linear, Module, ParamsAdamW, VarBuilder, VarMap};
|
||||
|
||||
fn gen_data() -> Result<(Tensor, Tensor)> {
|
||||
// Generate some sample linear data.
|
||||
|
@ -7,8 +7,8 @@ pub enum Activation {
|
||||
Elu(f64),
|
||||
}
|
||||
|
||||
impl Activation {
|
||||
pub fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
impl super::Module for Activation {
|
||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Gelu => xs.gelu(),
|
||||
Self::Relu => xs.relu(),
|
||||
|
@ -35,8 +35,10 @@ impl Conv1d {
|
||||
pub fn config(&self) -> &Conv1dConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
impl crate::Module for Conv1d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
@ -84,8 +86,10 @@ impl Conv2d {
|
||||
pub fn config(&self) -> &Conv2dConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
impl crate::Module for Conv2d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
|
@ -18,8 +18,10 @@ impl Embedding {
|
||||
pub fn embeddings(&self) -> &Tensor {
|
||||
&self.embeddings
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
impl crate::Module for Embedding {
|
||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
let mut final_dims = indexes.dims().to_vec();
|
||||
final_dims.push(self.hidden_size);
|
||||
let indexes = indexes.flatten_all()?;
|
||||
|
@ -34,8 +34,10 @@ impl GroupNorm {
|
||||
num_groups,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
impl crate::Module for GroupNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_shape = x.dims();
|
||||
if x_shape.len() <= 2 {
|
||||
candle::bail!("input rank for GroupNorm should be at least 3");
|
||||
|
@ -8,7 +8,7 @@
|
||||
//!
|
||||
//! ```rust
|
||||
//! use candle::{Tensor, Device::Cpu};
|
||||
//! use candle_nn::LayerNorm;
|
||||
//! use candle_nn::{LayerNorm, Module};
|
||||
//! # fn main() -> candle::Result<()> {
|
||||
//!
|
||||
//! let w = Tensor::new(1f32, &Cpu)?;
|
||||
@ -95,8 +95,10 @@ impl LayerNorm {
|
||||
eps,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
impl crate::Module for LayerNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
@ -152,8 +154,10 @@ impl RmsNorm {
|
||||
pub fn into_inner(self) -> LayerNorm {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
impl crate::Module for RmsNorm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.0.forward(xs)
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
// For now this crate shares its error type with candle-core. We may introduce some separate
|
||||
// error type if needed or add some specialized cases on the candle-core side.
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
pub mod activation;
|
||||
pub mod conv;
|
||||
pub mod embedding;
|
||||
@ -21,3 +21,20 @@ pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
||||
pub use linear::{linear, linear_no_bias, Linear};
|
||||
pub use optim::{AdamW, ParamsAdamW, SGD};
|
||||
pub use var_builder::{VarBuilder, VarMap};
|
||||
|
||||
// A simple trait defining a module with forward method using a single argument.
|
||||
pub trait Module: std::fmt::Debug {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||
|
||||
/// Change the module to use training mode vs eval mode.
|
||||
///
|
||||
/// The default implementation does nothing as this is only used for a couple modules such as
|
||||
/// dropout or batch-normalization.
|
||||
fn set_training(&mut self, _training: bool) {}
|
||||
}
|
||||
|
||||
impl Module for candle::quantized::QMatMul {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,7 @@
|
||||
//!
|
||||
//! ```rust
|
||||
//! use candle::{Tensor, Device::Cpu};
|
||||
//! use candle_nn::Linear;
|
||||
//! use candle_nn::{Linear, Module};
|
||||
//! # fn main() -> candle::Result<()> {
|
||||
//!
|
||||
//! let w = Tensor::new(&[[1f32, 2.], [3., 4.], [5., 6.]], &Cpu)?;
|
||||
@ -29,8 +29,10 @@ impl Linear {
|
||||
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
||||
Self { weight, bias }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
impl super::Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let w = match x.dims() {
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
_ => self.weight.t()?,
|
||||
|
@ -23,7 +23,7 @@ extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::GroupNorm;
|
||||
use candle_nn::{GroupNorm, Module};
|
||||
mod test_utils;
|
||||
use test_utils::to_vec3_round;
|
||||
|
||||
|
@ -3,7 +3,7 @@ extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::LayerNorm;
|
||||
use candle_nn::{LayerNorm, Module};
|
||||
|
||||
#[test]
|
||||
fn layer_norm() -> Result<()> {
|
||||
|
@ -6,7 +6,7 @@ use test_utils::{to_vec0_round, to_vec2_round};
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor, Var};
|
||||
use candle_nn::{AdamW, Linear, ParamsAdamW, SGD};
|
||||
use candle_nn::{AdamW, Linear, Module, ParamsAdamW, SGD};
|
||||
|
||||
#[test]
|
||||
fn sgd_optim() -> Result<()> {
|
||||
|
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
||||
use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
// back when using RUST_LIB_BACKTRACE=1.
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
// The names in comments correspond to the original implementation:
|
||||
|
Reference in New Issue
Block a user