feat: taxes

closes #18
This commit is contained in:
2025-03-11 15:58:34 +00:00
parent 029f2cb6e1
commit 02d606d60c
13 changed files with 222 additions and 63 deletions

11
Cargo.lock generated
View File

@ -1863,6 +1863,16 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "isocountry"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ea1dc4bf0fb4904ba83ffdb98af3d9c325274e92e6e295e4151e86c96363e04"
dependencies = [
"serde",
"thiserror 1.0.69",
]
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.12.1" version = "0.12.1"
@ -2011,6 +2021,7 @@ dependencies = [
"hex", "hex",
"hmac", "hmac",
"ipnetwork", "ipnetwork",
"isocountry",
"lettre", "lettre",
"lnvps_db", "lnvps_db",
"log", "log",

View File

@ -41,6 +41,7 @@ ws = { package = "rocket_ws", version = "0.1.0" }
native-tls = "0.2.12" native-tls = "0.2.12"
hex = "0.4.3" hex = "0.4.3"
futures = "0.3.31" futures = "0.3.31"
isocountry = "0.3.2"
#nostr-dm #nostr-dm
nostr = { version = "0.39.0", default-features = false, features = ["std"] } nostr = { version = "0.39.0", default-features = false, features = ["std"] }

View File

@ -127,4 +127,14 @@ dns:
forward-zone-id: "my-forward-zone-id" forward-zone-id: "my-forward-zone-id"
# API token to add/remove DNS records to this zone # API token to add/remove DNS records to this zone
token: "my-api-token" token: "my-api-token"
``` ```
### Taxes
To charge taxes add the following config, the values are percentage whole numbers:
```yaml
tax-rate:
IE: 23
US: 15
```
Taxes are charged based on the users specified country

View File

@ -0,0 +1,4 @@
alter table vm_payment
add column tax bigint unsigned not null;
alter table users
add column country_code varchar(3) not null default 'USA';

View File

@ -21,6 +21,8 @@ pub struct User {
pub contact_nip17: bool, pub contact_nip17: bool,
/// If user should be contacted via email for notifications /// If user should be contacted via email for notifications
pub contact_email: bool, pub contact_email: bool,
/// Users country
pub country_code: String,
} }
#[derive(FromRow, Clone, Debug, Default)] #[derive(FromRow, Clone, Debug, Default)]
@ -327,6 +329,8 @@ pub struct VmPayment {
pub rate: f32, pub rate: f32,
/// Number of seconds this payment will add to vm expiry /// Number of seconds this payment will add to vm expiry
pub time_value: u64, pub time_value: u64,
/// Taxes to charge on payment
pub tax: u64,
} }
#[derive(Type, Clone, Copy, Debug, Default, PartialEq)] #[derive(Type, Clone, Copy, Debug, Default, PartialEq)]

View File

@ -60,11 +60,12 @@ impl LNVpsDb for LNVpsDbMysql {
async fn update_user(&self, user: &User) -> Result<()> { async fn update_user(&self, user: &User) -> Result<()> {
sqlx::query( sqlx::query(
"update users set email = ?, contact_nip17 = ?, contact_email = ? where id = ?", "update users set email=?, contact_nip17=?, contact_email=?, country_code=? where id = ?",
) )
.bind(&user.email) .bind(&user.email)
.bind(user.contact_nip17) .bind(user.contact_nip17)
.bind(user.contact_email) .bind(user.contact_email)
.bind(&user.country_code)
.bind(user.id) .bind(user.id)
.execute(&self.db) .execute(&self.db)
.await?; .await?;
@ -387,12 +388,13 @@ impl LNVpsDb for LNVpsDbMysql {
} }
async fn insert_vm_payment(&self, vm_payment: &VmPayment) -> Result<()> { async fn insert_vm_payment(&self, vm_payment: &VmPayment) -> Result<()> {
sqlx::query("insert into vm_payment(id,vm_id,created,expires,amount,currency,payment_method,time_value,is_paid,rate,external_id,external_data) values(?,?,?,?,?,?,?,?,?,?,?,?)") sqlx::query("insert into vm_payment(id,vm_id,created,expires,amount,tax,currency,payment_method,time_value,is_paid,rate,external_id,external_data) values(?,?,?,?,?,?,?,?,?,?,?,?,?)")
.bind(&vm_payment.id) .bind(&vm_payment.id)
.bind(vm_payment.vm_id) .bind(vm_payment.vm_id)
.bind(vm_payment.created) .bind(vm_payment.created)
.bind(vm_payment.expires) .bind(vm_payment.expires)
.bind(vm_payment.amount) .bind(vm_payment.amount)
.bind(vm_payment.tax)
.bind(&vm_payment.currency) .bind(&vm_payment.currency)
.bind(&vm_payment.payment_method) .bind(&vm_payment.payment_method)
.bind(vm_payment.time_value) .bind(vm_payment.time_value)

View File

@ -402,6 +402,7 @@ pub struct AccountPatchRequest {
pub email: Option<String>, pub email: Option<String>,
pub contact_nip17: bool, pub contact_nip17: bool,
pub contact_email: bool, pub contact_email: bool,
pub country_code: String,
} }
#[derive(Serialize, Deserialize, JsonSchema)] #[derive(Serialize, Deserialize, JsonSchema)]
@ -473,6 +474,7 @@ pub struct ApiVmPayment {
pub created: DateTime<Utc>, pub created: DateTime<Utc>,
pub expires: DateTime<Utc>, pub expires: DateTime<Utc>,
pub amount: u64, pub amount: u64,
pub tax: u64,
pub currency: String, pub currency: String,
pub is_paid: bool, pub is_paid: bool,
pub data: ApiPaymentData, pub data: ApiPaymentData,
@ -486,6 +488,7 @@ impl From<lnvps_db::VmPayment> for ApiVmPayment {
created: value.created, created: value.created,
expires: value.expires, expires: value.expires,
amount: value.amount, amount: value.amount,
tax: value.tax,
currency: value.currency, currency: value.currency,
is_paid: value.is_paid, is_paid: value.is_paid,
data: match &value.payment_method { data: match &value.payment_method {

View File

@ -1,8 +1,8 @@
use crate::api::model::{ use crate::api::model::{
AccountPatchRequest, ApiCustomTemplateParams, ApiCustomVmOrder, AccountPatchRequest, ApiCustomTemplateParams, ApiCustomVmOrder, ApiCustomVmRequest,
ApiCustomVmRequest, ApiPaymentInfo, ApiPaymentMethod, ApiPrice, ApiTemplatesResponse, ApiPaymentInfo, ApiPaymentMethod, ApiPrice, ApiTemplatesResponse, ApiUserSshKey,
ApiUserSshKey, ApiVmIpAssignment, ApiVmOsImage, ApiVmPayment, ApiVmStatus, ApiVmIpAssignment, ApiVmOsImage, ApiVmPayment, ApiVmStatus, ApiVmTemplate, CreateSshKey,
ApiVmTemplate, CreateSshKey, CreateVmRequest, VMPatchRequest, CreateVmRequest, VMPatchRequest,
}; };
use crate::exchange::{Currency, ExchangeRateService}; use crate::exchange::{Currency, ExchangeRateService};
use crate::host::{get_host_client, FullVmInfo, TimeSeries, TimeSeriesData}; use crate::host::{get_host_client, FullVmInfo, TimeSeries, TimeSeriesData};
@ -13,6 +13,7 @@ use crate::status::{VmState, VmStateCache};
use crate::worker::WorkJob; use crate::worker::WorkJob;
use anyhow::Result; use anyhow::Result;
use futures::future::join_all; use futures::future::join_all;
use isocountry::CountryCode;
use lnvps_db::{ use lnvps_db::{
IpRange, LNVpsDb, PaymentMethod, VmCustomPricing, VmCustomPricingDisk, VmCustomTemplate, IpRange, LNVpsDb, PaymentMethod, VmCustomPricing, VmCustomPricingDisk, VmCustomTemplate,
}; };
@ -107,6 +108,9 @@ async fn v1_patch_account(
user.email = req.email.clone(); user.email = req.email.clone();
user.contact_nip17 = req.contact_nip17; user.contact_nip17 = req.contact_nip17;
user.contact_email = req.contact_email; user.contact_email = req.contact_email;
user.country_code = CountryCode::for_alpha3(&req.country_code)?
.alpha3()
.to_owned();
db.update_user(&user).await?; db.update_user(&user).await?;
ApiData::ok(()) ApiData::ok(())
@ -127,6 +131,7 @@ async fn v1_get_account(
email: user.email, email: user.email,
contact_nip17: user.contact_nip17, contact_nip17: user.contact_nip17,
contact_email: user.contact_email, contact_email: user.contact_email,
country_code: user.country_code,
}) })
} }
@ -349,14 +354,14 @@ async fn v1_list_vm_templates(
.filter_map(|t| { .filter_map(|t| {
let region = regions.get(&t.region_id)?; let region = regions.get(&t.region_id)?;
ApiCustomTemplateParams::from( ApiCustomTemplateParams::from(
&t, &t,
&custom_template_disks, &custom_template_disks,
region, region,
max_cpu, max_cpu,
max_memory, max_memory,
max_disk, max_disk,
) )
.ok() .ok()
}) })
.collect(), .collect(),
) )

View File

@ -19,3 +19,11 @@ pub mod worker;
#[cfg(test)] #[cfg(test)]
pub mod mocks; pub mod mocks;
/// SATS per BTC
pub const BTC_SATS: f64 = 100_000_000.0;
pub const KB: u64 = 1024;
pub const MB: u64 = KB * 1024;
pub const GB: u64 = MB * 1024;
pub const TB: u64 = GB * 1024;

View File

@ -40,11 +40,6 @@ pub struct MockDb {
} }
impl MockDb { impl MockDb {
pub const KB: u64 = 1024;
pub const MB: u64 = Self::KB * 1024;
pub const GB: u64 = Self::MB * 1024;
pub const TB: u64 = Self::GB * 1024;
pub fn empty() -> MockDb { pub fn empty() -> MockDb {
Self { Self {
..Default::default() ..Default::default()
@ -71,8 +66,8 @@ impl MockDb {
created: Utc::now(), created: Utc::now(),
expires: None, expires: None,
cpu: 2, cpu: 2,
memory: Self::GB * 2, memory: crate::GB * 2,
disk_size: Self::GB * 64, disk_size: crate::GB * 64,
disk_type: DiskType::SSD, disk_type: DiskType::SSD,
disk_interface: DiskInterface::PCIe, disk_interface: DiskInterface::PCIe,
cost_plan_id: 1, cost_plan_id: 1,
@ -132,7 +127,7 @@ impl Default for MockDb {
name: "mock-host".to_string(), name: "mock-host".to_string(),
ip: "https://localhost".to_string(), ip: "https://localhost".to_string(),
cpu: 4, cpu: 4,
memory: 8 * Self::GB, memory: 8 * crate::GB,
enabled: true, enabled: true,
api_token: "".to_string(), api_token: "".to_string(),
load_factor: 1.5, load_factor: 1.5,
@ -145,7 +140,7 @@ impl Default for MockDb {
id: 1, id: 1,
host_id: 1, host_id: 1,
name: "mock-disk".to_string(), name: "mock-disk".to_string(),
size: Self::TB * 10, size: crate::TB * 10,
kind: DiskType::SSD, kind: DiskType::SSD,
interface: DiskInterface::PCIe, interface: DiskInterface::PCIe,
enabled: true, enabled: true,
@ -209,6 +204,7 @@ impl LNVpsDb for MockDb {
email: None, email: None,
contact_nip17: false, contact_nip17: false,
contact_email: false, contact_email: false,
country_code: "USA".to_string(),
}, },
); );
Ok(max + 1) Ok(max + 1)
@ -650,14 +646,15 @@ impl Router for MockRouter {
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub struct MockNode { pub struct MockNode {
invoices: Arc<Mutex<HashMap<String, MockInvoice>>>, pub invoices: Arc<Mutex<HashMap<String, MockInvoice>>>,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct MockInvoice { pub struct MockInvoice {
pr: String, pub pr: String,
expiry: DateTime<Utc>, pub amount: u64,
settle_index: u64, pub expiry: DateTime<Utc>,
pub is_paid: bool,
} }
impl MockNode { impl MockNode {
@ -673,7 +670,22 @@ impl MockNode {
#[async_trait] #[async_trait]
impl LightningNode for MockNode { impl LightningNode for MockNode {
async fn add_invoice(&self, req: AddInvoiceRequest) -> anyhow::Result<AddInvoiceResult> { async fn add_invoice(&self, req: AddInvoiceRequest) -> anyhow::Result<AddInvoiceResult> {
todo!() let mut invoices = self.invoices.lock().await;
let id: [u8; 32] = rand::random();
let hex_id = hex::encode(id);
invoices.insert(
hex_id.clone(),
MockInvoice {
pr: format!("lnrt1{}", hex_id),
amount: req.amount,
expiry: Utc::now().add(TimeDelta::seconds(req.expire.unwrap_or(3600) as i64)),
is_paid: false,
},
);
Ok(AddInvoiceResult {
pr: format!("lnrt1{}", hex_id),
payment_hash: hex_id.clone(),
})
} }
async fn subscribe_invoices( async fn subscribe_invoices(

View File

@ -10,12 +10,11 @@ use crate::router::{ArpEntry, Router};
use crate::settings::{NetworkAccessPolicy, NetworkPolicy, ProvisionerConfig, Settings}; use crate::settings::{NetworkAccessPolicy, NetworkPolicy, ProvisionerConfig, Settings};
use anyhow::{bail, ensure, Context, Result}; use anyhow::{bail, ensure, Context, Result};
use chrono::Utc; use chrono::Utc;
use lnvps_db::{ use isocountry::CountryCode;
LNVpsDb, PaymentMethod, Vm, VmCustomTemplate, VmIpAssignment, use lnvps_db::{LNVpsDb, PaymentMethod, Vm, VmCustomTemplate, VmIpAssignment, VmPayment};
VmPayment,
};
use log::{info, warn}; use log::{info, warn};
use nostr::util::hex; use nostr::util::hex;
use std::collections::HashMap;
use std::ops::Add; use std::ops::Add;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -29,6 +28,7 @@ pub struct LNVpsProvisioner {
db: Arc<dyn LNVpsDb>, db: Arc<dyn LNVpsDb>,
node: Arc<dyn LightningNode>, node: Arc<dyn LightningNode>,
rates: Arc<dyn ExchangeRateService>, rates: Arc<dyn ExchangeRateService>,
tax_rates: HashMap<CountryCode, f32>,
router: Option<Arc<dyn Router>>, router: Option<Arc<dyn Router>>,
dns: Option<Arc<dyn DnsServer>>, dns: Option<Arc<dyn DnsServer>>,
@ -52,6 +52,7 @@ impl LNVpsProvisioner {
router: settings.get_router().expect("router config"), router: settings.get_router().expect("router config"),
dns: settings.get_dns().expect("dns config"), dns: settings.get_dns().expect("dns config"),
revolut: settings.get_revolut().expect("revolut config"), revolut: settings.get_revolut().expect("revolut config"),
tax_rates: settings.tax_rate,
network_policy: settings.network_policy, network_policy: settings.network_policy,
provisioner_config: settings.provisioner, provisioner_config: settings.provisioner,
read_only: settings.read_only, read_only: settings.read_only,
@ -356,7 +357,7 @@ impl LNVpsProvisioner {
/// Create a renewal payment /// Create a renewal payment
pub async fn renew(&self, vm_id: u64, method: PaymentMethod) -> Result<VmPayment> { pub async fn renew(&self, vm_id: u64, method: PaymentMethod) -> Result<VmPayment> {
let pe = PricingEngine::new(self.db.clone(), self.rates.clone()); let pe = PricingEngine::new(self.db.clone(), self.rates.clone(), self.tax_rates.clone());
let price = pe.get_vm_cost(vm_id, method).await?; let price = pe.get_vm_cost(vm_id, method).await?;
match price { match price {
@ -367,6 +368,7 @@ impl LNVpsProvisioner {
time_value, time_value,
new_expiry, new_expiry,
rate, rate,
tax,
} => { } => {
let desc = format!("VM renewal {vm_id} to {new_expiry}"); let desc = format!("VM renewal {vm_id} to {new_expiry}");
let vm_payment = match method { let vm_payment = match method {
@ -376,12 +378,16 @@ impl LNVpsProvisioner {
"Cannot create invoices for non-BTC currency" "Cannot create invoices for non-BTC currency"
); );
const INVOICE_EXPIRE: u64 = 600; const INVOICE_EXPIRE: u64 = 600;
info!("Creating invoice for {vm_id} for {} sats", amount / 1000); let total_amount = amount + tax;
info!(
"Creating invoice for {vm_id} for {} sats",
total_amount / 1000
);
let invoice = self let invoice = self
.node .node
.add_invoice(AddInvoiceRequest { .add_invoice(AddInvoiceRequest {
memo: Some(desc), memo: Some(desc),
amount, amount: total_amount,
expire: Some(INVOICE_EXPIRE as u32), expire: Some(INVOICE_EXPIRE as u32),
}) })
.await?; .await?;
@ -391,6 +397,7 @@ impl LNVpsProvisioner {
created: Utc::now(), created: Utc::now(),
expires: Utc::now().add(Duration::from_secs(INVOICE_EXPIRE)), expires: Utc::now().add(Duration::from_secs(INVOICE_EXPIRE)),
amount, amount,
tax,
currency: currency.to_string(), currency: currency.to_string(),
payment_method: method, payment_method: method,
time_value, time_value,
@ -411,7 +418,7 @@ impl LNVpsProvisioner {
"Cannot create revolut orders for BTC currency" "Cannot create revolut orders for BTC currency"
); );
let order = rev let order = rev
.create_order(&desc, CurrencyAmount::from_u64(currency, amount)) .create_order(&desc, CurrencyAmount::from_u64(currency, amount + tax))
.await?; .await?;
let new_id: [u8; 32] = rand::random(); let new_id: [u8; 32] = rand::random();
VmPayment { VmPayment {
@ -420,6 +427,7 @@ impl LNVpsProvisioner {
created: Utc::now(), created: Utc::now(),
expires: Utc::now().add(Duration::from_secs(3600)), expires: Utc::now().add(Duration::from_secs(3600)),
amount, amount,
tax,
currency: currency.to_string(), currency: currency.to_string(),
payment_method: method, payment_method: method,
time_value, time_value,
@ -483,8 +491,8 @@ impl LNVpsProvisioner {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::exchange::DefaultRateCache; use crate::exchange::{DefaultRateCache, Ticker};
use crate::mocks::{MockDb, MockDnsServer, MockNode, MockRouter}; use crate::mocks::{MockDb, MockDnsServer, MockExchangeRate, MockNode, MockRouter};
use crate::settings::{DnsServerConfig, LightningConfig, QemuConfig, RouterConfig}; use crate::settings::{DnsServerConfig, LightningConfig, QemuConfig, RouterConfig};
use lnvps_db::{DiskInterface, DiskType, User, UserSshKey, VmTemplate}; use lnvps_db::{DiskInterface, DiskType, User, UserSshKey, VmTemplate};
use std::net::IpAddr; use std::net::IpAddr;
@ -535,6 +543,7 @@ mod tests {
}), }),
nostr: None, nostr: None,
revolut: None, revolut: None,
tax_rate: HashMap::from([(CountryCode::IRL, 23.0), (CountryCode::USA, 1.0)]),
} }
} }
@ -559,7 +568,10 @@ mod tests {
let settings = settings(); let settings = settings();
let db = Arc::new(MockDb::default()); let db = Arc::new(MockDb::default());
let node = Arc::new(MockNode::default()); let node = Arc::new(MockNode::default());
let rates = Arc::new(DefaultRateCache::default()); let rates = Arc::new(MockExchangeRate::new());
const MOCK_RATE: f32 = 69_420.0;
rates.set_rate(Ticker::btc_rate("EUR")?, MOCK_RATE).await;
let router = MockRouter::new(settings.network_policy.clone()); let router = MockRouter::new(settings.network_policy.clone());
let dns = MockDnsServer::new(); let dns = MockDnsServer::new();
let provisioner = LNVpsProvisioner::new(settings, db.clone(), node.clone(), rates.clone()); let provisioner = LNVpsProvisioner::new(settings, db.clone(), node.clone(), rates.clone());
@ -569,6 +581,21 @@ mod tests {
.provision(user.id, 1, 1, ssh_key.id, Some("mock-ref".to_string())) .provision(user.id, 1, 1, ssh_key.id, Some("mock-ref".to_string()))
.await?; .await?;
println!("{:?}", vm); println!("{:?}", vm);
// renew vm
let payment = provisioner.renew(vm.id, PaymentMethod::Lightning).await?;
assert_eq!(vm.id, payment.vm_id);
assert_eq!(payment.tax, (payment.amount as f64 * 0.01).floor() as u64);
// check invoice amount matches amount+tax
let inv = node.invoices.lock().await;
if let Some(i) = inv.get(&hex::encode(payment.id)) {
assert_eq!(i.amount, payment.amount + payment.tax);
} else {
bail!("Invoice doesnt exist");
}
// spawn vm
provisioner.spawn_vm(vm.id).await?; provisioner.spawn_vm(vm.id).await?;
// check resources // check resources
@ -636,8 +663,8 @@ mod tests {
created: Default::default(), created: Default::default(),
expires: None, expires: None,
cpu: 64, cpu: 64,
memory: 512 * MockDb::GB, memory: 512 * crate::GB,
disk_size: 20 * MockDb::TB, disk_size: 20 * crate::TB,
disk_type: DiskType::SSD, disk_type: DiskType::SSD,
disk_interface: DiskInterface::PCIe, disk_interface: DiskInterface::PCIe,
cost_plan_id: 1, cost_plan_id: 1,

View File

@ -1,11 +1,13 @@
use crate::exchange::{Currency, CurrencyAmount, ExchangeRateService, Ticker, TickerRate}; use crate::exchange::{Currency, CurrencyAmount, ExchangeRateService, Ticker, TickerRate};
use anyhow::{bail, Result}; use anyhow::{bail, Context, Result};
use chrono::{DateTime, Days, Months, TimeDelta, Utc}; use chrono::{DateTime, Days, Months, TimeDelta, Utc};
use ipnetwork::IpNetwork; use ipnetwork::IpNetwork;
use isocountry::CountryCode;
use lnvps_db::{ use lnvps_db::{
LNVpsDb, PaymentMethod, Vm, VmCostPlan, VmCostPlanIntervalType, VmCustomTemplate, VmPayment, LNVpsDb, PaymentMethod, Vm, VmCostPlan, VmCostPlanIntervalType, VmCustomTemplate, VmPayment,
}; };
use log::info; use log::info;
use std::collections::HashMap;
use std::ops::Add; use std::ops::Add;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
@ -16,17 +18,20 @@ use std::sync::Arc;
pub struct PricingEngine { pub struct PricingEngine {
db: Arc<dyn LNVpsDb>, db: Arc<dyn LNVpsDb>,
rates: Arc<dyn ExchangeRateService>, rates: Arc<dyn ExchangeRateService>,
tax_rates: HashMap<CountryCode, f32>,
} }
impl PricingEngine { impl PricingEngine {
/// SATS per BTC pub fn new(
const BTC_SATS: f64 = 100_000_000.0; db: Arc<dyn LNVpsDb>,
const KB: u64 = 1024; rates: Arc<dyn ExchangeRateService>,
const MB: u64 = Self::KB * 1024; tax_rates: HashMap<CountryCode, f32>,
const GB: u64 = Self::MB * 1024; ) -> Self {
Self {
pub fn new(db: Arc<dyn LNVpsDb>, rates: Arc<dyn ExchangeRateService>) -> Self { db,
Self { db, rates } rates,
tax_rates,
}
} }
/// Get VM cost (for renewal) /// Get VM cost (for renewal)
@ -82,9 +87,9 @@ impl PricingEngine {
} else { } else {
bail!("No disk price found") bail!("No disk price found")
}; };
let disk_cost = (template.disk_size / Self::GB) as f32 * disk_pricing.cost; let disk_cost = (template.disk_size / crate::GB) as f32 * disk_pricing.cost;
let cpu_cost = pricing.cpu_cost * template.cpu as f32; let cpu_cost = pricing.cpu_cost * template.cpu as f32;
let memory_cost = pricing.memory_cost * (template.memory / Self::GB) as f32; let memory_cost = pricing.memory_cost * (template.memory / crate::GB) as f32;
let ip4_cost = pricing.ip4_cost * v4s as f32; let ip4_cost = pricing.ip4_cost * v4s as f32;
let ip6_cost = pricing.ip6_cost * v6s as f32; let ip6_cost = pricing.ip6_cost * v6s as f32;
@ -121,6 +126,7 @@ impl PricingEngine {
.await?; .await?;
Ok(CostResult::New { Ok(CostResult::New {
amount, amount,
tax: self.get_tax_for_user(vm.user_id, amount).await?,
currency, currency,
rate, rate,
time_value, time_value,
@ -128,6 +134,16 @@ impl PricingEngine {
}) })
} }
async fn get_tax_for_user(&self, user_id: u64, amount: u64) -> Result<u64> {
let user = self.db.get_user(user_id).await?;
let cc = CountryCode::for_alpha3(&user.country_code).context("Invalid country code")?;
if let Some(c) = self.tax_rates.get(&cc) {
Ok((amount as f64 * (*c as f64 / 100f64)).floor() as u64)
} else {
Ok(0)
}
}
async fn get_ticker(&self, currency: Currency) -> Result<TickerRate> { async fn get_ticker(&self, currency: Currency) -> Result<TickerRate> {
let ticker = Ticker(Currency::BTC, currency); let ticker = Ticker(Currency::BTC, currency);
if let Some(r) = self.rates.get_rate(ticker).await { if let Some(r) = self.rates.get_rate(ticker).await {
@ -140,7 +156,7 @@ impl PricingEngine {
async fn get_msats_amount(&self, amount: CurrencyAmount) -> Result<(u64, f32)> { async fn get_msats_amount(&self, amount: CurrencyAmount) -> Result<(u64, f32)> {
let rate = self.get_ticker(amount.0).await?; let rate = self.get_ticker(amount.0).await?;
let cost_btc = amount.1 / rate.1; let cost_btc = amount.1 / rate.1;
let cost_msats = (cost_btc as f64 * Self::BTC_SATS) as u64 * 1000; let cost_msats = (cost_btc as f64 * crate::BTC_SATS) as u64 * 1000;
Ok((cost_msats, rate.1)) Ok((cost_msats, rate.1))
} }
@ -174,6 +190,7 @@ impl PricingEngine {
let time_value = Self::next_template_expire(vm, &cost_plan); let time_value = Self::next_template_expire(vm, &cost_plan);
Ok(CostResult::New { Ok(CostResult::New {
amount, amount,
tax: self.get_tax_for_user(vm.user_id, amount).await?,
currency, currency,
rate, rate,
time_value, time_value,
@ -215,6 +232,8 @@ pub enum CostResult {
time_value: u64, time_value: u64,
/// The absolute expiry time of the vm if renewed /// The absolute expiry time of the vm if renewed
new_expiry: DateTime<Utc>, new_expiry: DateTime<Utc>,
/// Taxes to charge
tax: u64,
}, },
} }
@ -238,8 +257,7 @@ impl PricingData {
mod tests { mod tests {
use super::*; use super::*;
use crate::mocks::{MockDb, MockExchangeRate}; use crate::mocks::{MockDb, MockExchangeRate};
use lnvps_db::{DiskType, VmCustomPricing, VmCustomPricingDisk, VmCustomTemplate}; use lnvps_db::{DiskType, User, VmCustomPricing, VmCustomPricingDisk, VmCustomTemplate};
const GB: u64 = 1024 * 1024 * 1024;
const MOCK_RATE: f32 = 100_000.0; const MOCK_RATE: f32 = 100_000.0;
async fn add_custom_pricing(db: &MockDb) { async fn add_custom_pricing(db: &MockDb) {
@ -266,8 +284,8 @@ mod tests {
VmCustomTemplate { VmCustomTemplate {
id: 1, id: 1,
cpu: 2, cpu: 2,
memory: 2 * GB, memory: 2 * crate::GB,
disk_size: 80 * GB, disk_size: 80 * crate::GB,
disk_type: DiskType::SSD, disk_type: DiskType::SSD,
disk_interface: Default::default(), disk_interface: Default::default(),
pricing_id: 1, pricing_id: 1,
@ -313,17 +331,65 @@ mod tests {
{ {
let mut v = db.vms.lock().await; let mut v = db.vms.lock().await;
v.insert(1, MockDb::mock_vm()); v.insert(1, MockDb::mock_vm());
v.insert(
2,
Vm {
user_id: 2,
..MockDb::mock_vm()
},
);
let mut u = db.users.lock().await;
u.insert(
1,
User {
id: 1,
pubkey: vec![],
created: Default::default(),
email: None,
contact_nip17: false,
contact_email: false,
country_code: "USA".to_string(),
},
);
u.insert(
2,
User {
id: 2,
pubkey: vec![],
created: Default::default(),
email: None,
contact_nip17: false,
contact_email: false,
country_code: "IRL".to_string(),
},
);
} }
let db: Arc<dyn LNVpsDb> = Arc::new(db); let db: Arc<dyn LNVpsDb> = Arc::new(db);
let pe = PricingEngine::new(db.clone(), rates); let taxes = HashMap::from([(CountryCode::IRL, 23.0)]);
let price = pe.get_vm_cost(1, PaymentMethod::Lightning).await?;
let pe = PricingEngine::new(db.clone(), rates, taxes);
let plan = MockDb::mock_cost_plan(); let plan = MockDb::mock_cost_plan();
let price = pe.get_vm_cost(1, PaymentMethod::Lightning).await?;
match price { match price {
CostResult::New { amount, .. } => { CostResult::New { amount, tax, .. } => {
let expect_price = (plan.amount / MOCK_RATE * 1.0e11) as u64; let expect_price = (plan.amount / MOCK_RATE * 1.0e11) as u64;
assert_eq!(expect_price, amount); assert_eq!(expect_price, amount);
assert_eq!(0, tax);
}
_ => bail!("??"),
}
// with taxes
let price = pe.get_vm_cost(2, PaymentMethod::Lightning).await?;
match price {
CostResult::New { amount, tax, .. } => {
let expect_price = (plan.amount / MOCK_RATE * 1.0e11) as u64;
assert_eq!(expect_price, amount);
assert_eq!((expect_price as f64 * 0.23).floor() as u64, tax);
} }
_ => bail!("??"), _ => bail!("??"),
} }

View File

@ -1,3 +1,4 @@
use std::collections::HashMap;
use crate::dns::DnsServer; use crate::dns::DnsServer;
use crate::exchange::ExchangeRateService; use crate::exchange::ExchangeRateService;
use crate::fiat::FiatPaymentService; use crate::fiat::FiatPaymentService;
@ -9,6 +10,7 @@ use lnvps_db::LNVpsDb;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use isocountry::CountryCode;
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")] #[serde(rename_all = "kebab-case")]
@ -31,8 +33,8 @@ pub struct Settings {
/// Provisioning profiles /// Provisioning profiles
pub provisioner: ProvisionerConfig, pub provisioner: ProvisionerConfig,
/// Network policy
#[serde(default)] #[serde(default)]
/// Network policy
pub network_policy: NetworkPolicy, pub network_policy: NetworkPolicy,
/// Number of days after an expired VM is deleted /// Number of days after an expired VM is deleted
@ -52,6 +54,10 @@ pub struct Settings {
/// Config for accepting revolut payments /// Config for accepting revolut payments
pub revolut: Option<RevolutConfig>, pub revolut: Option<RevolutConfig>,
#[serde(default)]
/// Tax rates to change per country as a percent of the amount
pub tax_rate: HashMap<CountryCode, f32>,
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]