From fc735590d621999e6bc4ca15b8fab86b952225ce Mon Sep 17 00:00:00 2001 From: kieran Date: Tue, 4 Mar 2025 14:28:50 +0000 Subject: [PATCH] refactor: simplify dns server trait feat: update reverse dns --- src/api/model.rs | 2 + src/api/routes.rs | 22 +++++-- src/bin/api.rs | 1 - src/dns/cloudflare.rs | 91 ++++++++++++++++++--------- src/dns/mod.rs | 131 ++++++++++++++++++++++++++++++++++++--- src/mocks.rs | 60 ++++++------------ src/provisioner/lnvps.rs | 70 ++++++++++++++------- 7 files changed, 271 insertions(+), 106 deletions(-) diff --git a/src/api/model.rs b/src/api/model.rs index 033171c..e2db237 100644 --- a/src/api/model.rs +++ b/src/api/model.rs @@ -191,6 +191,8 @@ pub struct ApiVmHostRegion { pub struct VMPatchRequest { /// SSH key assigned to vm pub ssh_key_id: Option, + /// Reverse DNS PTR domain + pub reverse_dns: Option, } #[derive(Serialize, Deserialize, JsonSchema)] diff --git a/src/api/routes.rs b/src/api/routes.rs index 894e25e..e6ee000 100644 --- a/src/api/routes.rs +++ b/src/api/routes.rs @@ -205,6 +205,7 @@ async fn v1_get_vm( async fn v1_patch_vm( auth: Nip98Auth, db: &State>, + provisioner: &State>, settings: &State, id: u64, data: Json, @@ -216,20 +217,31 @@ async fn v1_patch_vm( return ApiData::err("VM doesnt belong to you"); } + let mut vm_config = false; if let Some(k) = data.ssh_key_id { let ssh_key = db.get_user_ssh_key(k).await?; if ssh_key.user_id != uid { return ApiData::err("SSH key doesnt belong to you"); } vm.ssh_key_id = ssh_key.id; + vm_config = true; } - db.update_vm(&vm).await?; + if let Some(ptr) = &data.reverse_dns { + let mut ips = db.list_vm_ip_assignments(vm.id).await?; + for mut ip in ips.iter_mut() { + ip.dns_reverse = Some(ptr.to_string()); + provisioner.update_reverse_ip_dns(&mut ip).await?; + } + } - let info = FullVmInfo::load(vm.id, (*db).clone()).await?; - let host = db.get_host(vm.host_id).await?; - let client = get_host_client(&host, &settings.provisioner)?; - client.configure_vm(&info).await?; + if vm_config { + db.update_vm(&vm).await?; + let info = FullVmInfo::load(vm.id, (*db).clone()).await?; + let host = db.get_host(vm.host_id).await?; + let client = get_host_client(&host, &settings.provisioner)?; + client.configure_vm(&info).await?; + } ApiData::ok(()) } diff --git a/src/bin/api.rs b/src/bin/api.rs index 84a7e57..6b47036 100644 --- a/src/bin/api.rs +++ b/src/bin/api.rs @@ -19,7 +19,6 @@ use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; -use tokio::fs::create_dir_all; use tokio::time::sleep; #[derive(Parser)] diff --git a/src/dns/cloudflare.rs b/src/dns/cloudflare.rs index 330bca8..80487e1 100644 --- a/src/dns/cloudflare.rs +++ b/src/dns/cloudflare.rs @@ -1,8 +1,9 @@ use crate::dns::{BasicRecord, DnsServer, RecordType}; use crate::json_api::JsonApi; +use anyhow::Context; use lnvps_db::async_trait; +use log::info; use serde::{Deserialize, Serialize}; -use std::net::IpAddr; pub struct Cloudflare { api: JsonApi, @@ -23,67 +24,101 @@ impl Cloudflare { #[async_trait] impl DnsServer for Cloudflare { - async fn add_ptr_record(&self, key: &str, value: &str) -> anyhow::Result { + async fn add_record(&self, record: &BasicRecord) -> anyhow::Result { + let zone_id = match &record.kind { + RecordType::PTR => &self.reverse_zone_id, + _ => &self.forward_zone_id, + }; + info!( + "Adding record: [{}] {} => {}", + record.kind, record.name, record.value + ); let id_response: CfResult = self .api .post( - &format!("/client/v4/zones/{}/dns_records", self.reverse_zone_id), + &format!("/client/v4/zones/{zone_id}/dns_records"), CfRecord { - content: value.to_string(), - name: key.to_string(), - r_type: "PTR".to_string(), + content: record.value.to_string(), + name: record.name.to_string(), + r_type: Some(record.kind.to_string()), id: None, }, ) .await?; Ok(BasicRecord { name: id_response.result.name, - value: value.to_string(), + value: id_response.result.content, id: id_response.result.id, - kind: RecordType::PTR, + kind: record.kind.clone(), }) } - async fn delete_ptr_record(&self, key: &str) -> anyhow::Result<()> { - todo!() + async fn delete_record(&self, record: &BasicRecord) -> anyhow::Result<()> { + let zone_id = match &record.kind { + RecordType::PTR => &self.reverse_zone_id, + _ => &self.forward_zone_id, + }; + let record_id = record.id.as_ref().context("record id missing")?; + info!( + "Deleting record: [{}] {} => {}", + record.kind, record.name, record.value + ); + self.api + .req( + reqwest::Method::DELETE, + &format!("/client/v4/zones/{}/dns_records/{}", zone_id, record_id), + CfRecord { + content: record.value.to_string(), + name: record.name.to_string(), + r_type: None, + id: None, + }, + ) + .await?; + Ok(()) } - async fn add_a_record(&self, name: &str, ip: IpAddr) -> anyhow::Result { + async fn update_record(&self, record: &BasicRecord) -> anyhow::Result { + let zone_id = match &record.kind { + RecordType::PTR => &self.reverse_zone_id, + _ => &self.forward_zone_id, + }; + info!( + "Updating record: [{}] {} => {}", + record.kind, record.name, record.value + ); + let record_id = record.id.as_ref().context("record id missing")?; let id_response: CfResult = self .api - .post( - &format!("/client/v4/zones/{}/dns_records", self.forward_zone_id), + .req( + reqwest::Method::PATCH, + &format!("/client/v4/zones/{}/dns_records/{}", zone_id, record_id), CfRecord { - content: ip.to_string(), - name: name.to_string(), - r_type: if ip.is_ipv4() { - "A".to_string() - } else { - "AAAA".to_string() - }, - id: None, + content: record.value.to_string(), + name: record.name.to_string(), + r_type: Some(record.kind.to_string()), + id: Some(record_id.to_string()), }, ) .await?; Ok(BasicRecord { name: id_response.result.name, - value: ip.to_string(), + value: id_response.result.content, id: id_response.result.id, - kind: RecordType::A, + kind: record.kind.clone(), }) } - - async fn delete_a_record(&self, name: &str) -> anyhow::Result<()> { - todo!() - } } #[derive(Debug, Serialize, Deserialize)] struct CfRecord { pub content: String, pub name: String, + + #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "type")] - pub r_type: String, + pub r_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub id: Option, } diff --git a/src/dns/mod.rs b/src/dns/mod.rs index 952b418..d1027db 100644 --- a/src/dns/mod.rs +++ b/src/dns/mod.rs @@ -1,7 +1,9 @@ -use anyhow::Result; -use lnvps_db::async_trait; +use anyhow::{bail, Context, Result}; +use lnvps_db::{async_trait, VmIpAssignment}; use serde::{Deserialize, Serialize}; +use std::fmt::{Display, Formatter}; use std::net::IpAddr; +use std::str::FromStr; #[cfg(feature = "cloudflare")] mod cloudflare; @@ -11,16 +13,13 @@ pub use cloudflare::*; #[async_trait] pub trait DnsServer: Send + Sync { /// Add PTR record to the reverse zone - async fn add_ptr_record(&self, key: &str, value: &str) -> Result; + async fn add_record(&self, record: &BasicRecord) -> Result; /// Delete PTR record from the reverse zone - async fn delete_ptr_record(&self, key: &str) -> Result<()>; + async fn delete_record(&self, record: &BasicRecord) -> Result<()>; - /// Add A/AAAA record onto the forward zone - async fn add_a_record(&self, name: &str, ip: IpAddr) -> Result; - - /// Delete A/AAAA record from the forward zone - async fn delete_a_record(&self, name: &str) -> Result<()>; + /// Update a record + async fn update_record(&self, record: &BasicRecord) -> Result; } #[derive(Clone, Debug)] @@ -37,3 +36,117 @@ pub struct BasicRecord { pub id: Option, pub kind: RecordType, } + +impl Display for RecordType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + RecordType::A => write!(f, "A"), + RecordType::AAAA => write!(f, "AAAA"), + RecordType::PTR => write!(f, "PTR"), + } + } +} + +impl BasicRecord { + pub fn forward(ip: &VmIpAssignment) -> Result { + let addr = IpAddr::from_str(&ip.ip)?; + Ok(Self { + name: format!("vm-{}", &ip.vm_id), + value: addr.to_string(), + id: ip.dns_forward_ref.clone(), + kind: match addr { + IpAddr::V4(_) => RecordType::A, + IpAddr::V6(_) => RecordType::AAAA, + }, + }) + } + + pub fn reverse_to_fwd(ip: &VmIpAssignment) -> Result { + let addr = IpAddr::from_str(&ip.ip)?; + + // Use explicit reverse entry or use the forward entry + let fwd = ip + .dns_reverse + .as_ref() + .or(ip.dns_forward.as_ref()) + .context("Reverse/Forward DNS name required for reverse entry")? + .to_string(); + + if !is_valid_fqdn(fwd.as_str()) { + bail!("Forward DNS name is not a valid FQDN"); + } + Ok(Self { + name: match addr { + IpAddr::V4(i) => i.octets()[3].to_string(), + IpAddr::V6(_) => bail!("IPv6 PTR not supported"), + }, + value: fwd, + id: ip.dns_reverse_ref.clone(), + kind: RecordType::PTR, + }) + } + + pub fn reverse(ip: &VmIpAssignment) -> Result { + let addr = IpAddr::from_str(&ip.ip)?; + let rev = ip + .dns_reverse + .as_ref() + .context("Reverse DNS name required for reverse entry")? + .to_string(); + if !is_valid_fqdn(&rev) { + bail!("Reverse DNS name is not a valid FQDN"); + } + Ok(Self { + name: match addr { + IpAddr::V4(i) => i.octets()[3].to_string(), + IpAddr::V6(_) => bail!("IPv6 PTR not supported"), + }, + value: rev, + id: ip.dns_reverse_ref.clone(), + kind: RecordType::PTR, + }) + } +} + +/// Grok 3 +pub fn is_valid_fqdn(s: &str) -> bool { + // Remove trailing dot if present (optional in practice) + let s = s.strip_suffix('.').unwrap_or(s); + + // Check total length (max 255 chars, including dots) + if s.len() > 255 || s.is_empty() { + return false; + } + + // Split into labels and validate each + let labels: Vec<&str> = s.split('.').collect(); + + // Must have at least two labels (e.g., "example.com") + if labels.len() < 2 { + return false; + } + + for label in labels { + // Each label must be 1-63 chars + if label.len() > 63 || label.is_empty() { + return false; + } + + // Must start with a letter or digit + if !label.chars().next().unwrap().is_alphanumeric() { + return false; + } + + // Must end with a letter or digit + if !label.chars().last().unwrap().is_alphanumeric() { + return false; + } + + // Only letters, digits, and hyphens allowed + if !label.chars().all(|c| c.is_alphanumeric() || c == '-') { + return false; + } + } + + true +} diff --git a/src/mocks.rs b/src/mocks.rs index e4b6051..560f807 100644 --- a/src/mocks.rs +++ b/src/mocks.rs @@ -666,7 +666,7 @@ impl VmHostClient for MockVmHost { } } - async fn configure_vm(&self, vm: &Vm) -> anyhow::Result<()> { + async fn configure_vm(&self, vm: &FullVmInfo) -> anyhow::Result<()> { Ok(()) } } @@ -696,62 +696,42 @@ impl MockDnsServer { } #[async_trait] impl DnsServer for MockDnsServer { - async fn add_ptr_record(&self, key: &str, value: &str) -> anyhow::Result { - let mut rev = self.reverse.lock().await; + async fn add_record(&self, record: &BasicRecord) -> anyhow::Result { + let mut table = match record.kind { + RecordType::PTR => self.reverse.lock().await, + _ => self.forward.lock().await, + }; - if rev.values().any(|v| v.name == key) { - bail!("Duplicate record with name {}", key); + if table.values().any(|v| v.name == record.name) { + bail!("Duplicate record with name {}", record.name); } let rnd_id: [u8; 12] = rand::random(); let id = hex::encode(rnd_id); - rev.insert( + table.insert( id.clone(), MockDnsEntry { - name: key.to_string(), - value: value.to_string(), - kind: "PTR".to_string(), + name: record.name.to_string(), + value: record.value.to_string(), + kind: record.kind.to_string(), }, ); Ok(BasicRecord { - name: format!("{}.X.Y.Z.in-addr.arpa", key), - value: value.to_string(), + name: match record.kind { + RecordType::PTR => format!("{}.X.Y.Z.addr.in-arpa", record.name), + _ => format!("{}.lnvps.mock", record.name), + }, + value: record.value.clone(), id: Some(id), - kind: RecordType::PTR, + kind: record.kind.clone(), }) } - async fn delete_ptr_record(&self, key: &str) -> anyhow::Result<()> { + async fn delete_record(&self, record: &BasicRecord) -> anyhow::Result<()> { todo!() } - async fn add_a_record(&self, name: &str, ip: IpAddr) -> anyhow::Result { - let mut rev = self.forward.lock().await; - - if rev.values().any(|v| v.name == name) { - bail!("Duplicate record with name {}", name); - } - - let fqdn = format!("{}.lnvps.mock", name); - let rnd_id: [u8; 12] = rand::random(); - let id = hex::encode(rnd_id); - rev.insert( - id.clone(), - MockDnsEntry { - name: fqdn.clone(), - value: ip.to_string(), - kind: "A".to_string(), - }, - ); - Ok(BasicRecord { - name: fqdn, - value: ip.to_string(), - id: Some(id), - kind: RecordType::A, - }) - } - - async fn delete_a_record(&self, name: &str) -> anyhow::Result<()> { + async fn update_record(&self, name: &BasicRecord) -> anyhow::Result { todo!() } } diff --git a/src/provisioner/lnvps.rs b/src/provisioner/lnvps.rs index 5957d7a..7b04735 100644 --- a/src/provisioner/lnvps.rs +++ b/src/provisioner/lnvps.rs @@ -1,11 +1,11 @@ -use crate::dns::DnsServer; +use crate::dns::{BasicRecord, DnsServer}; use crate::exchange::{ExchangeRateService, Ticker}; use crate::host::{get_host_client, FullVmInfo}; use crate::lightning::{AddInvoiceRequest, LightningNode}; use crate::provisioner::{NetworkProvisioner, ProvisionerMethod}; use crate::router::Router; use crate::settings::{NetworkAccessPolicy, NetworkPolicy, ProvisionerConfig, Settings}; -use anyhow::{bail, Result}; +use anyhow::{bail, Context, Result}; use chrono::{Days, Months, Utc}; use futures::future::join_all; use lnvps_db::{IpRange, LNVpsDb, Vm, VmCostPlanIntervalType, VmIpAssignment, VmPayment}; @@ -55,7 +55,8 @@ impl LNVpsProvisioner { } } - async fn delete_ip_assignment(&self, vm: &Vm) -> Result<()> { + pub async fn delete_ip_assignment(&self, vm: &Vm) -> Result<()> { + // Delete access policy if let NetworkAccessPolicy::StaticArp { .. } = &self.network_policy.access { if let Some(r) = self.router.as_ref() { let ent = r.list_arp_entry().await?; @@ -72,6 +73,45 @@ impl LNVpsProvisioner { Ok(()) } + pub async fn update_forward_ip_dns(&self, assignment: &mut VmIpAssignment) -> Result<()> { + if let Some(dns) = &self.dns { + let fwd = BasicRecord::forward(assignment)?; + let ret_fwd = if fwd.id.is_some() { + dns.update_record(&fwd).await? + } else { + dns.add_record(&fwd).await? + }; + assignment.dns_forward = Some(ret_fwd.name); + assignment.dns_forward_ref = Some(ret_fwd.id.context("Record id is missing")?); + + // save to db + if assignment.id != 0 { + self.db.update_vm_ip_assignment(assignment).await?; + } + } + Ok(()) + } + + pub async fn update_reverse_ip_dns(&self, assignment: &mut VmIpAssignment) -> Result<()> { + if let Some(dns) = &self.dns { + let ret_rev = if assignment.dns_reverse_ref.is_some() { + dns.update_record(&BasicRecord::reverse(assignment)?) + .await? + } else { + dns.add_record(&BasicRecord::reverse_to_fwd(assignment)?) + .await? + }; + assignment.dns_reverse = Some(ret_rev.value); + assignment.dns_reverse_ref = Some(ret_rev.id.context("Record id is missing")?); + + // save to db + if assignment.id != 0 { + self.db.update_vm_ip_assignment(assignment).await?; + } + } + Ok(()) + } + async fn save_ip_assignment(&self, vm: &Vm, assignment: &mut VmIpAssignment) -> Result<()> { let ip = IpAddr::from_str(&assignment.ip)?; @@ -91,27 +131,11 @@ impl LNVpsProvisioner { } // Add DNS records - if let Some(dns) = &self.dns { - let sub_name = format!("vm-{}", vm.id); - let fwd = dns.add_a_record(&sub_name, ip.clone()).await?; - assignment.dns_forward = Some(fwd.name.clone()); - assignment.dns_forward_ref = fwd.id; - - match ip { - IpAddr::V4(ip) => { - let last_octet = ip.octets()[3].to_string(); - let rev = dns.add_ptr_record(&last_octet, &fwd.name).await?; - assignment.dns_reverse = Some(fwd.name.clone()); - assignment.dns_reverse_ref = rev.id; - } - IpAddr::V6(_) => { - warn!("IPv6 forward DNS not supported yet") - } - } - } + self.update_forward_ip_dns(assignment).await?; + self.update_reverse_ip_dns(assignment).await?; // save to db - self.db.insert_vm_ip_assignment(&assignment).await?; + self.db.insert_vm_ip_assignment(assignment).await?; Ok(()) } @@ -416,7 +440,7 @@ mod tests { assert!(ip.dns_forward.is_some()); assert!(ip.dns_reverse.is_some()); assert!(ip.dns_reverse_ref.is_some()); - assert!(ip.dns_forward.is_some()); + assert!(ip.dns_forward_ref.is_some()); assert_eq!(ip.dns_reverse, ip.dns_forward); // assert IP address is not CIDR