diff --git a/Cargo.lock b/Cargo.lock index 458d6c8..0ca5de2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1841,8 +1841,7 @@ checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "ipnetwork" version = "0.21.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf370abdafd54d13e54a620e8c3e1145f28e46cc9d704bc6d94414559df41763" +source = "git+https://github.com/v0l/ipnetwork.git?rev=51b4816358f255ecbcff90b739fbbd4b1cbc2d6a#51b4816358f255ecbcff90b739fbbd4b1cbc2d6a" [[package]] name = "is-terminal" diff --git a/Cargo.toml b/Cargo.toml index f8ca2c8..df831c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ nostr = { version = "0.39.0", default-features = false, features = ["std"] } base64 = { version = "0.22.1", features = ["alloc"] } urlencoding = "2.1.3" fedimint-tonic-lnd = { version = "0.2.0", default-features = false, features = ["invoicesrpc"] } -ipnetwork = "0.21.1" +ipnetwork = { git = "https://github.com/v0l/ipnetwork.git", rev = "51b4816358f255ecbcff90b739fbbd4b1cbc2d6a" } rand = "0.9.0" clap = { version = "4.5.21", features = ["derive"] } ssh2 = "0.9.4" diff --git a/README.md b/README.md index 94e1182..1066fc5 100644 --- a/README.md +++ b/README.md @@ -21,19 +21,19 @@ lnd: macaroon: "$HOME/.lnd/data/chain/bitcoin/mainnet/admin.macaroon" # Number of days after a VM expires to delete -delete_after: 3 +delete-after: 3 # Provisioner is the main process which handles creating/deleting VM's # Currently supports: Proxmox provisioner: proxmox: # Read-only mode prevents spawning VM's - read_only: false + read-only: false # Proxmox (QEMU) settings used for spawning VM's qemu: bios: "ovmf" machine: "q35" - os_type: "l26" + os-type: "l26" bridge: "vmbr0" cpu: "kvm64" vlan: 100 @@ -89,6 +89,12 @@ router: url: "https://my-router.net" username: "admin" password: "admin" - # Interface where the static ARP entry is added - arp_interface: "bridge1" +network-policy: + # How packets get to the VM + # (default "auto", nothing to do, packets will always arrive) + access: + # Static ARP entries are added to the router for each provisioned IP + static-arp: + # Interface where the static ARP entry is added + interface: "bridge1" ``` \ No newline at end of file diff --git a/config.yaml b/config.yaml index b59fac6..371de64 100644 --- a/config.yaml +++ b/config.yaml @@ -3,14 +3,14 @@ lnd: url: "https://127.0.0.1:10003" cert: "/home/kieran/.polar/networks/2/volumes/lnd/alice/tls.cert" macaroon: "/home/kieran/.polar/networks/2/volumes/lnd/alice/data/chain/bitcoin/regtest/admin.macaroon" -delete_after: 3 +delete-after: 3 provisioner: proxmox: - read_only: false + read-only: false qemu: bios: "ovmf" machine: "q35" - os_type: "l26" + os-type: "l26" bridge: "vmbr0" cpu: "host" vlan: 100 diff --git a/lnvps_db/src/lib.rs b/lnvps_db/src/lib.rs index 5b0ccc0..f62e475 100644 --- a/lnvps_db/src/lib.rs +++ b/lnvps_db/src/lib.rs @@ -1,6 +1,4 @@ use anyhow::Result; -use async_trait::async_trait; - mod model; #[cfg(feature = "mysql")] mod mysql; @@ -9,6 +7,8 @@ pub use model::*; #[cfg(feature = "mysql")] pub use mysql::*; +pub use async_trait::async_trait; + #[async_trait] pub trait LNVpsDb: Sync + Send { /// Migrate database @@ -65,6 +65,9 @@ pub trait LNVpsDb: Sync + Send { /// List available IP Ranges async fn list_ip_range(&self) -> Result>; + /// List available IP Ranges in a given region + async fn list_ip_range_in_region(&self, region_id: u64) -> Result>; + /// Get a VM cost plan by id async fn get_cost_plan(&self, id: u64) -> Result; diff --git a/lnvps_db/src/model.rs b/lnvps_db/src/model.rs index ea6922d..6c0ca86 100644 --- a/lnvps_db/src/model.rs +++ b/lnvps_db/src/model.rs @@ -220,6 +220,7 @@ pub struct VmIpAssignment { pub vm_id: u64, pub ip_range_id: u64, pub ip: String, + pub deleted: bool, } impl Display for VmIpAssignment { diff --git a/lnvps_db/src/mysql.rs b/lnvps_db/src/mysql.rs index 5c2a479..9d2b134 100644 --- a/lnvps_db/src/mysql.rs +++ b/lnvps_db/src/mysql.rs @@ -171,7 +171,15 @@ impl LNVpsDb for LNVpsDbMysql { } async fn list_ip_range(&self) -> Result> { - sqlx::query_as("select * from ip_range") + sqlx::query_as("select * from ip_range where enabled = 1") + .fetch_all(&self.db) + .await + .map_err(Error::new) + } + + async fn list_ip_range_in_region(&self, region_id: u64) -> Result> { + sqlx::query_as("select * from ip_range where region_id = ? and enabled = 1") + .bind(region_id) .fetch_all(&self.db) .await .map_err(Error::new) diff --git a/src/api/mod.rs b/src/api/mod.rs index feb9043..fb9221d 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,4 +1,4 @@ -mod routes; mod model; +mod routes; -pub use routes::routes; \ No newline at end of file +pub use routes::routes; diff --git a/src/api/model.rs b/src/api/model.rs index a1b9efe..6843902 100644 --- a/src/api/model.rs +++ b/src/api/model.rs @@ -1,9 +1,12 @@ -use nostr::util::hex; use crate::status::VmState; use chrono::{DateTime, Utc}; +use ipnetwork::IpNetwork; use lnvps_db::VmHostRegion; +use nostr::util::hex; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use std::net::IpAddr; +use std::str::FromStr; #[derive(Serialize, Deserialize, JsonSchema)] pub struct ApiVmStatus { @@ -48,15 +51,20 @@ impl From for ApiUserSshKey { pub struct ApiVmIpAssignment { pub id: u64, pub ip: String, - pub range: String, + pub gateway: String, } impl ApiVmIpAssignment { - pub fn from(ip: lnvps_db::VmIpAssignment, range: &lnvps_db::IpRange) -> Self { + pub fn from(ip: &lnvps_db::VmIpAssignment, range: &lnvps_db::IpRange) -> Self { ApiVmIpAssignment { id: ip.id, - ip: ip.ip, - range: range.cidr.clone(), + ip: IpNetwork::new( + IpNetwork::from_str(&ip.ip).unwrap().ip(), + IpNetwork::from_str(&range.cidr).unwrap().prefix(), + ) + .unwrap() + .to_string(), + gateway: range.gateway.to_string(), } } } diff --git a/src/api/routes.rs b/src/api/routes.rs index 2b479ff..2d7f33d 100644 --- a/src/api/routes.rs +++ b/src/api/routes.rs @@ -155,7 +155,7 @@ async fn vm_to_status( let range = ip_ranges .get(&i.ip_range_id) .expect("ip range id not found"); - ApiVmIpAssignment::from(i, range) + ApiVmIpAssignment::from(&i, range) }) .collect(), }) @@ -457,6 +457,7 @@ async fn v1_get_payment( } else { return ApiData::err("Invalid payment id"); }; + let payment = db.get_vm_payment(&id).await?; let vm = db.get_vm(payment.vm_id).await?; if vm.user_id != uid { diff --git a/src/bin/api.rs b/src/bin/api.rs index 7635850..d43b0ac 100644 --- a/src/bin/api.rs +++ b/src/bin/api.rs @@ -64,12 +64,10 @@ async fn main() -> Result<(), Error> { } else { None }; - let router = settings.router.as_ref().map(|r| r.get_router()); + let router = settings.get_router()?; let status = VmStateCache::new(); let worker_provisioner = - settings - .provisioner - .get_provisioner(db.clone(), router, lnd.clone(), exchange.clone()); + settings.get_provisioner(db.clone(), router, lnd.clone(), exchange.clone()); worker_provisioner.init().await?; let mut worker = Worker::new( @@ -82,11 +80,11 @@ async fn main() -> Result<(), Error> { let sender = worker.sender(); // send a startup notification - if let Some(admin) = &settings.smtp.and_then(|s| s.admin) { + if let Some(admin) = settings.smtp.as_ref().and_then(|s| s.admin) { sender.send(WorkJob::SendNotification { title: Some("Startup".to_string()), message: "System is starting!".to_string(), - user_id: *admin, + user_id: admin, })?; } @@ -132,11 +130,8 @@ async fn main() -> Result<(), Error> { } }); - let router = settings.router.as_ref().map(|r| r.get_router()); - let provisioner = - settings - .provisioner - .get_provisioner(db.clone(), router, lnd.clone(), exchange.clone()); + let router = settings.get_router()?; + let provisioner = settings.get_provisioner(db.clone(), router, lnd.clone(), exchange.clone()); let db: Box = Box::new(db.clone()); let pv: Box = Box::new(provisioner); diff --git a/src/lib.rs b/src/lib.rs index 6673279..f7fb4e9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,3 +10,6 @@ pub mod settings; pub mod ssh_client; pub mod status; pub mod worker; + +#[cfg(test)] +pub mod mocks; diff --git a/src/mocks.rs b/src/mocks.rs new file mode 100644 index 0000000..c3884c6 --- /dev/null +++ b/src/mocks.rs @@ -0,0 +1,337 @@ +use crate::router::{ArpEntry, Router}; +use crate::settings::NetworkPolicy; +use anyhow::anyhow; +use chrono::Utc; +use lnvps_db::{ + async_trait, IpRange, LNVpsDb, User, UserSshKey, Vm, VmCostPlan, VmHost, VmHostDisk, + VmHostKind, VmHostRegion, VmIpAssignment, VmOsImage, VmPayment, VmTemplate, +}; +use std::collections::HashMap; +use std::net::IpAddr; +use std::sync::Arc; +use tokio::sync::Mutex; + +#[derive(Debug, Clone)] +pub struct MockDb { + pub regions: Arc>>, + pub hosts: Arc>>, + pub users: Arc>>, + pub vms: Arc>>, + pub ip_range: Arc>>, + pub ip_assignments: Arc>>, +} + +impl MockDb { + pub fn empty() -> MockDb { + Self { + ..Default::default() + } + } +} + +impl Default for MockDb { + fn default() -> Self { + let mut regions = HashMap::new(); + regions.insert( + 1, + VmHostRegion { + id: 1, + name: "Mock".to_string(), + enabled: true, + }, + ); + let mut ip_ranges = HashMap::new(); + ip_ranges.insert( + 1, + IpRange { + id: 1, + cidr: "10.0.0.0/8".to_string(), + gateway: "10.0.0.1".to_string(), + enabled: true, + region_id: 1, + }, + ); + let mut hosts = HashMap::new(); + hosts.insert( + 1, + VmHost { + id: 1, + kind: VmHostKind::Proxmox, + region_id: 1, + name: "mock-host".to_string(), + ip: "https://localhost".to_string(), + cpu: 4, + memory: 8192, + enabled: true, + api_token: "".to_string(), + }, + ); + Self { + regions: Arc::new(Mutex::new(regions)), + ip_range: Arc::new(Mutex::new(ip_ranges)), + hosts: Arc::new(Mutex::new(hosts)), + users: Arc::new(Default::default()), + vms: Arc::new(Default::default()), + ip_assignments: Arc::new(Default::default()), + } + } +} + +#[async_trait] +impl LNVpsDb for MockDb { + async fn migrate(&self) -> anyhow::Result<()> { + Ok(()) + } + + async fn upsert_user(&self, pubkey: &[u8; 32]) -> anyhow::Result { + let mut users = self.users.lock().await; + if let Some(e) = users.iter().find(|(k, u)| u.pubkey == *pubkey) { + Ok(*e.0) + } else { + let max = *users.keys().max().unwrap_or(&0); + users.insert( + max + 1, + User { + id: max + 1, + pubkey: pubkey.to_vec(), + created: Utc::now(), + email: None, + contact_nip4: false, + contact_nip17: false, + contact_email: false, + }, + ); + Ok(max + 1) + } + } + + async fn get_user(&self, id: u64) -> anyhow::Result { + let mut users = self.users.lock().await; + Ok(users.get(&id).ok_or(anyhow!("no user"))?.clone()) + } + + async fn update_user(&self, user: &User) -> anyhow::Result<()> { + let mut users = self.users.lock().await; + if let Some(u) = users.get_mut(&user.id) { + u.email = user.email.clone(); + u.contact_email = user.contact_email.clone(); + u.contact_nip17 = user.contact_nip17.clone(); + u.contact_nip4 = user.contact_nip4.clone(); + } + Ok(()) + } + + async fn delete_user(&self, id: u64) -> anyhow::Result<()> { + let mut users = self.users.lock().await; + users.remove(&id); + Ok(()) + } + + async fn insert_user_ssh_key(&self, new_key: &UserSshKey) -> anyhow::Result { + todo!() + } + + async fn get_user_ssh_key(&self, id: u64) -> anyhow::Result { + todo!() + } + + async fn delete_user_ssh_key(&self, id: u64) -> anyhow::Result<()> { + todo!() + } + + async fn list_user_ssh_key(&self, user_id: u64) -> anyhow::Result> { + todo!() + } + + async fn get_host_region(&self, id: u64) -> anyhow::Result { + let regions = self.regions.lock().await; + Ok(regions.get(&id).ok_or(anyhow!("no region"))?.clone()) + } + + async fn list_hosts(&self) -> anyhow::Result> { + let hosts = self.hosts.lock().await; + Ok(hosts.values().filter(|h| h.enabled).cloned().collect()) + } + + async fn get_host(&self, id: u64) -> anyhow::Result { + let hosts = self.hosts.lock().await; + Ok(hosts.get(&id).ok_or(anyhow!("no host"))?.clone()) + } + + async fn update_host(&self, host: &VmHost) -> anyhow::Result<()> { + let mut hosts = self.hosts.lock().await; + if let Some(h) = hosts.get_mut(&host.id) { + h.enabled = host.enabled; + h.cpu = host.cpu; + h.memory = host.memory; + } + Ok(()) + } + + async fn list_host_disks(&self, host_id: u64) -> anyhow::Result> { + todo!() + } + + async fn get_os_image(&self, id: u64) -> anyhow::Result { + todo!() + } + + async fn list_os_image(&self) -> anyhow::Result> { + todo!() + } + + async fn get_ip_range(&self, id: u64) -> anyhow::Result { + let ip_range = self.ip_range.lock().await; + Ok(ip_range.get(&id).ok_or(anyhow!("no ip range"))?.clone()) + } + + async fn list_ip_range(&self) -> anyhow::Result> { + let ip_range = self.ip_range.lock().await; + Ok(ip_range.values().filter(|r| r.enabled).cloned().collect()) + } + + async fn list_ip_range_in_region(&self, region_id: u64) -> anyhow::Result> { + let ip_range = self.ip_range.lock().await; + Ok(ip_range + .values() + .filter(|r| r.enabled && r.region_id == region_id) + .cloned() + .collect()) + } + + async fn get_cost_plan(&self, id: u64) -> anyhow::Result { + todo!() + } + + async fn get_vm_template(&self, id: u64) -> anyhow::Result { + todo!() + } + + async fn list_vm_templates(&self) -> anyhow::Result> { + todo!() + } + + async fn list_vms(&self) -> anyhow::Result> { + todo!() + } + + async fn list_expired_vms(&self) -> anyhow::Result> { + todo!() + } + + async fn list_user_vms(&self, id: u64) -> anyhow::Result> { + todo!() + } + + async fn get_vm(&self, vm_id: u64) -> anyhow::Result { + todo!() + } + + async fn insert_vm(&self, vm: &Vm) -> anyhow::Result { + todo!() + } + + async fn delete_vm(&self, vm_id: u64) -> anyhow::Result<()> { + todo!() + } + + async fn update_vm(&self, vm: &Vm) -> anyhow::Result<()> { + todo!() + } + + async fn insert_vm_ip_assignment(&self, ip_assignment: &VmIpAssignment) -> anyhow::Result { + let mut ip_assignments = self.ip_assignments.lock().await; + let max = *ip_assignments.keys().max().unwrap_or(&0); + ip_assignments.insert( + max + 1, + VmIpAssignment { + id: max + 1, + vm_id: ip_assignment.vm_id, + ip_range_id: ip_assignment.ip_range_id, + ip: ip_assignment.ip.clone(), + deleted: false, + }, + ); + Ok(max + 1) + } + + async fn list_vm_ip_assignments(&self, vm_id: u64) -> anyhow::Result> { + let ip_assignments = self.ip_assignments.lock().await; + Ok(ip_assignments + .values() + .filter(|a| a.vm_id == vm_id && !a.deleted) + .cloned() + .collect()) + } + + async fn list_vm_ip_assignments_in_range( + &self, + range_id: u64, + ) -> anyhow::Result> { + let ip_assignments = self.ip_assignments.lock().await; + Ok(ip_assignments + .values() + .filter(|a| a.ip_range_id == range_id && !a.deleted) + .cloned() + .collect()) + } + + async fn delete_vm_ip_assignment(&self, vm_id: u64) -> anyhow::Result<()> { + let mut ip_assignments = self.ip_assignments.lock().await; + for ip_assignment in ip_assignments.values_mut() { + if ip_assignment.vm_id == vm_id { + ip_assignment.deleted = true; + } + } + Ok(()) + } + + async fn list_vm_payment(&self, vm_id: u64) -> anyhow::Result> { + todo!() + } + + async fn insert_vm_payment(&self, vm_payment: &VmPayment) -> anyhow::Result<()> { + todo!() + } + + async fn get_vm_payment(&self, id: &Vec) -> anyhow::Result { + todo!() + } + + async fn update_vm_payment(&self, vm_payment: &VmPayment) -> anyhow::Result<()> { + todo!() + } + + async fn vm_payment_paid(&self, id: &VmPayment) -> anyhow::Result<()> { + todo!() + } + + async fn last_paid_invoice(&self) -> anyhow::Result> { + todo!() + } +} + +struct MockRouter { + pub policy: NetworkPolicy, +} + +#[async_trait] +impl Router for MockRouter { + async fn list_arp_entry(&self) -> anyhow::Result> { + todo!() + } + + async fn add_arp_entry( + &self, + ip: IpAddr, + mac: &str, + interface: &str, + comment: Option<&str>, + ) -> anyhow::Result<()> { + todo!() + } + + async fn remove_arp_entry(&self, id: &str) -> anyhow::Result<()> { + todo!() + } +} diff --git a/src/provisioner/lnvps.rs b/src/provisioner/lnvps.rs index 034e380..4abff65 100644 --- a/src/provisioner/lnvps.rs +++ b/src/provisioner/lnvps.rs @@ -4,9 +4,9 @@ use crate::host::proxmox::{ ConfigureVm, CreateVm, DownloadUrlRequest, ProxmoxClient, ResizeDiskRequest, StorageContent, VmBios, VmConfig, }; -use crate::provisioner::Provisioner; +use crate::provisioner::{NetworkProvisioner, Provisioner, ProvisionerMethod}; use crate::router::Router; -use crate::settings::{QemuConfig, SshConfig}; +use crate::settings::{NetworkAccessPolicy, NetworkPolicy, QemuConfig, SshConfig}; use crate::ssh_client::SshClient; use anyhow::{bail, Result}; use chrono::{Days, Months, Utc}; @@ -25,18 +25,21 @@ use rocket::futures::{SinkExt, StreamExt}; use std::collections::{HashMap, HashSet}; use std::net::IpAddr; use std::ops::Add; +use std::str::FromStr; +use std::sync::Arc; use std::time::Duration; use tokio::net::TcpStream; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; pub struct LNVpsProvisioner { - db: Box, + db: Arc, router: Option>, lnd: Client, rates: ExchangeRateCache, read_only: bool, config: QemuConfig, + network_policy: NetworkPolicy, ssh: Option, } @@ -44,6 +47,7 @@ impl LNVpsProvisioner { pub fn new( read_only: bool, config: QemuConfig, + network_policy: NetworkPolicy, ssh: Option, db: impl LNVpsDb + 'static, router: Option, @@ -51,9 +55,10 @@ impl LNVpsProvisioner { rates: ExchangeRateCache, ) -> Self { Self { - db: Box::new(db), + db: Arc::new(db), router: router.map(|r| Box::new(r) as Box), lnd, + network_policy, rates, config, read_only, @@ -73,14 +78,9 @@ impl LNVpsProvisioner { } } - async fn get_vm_config(&self, vm: &Vm) -> Result { + async fn get_vm_config(&self, vm: &Vm, ips: &Vec) -> Result { let ssh_key = self.db.get_user_ssh_key(vm.ssh_key_id).await?; - let mut ips = self.db.list_vm_ip_assignments(vm.id).await?; - if ips.is_empty() { - ips = self.allocate_ips(vm.id).await?; - } - let ip_range_ids: HashSet = ips.iter().map(|i| i.ip_range_id).collect(); let ip_ranges: Vec<_> = ip_range_ids .iter() @@ -148,6 +148,30 @@ impl LNVpsProvisioner { ..Default::default() }) } + + async fn save_ip_assignment(&self, vm: &Vm, assignment: &VmIpAssignment) -> Result<()> { + // apply network policy + match &self.network_policy.access { + NetworkAccessPolicy::StaticArp { interface } => { + if let Some(r) = self.router.as_ref() { + r.add_arp_entry( + IpAddr::from_str(&assignment.ip)?, + &vm.mac_address, + interface, + Some(&format!("VM{}", vm.id)), + ) + .await?; + } else { + bail!("No router found to apply static arp entry!") + } + } + _ => {} + } + + // save to db + self.db.insert_vm_ip_assignment(&assignment).await?; + Ok(()) + } } #[async_trait] @@ -303,69 +327,29 @@ impl Provisioner for LNVpsProvisioner { async fn allocate_ips(&self, vm_id: u64) -> Result> { let vm = self.db.get_vm(vm_id).await?; + let existing_ips = self.db.list_vm_ip_assignments(vm_id).await?; + if !existing_ips.is_empty() { + return Ok(existing_ips); + } + let template = self.db.get_vm_template(vm.template_id).await?; - let ips = self.db.list_vm_ip_assignments(vm.id).await?; + let prov = NetworkProvisioner::new( + ProvisionerMethod::Random, + self.network_policy.clone(), + self.db.clone(), + ); - if !ips.is_empty() { - bail!("IP resources are already assigned"); - } + let ip = prov.pick_ip_for_region(template.region_id).await?; + let assignment = VmIpAssignment { + id: 0, + vm_id, + ip_range_id: ip.range_id, + ip: ip.ip.to_string(), + deleted: false, + }; - let ip_ranges = self.db.list_ip_range().await?; - let ip_ranges: Vec = ip_ranges - .into_iter() - .filter(|i| i.region_id == template.region_id && i.enabled) - .collect(); - - if ip_ranges.is_empty() { - bail!("No ip range found in this region"); - } - - let mut ret = vec![]; - // Try all ranges - // TODO: pick round-robin ranges - // TODO: pick one of each type - 'ranges: for range in ip_ranges { - let range_cidr: IpNetwork = range.cidr.parse()?; - let ips = self.db.list_vm_ip_assignments_in_range(range.id).await?; - let ips: HashSet = ips.iter().map_while(|i| i.ip.parse().ok()).collect(); - - // pick an IP at random - let cidr: Vec = { - let mut rng = rand::thread_rng(); - range_cidr.iter().choose(&mut rng).into_iter().collect() - }; - - for ip in cidr { - let ip_net = IpNetwork::new(ip, range_cidr.prefix())?; - if !ips.contains(&ip_net) { - info!("Attempting to allocate IP for {vm_id} to {ip}"); - let mut assignment = VmIpAssignment { - id: 0, - vm_id, - ip_range_id: range.id, - ip: ip_net.to_string(), - ..Default::default() - }; - - // add arp entry for router - if let Some(r) = self.router.as_ref() { - r.add_arp_entry(ip, &vm.mac_address, Some(&format!("VM{}", vm.id))) - .await?; - } - let id = self.db.insert_vm_ip_assignment(&assignment).await?; - assignment.id = id; - - ret.push(assignment); - break 'ranges; - } - } - } - - if ret.is_empty() { - bail!("No ip ranges found in this region"); - } - - Ok(ret) + self.save_ip_assignment(&vm, &assignment).await?; + Ok(vec![assignment]) } async fn spawn_vm(&self, vm_id: u64) -> Result<()> { @@ -378,16 +362,21 @@ impl Provisioner for LNVpsProvisioner { let client = get_host_client(&host)?; let vm_id = 100 + vm.id as i32; + // setup network by allocating some IP space + let ips = self.allocate_ips(vm.id).await?; + // create VM + let config = self.get_vm_config(&vm, &ips).await?; let t_create = client .create_vm(CreateVm { node: host.name.clone(), vm_id, - config: self.get_vm_config(&vm).await?, + config, }) .await?; client.wait_for_task(&t_create).await?; + // save // import the disk // TODO: find a way to avoid using SSH if let Some(ssh_config) = &self.ssh { @@ -451,8 +440,10 @@ impl Provisioner for LNVpsProvisioner { .await?; client.wait_for_task(&j_resize).await?; - let j_start = client.start_vm(&host.name, vm_id as u64).await?; - client.wait_for_task(&j_start).await?; + // try start, otherwise ignore error (maybe its already running) + if let Ok(j_start) = client.start_vm(&host.name, vm_id as u64).await { + client.wait_for_task(&j_start).await?; + } Ok(()) } @@ -549,6 +540,7 @@ impl Provisioner for LNVpsProvisioner { async fn patch_vm(&self, vm_id: u64) -> Result<()> { let vm = self.db.get_vm(vm_id).await?; let host = self.db.get_host(vm.host_id).await?; + let ips = self.db.list_vm_ip_assignments(vm.id).await?; let client = get_host_client(&host)?; let host_vm_id = vm.id + 100; @@ -562,7 +554,7 @@ impl Provisioner for LNVpsProvisioner { scsi_0: None, scsi_1: None, efi_disk_0: None, - ..self.get_vm_config(&vm).await? + ..self.get_vm_config(&vm, &ips).await? }, }) .await?; diff --git a/src/provisioner/mod.rs b/src/provisioner/mod.rs index d2a0bd8..d2faa34 100644 --- a/src/provisioner/mod.rs +++ b/src/provisioner/mod.rs @@ -4,7 +4,11 @@ use rocket::async_trait; use tokio::net::TcpStream; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; -pub mod lnvps; +mod lnvps; +mod network; + +pub use lnvps::*; +pub use network::*; #[async_trait] pub trait Provisioner: Send + Sync { diff --git a/src/provisioner/network.rs b/src/provisioner/network.rs new file mode 100644 index 0000000..3cf6ec1 --- /dev/null +++ b/src/provisioner/network.rs @@ -0,0 +1,136 @@ +use crate::settings::NetworkPolicy; +use anyhow::{bail, Result}; +use ipnetwork::IpNetwork; +use lnvps_db::LNVpsDb; +use rand::prelude::IteratorRandom; +use std::collections::HashSet; +use std::net::IpAddr; +use std::sync::Arc; + +#[derive(Debug, Clone, Copy)] +pub enum ProvisionerMethod { + Sequential, + Random, +} + +#[derive(Debug, Clone, Copy)] +pub struct AvailableIp { + pub ip: IpAddr, + pub range_id: u64, + pub region_id: u64, +} + +#[derive(Clone)] +pub struct NetworkProvisioner { + method: ProvisionerMethod, + settings: NetworkPolicy, + db: Arc, +} + +impl NetworkProvisioner { + pub fn new(method: ProvisionerMethod, settings: NetworkPolicy, db: Arc) -> Self { + Self { + method, + settings, + db, + } + } + + /// Pick an IP from one of the available ip ranges + /// This method MUST return a free IP which can be used + pub async fn pick_ip_for_region(&self, region_id: u64) -> Result { + let ip_ranges = self.db.list_ip_range_in_region(region_id).await?; + if ip_ranges.is_empty() { + bail!("No ip range found in this region"); + } + + for range in ip_ranges { + let range_cidr: IpNetwork = range.cidr.parse()?; + let ips = self.db.list_vm_ip_assignments_in_range(range.id).await?; + let ips: HashSet = ips.iter().map_while(|i| i.ip.parse().ok()).collect(); + + // pick an IP at random + let ip_pick = { + let first_ip = range_cidr.iter().next().unwrap(); + let last_ip = range_cidr.iter().last().unwrap(); + match self.method { + ProvisionerMethod::Sequential => range_cidr + .iter() + .find(|i| *i != first_ip && *i != last_ip && !ips.contains(i)), + ProvisionerMethod::Random => { + let mut rng = rand::rng(); + loop { + if let Some(i) = range_cidr.iter().choose(&mut rng) { + if i != first_ip && i != last_ip && !ips.contains(&i) { + break Some(i); + } + } else { + break None; + } + } + } + } + }; + + if let Some(ip_pick) = ip_pick { + return Ok(AvailableIp { + range_id: range.id, + ip: ip_pick, + region_id, + }); + } + } + bail!("No IPs available in this region"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::mocks::*; + use crate::settings::NetworkAccessPolicy; + use lnvps_db::VmIpAssignment; + use std::str::FromStr; + + #[tokio::test] + async fn pick_seq_ip_for_region_test() { + let db: Arc = Arc::new(MockDb::default()); + let mgr = NetworkProvisioner::new( + ProvisionerMethod::Sequential, + NetworkPolicy { + access: NetworkAccessPolicy::Auto, + }, + db.clone(), + ); + + let ip = mgr.pick_ip_for_region(1).await.expect("No ip found in db"); + assert_eq!(1, ip.region_id); + assert_eq!(IpAddr::from_str("10.0.0.1").unwrap(), ip.ip); + db.insert_vm_ip_assignment(&VmIpAssignment { + id: 0, + vm_id: 0, + ip_range_id: ip.range_id, + ip: ip.ip.to_string(), + deleted: false, + }) + .await + .expect("Could not insert vm ip"); + let ip = mgr.pick_ip_for_region(1).await.expect("No ip found in db"); + assert_eq!(IpAddr::from_str("10.0.0.2").unwrap(), ip.ip); + } + + #[tokio::test] + async fn pick_rng_ip_for_region_test() { + let db: Arc = Arc::new(MockDb::default()); + let mgr = NetworkProvisioner::new( + ProvisionerMethod::Random, + NetworkPolicy { + access: NetworkAccessPolicy::Auto, + }, + db, + ); + + let ip = mgr.pick_ip_for_region(1).await.expect("No ip found in db"); + assert_eq!(1, ip.region_id); + } +} diff --git a/src/router/mikrotik.rs b/src/router/mikrotik.rs index 7104d40..7548891 100644 --- a/src/router/mikrotik.rs +++ b/src/router/mikrotik.rs @@ -14,11 +14,10 @@ pub struct MikrotikRouter { username: String, password: String, client: Client, - arp_interface: String, } impl MikrotikRouter { - pub fn new(url: &str, username: &str, password: &str, arp_interface: &str) -> Self { + pub fn new(url: &str, username: &str, password: &str) -> Self { Self { url: url.parse().unwrap(), username: username.to_string(), @@ -27,7 +26,6 @@ impl MikrotikRouter { .danger_accept_invalid_certs(true) .build() .unwrap(), - arp_interface: arp_interface.to_string(), } } @@ -73,7 +71,13 @@ impl Router for MikrotikRouter { Ok(rsp) } - async fn add_arp_entry(&self, ip: IpAddr, mac: &str, comment: Option<&str>) -> Result<()> { + async fn add_arp_entry( + &self, + ip: IpAddr, + mac: &str, + arp_interface: &str, + comment: Option<&str>, + ) -> Result<()> { let _rsp: ArpEntry = self .req( Method::PUT, @@ -81,7 +85,7 @@ impl Router for MikrotikRouter { ArpEntry { address: ip.to_string(), mac_address: Some(mac.to_string()), - interface: self.arp_interface.to_string(), + interface: arp_interface.to_string(), comment: comment.map(|c| c.to_string()), ..Default::default() }, diff --git a/src/router/mod.rs b/src/router/mod.rs index 8583e2d..c696f82 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -13,7 +13,13 @@ use std::net::IpAddr; #[async_trait] pub trait Router: Send + Sync { async fn list_arp_entry(&self) -> Result>; - async fn add_arp_entry(&self, ip: IpAddr, mac: &str, comment: Option<&str>) -> Result<()>; + async fn add_arp_entry( + &self, + ip: IpAddr, + mac: &str, + interface: &str, + comment: Option<&str>, + ) -> Result<()>; async fn remove_arp_entry(&self, id: &str) -> Result<()>; } @@ -35,3 +41,12 @@ pub struct ArpEntry { mod mikrotik; #[cfg(feature = "mikrotik")] pub use mikrotik::*; + +#[cfg(test)] +mod tests { + use super::*; + use crate::settings::NetworkPolicy; + + #[test] + fn provision_ips_with_arp() {} +} diff --git a/src/settings.rs b/src/settings.rs index 939ae6e..9789aec 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,21 +1,27 @@ use crate::exchange::ExchangeRateCache; -use crate::provisioner::lnvps::LNVpsProvisioner; +use crate::provisioner::LNVpsProvisioner; use crate::provisioner::Provisioner; use crate::router::{MikrotikRouter, Router}; +use anyhow::{bail, Result}; use fedimint_tonic_lnd::Client; use lnvps_db::LNVpsDb; use serde::{Deserialize, Serialize}; use std::path::PathBuf; #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "kebab-case")] pub struct Settings { pub listen: Option, pub db: String, pub lnd: LndConfig, - /// Main control process impl + /// Provisioning profiles pub provisioner: ProvisionerConfig, + /// Network policy + #[serde(default)] + pub network_policy: NetworkPolicy, + /// Number of days after an expired VM is deleted pub delete_after: u16, @@ -25,7 +31,10 @@ pub struct Settings { /// Network router config pub router: Option, - /// Nostr config for sending DM's + /// DNS configurations for PTR records + pub dns: Option, + + /// Nostr config for sending DMs pub nostr: Option, } @@ -43,18 +52,55 @@ pub struct NostrConfig { } #[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "lowercase")] +#[serde(rename_all = "kebab-case")] pub enum RouterConfig { - Mikrotik { - url: String, - username: String, - password: String, + Mikrotik(ApiConfig), +} +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "kebab-case")] +pub enum DnsServerConfig { + Cloudflare { api: ApiConfig, zone_id: String }, +} + +/// Generic remote API credentials +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ApiConfig { + /// unique ID of this router, used in references + pub id: String, + /// http:// + pub url: String, + /// Login credentials used for this router + pub credentials: Credentials, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "kebab-case")] +pub enum Credentials { + UsernamePassword { username: String, password: String }, + ApiToken { token: String }, +} + +/// Policy that determines how packets arrive at the VM +#[derive(Debug, Clone, Deserialize, Serialize, Default)] +#[serde(rename_all = "kebab-case")] +pub enum NetworkAccessPolicy { + /// No special procedure required for packets to arrive + #[default] + Auto, + /// ARP entries are added statically on the access router + StaticArp { /// Interface used to add arp entries - arp_interface: String, + interface: String, }, } +#[derive(Debug, Clone, Deserialize, Serialize, Default)] +#[serde(rename_all = "kebab-case")] +pub struct NetworkPolicy { + pub access: NetworkAccessPolicy, +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct SmtpConfig { /// Admin user id, for sending system notifications @@ -74,8 +120,9 @@ pub struct SmtpConfig { } #[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(rename_all = "lowercase")] +#[serde(rename_all = "kebab-case")] pub enum ProvisionerConfig { + #[serde(rename_all = "kebab-case")] Proxmox { /// Readonly mode, don't spawn any VM's read_only: bool, @@ -95,27 +142,23 @@ pub struct SshConfig { } #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "kebab-case")] pub struct QemuConfig { /// Machine type (q35) pub machine: String, - /// OS Type pub os_type: String, - /// Network bridge used for the networking interface pub bridge: String, - /// CPU type pub cpu: String, - /// VLAN tag all spawned VM's pub vlan: Option, - /// Enable virtualization inside VM pub kvm: bool, } -impl ProvisionerConfig { +impl Settings { pub fn get_provisioner( &self, db: impl LNVpsDb + 'static, @@ -123,7 +166,7 @@ impl ProvisionerConfig { lnd: Client, exchange: ExchangeRateCache, ) -> impl Provisioner + 'static { - match self { + match &self.provisioner { ProvisionerConfig::Proxmox { qemu, ssh, @@ -131,6 +174,7 @@ impl ProvisionerConfig { } => LNVpsProvisioner::new( *read_only, qemu.clone(), + self.network_policy.clone(), ssh.clone(), db, router, @@ -139,17 +183,15 @@ impl ProvisionerConfig { ), } } -} - -impl RouterConfig { - pub fn get_router(&self) -> impl Router + 'static { - match self { - RouterConfig::Mikrotik { - url, - username, - password, - arp_interface, - } => MikrotikRouter::new(url, username, password, arp_interface), + pub fn get_router<'a>(&'a self) -> Result> { + match &self.router { + Some(RouterConfig::Mikrotik(api)) => match &api.credentials { + Credentials::UsernamePassword { username, password } => { + Ok(Some(MikrotikRouter::new(&api.url, username, password))) + } + _ => bail!("Only username/password is supported for Mikrotik routers"), + }, + _ => Ok(None), } } } diff --git a/src/status.rs b/src/status.rs index acf49b0..263ef28 100644 --- a/src/status.rs +++ b/src/status.rs @@ -1,9 +1,9 @@ use anyhow::Result; +use rocket::serde::Deserialize; +use schemars::JsonSchema; use serde::Serialize; use std::collections::HashMap; use std::sync::Arc; -use rocket::serde::Deserialize; -use schemars::JsonSchema; use tokio::sync::RwLock; #[derive(Clone, Serialize, Deserialize, Default, JsonSchema)] diff --git a/src/worker.rs b/src/worker.rs index 6a13a88..54f868e 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -327,12 +327,7 @@ impl Worker { if let Err(e) = self.check_vm(*vm_id).await { error!("Failed to check VM {}: {}", vm_id, e); self.queue_admin_notification( - format!( - "Failed to check VM {}:\n{:?}\n{}", - vm_id, - &job, - e - ), + format!("Failed to check VM {}:\n{:?}\n{}", vm_id, &job, e), Some("Job Failed".to_string()), )? } @@ -348,11 +343,7 @@ impl Worker { { error!("Failed to send notification {}: {}", user_id, e); self.queue_admin_notification( - format!( - "Failed to send notification:\n{:?}\n{}", - &job, - e - ), + format!("Failed to send notification:\n{:?}\n{}", &job, e), Some("Job Failed".to_string()), )? }