diff --git a/README.md b/README.md index c12a474..6db59ba 100644 --- a/README.md +++ b/README.md @@ -95,12 +95,13 @@ nostr: To create PTR records automatically use the following config: ```yaml dns: - cloudflare: - # The zone where forward (A/AAAA) entries are added (eg. lnvps.cloud zone) - # We create forward entries with the format vm-.lnvps.cloud - forward-zone-id: "my-forward-zone-id" - # API token to add/remove DNS records to this zone - token: "my-api-token" + # The zone where forward (A/AAAA) entries are added (eg. lnvps.cloud zone) + # We create forward entries with the format vm-.lnvps.cloud + forward-zone-id: "my-forward-zone-id" + api: + cloudflare: + # API token to add/remove DNS records to this zone + token: "my-api-token" ``` ### Taxes diff --git a/src/data_migration/dns.rs b/src/data_migration/dns.rs index 26c134c..2769445 100644 --- a/src/data_migration/dns.rs +++ b/src/data_migration/dns.rs @@ -10,12 +10,17 @@ use std::sync::Arc; pub struct DnsDataMigration { db: Arc, dns: Arc, + forward_zone_id: Option, } impl DnsDataMigration { pub fn new(db: Arc, settings: &Settings) -> Option { let dns = settings.get_dns().ok().flatten()?; - Some(Self { db, dns }) + Some(Self { + db, + dns, + forward_zone_id: settings.dns.as_ref().map(|z| z.forward_zone_id.to_string()), + }) } } @@ -23,7 +28,13 @@ impl DataMigration for DnsDataMigration { fn migrate(&self) -> Pin> + Send>> { let db = self.db.clone(); let dns = self.dns.clone(); + let forward_zone_id = self.forward_zone_id.clone(); Box::pin(async move { + let zone_id = if let Some(z) = forward_zone_id { + z + } else { + return Ok(()); + }; let vms = db.list_vms().await?; for vm in vms { @@ -32,14 +43,14 @@ impl DataMigration for DnsDataMigration { let mut did_change = false; if ip.dns_forward.is_none() { let rec = BasicRecord::forward(ip)?; - let r = dns.add_record(&rec).await?; + let r = dns.add_record(&zone_id, &rec).await?; ip.dns_forward = Some(r.name); ip.dns_forward_ref = r.id; did_change = true; } if ip.dns_reverse.is_none() { let rec = BasicRecord::reverse_to_fwd(ip)?; - let r = dns.add_record(&rec).await?; + let r = dns.add_record(&zone_id, &rec).await?; ip.dns_reverse = Some(r.value); ip.dns_reverse_ref = r.id; did_change = true; diff --git a/src/dns/cloudflare.rs b/src/dns/cloudflare.rs index d42dd8f..60b665d 100644 --- a/src/dns/cloudflare.rs +++ b/src/dns/cloudflare.rs @@ -1,4 +1,4 @@ -use crate::dns::{BasicRecord, DnsServer, RecordType}; +use crate::dns::{BasicRecord, DnsServer}; use crate::json_api::JsonApi; use anyhow::Context; use lnvps_db::async_trait; @@ -7,12 +7,10 @@ use serde::{Deserialize, Serialize}; pub struct Cloudflare { api: JsonApi, - reverse_zone_id: String, - forward_zone_id: String, } impl Cloudflare { - pub fn new(token: &str, reverse_zone_id: &str, forward_zone_id: &str) -> Cloudflare { + pub fn new(token: &str) -> Cloudflare { Self { api: JsonApi::token( "https://api.cloudflare.com", @@ -20,8 +18,6 @@ impl Cloudflare { false, ) .unwrap(), - reverse_zone_id: reverse_zone_id.to_owned(), - forward_zone_id: forward_zone_id.to_owned(), } } @@ -45,11 +41,7 @@ impl Cloudflare { #[async_trait] impl DnsServer for Cloudflare { - async fn add_record(&self, record: &BasicRecord) -> anyhow::Result { - let zone_id = match &record.kind { - RecordType::PTR => &self.reverse_zone_id, - _ => &self.forward_zone_id, - }; + async fn add_record(&self, zone_id: &str, record: &BasicRecord) -> anyhow::Result { info!( "Adding record: [{}] {} => {}", record.kind, record.name, record.value @@ -75,11 +67,7 @@ impl DnsServer for Cloudflare { }) } - async fn delete_record(&self, record: &BasicRecord) -> anyhow::Result<()> { - let zone_id = match &record.kind { - RecordType::PTR => &self.reverse_zone_id, - _ => &self.forward_zone_id, - }; + async fn delete_record(&self, zone_id: &str, record: &BasicRecord) -> anyhow::Result<()> { let record_id = record.id.as_ref().context("record id missing")?; info!( "Deleting record: [{}] {} => {}", @@ -102,11 +90,11 @@ impl DnsServer for Cloudflare { Ok(()) } - async fn update_record(&self, record: &BasicRecord) -> anyhow::Result { - let zone_id = match &record.kind { - RecordType::PTR => &self.reverse_zone_id, - _ => &self.forward_zone_id, - }; + async fn update_record( + &self, + zone_id: &str, + record: &BasicRecord, + ) -> anyhow::Result { info!( "Updating record: [{}] {} => {}", record.kind, record.name, record.value diff --git a/src/dns/mod.rs b/src/dns/mod.rs index 957e5fc..0234ab2 100644 --- a/src/dns/mod.rs +++ b/src/dns/mod.rs @@ -12,13 +12,13 @@ pub use cloudflare::*; #[async_trait] pub trait DnsServer: Send + Sync { /// Add PTR record to the reverse zone - async fn add_record(&self, record: &BasicRecord) -> Result; + async fn add_record(&self, zone_id: &str, record: &BasicRecord) -> Result; /// Delete PTR record from the reverse zone - async fn delete_record(&self, record: &BasicRecord) -> Result<()>; + async fn delete_record(&self, zone_id: &str, record: &BasicRecord) -> Result<()>; /// Update a record - async fn update_record(&self, record: &BasicRecord) -> Result; + async fn update_record(&self, zone_id: &str, record: &BasicRecord) -> Result; } #[derive(Clone, Debug)] diff --git a/src/mocks.rs b/src/mocks.rs index 1cf5ecb..1d06b13 100644 --- a/src/mocks.rs +++ b/src/mocks.rs @@ -137,7 +137,7 @@ impl Default for MockDb { load_cpu: 1.5, load_memory: 2.0, load_disk: 3.0, - vlan_id: Some(100) + vlan_id: Some(100), }, ); let mut host_disks = HashMap::new(); @@ -847,8 +847,7 @@ impl VmHostClient for MockVmHost { } pub struct MockDnsServer { - pub forward: Arc>>, - pub reverse: Arc>>, + pub zones: Arc>>>, } pub struct MockDnsEntry { @@ -859,22 +858,22 @@ pub struct MockDnsEntry { impl MockDnsServer { pub fn new() -> Self { - static LAZY_FWD: LazyLock>>> = - LazyLock::new(|| Arc::new(Mutex::new(HashMap::new()))); - static LAZY_REV: LazyLock>>> = + static LAZY_ZONES: LazyLock>>>> = LazyLock::new(|| Arc::new(Mutex::new(HashMap::new()))); Self { - forward: LAZY_FWD.clone(), - reverse: LAZY_REV.clone(), + zones: LAZY_ZONES.clone(), } } } #[async_trait] impl DnsServer for MockDnsServer { - async fn add_record(&self, record: &BasicRecord) -> anyhow::Result { - let mut table = match record.kind { - RecordType::PTR => self.reverse.lock().await, - _ => self.forward.lock().await, + async fn add_record(&self, zone_id: &str, record: &BasicRecord) -> anyhow::Result { + let mut zones = self.zones.lock().await; + let table = if let Some(t) = zones.get_mut(zone_id) { + t + } else { + zones.insert(zone_id.to_string(), HashMap::new()); + zones.get_mut(zone_id).unwrap() }; if table.values().any(|v| v.name == record.name) { @@ -902,20 +901,30 @@ impl DnsServer for MockDnsServer { }) } - async fn delete_record(&self, record: &BasicRecord) -> anyhow::Result<()> { - let mut table = match record.kind { - RecordType::PTR => self.reverse.lock().await, - _ => self.forward.lock().await, + async fn delete_record(&self, zone_id: &str, record: &BasicRecord) -> anyhow::Result<()> { + let mut zones = self.zones.lock().await; + let table = if let Some(t) = zones.get_mut(zone_id) { + t + } else { + zones.insert(zone_id.to_string(), HashMap::new()); + zones.get_mut(zone_id).unwrap() }; ensure!(record.id.is_some(), "Id is missing"); table.remove(record.id.as_ref().unwrap()); Ok(()) } - async fn update_record(&self, record: &BasicRecord) -> anyhow::Result { - let mut table = match record.kind { - RecordType::PTR => self.reverse.lock().await, - _ => self.forward.lock().await, + async fn update_record( + &self, + zone_id: &str, + record: &BasicRecord, + ) -> anyhow::Result { + let mut zones = self.zones.lock().await; + let table = if let Some(t) = zones.get_mut(zone_id) { + t + } else { + zones.insert(zone_id.to_string(), HashMap::new()); + zones.get_mut(zone_id).unwrap() }; ensure!(record.id.is_some(), "Id is missing"); if let Some(mut r) = table.get_mut(record.id.as_ref().unwrap()) { diff --git a/src/provisioner/lnvps.rs b/src/provisioner/lnvps.rs index 0aa65f4..92a6aaf 100644 --- a/src/provisioner/lnvps.rs +++ b/src/provisioner/lnvps.rs @@ -36,6 +36,9 @@ pub struct LNVpsProvisioner { dns: Option>, revolut: Option>, + /// Forward zone ID used for all VM's + /// passed to the DNSServer type + forward_zone_id: Option, provisioner_config: ProvisionerConfig, } @@ -55,6 +58,7 @@ impl LNVpsProvisioner { tax_rates: settings.tax_rate, provisioner_config: settings.provisioner, read_only: settings.read_only, + forward_zone_id: settings.dns.map(|z| z.forward_zone_id), } } @@ -149,17 +153,19 @@ impl LNVpsProvisioner { pub async fn remove_ip_dns(&self, assignment: &mut VmIpAssignment) -> Result<()> { // Delete forward/reverse dns if let Some(dns) = &self.dns { - if let Some(_r) = &assignment.dns_reverse_ref { + let range = self.db.get_ip_range(assignment.ip_range_id).await?; + + if let (Some(z), Some(_ref)) = (&range.reverse_zone_id, &assignment.dns_reverse_ref) { let rev = BasicRecord::reverse(assignment)?; - if let Err(e) = dns.delete_record(&rev).await { + if let Err(e) = dns.delete_record(z, &rev).await { warn!("Failed to delete reverse record: {}", e); } assignment.dns_reverse_ref = None; assignment.dns_reverse = None; } - if let Some(_r) = &assignment.dns_forward_ref { + if let (Some(z), Some(_ref)) = (&self.forward_zone_id, &assignment.dns_forward_ref) { let rev = BasicRecord::forward(assignment)?; - if let Err(e) = dns.delete_record(&rev).await { + if let Err(e) = dns.delete_record(z, &rev).await { warn!("Failed to delete forward record: {}", e); } assignment.dns_forward_ref = None; @@ -171,12 +177,12 @@ impl LNVpsProvisioner { /// Update DNS on the dns server, does not save to database! pub async fn update_forward_ip_dns(&self, assignment: &mut VmIpAssignment) -> Result<()> { - if let Some(dns) = &self.dns { + if let (Some(z), Some(dns)) = (&self.forward_zone_id, &self.dns) { let fwd = BasicRecord::forward(assignment)?; let ret_fwd = if fwd.id.is_some() { - dns.update_record(&fwd).await? + dns.update_record(z, &fwd).await? } else { - dns.add_record(&fwd).await? + dns.add_record(z, &fwd).await? }; assignment.dns_forward = Some(ret_fwd.name); assignment.dns_forward_ref = Some(ret_fwd.id.context("Record id is missing")?); @@ -187,15 +193,18 @@ impl LNVpsProvisioner { /// Update DNS on the dns server, does not save to database! pub async fn update_reverse_ip_dns(&self, assignment: &mut VmIpAssignment) -> Result<()> { if let Some(dns) = &self.dns { - let ret_rev = if assignment.dns_reverse_ref.is_some() { - dns.update_record(&BasicRecord::reverse(assignment)?) - .await? - } else { - dns.add_record(&BasicRecord::reverse_to_fwd(assignment)?) - .await? - }; - assignment.dns_reverse = Some(ret_rev.value); - assignment.dns_reverse_ref = Some(ret_rev.id.context("Record id is missing")?); + let range = self.db.get_ip_range(assignment.ip_range_id).await?; + if let Some(z) = &range.reverse_zone_id { + let ret_rev = if assignment.dns_reverse_ref.is_some() { + dns.update_record(z, &BasicRecord::reverse(assignment)?) + .await? + } else { + dns.add_record(z, &BasicRecord::reverse_to_fwd(assignment)?) + .await? + }; + assignment.dns_reverse = Some(ret_rev.value); + assignment.dns_reverse_ref = Some(ret_rev.id.context("Record id is missing")?); + } } Ok(()) } @@ -636,6 +645,7 @@ mod tests { let mut i = db.ip_range.lock().await; let r = i.get_mut(&1).unwrap(); r.access_policy_id = Some(1); + r.reverse_zone_id = Some("mock-rev-zone-id".to_string()); } let dns = MockDnsServer::new(); @@ -691,14 +701,26 @@ mod tests { assert!(!ip.ip.ends_with("/8")); assert!(!ip.ip.ends_with("/24")); + // test zones have dns entries + { + let zones = dns.zones.lock().await; + assert_eq!(zones.get("mock-rev-zone-id").unwrap().len(), 1); + assert_eq!(zones.get("mock-forward-zone-id").unwrap().len(), 1); + } + // now expire provisioner.delete_vm(vm.id).await?; // test arp/dns is removed let arp = router.list_arp_entry().await?; assert!(arp.is_empty()); - assert_eq!(dns.forward.lock().await.len(), 0); - assert_eq!(dns.reverse.lock().await.len(), 0); + + // test dns entries are deleted + { + let zones = dns.zones.lock().await; + assert_eq!(zones.get("mock-rev-zone-id").unwrap().len(), 0); + assert_eq!(zones.get("mock-forward-zone-id").unwrap().len(), 0); + } // ensure IPS are deleted let ips = db.ip_assignments.lock().await; diff --git a/src/settings.rs b/src/settings.rs index 496843d..a6e66f9 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -76,29 +76,28 @@ pub struct NostrConfig { #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "kebab-case")] -pub enum DnsServerConfig { +pub struct DnsServerConfig { + pub forward_zone_id: String, + pub api: DnsServerApi, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "kebab-case")] +pub enum DnsServerApi { #[serde(rename_all = "kebab-case")] - Cloudflare { - token: String, - forward_zone_id: String, - reverse_zone_id: String, - }, + Cloudflare { token: String }, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct SmtpConfig { /// Admin user id, for sending system notifications pub admin: Option, - /// Email server host:port pub server: String, - /// From header to use, otherwise empty pub from: Option, - /// Username for SMTP connection pub username: String, - /// Password for SMTP connection pub password: String, } @@ -168,16 +167,12 @@ impl Settings { { match &self.dns { None => Ok(None), - #[cfg(feature = "cloudflare")] - Some(DnsServerConfig::Cloudflare { - token, - forward_zone_id, - reverse_zone_id, - }) => Ok(Some(Arc::new(crate::dns::Cloudflare::new( - token, - reverse_zone_id, - forward_zone_id, - )))), + Some(c) => match &c.api { + #[cfg(feature = "cloudflare")] + DnsServerApi::Cloudflare { token } => { + Ok(Some(Arc::new(crate::dns::Cloudflare::new(token)))) + } + }, } } } @@ -216,10 +211,11 @@ pub fn mock_settings() -> Settings { }, delete_after: 0, smtp: None, - dns: Some(DnsServerConfig::Cloudflare { - token: "abc".to_string(), - forward_zone_id: "123".to_string(), - reverse_zone_id: "456".to_string(), + dns: Some(DnsServerConfig { + forward_zone_id: "mock-forward-zone-id".to_string(), + api: DnsServerApi::Cloudflare { + token: "abc".to_string(), + }, }), nostr: None, revolut: None,