Add a simple Module trait and implement it for the various nn layers (#500)

* Start adding the module trait.

* Use the module trait.

* Implement module for qmatmul.
This commit is contained in:
Laurent Mazare
2023-08-18 09:38:22 +01:00
committed by GitHub
parent 13401df4d1
commit c78ce76501
33 changed files with 70 additions and 28 deletions

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, VarBuilder};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
pub const DTYPE: DType = DType::F32;

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;

View File

@ -1,6 +1,6 @@
use anyhow::Result;
use candle::{DType, Device, Tensor, D};
use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
const MAX_SEQ_LEN: usize = 5000;

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, VarBuilder};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

View File

@ -1,6 +1,6 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::linear_no_bias as linear;
use candle_nn::{embedding, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

View File

@ -1,6 +1,6 @@
use candle::backend::BackendStorage;
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
use candle_nn::{rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use std::rc::Rc;

View File

@ -5,7 +5,7 @@ extern crate intel_mkl_src;
use clap::{Parser, ValueEnum};
use candle::{DType, Result, Tensor, D};
use candle_nn::{loss, ops, Linear, VarBuilder, VarMap};
use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap};
const IMAGE_DIM: usize = 784;
const LABELS: usize = 10;

View File

@ -1,6 +1,7 @@
use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder};
use anyhow::Result;
use candle::{DType, IndexOp, Tensor};
use candle_nn::Module;
// Encodec Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py

View File

@ -4,6 +4,7 @@ use crate::nn::{
use crate::{encodec_model, t5_model};
use anyhow::Result;
use candle::{DType, Device, Tensor, D};
use candle_nn::Module;
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)]

View File

@ -4,6 +4,7 @@
use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
use anyhow::Result;
use candle::{DType, Tensor, D};
use candle_nn::Module;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]

View File

@ -7,7 +7,7 @@ use tokenizers::Tokenizer;
use candle::quantized::ggml_file::Content;
use candle::quantized::QTensor;
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::Embedding;
use candle_nn::{Embedding, Module};
use candle_transformers::generation::LogitsProcessor;
const MAX_SEQ_LEN: usize = 4096;

View File

@ -1,6 +1,7 @@
//! Attention Based Building Blocks
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
#[derive(Debug)]
struct GeGlu {

View File

@ -7,6 +7,7 @@
//! https://github.com/openai/CLIP
use candle::{DType, Device, Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
#[derive(Debug, Clone, Copy)]
pub enum Activation {

View File

@ -1,6 +1,7 @@
#![allow(dead_code)]
use candle::{Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
#[derive(Debug)]
pub struct TimestepEmbedding {

View File

@ -8,6 +8,7 @@
use crate::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
/// Configuration for a ResNet block.
#[derive(Debug, Clone, Copy)]

View File

@ -7,6 +7,7 @@ use crate::unet_2d_blocks::*;
use crate::utils::{conv2d, Conv2d};
use candle::{Result, Tensor};
use candle_nn as nn;
use candle_nn::Module;
#[derive(Debug, Clone, Copy)]
pub struct BlockConfig {

View File

@ -1,4 +1,5 @@
use candle::{Device, Result, Tensor};
use candle_nn::Module;
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
if steps < 1 {

View File

@ -10,6 +10,7 @@ use crate::unet_2d_blocks::{
};
use candle::{Result, Tensor};
use candle_nn as nn;
use candle_nn::Module;
#[derive(Debug, Clone)]
struct EncoderConfig {

View File

@ -1,5 +1,5 @@
use candle::{Device, IndexOp, Result, Tensor};
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation: