Skip to content

Commit

Permalink
Fix e2e lockfile and tests (#945)
Browse files Browse the repository at this point in the history
* revert e2e lockfile to old version

* remove global state from tests

* thread local

* generalize global values

* sort imports, cleanup
  • Loading branch information
t-aleksander authored Jan 15, 2025
1 parent 89c795f commit 86894d2
Show file tree
Hide file tree
Showing 10 changed files with 1,187 additions and 1,433 deletions.
2,424 changes: 1,085 additions & 1,339 deletions e2e/pnpm-lock.yaml

Large diffs are not rendered by default.

29 changes: 6 additions & 23 deletions src/db/models/settings.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,23 @@
use std::{
collections::HashMap,
str::FromStr,
sync::{RwLock, RwLockReadGuard},
};
use std::{collections::HashMap, str::FromStr};

use sqlx::{query, query_as, PgExecutor, PgPool, Type};
use struct_patch::Patch;
use thiserror::Error;

use crate::secret::SecretString;
use crate::{global_value, secret::SecretString};

// wrap in `Option` since a static cannot be initialized with a non-const function
static SETTINGS: RwLock<Option<Settings>> = RwLock::new(None);

pub(crate) fn set_settings(new_settings: Settings) {
*SETTINGS
.write()
.expect("Failed to acquire lock on current settings.") = Some(new_settings);
}

pub(crate) fn get_settings() -> RwLockReadGuard<'static, Option<Settings>> {
SETTINGS
.read()
.expect("Failed to acquire lock on current settings.")
}
global_value!(SETTINGS, Option<Settings>, None, set_settings, get_settings);

/// Initializes global `SETTINGS` struct at program startup
pub async fn initialize_current_settings(pool: &PgPool) -> Result<(), sqlx::Error> {
debug!("Initializing global settings strut");
match Settings::get(pool).await? {
Some(settings) => {
set_settings(settings);
set_settings(Some(settings));
}
None => {
debug!("Settings not found in DB. Using default values to initialize global settings struct");
set_settings(Settings::default());
set_settings(Some(Settings::default()));
}
}
Ok(())
Expand All @@ -47,7 +30,7 @@ pub async fn update_current_settings(
) -> Result<(), sqlx::Error> {
debug!("Updating current settings to: {new_settings:?}");
new_settings.save(pool).await?;
set_settings(new_settings);
set_settings(Some(new_settings));
Ok(())
}

Expand Down
3 changes: 2 additions & 1 deletion src/enterprise/directory_sync/google.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::{str::FromStr, time::Duration};

use super::{DirectoryGroup, DirectorySync, DirectorySyncError, DirectoryUser};
use chrono::{DateTime, TimeDelta, Utc};
#[cfg(not(test))]
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
use reqwest::{header::AUTHORIZATION, Url};

use super::{DirectoryGroup, DirectorySync, DirectorySyncError, DirectoryUser};

#[cfg(not(test))]
const SCOPES: &str = "openid email profile https://www.googleapis.com/auth/admin.directory.customer.readonly https://www.googleapis.com/auth/admin.directory.group.readonly https://www.googleapis.com/auth/admin.directory.user.readonly";
const ACCESS_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
Expand Down
14 changes: 7 additions & 7 deletions src/enterprise/directory_sync/mod.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use std::collections::{HashMap, HashSet};

use sqlx::PgPool;

use crate::{
db::{Group, Id, User},
enterprise::db::models::openid_provider::DirectorySyncUserBehavior,
};
use sqlx::error::Error as SqlxError;
use sqlx::PgPool;
use thiserror::Error;

use super::db::models::openid_provider::{DirectorySyncTarget, OpenIdProvider};
#[cfg(not(test))]
use super::is_enterprise_enabled;
use crate::{
db::{Group, Id, User},
enterprise::db::models::openid_provider::DirectorySyncUserBehavior,
};

#[derive(Debug, Error)]
pub enum DirectorySyncError {
Expand Down Expand Up @@ -600,12 +599,13 @@ pub(crate) async fn do_directory_sync(pool: &PgPool) -> Result<(), DirectorySync

#[cfg(test)]
mod test {
use secrecy::ExposeSecret;

use super::*;
use crate::{
config::DefGuardConfig, enterprise::db::models::openid_provider::DirectorySyncTarget,
SERVER_CONFIG,
};
use secrecy::ExposeSecret;

async fn make_test_provider(
pool: &PgPool,
Expand Down
32 changes: 11 additions & 21 deletions src/enterprise/license.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use std::{
sync::{RwLock, RwLockReadGuard},
time::Duration,
};
use std::time::Duration;

use anyhow::Result;
use base64::prelude::*;
Expand All @@ -13,28 +10,21 @@ use sqlx::{error::Error as SqlxError, PgPool};
use thiserror::Error;
use tokio::time::sleep;

use super::limits::Counts;
use crate::{
db::{models::settings::update_current_settings, Settings},
server_config, VERSION,
global_value, server_config, VERSION,
};

use super::limits::Counts;

const LICENSE_SERVER_URL: &str = "https://pkgs.defguard.net/api/license/renew";

static LICENSE: RwLock<Option<License>> = RwLock::new(None);

pub fn set_cached_license(license: Option<License>) {
*LICENSE
.write()
.expect("Failed to acquire lock on the license mutex.") = license;
}

pub fn get_cached_license() -> RwLockReadGuard<'static, Option<License>> {
LICENSE
.read()
.expect("Failed to acquire lock on the license mutex.")
}
global_value!(
LICENSE,
Option<License>,
None,
set_cached_license,
get_cached_license
);

tonic::include_proto!("license");

Expand Down Expand Up @@ -584,7 +574,7 @@ pub async fn run_periodic_license_check(pool: &PgPool) -> Result<(), LicenseErro
let license = get_cached_license();
debug!("Checking if the license {license:?} requires a renewal...");

if let Some(license) = &*license {
if let Some(license) = license.as_ref() {
if license.requires_renewal() {
// check if we are pass the maximum expiration date, after which we don't
// want to try to renew the license anymore
Expand Down
38 changes: 15 additions & 23 deletions src/enterprise/limits.rs
Original file line number Diff line number Diff line change
@@ -1,39 +1,24 @@
use sqlx::{error::Error as SqlxError, query, PgPool};
use std::sync::{RwLock, RwLockReadGuard};

#[cfg(test)]
use super::license::get_cached_license;
use super::license::License;
use crate::global_value;

// Limits for free users
pub const DEFAULT_USERS_LIMIT: u32 = 5;
pub const DEFAULT_DEVICES_LIMIT: u32 = 10;
pub const DEFAULT_LOCATIONS_LIMIT: u32 = 1;

#[derive(Debug, Default)]
pub(crate) struct Counts {
#[derive(Debug)]
#[cfg_attr(test, derive(Clone))]
pub struct Counts {
user: u32,
device: u32,
wireguard_network: u32,
}

static COUNTS: RwLock<Counts> = RwLock::new(Counts {
user: 0,
device: 0,
wireguard_network: 0,
});

fn set_counts(new_counts: Counts) {
*COUNTS
.write()
.expect("Failed to acquire lock on the enterprise limit counts.") = new_counts;
}

pub(crate) fn get_counts() -> RwLockReadGuard<'static, Counts> {
COUNTS
.read()
.expect("Failed to acquire lock on the enterprise limit counts.")
}
global_value!(COUNTS, Counts, Counts::default(), set_counts, get_counts);

/// Update the counts of users, devices, and wireguard networks stored in the memory.
// TODO: Use it with database triggers when they are implemented
Expand Down Expand Up @@ -80,6 +65,14 @@ pub async fn do_count_update(pool: &PgPool) -> Result<(), SqlxError> {
}

impl Counts {
pub(crate) const fn default() -> Self {
Self {
user: 0,
device: 0,
wireguard_network: 0,
}
}

#[cfg(test)]
pub(crate) fn new(user: u32, device: u32, wireguard_network: u32) -> Self {
Self {
Expand All @@ -97,7 +90,7 @@ impl Counts {
let maybe_license = get_cached_license();

// validate limits against license if available, use defaults otherwise
match &*maybe_license {
match &maybe_license {
Some(license) => {
debug!("Cached license found. Validating license limits...");
self.is_over_license_limits(license)
Expand Down Expand Up @@ -177,9 +170,8 @@ impl LimitsExceeded {
mod test {
use chrono::{TimeDelta, Utc};

use crate::enterprise::license::{set_cached_license, License, LicenseLimits};

use super::*;
use crate::enterprise::license::{set_cached_license, License, LicenseLimits};

#[test]
fn test_counts() {
Expand Down
54 changes: 54 additions & 0 deletions src/globals.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#[macro_export]
/// Specify a global value that can be accessed from anywhere in the application.
/// Positional arguments:
/// - `$name`: The name of the global value. This will be the name of the variable that holds the value.
/// - `$type`: The type of the global value.
/// - `$init`: The initial value of the global value.
/// - `$set_fn`: The name of the function that will be used to set the global value.
/// - `$get_fn`: The name of the function that will be used to get the global value.
///
/// The macro will also automatically generate boilerplate code for unit tests to work correctly.
macro_rules! global_value {
($name:ident, $type:ty, $init:expr, $set_fn:ident, $get_fn:ident) => {
use std::sync::RwLock;
#[cfg(not(test))]
use std::sync::RwLockReadGuard;

#[cfg(test)]
thread_local! {
static $name: RwLock<$type> = const { RwLock::new($init) };
}

#[cfg(not(test))]
static $name: RwLock<$type> = RwLock::new($init);

#[cfg(not(test))]
pub fn $set_fn(value: $type) {
*$name.write().expect("Failed to acquire lock on the mutex.") = value;
}

#[cfg(not(test))]
pub fn $get_fn() -> RwLockReadGuard<'static, $type> {
$name.read().expect("Failed to acquire lock on the mutex.")
}

#[cfg(test)]
pub fn $set_fn(new_value: $type) {
$name.with(|value| {
*value.write().expect("Failed to acquire lock on the mutex.") = new_value;
});
}

// This is not really a 1:1 replacement for the non-test RwLockReadGuard<'static, $type> as the RwLock may be tried to be
// dereferenced
#[cfg(test)]
pub fn $get_fn() -> $type {
$name.with(|value| {
value
.read()
.expect("Failed to acquire lock on the mutex.")
.clone()
})
}
};
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ pub mod config;
pub mod db;
pub mod enterprise;
mod error;
pub mod globals;
pub mod grpc;
pub mod handlers;
pub mod headers;
Expand Down
1 change: 0 additions & 1 deletion src/templates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ pub fn gateway_reconnected_mail(
context.insert("gateway_ip", gateway_ip);
context.insert("network_name", network_name);
tera.add_raw_template("mail_gateway_reconnected", MAIL_GATEWAY_RECONNECTED)?;
println!("dupa");
Ok(tera.render("mail_gateway_reconnected", &context)?)
}

Expand Down
24 changes: 6 additions & 18 deletions src/updates.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use std::{
env,
sync::{RwLock, RwLockReadGuard},
};
use std::env;

use chrono::NaiveDate;
use semver::Version;

use crate::global_value;

const PRODUCT_NAME: &str = "Defguard";
const UPDATES_URL: &str = "https://update-service-dev.defguard.net/api/update/check";
const VERSION: &str = env!("CARGO_PKG_VERSION");

#[derive(Deserialize, Debug, Serialize)]
#[cfg_attr(test, derive(Clone))]
pub struct Update {
version: String,
release_date: NaiveDate,
Expand All @@ -20,19 +20,7 @@ pub struct Update {
notes: String,
}

static NEW_UPDATE: RwLock<Option<Update>> = RwLock::new(None);

fn set_update(update: Update) {
*NEW_UPDATE
.write()
.expect("Failed to acquire lock on the update.") = Some(update);
}

pub fn get_update() -> RwLockReadGuard<'static, Option<Update>> {
NEW_UPDATE
.read()
.expect("Failed to acquire lock on the update.")
}
global_value!(NEW_UPDATE, Option<Update>, None, set_update, get_update);

async fn fetch_update() -> Result<Update, anyhow::Error> {
let body = serde_json::json!({
Expand Down Expand Up @@ -63,7 +51,7 @@ pub(crate) async fn do_new_version_check() -> Result<(), anyhow::Error> {
update.version, update.release_date
);
}
set_update(update);
set_update(Some(update));
} else {
debug!("New version check done. You are using the latest version of Defguard.");
}
Expand Down

0 comments on commit 86894d2

Please sign in to comment.