feat: taxes

closes #18
This commit is contained in:
2025-03-11 15:58:34 +00:00
parent 029f2cb6e1
commit 02d606d60c
13 changed files with 222 additions and 63 deletions

11
Cargo.lock generated
View File

@ -1863,6 +1863,16 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "isocountry"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ea1dc4bf0fb4904ba83ffdb98af3d9c325274e92e6e295e4151e86c96363e04"
dependencies = [
"serde",
"thiserror 1.0.69",
]
[[package]]
name = "itertools"
version = "0.12.1"
@ -2011,6 +2021,7 @@ dependencies = [
"hex",
"hmac",
"ipnetwork",
"isocountry",
"lettre",
"lnvps_db",
"log",

View File

@ -41,6 +41,7 @@ ws = { package = "rocket_ws", version = "0.1.0" }
native-tls = "0.2.12"
hex = "0.4.3"
futures = "0.3.31"
isocountry = "0.3.2"
#nostr-dm
nostr = { version = "0.39.0", default-features = false, features = ["std"] }

View File

@ -128,3 +128,13 @@ dns:
# API token to add/remove DNS records to this zone
token: "my-api-token"
```
### Taxes
To charge taxes add the following config, the values are percentage whole numbers:
```yaml
tax-rate:
IE: 23
US: 15
```
Taxes are charged based on the users specified country

View File

@ -0,0 +1,4 @@
alter table vm_payment
add column tax bigint unsigned not null;
alter table users
add column country_code varchar(3) not null default 'USA';

View File

@ -21,6 +21,8 @@ pub struct User {
pub contact_nip17: bool,
/// If user should be contacted via email for notifications
pub contact_email: bool,
/// Users country
pub country_code: String,
}
#[derive(FromRow, Clone, Debug, Default)]
@ -327,6 +329,8 @@ pub struct VmPayment {
pub rate: f32,
/// Number of seconds this payment will add to vm expiry
pub time_value: u64,
/// Taxes to charge on payment
pub tax: u64,
}
#[derive(Type, Clone, Copy, Debug, Default, PartialEq)]

View File

@ -60,11 +60,12 @@ impl LNVpsDb for LNVpsDbMysql {
async fn update_user(&self, user: &User) -> Result<()> {
sqlx::query(
"update users set email = ?, contact_nip17 = ?, contact_email = ? where id = ?",
"update users set email=?, contact_nip17=?, contact_email=?, country_code=? where id = ?",
)
.bind(&user.email)
.bind(user.contact_nip17)
.bind(user.contact_email)
.bind(&user.country_code)
.bind(user.id)
.execute(&self.db)
.await?;
@ -387,12 +388,13 @@ impl LNVpsDb for LNVpsDbMysql {
}
async fn insert_vm_payment(&self, vm_payment: &VmPayment) -> Result<()> {
sqlx::query("insert into vm_payment(id,vm_id,created,expires,amount,currency,payment_method,time_value,is_paid,rate,external_id,external_data) values(?,?,?,?,?,?,?,?,?,?,?,?)")
sqlx::query("insert into vm_payment(id,vm_id,created,expires,amount,tax,currency,payment_method,time_value,is_paid,rate,external_id,external_data) values(?,?,?,?,?,?,?,?,?,?,?,?,?)")
.bind(&vm_payment.id)
.bind(vm_payment.vm_id)
.bind(vm_payment.created)
.bind(vm_payment.expires)
.bind(vm_payment.amount)
.bind(vm_payment.tax)
.bind(&vm_payment.currency)
.bind(&vm_payment.payment_method)
.bind(vm_payment.time_value)

View File

@ -402,6 +402,7 @@ pub struct AccountPatchRequest {
pub email: Option<String>,
pub contact_nip17: bool,
pub contact_email: bool,
pub country_code: String,
}
#[derive(Serialize, Deserialize, JsonSchema)]
@ -473,6 +474,7 @@ pub struct ApiVmPayment {
pub created: DateTime<Utc>,
pub expires: DateTime<Utc>,
pub amount: u64,
pub tax: u64,
pub currency: String,
pub is_paid: bool,
pub data: ApiPaymentData,
@ -486,6 +488,7 @@ impl From<lnvps_db::VmPayment> for ApiVmPayment {
created: value.created,
expires: value.expires,
amount: value.amount,
tax: value.tax,
currency: value.currency,
is_paid: value.is_paid,
data: match &value.payment_method {

View File

@ -1,8 +1,8 @@
use crate::api::model::{
AccountPatchRequest, ApiCustomTemplateParams, ApiCustomVmOrder,
ApiCustomVmRequest, ApiPaymentInfo, ApiPaymentMethod, ApiPrice, ApiTemplatesResponse,
ApiUserSshKey, ApiVmIpAssignment, ApiVmOsImage, ApiVmPayment, ApiVmStatus,
ApiVmTemplate, CreateSshKey, CreateVmRequest, VMPatchRequest,
AccountPatchRequest, ApiCustomTemplateParams, ApiCustomVmOrder, ApiCustomVmRequest,
ApiPaymentInfo, ApiPaymentMethod, ApiPrice, ApiTemplatesResponse, ApiUserSshKey,
ApiVmIpAssignment, ApiVmOsImage, ApiVmPayment, ApiVmStatus, ApiVmTemplate, CreateSshKey,
CreateVmRequest, VMPatchRequest,
};
use crate::exchange::{Currency, ExchangeRateService};
use crate::host::{get_host_client, FullVmInfo, TimeSeries, TimeSeriesData};
@ -13,6 +13,7 @@ use crate::status::{VmState, VmStateCache};
use crate::worker::WorkJob;
use anyhow::Result;
use futures::future::join_all;
use isocountry::CountryCode;
use lnvps_db::{
IpRange, LNVpsDb, PaymentMethod, VmCustomPricing, VmCustomPricingDisk, VmCustomTemplate,
};
@ -107,6 +108,9 @@ async fn v1_patch_account(
user.email = req.email.clone();
user.contact_nip17 = req.contact_nip17;
user.contact_email = req.contact_email;
user.country_code = CountryCode::for_alpha3(&req.country_code)?
.alpha3()
.to_owned();
db.update_user(&user).await?;
ApiData::ok(())
@ -127,6 +131,7 @@ async fn v1_get_account(
email: user.email,
contact_nip17: user.contact_nip17,
contact_email: user.contact_email,
country_code: user.country_code,
})
}

View File

@ -19,3 +19,11 @@ pub mod worker;
#[cfg(test)]
pub mod mocks;
/// SATS per BTC
pub const BTC_SATS: f64 = 100_000_000.0;
pub const KB: u64 = 1024;
pub const MB: u64 = KB * 1024;
pub const GB: u64 = MB * 1024;
pub const TB: u64 = GB * 1024;

View File

@ -40,11 +40,6 @@ pub struct MockDb {
}
impl MockDb {
pub const KB: u64 = 1024;
pub const MB: u64 = Self::KB * 1024;
pub const GB: u64 = Self::MB * 1024;
pub const TB: u64 = Self::GB * 1024;
pub fn empty() -> MockDb {
Self {
..Default::default()
@ -71,8 +66,8 @@ impl MockDb {
created: Utc::now(),
expires: None,
cpu: 2,
memory: Self::GB * 2,
disk_size: Self::GB * 64,
memory: crate::GB * 2,
disk_size: crate::GB * 64,
disk_type: DiskType::SSD,
disk_interface: DiskInterface::PCIe,
cost_plan_id: 1,
@ -132,7 +127,7 @@ impl Default for MockDb {
name: "mock-host".to_string(),
ip: "https://localhost".to_string(),
cpu: 4,
memory: 8 * Self::GB,
memory: 8 * crate::GB,
enabled: true,
api_token: "".to_string(),
load_factor: 1.5,
@ -145,7 +140,7 @@ impl Default for MockDb {
id: 1,
host_id: 1,
name: "mock-disk".to_string(),
size: Self::TB * 10,
size: crate::TB * 10,
kind: DiskType::SSD,
interface: DiskInterface::PCIe,
enabled: true,
@ -209,6 +204,7 @@ impl LNVpsDb for MockDb {
email: None,
contact_nip17: false,
contact_email: false,
country_code: "USA".to_string(),
},
);
Ok(max + 1)
@ -650,14 +646,15 @@ impl Router for MockRouter {
#[derive(Clone, Debug, Default)]
pub struct MockNode {
invoices: Arc<Mutex<HashMap<String, MockInvoice>>>,
pub invoices: Arc<Mutex<HashMap<String, MockInvoice>>>,
}
#[derive(Debug, Clone)]
struct MockInvoice {
pr: String,
expiry: DateTime<Utc>,
settle_index: u64,
pub struct MockInvoice {
pub pr: String,
pub amount: u64,
pub expiry: DateTime<Utc>,
pub is_paid: bool,
}
impl MockNode {
@ -673,7 +670,22 @@ impl MockNode {
#[async_trait]
impl LightningNode for MockNode {
async fn add_invoice(&self, req: AddInvoiceRequest) -> anyhow::Result<AddInvoiceResult> {
todo!()
let mut invoices = self.invoices.lock().await;
let id: [u8; 32] = rand::random();
let hex_id = hex::encode(id);
invoices.insert(
hex_id.clone(),
MockInvoice {
pr: format!("lnrt1{}", hex_id),
amount: req.amount,
expiry: Utc::now().add(TimeDelta::seconds(req.expire.unwrap_or(3600) as i64)),
is_paid: false,
},
);
Ok(AddInvoiceResult {
pr: format!("lnrt1{}", hex_id),
payment_hash: hex_id.clone(),
})
}
async fn subscribe_invoices(

View File

@ -10,12 +10,11 @@ use crate::router::{ArpEntry, Router};
use crate::settings::{NetworkAccessPolicy, NetworkPolicy, ProvisionerConfig, Settings};
use anyhow::{bail, ensure, Context, Result};
use chrono::Utc;
use lnvps_db::{
LNVpsDb, PaymentMethod, Vm, VmCustomTemplate, VmIpAssignment,
VmPayment,
};
use isocountry::CountryCode;
use lnvps_db::{LNVpsDb, PaymentMethod, Vm, VmCustomTemplate, VmIpAssignment, VmPayment};
use log::{info, warn};
use nostr::util::hex;
use std::collections::HashMap;
use std::ops::Add;
use std::sync::Arc;
use std::time::Duration;
@ -29,6 +28,7 @@ pub struct LNVpsProvisioner {
db: Arc<dyn LNVpsDb>,
node: Arc<dyn LightningNode>,
rates: Arc<dyn ExchangeRateService>,
tax_rates: HashMap<CountryCode, f32>,
router: Option<Arc<dyn Router>>,
dns: Option<Arc<dyn DnsServer>>,
@ -52,6 +52,7 @@ impl LNVpsProvisioner {
router: settings.get_router().expect("router config"),
dns: settings.get_dns().expect("dns config"),
revolut: settings.get_revolut().expect("revolut config"),
tax_rates: settings.tax_rate,
network_policy: settings.network_policy,
provisioner_config: settings.provisioner,
read_only: settings.read_only,
@ -356,7 +357,7 @@ impl LNVpsProvisioner {
/// Create a renewal payment
pub async fn renew(&self, vm_id: u64, method: PaymentMethod) -> Result<VmPayment> {
let pe = PricingEngine::new(self.db.clone(), self.rates.clone());
let pe = PricingEngine::new(self.db.clone(), self.rates.clone(), self.tax_rates.clone());
let price = pe.get_vm_cost(vm_id, method).await?;
match price {
@ -367,6 +368,7 @@ impl LNVpsProvisioner {
time_value,
new_expiry,
rate,
tax,
} => {
let desc = format!("VM renewal {vm_id} to {new_expiry}");
let vm_payment = match method {
@ -376,12 +378,16 @@ impl LNVpsProvisioner {
"Cannot create invoices for non-BTC currency"
);
const INVOICE_EXPIRE: u64 = 600;
info!("Creating invoice for {vm_id} for {} sats", amount / 1000);
let total_amount = amount + tax;
info!(
"Creating invoice for {vm_id} for {} sats",
total_amount / 1000
);
let invoice = self
.node
.add_invoice(AddInvoiceRequest {
memo: Some(desc),
amount,
amount: total_amount,
expire: Some(INVOICE_EXPIRE as u32),
})
.await?;
@ -391,6 +397,7 @@ impl LNVpsProvisioner {
created: Utc::now(),
expires: Utc::now().add(Duration::from_secs(INVOICE_EXPIRE)),
amount,
tax,
currency: currency.to_string(),
payment_method: method,
time_value,
@ -411,7 +418,7 @@ impl LNVpsProvisioner {
"Cannot create revolut orders for BTC currency"
);
let order = rev
.create_order(&desc, CurrencyAmount::from_u64(currency, amount))
.create_order(&desc, CurrencyAmount::from_u64(currency, amount + tax))
.await?;
let new_id: [u8; 32] = rand::random();
VmPayment {
@ -420,6 +427,7 @@ impl LNVpsProvisioner {
created: Utc::now(),
expires: Utc::now().add(Duration::from_secs(3600)),
amount,
tax,
currency: currency.to_string(),
payment_method: method,
time_value,
@ -483,8 +491,8 @@ impl LNVpsProvisioner {
#[cfg(test)]
mod tests {
use super::*;
use crate::exchange::DefaultRateCache;
use crate::mocks::{MockDb, MockDnsServer, MockNode, MockRouter};
use crate::exchange::{DefaultRateCache, Ticker};
use crate::mocks::{MockDb, MockDnsServer, MockExchangeRate, MockNode, MockRouter};
use crate::settings::{DnsServerConfig, LightningConfig, QemuConfig, RouterConfig};
use lnvps_db::{DiskInterface, DiskType, User, UserSshKey, VmTemplate};
use std::net::IpAddr;
@ -535,6 +543,7 @@ mod tests {
}),
nostr: None,
revolut: None,
tax_rate: HashMap::from([(CountryCode::IRL, 23.0), (CountryCode::USA, 1.0)]),
}
}
@ -559,7 +568,10 @@ mod tests {
let settings = settings();
let db = Arc::new(MockDb::default());
let node = Arc::new(MockNode::default());
let rates = Arc::new(DefaultRateCache::default());
let rates = Arc::new(MockExchangeRate::new());
const MOCK_RATE: f32 = 69_420.0;
rates.set_rate(Ticker::btc_rate("EUR")?, MOCK_RATE).await;
let router = MockRouter::new(settings.network_policy.clone());
let dns = MockDnsServer::new();
let provisioner = LNVpsProvisioner::new(settings, db.clone(), node.clone(), rates.clone());
@ -569,6 +581,21 @@ mod tests {
.provision(user.id, 1, 1, ssh_key.id, Some("mock-ref".to_string()))
.await?;
println!("{:?}", vm);
// renew vm
let payment = provisioner.renew(vm.id, PaymentMethod::Lightning).await?;
assert_eq!(vm.id, payment.vm_id);
assert_eq!(payment.tax, (payment.amount as f64 * 0.01).floor() as u64);
// check invoice amount matches amount+tax
let inv = node.invoices.lock().await;
if let Some(i) = inv.get(&hex::encode(payment.id)) {
assert_eq!(i.amount, payment.amount + payment.tax);
} else {
bail!("Invoice doesnt exist");
}
// spawn vm
provisioner.spawn_vm(vm.id).await?;
// check resources
@ -636,8 +663,8 @@ mod tests {
created: Default::default(),
expires: None,
cpu: 64,
memory: 512 * MockDb::GB,
disk_size: 20 * MockDb::TB,
memory: 512 * crate::GB,
disk_size: 20 * crate::TB,
disk_type: DiskType::SSD,
disk_interface: DiskInterface::PCIe,
cost_plan_id: 1,

View File

@ -1,11 +1,13 @@
use crate::exchange::{Currency, CurrencyAmount, ExchangeRateService, Ticker, TickerRate};
use anyhow::{bail, Result};
use anyhow::{bail, Context, Result};
use chrono::{DateTime, Days, Months, TimeDelta, Utc};
use ipnetwork::IpNetwork;
use isocountry::CountryCode;
use lnvps_db::{
LNVpsDb, PaymentMethod, Vm, VmCostPlan, VmCostPlanIntervalType, VmCustomTemplate, VmPayment,
};
use log::info;
use std::collections::HashMap;
use std::ops::Add;
use std::str::FromStr;
use std::sync::Arc;
@ -16,17 +18,20 @@ use std::sync::Arc;
pub struct PricingEngine {
db: Arc<dyn LNVpsDb>,
rates: Arc<dyn ExchangeRateService>,
tax_rates: HashMap<CountryCode, f32>,
}
impl PricingEngine {
/// SATS per BTC
const BTC_SATS: f64 = 100_000_000.0;
const KB: u64 = 1024;
const MB: u64 = Self::KB * 1024;
const GB: u64 = Self::MB * 1024;
pub fn new(db: Arc<dyn LNVpsDb>, rates: Arc<dyn ExchangeRateService>) -> Self {
Self { db, rates }
pub fn new(
db: Arc<dyn LNVpsDb>,
rates: Arc<dyn ExchangeRateService>,
tax_rates: HashMap<CountryCode, f32>,
) -> Self {
Self {
db,
rates,
tax_rates,
}
}
/// Get VM cost (for renewal)
@ -82,9 +87,9 @@ impl PricingEngine {
} else {
bail!("No disk price found")
};
let disk_cost = (template.disk_size / Self::GB) as f32 * disk_pricing.cost;
let disk_cost = (template.disk_size / crate::GB) as f32 * disk_pricing.cost;
let cpu_cost = pricing.cpu_cost * template.cpu as f32;
let memory_cost = pricing.memory_cost * (template.memory / Self::GB) as f32;
let memory_cost = pricing.memory_cost * (template.memory / crate::GB) as f32;
let ip4_cost = pricing.ip4_cost * v4s as f32;
let ip6_cost = pricing.ip6_cost * v6s as f32;
@ -121,6 +126,7 @@ impl PricingEngine {
.await?;
Ok(CostResult::New {
amount,
tax: self.get_tax_for_user(vm.user_id, amount).await?,
currency,
rate,
time_value,
@ -128,6 +134,16 @@ impl PricingEngine {
})
}
async fn get_tax_for_user(&self, user_id: u64, amount: u64) -> Result<u64> {
let user = self.db.get_user(user_id).await?;
let cc = CountryCode::for_alpha3(&user.country_code).context("Invalid country code")?;
if let Some(c) = self.tax_rates.get(&cc) {
Ok((amount as f64 * (*c as f64 / 100f64)).floor() as u64)
} else {
Ok(0)
}
}
async fn get_ticker(&self, currency: Currency) -> Result<TickerRate> {
let ticker = Ticker(Currency::BTC, currency);
if let Some(r) = self.rates.get_rate(ticker).await {
@ -140,7 +156,7 @@ impl PricingEngine {
async fn get_msats_amount(&self, amount: CurrencyAmount) -> Result<(u64, f32)> {
let rate = self.get_ticker(amount.0).await?;
let cost_btc = amount.1 / rate.1;
let cost_msats = (cost_btc as f64 * Self::BTC_SATS) as u64 * 1000;
let cost_msats = (cost_btc as f64 * crate::BTC_SATS) as u64 * 1000;
Ok((cost_msats, rate.1))
}
@ -174,6 +190,7 @@ impl PricingEngine {
let time_value = Self::next_template_expire(vm, &cost_plan);
Ok(CostResult::New {
amount,
tax: self.get_tax_for_user(vm.user_id, amount).await?,
currency,
rate,
time_value,
@ -215,6 +232,8 @@ pub enum CostResult {
time_value: u64,
/// The absolute expiry time of the vm if renewed
new_expiry: DateTime<Utc>,
/// Taxes to charge
tax: u64,
},
}
@ -238,8 +257,7 @@ impl PricingData {
mod tests {
use super::*;
use crate::mocks::{MockDb, MockExchangeRate};
use lnvps_db::{DiskType, VmCustomPricing, VmCustomPricingDisk, VmCustomTemplate};
const GB: u64 = 1024 * 1024 * 1024;
use lnvps_db::{DiskType, User, VmCustomPricing, VmCustomPricingDisk, VmCustomTemplate};
const MOCK_RATE: f32 = 100_000.0;
async fn add_custom_pricing(db: &MockDb) {
@ -266,8 +284,8 @@ mod tests {
VmCustomTemplate {
id: 1,
cpu: 2,
memory: 2 * GB,
disk_size: 80 * GB,
memory: 2 * crate::GB,
disk_size: 80 * crate::GB,
disk_type: DiskType::SSD,
disk_interface: Default::default(),
pricing_id: 1,
@ -313,17 +331,65 @@ mod tests {
{
let mut v = db.vms.lock().await;
v.insert(1, MockDb::mock_vm());
v.insert(
2,
Vm {
user_id: 2,
..MockDb::mock_vm()
},
);
let mut u = db.users.lock().await;
u.insert(
1,
User {
id: 1,
pubkey: vec![],
created: Default::default(),
email: None,
contact_nip17: false,
contact_email: false,
country_code: "USA".to_string(),
},
);
u.insert(
2,
User {
id: 2,
pubkey: vec![],
created: Default::default(),
email: None,
contact_nip17: false,
contact_email: false,
country_code: "IRL".to_string(),
},
);
}
let db: Arc<dyn LNVpsDb> = Arc::new(db);
let pe = PricingEngine::new(db.clone(), rates);
let price = pe.get_vm_cost(1, PaymentMethod::Lightning).await?;
let taxes = HashMap::from([(CountryCode::IRL, 23.0)]);
let pe = PricingEngine::new(db.clone(), rates, taxes);
let plan = MockDb::mock_cost_plan();
let price = pe.get_vm_cost(1, PaymentMethod::Lightning).await?;
match price {
CostResult::New { amount, .. } => {
CostResult::New { amount, tax, .. } => {
let expect_price = (plan.amount / MOCK_RATE * 1.0e11) as u64;
assert_eq!(expect_price, amount);
assert_eq!(0, tax);
}
_ => bail!("??"),
}
// with taxes
let price = pe.get_vm_cost(2, PaymentMethod::Lightning).await?;
match price {
CostResult::New { amount, tax, .. } => {
let expect_price = (plan.amount / MOCK_RATE * 1.0e11) as u64;
assert_eq!(expect_price, amount);
assert_eq!((expect_price as f64 * 0.23).floor() as u64, tax);
}
_ => bail!("??"),
}

View File

@ -1,3 +1,4 @@
use std::collections::HashMap;
use crate::dns::DnsServer;
use crate::exchange::ExchangeRateService;
use crate::fiat::FiatPaymentService;
@ -9,6 +10,7 @@ use lnvps_db::LNVpsDb;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
use isocountry::CountryCode;
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
@ -31,8 +33,8 @@ pub struct Settings {
/// Provisioning profiles
pub provisioner: ProvisionerConfig,
/// Network policy
#[serde(default)]
/// Network policy
pub network_policy: NetworkPolicy,
/// Number of days after an expired VM is deleted
@ -52,6 +54,10 @@ pub struct Settings {
/// Config for accepting revolut payments
pub revolut: Option<RevolutConfig>,
#[serde(default)]
/// Tax rates to change per country as a percent of the amount
pub tax_rate: HashMap<CountryCode, f32>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]