feat: add network policy

This commit is contained in:
2025-02-28 12:40:45 +00:00
parent 29488d75a3
commit 5e2088f09c
21 changed files with 700 additions and 155 deletions

3
Cargo.lock generated
View File

@ -1841,8 +1841,7 @@ checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]] [[package]]
name = "ipnetwork" name = "ipnetwork"
version = "0.21.1" version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/v0l/ipnetwork.git?rev=51b4816358f255ecbcff90b739fbbd4b1cbc2d6a#51b4816358f255ecbcff90b739fbbd4b1cbc2d6a"
checksum = "cf370abdafd54d13e54a620e8c3e1145f28e46cc9d704bc6d94414559df41763"
[[package]] [[package]]
name = "is-terminal" name = "is-terminal"

View File

@ -29,7 +29,7 @@ nostr = { version = "0.39.0", default-features = false, features = ["std"] }
base64 = { version = "0.22.1", features = ["alloc"] } base64 = { version = "0.22.1", features = ["alloc"] }
urlencoding = "2.1.3" urlencoding = "2.1.3"
fedimint-tonic-lnd = { version = "0.2.0", default-features = false, features = ["invoicesrpc"] } fedimint-tonic-lnd = { version = "0.2.0", default-features = false, features = ["invoicesrpc"] }
ipnetwork = "0.21.1" ipnetwork = { git = "https://github.com/v0l/ipnetwork.git", rev = "51b4816358f255ecbcff90b739fbbd4b1cbc2d6a" }
rand = "0.9.0" rand = "0.9.0"
clap = { version = "4.5.21", features = ["derive"] } clap = { version = "4.5.21", features = ["derive"] }
ssh2 = "0.9.4" ssh2 = "0.9.4"

View File

@ -21,19 +21,19 @@ lnd:
macaroon: "$HOME/.lnd/data/chain/bitcoin/mainnet/admin.macaroon" macaroon: "$HOME/.lnd/data/chain/bitcoin/mainnet/admin.macaroon"
# Number of days after a VM expires to delete # Number of days after a VM expires to delete
delete_after: 3 delete-after: 3
# Provisioner is the main process which handles creating/deleting VM's # Provisioner is the main process which handles creating/deleting VM's
# Currently supports: Proxmox # Currently supports: Proxmox
provisioner: provisioner:
proxmox: proxmox:
# Read-only mode prevents spawning VM's # Read-only mode prevents spawning VM's
read_only: false read-only: false
# Proxmox (QEMU) settings used for spawning VM's # Proxmox (QEMU) settings used for spawning VM's
qemu: qemu:
bios: "ovmf" bios: "ovmf"
machine: "q35" machine: "q35"
os_type: "l26" os-type: "l26"
bridge: "vmbr0" bridge: "vmbr0"
cpu: "kvm64" cpu: "kvm64"
vlan: 100 vlan: 100
@ -89,6 +89,12 @@ router:
url: "https://my-router.net" url: "https://my-router.net"
username: "admin" username: "admin"
password: "admin" password: "admin"
# Interface where the static ARP entry is added network-policy:
arp_interface: "bridge1" # How packets get to the VM
# (default "auto", nothing to do, packets will always arrive)
access:
# Static ARP entries are added to the router for each provisioned IP
static-arp:
# Interface where the static ARP entry is added
interface: "bridge1"
``` ```

View File

@ -3,14 +3,14 @@ lnd:
url: "https://127.0.0.1:10003" url: "https://127.0.0.1:10003"
cert: "/home/kieran/.polar/networks/2/volumes/lnd/alice/tls.cert" cert: "/home/kieran/.polar/networks/2/volumes/lnd/alice/tls.cert"
macaroon: "/home/kieran/.polar/networks/2/volumes/lnd/alice/data/chain/bitcoin/regtest/admin.macaroon" macaroon: "/home/kieran/.polar/networks/2/volumes/lnd/alice/data/chain/bitcoin/regtest/admin.macaroon"
delete_after: 3 delete-after: 3
provisioner: provisioner:
proxmox: proxmox:
read_only: false read-only: false
qemu: qemu:
bios: "ovmf" bios: "ovmf"
machine: "q35" machine: "q35"
os_type: "l26" os-type: "l26"
bridge: "vmbr0" bridge: "vmbr0"
cpu: "host" cpu: "host"
vlan: 100 vlan: 100

View File

@ -1,6 +1,4 @@
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait;
mod model; mod model;
#[cfg(feature = "mysql")] #[cfg(feature = "mysql")]
mod mysql; mod mysql;
@ -9,6 +7,8 @@ pub use model::*;
#[cfg(feature = "mysql")] #[cfg(feature = "mysql")]
pub use mysql::*; pub use mysql::*;
pub use async_trait::async_trait;
#[async_trait] #[async_trait]
pub trait LNVpsDb: Sync + Send { pub trait LNVpsDb: Sync + Send {
/// Migrate database /// Migrate database
@ -65,6 +65,9 @@ pub trait LNVpsDb: Sync + Send {
/// List available IP Ranges /// List available IP Ranges
async fn list_ip_range(&self) -> Result<Vec<IpRange>>; async fn list_ip_range(&self) -> Result<Vec<IpRange>>;
/// List available IP Ranges in a given region
async fn list_ip_range_in_region(&self, region_id: u64) -> Result<Vec<IpRange>>;
/// Get a VM cost plan by id /// Get a VM cost plan by id
async fn get_cost_plan(&self, id: u64) -> Result<VmCostPlan>; async fn get_cost_plan(&self, id: u64) -> Result<VmCostPlan>;

View File

@ -220,6 +220,7 @@ pub struct VmIpAssignment {
pub vm_id: u64, pub vm_id: u64,
pub ip_range_id: u64, pub ip_range_id: u64,
pub ip: String, pub ip: String,
pub deleted: bool,
} }
impl Display for VmIpAssignment { impl Display for VmIpAssignment {

View File

@ -171,7 +171,15 @@ impl LNVpsDb for LNVpsDbMysql {
} }
async fn list_ip_range(&self) -> Result<Vec<IpRange>> { async fn list_ip_range(&self) -> Result<Vec<IpRange>> {
sqlx::query_as("select * from ip_range") sqlx::query_as("select * from ip_range where enabled = 1")
.fetch_all(&self.db)
.await
.map_err(Error::new)
}
async fn list_ip_range_in_region(&self, region_id: u64) -> Result<Vec<IpRange>> {
sqlx::query_as("select * from ip_range where region_id = ? and enabled = 1")
.bind(region_id)
.fetch_all(&self.db) .fetch_all(&self.db)
.await .await
.map_err(Error::new) .map_err(Error::new)

View File

@ -1,4 +1,4 @@
mod routes;
mod model; mod model;
mod routes;
pub use routes::routes; pub use routes::routes;

View File

@ -1,9 +1,12 @@
use nostr::util::hex;
use crate::status::VmState; use crate::status::VmState;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use ipnetwork::IpNetwork;
use lnvps_db::VmHostRegion; use lnvps_db::VmHostRegion;
use nostr::util::hex;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::net::IpAddr;
use std::str::FromStr;
#[derive(Serialize, Deserialize, JsonSchema)] #[derive(Serialize, Deserialize, JsonSchema)]
pub struct ApiVmStatus { pub struct ApiVmStatus {
@ -48,15 +51,20 @@ impl From<lnvps_db::UserSshKey> for ApiUserSshKey {
pub struct ApiVmIpAssignment { pub struct ApiVmIpAssignment {
pub id: u64, pub id: u64,
pub ip: String, pub ip: String,
pub range: String, pub gateway: String,
} }
impl ApiVmIpAssignment { impl ApiVmIpAssignment {
pub fn from(ip: lnvps_db::VmIpAssignment, range: &lnvps_db::IpRange) -> Self { pub fn from(ip: &lnvps_db::VmIpAssignment, range: &lnvps_db::IpRange) -> Self {
ApiVmIpAssignment { ApiVmIpAssignment {
id: ip.id, id: ip.id,
ip: ip.ip, ip: IpNetwork::new(
range: range.cidr.clone(), IpNetwork::from_str(&ip.ip).unwrap().ip(),
IpNetwork::from_str(&range.cidr).unwrap().prefix(),
)
.unwrap()
.to_string(),
gateway: range.gateway.to_string(),
} }
} }
} }

View File

@ -155,7 +155,7 @@ async fn vm_to_status(
let range = ip_ranges let range = ip_ranges
.get(&i.ip_range_id) .get(&i.ip_range_id)
.expect("ip range id not found"); .expect("ip range id not found");
ApiVmIpAssignment::from(i, range) ApiVmIpAssignment::from(&i, range)
}) })
.collect(), .collect(),
}) })
@ -457,6 +457,7 @@ async fn v1_get_payment(
} else { } else {
return ApiData::err("Invalid payment id"); return ApiData::err("Invalid payment id");
}; };
let payment = db.get_vm_payment(&id).await?; let payment = db.get_vm_payment(&id).await?;
let vm = db.get_vm(payment.vm_id).await?; let vm = db.get_vm(payment.vm_id).await?;
if vm.user_id != uid { if vm.user_id != uid {

View File

@ -64,12 +64,10 @@ async fn main() -> Result<(), Error> {
} else { } else {
None None
}; };
let router = settings.router.as_ref().map(|r| r.get_router()); let router = settings.get_router()?;
let status = VmStateCache::new(); let status = VmStateCache::new();
let worker_provisioner = let worker_provisioner =
settings settings.get_provisioner(db.clone(), router, lnd.clone(), exchange.clone());
.provisioner
.get_provisioner(db.clone(), router, lnd.clone(), exchange.clone());
worker_provisioner.init().await?; worker_provisioner.init().await?;
let mut worker = Worker::new( let mut worker = Worker::new(
@ -82,11 +80,11 @@ async fn main() -> Result<(), Error> {
let sender = worker.sender(); let sender = worker.sender();
// send a startup notification // send a startup notification
if let Some(admin) = &settings.smtp.and_then(|s| s.admin) { if let Some(admin) = settings.smtp.as_ref().and_then(|s| s.admin) {
sender.send(WorkJob::SendNotification { sender.send(WorkJob::SendNotification {
title: Some("Startup".to_string()), title: Some("Startup".to_string()),
message: "System is starting!".to_string(), message: "System is starting!".to_string(),
user_id: *admin, user_id: admin,
})?; })?;
} }
@ -132,11 +130,8 @@ async fn main() -> Result<(), Error> {
} }
}); });
let router = settings.router.as_ref().map(|r| r.get_router()); let router = settings.get_router()?;
let provisioner = let provisioner = settings.get_provisioner(db.clone(), router, lnd.clone(), exchange.clone());
settings
.provisioner
.get_provisioner(db.clone(), router, lnd.clone(), exchange.clone());
let db: Box<dyn LNVpsDb> = Box::new(db.clone()); let db: Box<dyn LNVpsDb> = Box::new(db.clone());
let pv: Box<dyn Provisioner> = Box::new(provisioner); let pv: Box<dyn Provisioner> = Box::new(provisioner);

View File

@ -10,3 +10,6 @@ pub mod settings;
pub mod ssh_client; pub mod ssh_client;
pub mod status; pub mod status;
pub mod worker; pub mod worker;
#[cfg(test)]
pub mod mocks;

337
src/mocks.rs Normal file
View File

@ -0,0 +1,337 @@
use crate::router::{ArpEntry, Router};
use crate::settings::NetworkPolicy;
use anyhow::anyhow;
use chrono::Utc;
use lnvps_db::{
async_trait, IpRange, LNVpsDb, User, UserSshKey, Vm, VmCostPlan, VmHost, VmHostDisk,
VmHostKind, VmHostRegion, VmIpAssignment, VmOsImage, VmPayment, VmTemplate,
};
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Clone)]
pub struct MockDb {
pub regions: Arc<Mutex<HashMap<u64, VmHostRegion>>>,
pub hosts: Arc<Mutex<HashMap<u64, VmHost>>>,
pub users: Arc<Mutex<HashMap<u64, User>>>,
pub vms: Arc<Mutex<HashMap<u64, Vm>>>,
pub ip_range: Arc<Mutex<HashMap<u64, IpRange>>>,
pub ip_assignments: Arc<Mutex<HashMap<u64, VmIpAssignment>>>,
}
impl MockDb {
pub fn empty() -> MockDb {
Self {
..Default::default()
}
}
}
impl Default for MockDb {
fn default() -> Self {
let mut regions = HashMap::new();
regions.insert(
1,
VmHostRegion {
id: 1,
name: "Mock".to_string(),
enabled: true,
},
);
let mut ip_ranges = HashMap::new();
ip_ranges.insert(
1,
IpRange {
id: 1,
cidr: "10.0.0.0/8".to_string(),
gateway: "10.0.0.1".to_string(),
enabled: true,
region_id: 1,
},
);
let mut hosts = HashMap::new();
hosts.insert(
1,
VmHost {
id: 1,
kind: VmHostKind::Proxmox,
region_id: 1,
name: "mock-host".to_string(),
ip: "https://localhost".to_string(),
cpu: 4,
memory: 8192,
enabled: true,
api_token: "".to_string(),
},
);
Self {
regions: Arc::new(Mutex::new(regions)),
ip_range: Arc::new(Mutex::new(ip_ranges)),
hosts: Arc::new(Mutex::new(hosts)),
users: Arc::new(Default::default()),
vms: Arc::new(Default::default()),
ip_assignments: Arc::new(Default::default()),
}
}
}
#[async_trait]
impl LNVpsDb for MockDb {
async fn migrate(&self) -> anyhow::Result<()> {
Ok(())
}
async fn upsert_user(&self, pubkey: &[u8; 32]) -> anyhow::Result<u64> {
let mut users = self.users.lock().await;
if let Some(e) = users.iter().find(|(k, u)| u.pubkey == *pubkey) {
Ok(*e.0)
} else {
let max = *users.keys().max().unwrap_or(&0);
users.insert(
max + 1,
User {
id: max + 1,
pubkey: pubkey.to_vec(),
created: Utc::now(),
email: None,
contact_nip4: false,
contact_nip17: false,
contact_email: false,
},
);
Ok(max + 1)
}
}
async fn get_user(&self, id: u64) -> anyhow::Result<User> {
let mut users = self.users.lock().await;
Ok(users.get(&id).ok_or(anyhow!("no user"))?.clone())
}
async fn update_user(&self, user: &User) -> anyhow::Result<()> {
let mut users = self.users.lock().await;
if let Some(u) = users.get_mut(&user.id) {
u.email = user.email.clone();
u.contact_email = user.contact_email.clone();
u.contact_nip17 = user.contact_nip17.clone();
u.contact_nip4 = user.contact_nip4.clone();
}
Ok(())
}
async fn delete_user(&self, id: u64) -> anyhow::Result<()> {
let mut users = self.users.lock().await;
users.remove(&id);
Ok(())
}
async fn insert_user_ssh_key(&self, new_key: &UserSshKey) -> anyhow::Result<u64> {
todo!()
}
async fn get_user_ssh_key(&self, id: u64) -> anyhow::Result<UserSshKey> {
todo!()
}
async fn delete_user_ssh_key(&self, id: u64) -> anyhow::Result<()> {
todo!()
}
async fn list_user_ssh_key(&self, user_id: u64) -> anyhow::Result<Vec<UserSshKey>> {
todo!()
}
async fn get_host_region(&self, id: u64) -> anyhow::Result<VmHostRegion> {
let regions = self.regions.lock().await;
Ok(regions.get(&id).ok_or(anyhow!("no region"))?.clone())
}
async fn list_hosts(&self) -> anyhow::Result<Vec<VmHost>> {
let hosts = self.hosts.lock().await;
Ok(hosts.values().filter(|h| h.enabled).cloned().collect())
}
async fn get_host(&self, id: u64) -> anyhow::Result<VmHost> {
let hosts = self.hosts.lock().await;
Ok(hosts.get(&id).ok_or(anyhow!("no host"))?.clone())
}
async fn update_host(&self, host: &VmHost) -> anyhow::Result<()> {
let mut hosts = self.hosts.lock().await;
if let Some(h) = hosts.get_mut(&host.id) {
h.enabled = host.enabled;
h.cpu = host.cpu;
h.memory = host.memory;
}
Ok(())
}
async fn list_host_disks(&self, host_id: u64) -> anyhow::Result<Vec<VmHostDisk>> {
todo!()
}
async fn get_os_image(&self, id: u64) -> anyhow::Result<VmOsImage> {
todo!()
}
async fn list_os_image(&self) -> anyhow::Result<Vec<VmOsImage>> {
todo!()
}
async fn get_ip_range(&self, id: u64) -> anyhow::Result<IpRange> {
let ip_range = self.ip_range.lock().await;
Ok(ip_range.get(&id).ok_or(anyhow!("no ip range"))?.clone())
}
async fn list_ip_range(&self) -> anyhow::Result<Vec<IpRange>> {
let ip_range = self.ip_range.lock().await;
Ok(ip_range.values().filter(|r| r.enabled).cloned().collect())
}
async fn list_ip_range_in_region(&self, region_id: u64) -> anyhow::Result<Vec<IpRange>> {
let ip_range = self.ip_range.lock().await;
Ok(ip_range
.values()
.filter(|r| r.enabled && r.region_id == region_id)
.cloned()
.collect())
}
async fn get_cost_plan(&self, id: u64) -> anyhow::Result<VmCostPlan> {
todo!()
}
async fn get_vm_template(&self, id: u64) -> anyhow::Result<VmTemplate> {
todo!()
}
async fn list_vm_templates(&self) -> anyhow::Result<Vec<VmTemplate>> {
todo!()
}
async fn list_vms(&self) -> anyhow::Result<Vec<Vm>> {
todo!()
}
async fn list_expired_vms(&self) -> anyhow::Result<Vec<Vm>> {
todo!()
}
async fn list_user_vms(&self, id: u64) -> anyhow::Result<Vec<Vm>> {
todo!()
}
async fn get_vm(&self, vm_id: u64) -> anyhow::Result<Vm> {
todo!()
}
async fn insert_vm(&self, vm: &Vm) -> anyhow::Result<u64> {
todo!()
}
async fn delete_vm(&self, vm_id: u64) -> anyhow::Result<()> {
todo!()
}
async fn update_vm(&self, vm: &Vm) -> anyhow::Result<()> {
todo!()
}
async fn insert_vm_ip_assignment(&self, ip_assignment: &VmIpAssignment) -> anyhow::Result<u64> {
let mut ip_assignments = self.ip_assignments.lock().await;
let max = *ip_assignments.keys().max().unwrap_or(&0);
ip_assignments.insert(
max + 1,
VmIpAssignment {
id: max + 1,
vm_id: ip_assignment.vm_id,
ip_range_id: ip_assignment.ip_range_id,
ip: ip_assignment.ip.clone(),
deleted: false,
},
);
Ok(max + 1)
}
async fn list_vm_ip_assignments(&self, vm_id: u64) -> anyhow::Result<Vec<VmIpAssignment>> {
let ip_assignments = self.ip_assignments.lock().await;
Ok(ip_assignments
.values()
.filter(|a| a.vm_id == vm_id && !a.deleted)
.cloned()
.collect())
}
async fn list_vm_ip_assignments_in_range(
&self,
range_id: u64,
) -> anyhow::Result<Vec<VmIpAssignment>> {
let ip_assignments = self.ip_assignments.lock().await;
Ok(ip_assignments
.values()
.filter(|a| a.ip_range_id == range_id && !a.deleted)
.cloned()
.collect())
}
async fn delete_vm_ip_assignment(&self, vm_id: u64) -> anyhow::Result<()> {
let mut ip_assignments = self.ip_assignments.lock().await;
for ip_assignment in ip_assignments.values_mut() {
if ip_assignment.vm_id == vm_id {
ip_assignment.deleted = true;
}
}
Ok(())
}
async fn list_vm_payment(&self, vm_id: u64) -> anyhow::Result<Vec<VmPayment>> {
todo!()
}
async fn insert_vm_payment(&self, vm_payment: &VmPayment) -> anyhow::Result<()> {
todo!()
}
async fn get_vm_payment(&self, id: &Vec<u8>) -> anyhow::Result<VmPayment> {
todo!()
}
async fn update_vm_payment(&self, vm_payment: &VmPayment) -> anyhow::Result<()> {
todo!()
}
async fn vm_payment_paid(&self, id: &VmPayment) -> anyhow::Result<()> {
todo!()
}
async fn last_paid_invoice(&self) -> anyhow::Result<Option<VmPayment>> {
todo!()
}
}
struct MockRouter {
pub policy: NetworkPolicy,
}
#[async_trait]
impl Router for MockRouter {
async fn list_arp_entry(&self) -> anyhow::Result<Vec<ArpEntry>> {
todo!()
}
async fn add_arp_entry(
&self,
ip: IpAddr,
mac: &str,
interface: &str,
comment: Option<&str>,
) -> anyhow::Result<()> {
todo!()
}
async fn remove_arp_entry(&self, id: &str) -> anyhow::Result<()> {
todo!()
}
}

View File

@ -4,9 +4,9 @@ use crate::host::proxmox::{
ConfigureVm, CreateVm, DownloadUrlRequest, ProxmoxClient, ResizeDiskRequest, StorageContent, ConfigureVm, CreateVm, DownloadUrlRequest, ProxmoxClient, ResizeDiskRequest, StorageContent,
VmBios, VmConfig, VmBios, VmConfig,
}; };
use crate::provisioner::Provisioner; use crate::provisioner::{NetworkProvisioner, Provisioner, ProvisionerMethod};
use crate::router::Router; use crate::router::Router;
use crate::settings::{QemuConfig, SshConfig}; use crate::settings::{NetworkAccessPolicy, NetworkPolicy, QemuConfig, SshConfig};
use crate::ssh_client::SshClient; use crate::ssh_client::SshClient;
use anyhow::{bail, Result}; use anyhow::{bail, Result};
use chrono::{Days, Months, Utc}; use chrono::{Days, Months, Utc};
@ -25,18 +25,21 @@ use rocket::futures::{SinkExt, StreamExt};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::net::IpAddr; use std::net::IpAddr;
use std::ops::Add; use std::ops::Add;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
pub struct LNVpsProvisioner { pub struct LNVpsProvisioner {
db: Box<dyn LNVpsDb>, db: Arc<dyn LNVpsDb>,
router: Option<Box<dyn Router>>, router: Option<Box<dyn Router>>,
lnd: Client, lnd: Client,
rates: ExchangeRateCache, rates: ExchangeRateCache,
read_only: bool, read_only: bool,
config: QemuConfig, config: QemuConfig,
network_policy: NetworkPolicy,
ssh: Option<SshConfig>, ssh: Option<SshConfig>,
} }
@ -44,6 +47,7 @@ impl LNVpsProvisioner {
pub fn new( pub fn new(
read_only: bool, read_only: bool,
config: QemuConfig, config: QemuConfig,
network_policy: NetworkPolicy,
ssh: Option<SshConfig>, ssh: Option<SshConfig>,
db: impl LNVpsDb + 'static, db: impl LNVpsDb + 'static,
router: Option<impl Router + 'static>, router: Option<impl Router + 'static>,
@ -51,9 +55,10 @@ impl LNVpsProvisioner {
rates: ExchangeRateCache, rates: ExchangeRateCache,
) -> Self { ) -> Self {
Self { Self {
db: Box::new(db), db: Arc::new(db),
router: router.map(|r| Box::new(r) as Box<dyn Router>), router: router.map(|r| Box::new(r) as Box<dyn Router>),
lnd, lnd,
network_policy,
rates, rates,
config, config,
read_only, read_only,
@ -73,14 +78,9 @@ impl LNVpsProvisioner {
} }
} }
async fn get_vm_config(&self, vm: &Vm) -> Result<VmConfig> { async fn get_vm_config(&self, vm: &Vm, ips: &Vec<VmIpAssignment>) -> Result<VmConfig> {
let ssh_key = self.db.get_user_ssh_key(vm.ssh_key_id).await?; let ssh_key = self.db.get_user_ssh_key(vm.ssh_key_id).await?;
let mut ips = self.db.list_vm_ip_assignments(vm.id).await?;
if ips.is_empty() {
ips = self.allocate_ips(vm.id).await?;
}
let ip_range_ids: HashSet<u64> = ips.iter().map(|i| i.ip_range_id).collect(); let ip_range_ids: HashSet<u64> = ips.iter().map(|i| i.ip_range_id).collect();
let ip_ranges: Vec<_> = ip_range_ids let ip_ranges: Vec<_> = ip_range_ids
.iter() .iter()
@ -148,6 +148,30 @@ impl LNVpsProvisioner {
..Default::default() ..Default::default()
}) })
} }
async fn save_ip_assignment(&self, vm: &Vm, assignment: &VmIpAssignment) -> Result<()> {
// apply network policy
match &self.network_policy.access {
NetworkAccessPolicy::StaticArp { interface } => {
if let Some(r) = self.router.as_ref() {
r.add_arp_entry(
IpAddr::from_str(&assignment.ip)?,
&vm.mac_address,
interface,
Some(&format!("VM{}", vm.id)),
)
.await?;
} else {
bail!("No router found to apply static arp entry!")
}
}
_ => {}
}
// save to db
self.db.insert_vm_ip_assignment(&assignment).await?;
Ok(())
}
} }
#[async_trait] #[async_trait]
@ -303,69 +327,29 @@ impl Provisioner for LNVpsProvisioner {
async fn allocate_ips(&self, vm_id: u64) -> Result<Vec<VmIpAssignment>> { async fn allocate_ips(&self, vm_id: u64) -> Result<Vec<VmIpAssignment>> {
let vm = self.db.get_vm(vm_id).await?; let vm = self.db.get_vm(vm_id).await?;
let existing_ips = self.db.list_vm_ip_assignments(vm_id).await?;
if !existing_ips.is_empty() {
return Ok(existing_ips);
}
let template = self.db.get_vm_template(vm.template_id).await?; let template = self.db.get_vm_template(vm.template_id).await?;
let ips = self.db.list_vm_ip_assignments(vm.id).await?; let prov = NetworkProvisioner::new(
ProvisionerMethod::Random,
self.network_policy.clone(),
self.db.clone(),
);
if !ips.is_empty() { let ip = prov.pick_ip_for_region(template.region_id).await?;
bail!("IP resources are already assigned"); let assignment = VmIpAssignment {
} id: 0,
vm_id,
ip_range_id: ip.range_id,
ip: ip.ip.to_string(),
deleted: false,
};
let ip_ranges = self.db.list_ip_range().await?; self.save_ip_assignment(&vm, &assignment).await?;
let ip_ranges: Vec<IpRange> = ip_ranges Ok(vec![assignment])
.into_iter()
.filter(|i| i.region_id == template.region_id && i.enabled)
.collect();
if ip_ranges.is_empty() {
bail!("No ip range found in this region");
}
let mut ret = vec![];
// Try all ranges
// TODO: pick round-robin ranges
// TODO: pick one of each type
'ranges: for range in ip_ranges {
let range_cidr: IpNetwork = range.cidr.parse()?;
let ips = self.db.list_vm_ip_assignments_in_range(range.id).await?;
let ips: HashSet<IpNetwork> = ips.iter().map_while(|i| i.ip.parse().ok()).collect();
// pick an IP at random
let cidr: Vec<IpAddr> = {
let mut rng = rand::thread_rng();
range_cidr.iter().choose(&mut rng).into_iter().collect()
};
for ip in cidr {
let ip_net = IpNetwork::new(ip, range_cidr.prefix())?;
if !ips.contains(&ip_net) {
info!("Attempting to allocate IP for {vm_id} to {ip}");
let mut assignment = VmIpAssignment {
id: 0,
vm_id,
ip_range_id: range.id,
ip: ip_net.to_string(),
..Default::default()
};
// add arp entry for router
if let Some(r) = self.router.as_ref() {
r.add_arp_entry(ip, &vm.mac_address, Some(&format!("VM{}", vm.id)))
.await?;
}
let id = self.db.insert_vm_ip_assignment(&assignment).await?;
assignment.id = id;
ret.push(assignment);
break 'ranges;
}
}
}
if ret.is_empty() {
bail!("No ip ranges found in this region");
}
Ok(ret)
} }
async fn spawn_vm(&self, vm_id: u64) -> Result<()> { async fn spawn_vm(&self, vm_id: u64) -> Result<()> {
@ -378,16 +362,21 @@ impl Provisioner for LNVpsProvisioner {
let client = get_host_client(&host)?; let client = get_host_client(&host)?;
let vm_id = 100 + vm.id as i32; let vm_id = 100 + vm.id as i32;
// setup network by allocating some IP space
let ips = self.allocate_ips(vm.id).await?;
// create VM // create VM
let config = self.get_vm_config(&vm, &ips).await?;
let t_create = client let t_create = client
.create_vm(CreateVm { .create_vm(CreateVm {
node: host.name.clone(), node: host.name.clone(),
vm_id, vm_id,
config: self.get_vm_config(&vm).await?, config,
}) })
.await?; .await?;
client.wait_for_task(&t_create).await?; client.wait_for_task(&t_create).await?;
// save
// import the disk // import the disk
// TODO: find a way to avoid using SSH // TODO: find a way to avoid using SSH
if let Some(ssh_config) = &self.ssh { if let Some(ssh_config) = &self.ssh {
@ -451,8 +440,10 @@ impl Provisioner for LNVpsProvisioner {
.await?; .await?;
client.wait_for_task(&j_resize).await?; client.wait_for_task(&j_resize).await?;
let j_start = client.start_vm(&host.name, vm_id as u64).await?; // try start, otherwise ignore error (maybe its already running)
client.wait_for_task(&j_start).await?; if let Ok(j_start) = client.start_vm(&host.name, vm_id as u64).await {
client.wait_for_task(&j_start).await?;
}
Ok(()) Ok(())
} }
@ -549,6 +540,7 @@ impl Provisioner for LNVpsProvisioner {
async fn patch_vm(&self, vm_id: u64) -> Result<()> { async fn patch_vm(&self, vm_id: u64) -> Result<()> {
let vm = self.db.get_vm(vm_id).await?; let vm = self.db.get_vm(vm_id).await?;
let host = self.db.get_host(vm.host_id).await?; let host = self.db.get_host(vm.host_id).await?;
let ips = self.db.list_vm_ip_assignments(vm.id).await?;
let client = get_host_client(&host)?; let client = get_host_client(&host)?;
let host_vm_id = vm.id + 100; let host_vm_id = vm.id + 100;
@ -562,7 +554,7 @@ impl Provisioner for LNVpsProvisioner {
scsi_0: None, scsi_0: None,
scsi_1: None, scsi_1: None,
efi_disk_0: None, efi_disk_0: None,
..self.get_vm_config(&vm).await? ..self.get_vm_config(&vm, &ips).await?
}, },
}) })
.await?; .await?;

View File

@ -4,7 +4,11 @@ use rocket::async_trait;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
pub mod lnvps; mod lnvps;
mod network;
pub use lnvps::*;
pub use network::*;
#[async_trait] #[async_trait]
pub trait Provisioner: Send + Sync { pub trait Provisioner: Send + Sync {

136
src/provisioner/network.rs Normal file
View File

@ -0,0 +1,136 @@
use crate::settings::NetworkPolicy;
use anyhow::{bail, Result};
use ipnetwork::IpNetwork;
use lnvps_db::LNVpsDb;
use rand::prelude::IteratorRandom;
use std::collections::HashSet;
use std::net::IpAddr;
use std::sync::Arc;
#[derive(Debug, Clone, Copy)]
pub enum ProvisionerMethod {
Sequential,
Random,
}
#[derive(Debug, Clone, Copy)]
pub struct AvailableIp {
pub ip: IpAddr,
pub range_id: u64,
pub region_id: u64,
}
#[derive(Clone)]
pub struct NetworkProvisioner {
method: ProvisionerMethod,
settings: NetworkPolicy,
db: Arc<dyn LNVpsDb>,
}
impl NetworkProvisioner {
pub fn new(method: ProvisionerMethod, settings: NetworkPolicy, db: Arc<dyn LNVpsDb>) -> Self {
Self {
method,
settings,
db,
}
}
/// Pick an IP from one of the available ip ranges
/// This method MUST return a free IP which can be used
pub async fn pick_ip_for_region(&self, region_id: u64) -> Result<AvailableIp> {
let ip_ranges = self.db.list_ip_range_in_region(region_id).await?;
if ip_ranges.is_empty() {
bail!("No ip range found in this region");
}
for range in ip_ranges {
let range_cidr: IpNetwork = range.cidr.parse()?;
let ips = self.db.list_vm_ip_assignments_in_range(range.id).await?;
let ips: HashSet<IpAddr> = ips.iter().map_while(|i| i.ip.parse().ok()).collect();
// pick an IP at random
let ip_pick = {
let first_ip = range_cidr.iter().next().unwrap();
let last_ip = range_cidr.iter().last().unwrap();
match self.method {
ProvisionerMethod::Sequential => range_cidr
.iter()
.find(|i| *i != first_ip && *i != last_ip && !ips.contains(i)),
ProvisionerMethod::Random => {
let mut rng = rand::rng();
loop {
if let Some(i) = range_cidr.iter().choose(&mut rng) {
if i != first_ip && i != last_ip && !ips.contains(&i) {
break Some(i);
}
} else {
break None;
}
}
}
}
};
if let Some(ip_pick) = ip_pick {
return Ok(AvailableIp {
range_id: range.id,
ip: ip_pick,
region_id,
});
}
}
bail!("No IPs available in this region");
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mocks::*;
use crate::settings::NetworkAccessPolicy;
use lnvps_db::VmIpAssignment;
use std::str::FromStr;
#[tokio::test]
async fn pick_seq_ip_for_region_test() {
let db: Arc<dyn LNVpsDb> = Arc::new(MockDb::default());
let mgr = NetworkProvisioner::new(
ProvisionerMethod::Sequential,
NetworkPolicy {
access: NetworkAccessPolicy::Auto,
},
db.clone(),
);
let ip = mgr.pick_ip_for_region(1).await.expect("No ip found in db");
assert_eq!(1, ip.region_id);
assert_eq!(IpAddr::from_str("10.0.0.1").unwrap(), ip.ip);
db.insert_vm_ip_assignment(&VmIpAssignment {
id: 0,
vm_id: 0,
ip_range_id: ip.range_id,
ip: ip.ip.to_string(),
deleted: false,
})
.await
.expect("Could not insert vm ip");
let ip = mgr.pick_ip_for_region(1).await.expect("No ip found in db");
assert_eq!(IpAddr::from_str("10.0.0.2").unwrap(), ip.ip);
}
#[tokio::test]
async fn pick_rng_ip_for_region_test() {
let db: Arc<dyn LNVpsDb> = Arc::new(MockDb::default());
let mgr = NetworkProvisioner::new(
ProvisionerMethod::Random,
NetworkPolicy {
access: NetworkAccessPolicy::Auto,
},
db,
);
let ip = mgr.pick_ip_for_region(1).await.expect("No ip found in db");
assert_eq!(1, ip.region_id);
}
}

View File

@ -14,11 +14,10 @@ pub struct MikrotikRouter {
username: String, username: String,
password: String, password: String,
client: Client, client: Client,
arp_interface: String,
} }
impl MikrotikRouter { impl MikrotikRouter {
pub fn new(url: &str, username: &str, password: &str, arp_interface: &str) -> Self { pub fn new(url: &str, username: &str, password: &str) -> Self {
Self { Self {
url: url.parse().unwrap(), url: url.parse().unwrap(),
username: username.to_string(), username: username.to_string(),
@ -27,7 +26,6 @@ impl MikrotikRouter {
.danger_accept_invalid_certs(true) .danger_accept_invalid_certs(true)
.build() .build()
.unwrap(), .unwrap(),
arp_interface: arp_interface.to_string(),
} }
} }
@ -73,7 +71,13 @@ impl Router for MikrotikRouter {
Ok(rsp) Ok(rsp)
} }
async fn add_arp_entry(&self, ip: IpAddr, mac: &str, comment: Option<&str>) -> Result<()> { async fn add_arp_entry(
&self,
ip: IpAddr,
mac: &str,
arp_interface: &str,
comment: Option<&str>,
) -> Result<()> {
let _rsp: ArpEntry = self let _rsp: ArpEntry = self
.req( .req(
Method::PUT, Method::PUT,
@ -81,7 +85,7 @@ impl Router for MikrotikRouter {
ArpEntry { ArpEntry {
address: ip.to_string(), address: ip.to_string(),
mac_address: Some(mac.to_string()), mac_address: Some(mac.to_string()),
interface: self.arp_interface.to_string(), interface: arp_interface.to_string(),
comment: comment.map(|c| c.to_string()), comment: comment.map(|c| c.to_string()),
..Default::default() ..Default::default()
}, },

View File

@ -13,7 +13,13 @@ use std::net::IpAddr;
#[async_trait] #[async_trait]
pub trait Router: Send + Sync { pub trait Router: Send + Sync {
async fn list_arp_entry(&self) -> Result<Vec<ArpEntry>>; async fn list_arp_entry(&self) -> Result<Vec<ArpEntry>>;
async fn add_arp_entry(&self, ip: IpAddr, mac: &str, comment: Option<&str>) -> Result<()>; async fn add_arp_entry(
&self,
ip: IpAddr,
mac: &str,
interface: &str,
comment: Option<&str>,
) -> Result<()>;
async fn remove_arp_entry(&self, id: &str) -> Result<()>; async fn remove_arp_entry(&self, id: &str) -> Result<()>;
} }
@ -35,3 +41,12 @@ pub struct ArpEntry {
mod mikrotik; mod mikrotik;
#[cfg(feature = "mikrotik")] #[cfg(feature = "mikrotik")]
pub use mikrotik::*; pub use mikrotik::*;
#[cfg(test)]
mod tests {
use super::*;
use crate::settings::NetworkPolicy;
#[test]
fn provision_ips_with_arp() {}
}

View File

@ -1,21 +1,27 @@
use crate::exchange::ExchangeRateCache; use crate::exchange::ExchangeRateCache;
use crate::provisioner::lnvps::LNVpsProvisioner; use crate::provisioner::LNVpsProvisioner;
use crate::provisioner::Provisioner; use crate::provisioner::Provisioner;
use crate::router::{MikrotikRouter, Router}; use crate::router::{MikrotikRouter, Router};
use anyhow::{bail, Result};
use fedimint_tonic_lnd::Client; use fedimint_tonic_lnd::Client;
use lnvps_db::LNVpsDb; use lnvps_db::LNVpsDb;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::path::PathBuf; use std::path::PathBuf;
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct Settings { pub struct Settings {
pub listen: Option<String>, pub listen: Option<String>,
pub db: String, pub db: String,
pub lnd: LndConfig, pub lnd: LndConfig,
/// Main control process impl /// Provisioning profiles
pub provisioner: ProvisionerConfig, pub provisioner: ProvisionerConfig,
/// Network policy
#[serde(default)]
pub network_policy: NetworkPolicy,
/// Number of days after an expired VM is deleted /// Number of days after an expired VM is deleted
pub delete_after: u16, pub delete_after: u16,
@ -25,7 +31,10 @@ pub struct Settings {
/// Network router config /// Network router config
pub router: Option<RouterConfig>, pub router: Option<RouterConfig>,
/// Nostr config for sending DM's /// DNS configurations for PTR records
pub dns: Option<DnsServerConfig>,
/// Nostr config for sending DMs
pub nostr: Option<NostrConfig>, pub nostr: Option<NostrConfig>,
} }
@ -43,18 +52,55 @@ pub struct NostrConfig {
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "kebab-case")]
pub enum RouterConfig { pub enum RouterConfig {
Mikrotik { Mikrotik(ApiConfig),
url: String, }
username: String,
password: String,
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub enum DnsServerConfig {
Cloudflare { api: ApiConfig, zone_id: String },
}
/// Generic remote API credentials
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ApiConfig {
/// unique ID of this router, used in references
pub id: String,
/// http://<my-router>
pub url: String,
/// Login credentials used for this router
pub credentials: Credentials,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub enum Credentials {
UsernamePassword { username: String, password: String },
ApiToken { token: String },
}
/// Policy that determines how packets arrive at the VM
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum NetworkAccessPolicy {
/// No special procedure required for packets to arrive
#[default]
Auto,
/// ARP entries are added statically on the access router
StaticArp {
/// Interface used to add arp entries /// Interface used to add arp entries
arp_interface: String, interface: String,
}, },
} }
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
#[serde(rename_all = "kebab-case")]
pub struct NetworkPolicy {
pub access: NetworkAccessPolicy,
}
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SmtpConfig { pub struct SmtpConfig {
/// Admin user id, for sending system notifications /// Admin user id, for sending system notifications
@ -74,8 +120,9 @@ pub struct SmtpConfig {
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "kebab-case")]
pub enum ProvisionerConfig { pub enum ProvisionerConfig {
#[serde(rename_all = "kebab-case")]
Proxmox { Proxmox {
/// Readonly mode, don't spawn any VM's /// Readonly mode, don't spawn any VM's
read_only: bool, read_only: bool,
@ -95,27 +142,23 @@ pub struct SshConfig {
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct QemuConfig { pub struct QemuConfig {
/// Machine type (q35) /// Machine type (q35)
pub machine: String, pub machine: String,
/// OS Type /// OS Type
pub os_type: String, pub os_type: String,
/// Network bridge used for the networking interface /// Network bridge used for the networking interface
pub bridge: String, pub bridge: String,
/// CPU type /// CPU type
pub cpu: String, pub cpu: String,
/// VLAN tag all spawned VM's /// VLAN tag all spawned VM's
pub vlan: Option<u16>, pub vlan: Option<u16>,
/// Enable virtualization inside VM /// Enable virtualization inside VM
pub kvm: bool, pub kvm: bool,
} }
impl ProvisionerConfig { impl Settings {
pub fn get_provisioner( pub fn get_provisioner(
&self, &self,
db: impl LNVpsDb + 'static, db: impl LNVpsDb + 'static,
@ -123,7 +166,7 @@ impl ProvisionerConfig {
lnd: Client, lnd: Client,
exchange: ExchangeRateCache, exchange: ExchangeRateCache,
) -> impl Provisioner + 'static { ) -> impl Provisioner + 'static {
match self { match &self.provisioner {
ProvisionerConfig::Proxmox { ProvisionerConfig::Proxmox {
qemu, qemu,
ssh, ssh,
@ -131,6 +174,7 @@ impl ProvisionerConfig {
} => LNVpsProvisioner::new( } => LNVpsProvisioner::new(
*read_only, *read_only,
qemu.clone(), qemu.clone(),
self.network_policy.clone(),
ssh.clone(), ssh.clone(),
db, db,
router, router,
@ -139,17 +183,15 @@ impl ProvisionerConfig {
), ),
} }
} }
} pub fn get_router<'a>(&'a self) -> Result<Option<impl Router + 'static>> {
match &self.router {
impl RouterConfig { Some(RouterConfig::Mikrotik(api)) => match &api.credentials {
pub fn get_router(&self) -> impl Router + 'static { Credentials::UsernamePassword { username, password } => {
match self { Ok(Some(MikrotikRouter::new(&api.url, username, password)))
RouterConfig::Mikrotik { }
url, _ => bail!("Only username/password is supported for Mikrotik routers"),
username, },
password, _ => Ok(None),
arp_interface,
} => MikrotikRouter::new(url, username, password, arp_interface),
} }
} }
} }

View File

@ -1,9 +1,9 @@
use anyhow::Result; use anyhow::Result;
use rocket::serde::Deserialize;
use schemars::JsonSchema;
use serde::Serialize; use serde::Serialize;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use rocket::serde::Deserialize;
use schemars::JsonSchema;
use tokio::sync::RwLock; use tokio::sync::RwLock;
#[derive(Clone, Serialize, Deserialize, Default, JsonSchema)] #[derive(Clone, Serialize, Deserialize, Default, JsonSchema)]

View File

@ -327,12 +327,7 @@ impl Worker {
if let Err(e) = self.check_vm(*vm_id).await { if let Err(e) = self.check_vm(*vm_id).await {
error!("Failed to check VM {}: {}", vm_id, e); error!("Failed to check VM {}: {}", vm_id, e);
self.queue_admin_notification( self.queue_admin_notification(
format!( format!("Failed to check VM {}:\n{:?}\n{}", vm_id, &job, e),
"Failed to check VM {}:\n{:?}\n{}",
vm_id,
&job,
e
),
Some("Job Failed".to_string()), Some("Job Failed".to_string()),
)? )?
} }
@ -348,11 +343,7 @@ impl Worker {
{ {
error!("Failed to send notification {}: {}", user_id, e); error!("Failed to send notification {}: {}", user_id, e);
self.queue_admin_notification( self.queue_admin_notification(
format!( format!("Failed to send notification:\n{:?}\n{}", &job, e),
"Failed to send notification:\n{:?}\n{}",
&job,
e
),
Some("Job Failed".to_string()), Some("Job Failed".to_string()),
)? )?
} }