diff --git a/Snowflake.Data.Tests/UnitTests/Tools/FileOperationsTest.cs b/Snowflake.Data.Tests/UnitTests/Tools/FileOperationsTest.cs index b8b311357..37b7cc48d 100644 --- a/Snowflake.Data.Tests/UnitTests/Tools/FileOperationsTest.cs +++ b/Snowflake.Data.Tests/UnitTests/Tools/FileOperationsTest.cs @@ -50,7 +50,7 @@ public void TestReadAllTextOnWindows() var filePath = CreateConfigTempFile(s_workingDirectory, content); // act - var result = s_fileOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions); + var result = s_fileOperations.ReadAllText(filePath, UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner); // assert Assert.AreEqual(content, result); @@ -73,7 +73,7 @@ public void TestReadAllTextCheckingPermissionsUsingTomlConfigurationFileValidati Syscall.chmod(filePath, (FilePermissions)filePermissions); // act - var result = s_fileOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions); + var result = s_fileOperations.ReadAllText(filePath, UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner); // assert Assert.AreEqual(content, result); @@ -96,7 +96,7 @@ public void TestShouldThrowExceptionIfOtherPermissionsIsSetWhenReadConfiguration Syscall.chmod(filePath, (FilePermissions)filePermissions); // act and assert - Assert.Throws(() => s_fileOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions), + Assert.Throws(() => s_fileOperations.ReadAllText(filePath, UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner), "Attempting to read a file with too broad permissions assigned"); } diff --git a/Snowflake.Data.Tests/UnitTests/Tools/UnixOperationsTest.cs b/Snowflake.Data.Tests/UnitTests/Tools/UnixOperationsTest.cs index 14e2df121..5a6db8ae7 100644 --- a/Snowflake.Data.Tests/UnitTests/Tools/UnixOperationsTest.cs +++ b/Snowflake.Data.Tests/UnitTests/Tools/UnixOperationsTest.cs @@ -96,14 +96,30 @@ public void TestReadAllTextCheckingPermissionsUsingTomlConfigurationFileValidati Syscall.chmod(filePath, userAllowedPermissions); // act - var result = s_unixOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions); + var result = s_unixOperations.ReadAllText(filePath, UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner); // assert Assert.AreEqual(content, result); } [Test] - public void TestFailIfGroupOrOthersHavePermissionsToFileWithTomlConfigurationValidations([ValueSource(nameof(UserReadWritePermissions))] FilePermissions userPermissions, + public void TestWriteAllTextCheckingPermissionsUsingTomlConfigurationFileValidations( + [ValueSource(nameof(UserAllowedWritePermissions))] FilePermissions userAllowedPermissions) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test on Windows"); + } + var content = "random text"; + var filePath = CreateConfigTempFile(s_workingDirectory, content); + Syscall.chmod(filePath, userAllowedPermissions); + + // act and assert + Assert.DoesNotThrow(() => s_unixOperations.WriteAllText(filePath,"test", UnixOperations.ValidateFileWhenWriteIsAccessedOnlyByItsOwner)); + } + + [Test] + public void TestFailIfGroupOrOthersHavePermissionsToFileWhileReadingWithUnixValidations([ValueSource(nameof(UserReadWritePermissions))] FilePermissions userPermissions, [ValueSource(nameof(GroupPermissions))] FilePermissions groupPermissions, [ValueSource(nameof(OthersPermissions))] FilePermissions othersPermissions) { @@ -123,7 +139,31 @@ public void TestFailIfGroupOrOthersHavePermissionsToFileWithTomlConfigurationVal Syscall.chmod(filePath, filePermissions); // act and assert - Assert.Throws(() => s_unixOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions), "Attempting to read a file with too broad permissions assigned"); + Assert.Throws(() => s_unixOperations.ReadAllText(filePath, UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner), "Attempting to read a file with too broad permissions assigned"); + } + + [Test] + public void TestFailIfGroupOrOthersHavePermissionsToFileWhileWritingWithUnixValidations([ValueSource(nameof(UserReadWritePermissions))] FilePermissions userPermissions, + [ValueSource(nameof(GroupPermissions))] FilePermissions groupPermissions, + [ValueSource(nameof(OthersPermissions))] FilePermissions othersPermissions) + { + if(groupPermissions == 0 && othersPermissions == 0) + { + Assert.Ignore("Skip test when group and others have no permissions"); + } + + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.Ignore("skip test on Windows"); + } + var content = "random text"; + var filePath = CreateConfigTempFile(s_workingDirectory, content); + + var filePermissions = userPermissions | groupPermissions | othersPermissions; + Syscall.chmod(filePath, filePermissions); + + // act and assert + Assert.Throws(() => s_unixOperations.WriteAllText(filePath, "test", UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner), "Attempting to write a file with too broad permissions assigned"); } public static IEnumerable UserPermissions() @@ -186,6 +226,11 @@ public static IEnumerable UserAllowedPermissions() yield return FilePermissions.S_IRUSR | FilePermissions.S_IWUSR; } + public static IEnumerable UserAllowedWritePermissions() + { + yield return FilePermissions.S_IRUSR | FilePermissions.S_IWUSR; + } + public static IEnumerable GroupOrOthersReadablePermissions() { yield return 0; diff --git a/Snowflake.Data/Core/Authenticator/IAuthenticator.cs b/Snowflake.Data/Core/Authenticator/IAuthenticator.cs index cfaf28a5f..267f878aa 100644 --- a/Snowflake.Data/Core/Authenticator/IAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/IAuthenticator.cs @@ -206,7 +206,7 @@ internal static IAuthenticator GetAuthenticator(SFSession session) return new OAuthAuthenticator(session); } - else if (type.Equals(MFACacheAuthenticator.AUTH_NAME, StringComparison.InvariantCultureIgnoreCase)) + else if (type.Equals(MFACacheAuthenticator.AuthName, StringComparison.InvariantCultureIgnoreCase)) { return new MFACacheAuthenticator(session); } diff --git a/Snowflake.Data/Core/Authenticator/MFACacheAuthenticator.cs b/Snowflake.Data/Core/Authenticator/MFACacheAuthenticator.cs index 2d398352d..1e65ca376 100644 --- a/Snowflake.Data/Core/Authenticator/MFACacheAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/MFACacheAuthenticator.cs @@ -11,10 +11,10 @@ namespace Snowflake.Data.Core.Authenticator { class MFACacheAuthenticator : BaseAuthenticator, IAuthenticator { - public const string AUTH_NAME = "username_password_mfa"; - private const int _MFA_LOGIN_HTTP_TIMEOUT = 60; + public const string AuthName = "username_password_mfa"; + private const int MfaLoginHttpTimeout = 60; - internal MFACacheAuthenticator(SFSession session) : base(session, AUTH_NAME) + internal MFACacheAuthenticator(SFSession session) : base(session, AuthName) { } @@ -36,7 +36,7 @@ protected override void SetSpecializedAuthenticatorData(ref LoginRequestData dat // Only need to add the password to Data for basic authentication data.password = session.properties[SFSessionProperty.PASSWORD]; data.SessionParameters[SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN] = true; - data.HttpTimeout = TimeSpan.FromSeconds(_MFA_LOGIN_HTTP_TIMEOUT); + data.HttpTimeout = TimeSpan.FromSeconds(MfaLoginHttpTimeout); if (!string.IsNullOrEmpty(session._mfaToken?.ToString())) { data.Token = SecureStringHelper.Decode(session._mfaToken); diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs index 5b3b059ab..3dd547611 100644 --- a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs +++ b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs @@ -11,6 +11,7 @@ using System; using System.IO; using System.Runtime.InteropServices; +using System.Threading; using KeyTokenDict = System.Collections.Generic.Dictionary; namespace Snowflake.Data.Core.CredentialManager.Infrastructure @@ -25,6 +26,8 @@ internal class SFCredentialManagerFileImpl : ISnowflakeCredentialManager private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + private readonly ReaderWriterLockSlim _lock = new ReaderWriterLockSlim(); + private readonly string _jsonCacheDirectory; private readonly string _jsonCacheFilePath; @@ -88,7 +91,7 @@ internal void WriteToJsonFile(string content) } else { - _fileOperations.Write(_jsonCacheFilePath, content); + _fileOperations.Write(_jsonCacheFilePath, content, UnixOperations.ValidateFileWhenWriteIsAccessedOnlyByItsOwner); } var jsonPermissions = _unixOperations.GetFilePermissions(_jsonCacheFilePath); @@ -103,45 +106,69 @@ internal void WriteToJsonFile(string content) internal KeyTokenDict ReadJsonFile() { - var contentFile = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? File.ReadAllText(_jsonCacheFilePath) : _fileOperations.ReadAllText(_jsonCacheFilePath, TomlConnectionBuilder.ValidateFilePermissions); + var contentFile = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? File.ReadAllText(_jsonCacheFilePath) : _fileOperations.ReadAllText(_jsonCacheFilePath, UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner); return JsonConvert.DeserializeObject(contentFile); } public string GetCredentials(string key) { - s_logger.Debug($"Getting credentials from json file in {_jsonCacheFilePath} for key: {key}"); - if (_fileOperations.Exists(_jsonCacheFilePath)) + try { - var keyTokenPairs = ReadJsonFile(); - if (keyTokenPairs.TryGetValue(key, out string token)) + _lock.EnterReadLock(); + s_logger.Debug($"Getting credentials from json file in {_jsonCacheFilePath} for key: {key}"); + if (_fileOperations.Exists(_jsonCacheFilePath)) { - return token; + var keyTokenPairs = ReadJsonFile(); + if (keyTokenPairs.TryGetValue(key, out string token)) + { + return token; + } } - } - s_logger.Info("Unable to get credentials for the specified key"); - return ""; + s_logger.Info("Unable to get credentials for the specified key"); + return ""; + } + finally + { + _lock.ExitReadLock(); + } } public void RemoveCredentials(string key) { - s_logger.Debug($"Removing credentials from json file in {_jsonCacheFilePath} for key: {key}"); - if (_fileOperations.Exists(_jsonCacheFilePath)) + try { - var keyTokenPairs = ReadJsonFile(); - keyTokenPairs.Remove(key); - WriteToJsonFile(JsonConvert.SerializeObject(keyTokenPairs)); + _lock.EnterWriteLock(); + s_logger.Debug($"Removing credentials from json file in {_jsonCacheFilePath} for key: {key}"); + if (_fileOperations.Exists(_jsonCacheFilePath)) + { + var keyTokenPairs = ReadJsonFile(); + keyTokenPairs.Remove(key); + WriteToJsonFile(JsonConvert.SerializeObject(keyTokenPairs)); + } + } + finally + { + _lock.ExitWriteLock(); } } public void SaveCredentials(string key, string token) { - s_logger.Debug($"Saving credentials to json file in {_jsonCacheFilePath} for key: {key}"); - KeyTokenDict keyTokenPairs = _fileOperations.Exists(_jsonCacheFilePath) ? ReadJsonFile() : new KeyTokenDict(); - keyTokenPairs[key] = token; + try + { + _lock.EnterWriteLock(); + s_logger.Debug($"Saving credentials to json file in {_jsonCacheFilePath} for key: {key}"); + KeyTokenDict keyTokenPairs = _fileOperations.Exists(_jsonCacheFilePath) ? ReadJsonFile() : new KeyTokenDict(); + keyTokenPairs[key] = token; - string jsonString = JsonConvert.SerializeObject(keyTokenPairs); - WriteToJsonFile(jsonString); + string jsonString = JsonConvert.SerializeObject(keyTokenPairs); + WriteToJsonFile(jsonString); + } + finally + { + _lock.ExitWriteLock(); + } } } } diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerInMemoryImpl.cs b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerInMemoryImpl.cs index 21b7fa555..60c20485a 100644 --- a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerInMemoryImpl.cs +++ b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerInMemoryImpl.cs @@ -3,8 +3,10 @@ */ +using System; using System.Collections.Generic; using System.Security; +using System.Threading; using Snowflake.Data.Client; using Snowflake.Data.Core.Tools; using Snowflake.Data.Log; @@ -15,34 +17,60 @@ internal class SFCredentialManagerInMemoryImpl : ISnowflakeCredentialManager { private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + private readonly ReaderWriterLockSlim _lock = new ReaderWriterLockSlim(); + private Dictionary s_credentials = new Dictionary(); public static readonly SFCredentialManagerInMemoryImpl Instance = new SFCredentialManagerInMemoryImpl(); public string GetCredentials(string key) { - s_logger.Debug($"Getting credentials from memory for key: {key}"); - if (s_credentials.TryGetValue(key, out var secureToken)) + try { - return SecureStringHelper.Decode(secureToken); + _lock.EnterReadLock(); + s_logger.Debug($"Getting credentials from memory for key: {key}"); + if (s_credentials.TryGetValue(key, out var secureToken)) + { + return SecureStringHelper.Decode(secureToken); + } + else + { + s_logger.Info("Unable to get credentials for the specified key"); + return ""; + } } - else + finally { - s_logger.Info("Unable to get credentials for the specified key"); - return ""; + _lock.ExitReadLock(); } } public void RemoveCredentials(string key) { - s_logger.Debug($"Removing credentials from memory for key: {key}"); - s_credentials.Remove(key); + try + { + _lock.EnterWriteLock(); + s_logger.Debug($"Removing credentials from memory for key: {key}"); + s_credentials.Remove(key); + } + finally + { + _lock.ExitWriteLock(); + } } public void SaveCredentials(string key, string token) { - s_logger.Debug($"Saving credentials into memory for key: {key}"); - s_credentials[key] = SecureStringHelper.Encode(token); + try + { + _lock.EnterWriteLock(); + s_logger.Debug($"Saving credentials into memory for key: {key}"); + s_credentials[key] = SecureStringHelper.Encode(token); + } + finally + { + _lock.ExitWriteLock(); + } } } } diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs index 3b5c42954..efed5719e 100644 --- a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs +++ b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs @@ -6,6 +6,7 @@ using System; using System.Runtime.InteropServices; using System.Text; +using System.Threading; using Snowflake.Data.Client; using Snowflake.Data.Core.Tools; using Snowflake.Data.Log; @@ -17,52 +18,78 @@ internal class SFCredentialManagerWindowsNativeImpl : ISnowflakeCredentialManage { private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + private readonly ReaderWriterLockSlim _lock = new ReaderWriterLockSlim(); + public static readonly SFCredentialManagerWindowsNativeImpl Instance = new SFCredentialManagerWindowsNativeImpl(); public string GetCredentials(string key) { - s_logger.Debug($"Getting the credentials for key: {key}"); - IntPtr nCredPtr; - if (!CredRead(key, 1 /* Generic */, 0, out nCredPtr)) + try { - s_logger.Info($"Unable to get credentials for key: {key}"); - return ""; - } + _lock.EnterReadLock(); + s_logger.Debug($"Getting the credentials for key: {key}"); + IntPtr nCredPtr; + if (!CredRead(key, 1 /* Generic */, 0, out nCredPtr)) + { + s_logger.Info($"Unable to get credentials for key: {key}"); + return ""; + } - using (var critCred = new CriticalCredentialHandle(nCredPtr)) + using (var critCred = new CriticalCredentialHandle(nCredPtr)) + { + var cred = critCred.GetCredential(); + return cred.CredentialBlob; + } + } + finally { - var cred = critCred.GetCredential(); - return cred.CredentialBlob; + _lock.ExitReadLock(); } } public void RemoveCredentials(string key) { - s_logger.Debug($"Removing the credentials for key: {key}"); + try + { + _lock.EnterWriteLock(); + s_logger.Debug($"Removing the credentials for key: {key}"); - if (!CredDelete(key, 1 /* Generic */, 0)) + if (!CredDelete(key, 1 /* Generic */, 0)) + { + s_logger.Info($"Unable to remove credentials because the specified key did not exist: {key}"); + } + } + finally { - s_logger.Info($"Unable to remove credentials because the specified key did not exist: {key}"); + _lock.ExitWriteLock(); } } public void SaveCredentials(string key, string token) { - s_logger.Debug($"Saving the credentials for key: {key}"); - byte[] byteArray = Encoding.Unicode.GetBytes(token); - Credential credential = new Credential(); - credential.AttributeCount = 0; - credential.Attributes = IntPtr.Zero; - credential.Comment = IntPtr.Zero; - credential.TargetAlias = IntPtr.Zero; - credential.Type = 1; // Generic - credential.Persist = 2; // Local Machine - credential.CredentialBlobSize = (uint)(byteArray == null ? 0 : byteArray.Length); - credential.TargetName = key; - credential.CredentialBlob = token; - credential.UserName = Environment.UserName; - - CredWrite(ref credential, 0); + try + { + _lock.EnterWriteLock(); + s_logger.Debug($"Saving the credentials for key: {key}"); + byte[] byteArray = Encoding.Unicode.GetBytes(token); + Credential credential = new Credential(); + credential.AttributeCount = 0; + credential.Attributes = IntPtr.Zero; + credential.Comment = IntPtr.Zero; + credential.TargetAlias = IntPtr.Zero; + credential.Type = 1; // Generic + credential.Persist = 2; // Local Machine + credential.CredentialBlobSize = (uint)(byteArray == null ? 0 : byteArray.Length); + credential.TargetName = key; + credential.CredentialBlob = token; + credential.UserName = Environment.UserName; + + CredWrite(ref credential, 0); + } + finally + { + _lock.ExitWriteLock(); + } } [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] diff --git a/Snowflake.Data/Core/Session/SFSessionProperty.cs b/Snowflake.Data/Core/Session/SFSessionProperty.cs index a9663961d..5575f7c63 100644 --- a/Snowflake.Data/Core/Session/SFSessionProperty.cs +++ b/Snowflake.Data/Core/Session/SFSessionProperty.cs @@ -314,7 +314,7 @@ private static void ValidateAuthenticator(SFSessionProperties properties) OAuthAuthenticator.AUTH_NAME, KeyPairAuthenticator.AUTH_NAME, ExternalBrowserAuthenticator.AUTH_NAME, - MFACacheAuthenticator.AUTH_NAME + MFACacheAuthenticator.AuthName }; if (properties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator)) diff --git a/Snowflake.Data/Core/Session/SessionPool.cs b/Snowflake.Data/Core/Session/SessionPool.cs index d58c06223..abadd88e5 100644 --- a/Snowflake.Data/Core/Session/SessionPool.cs +++ b/Snowflake.Data/Core/Session/SessionPool.cs @@ -142,7 +142,7 @@ internal SFSession GetSession(string connStr, SecureString password, SecureStrin ValidateMinPoolSizeWithPasscode(sessionProperties, passcode); if (!GetPooling()) return NewNonPoolingSession(connStr, password, passcode); - var isMfaAuthentication = sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) && authenticator == MFACacheAuthenticator.AUTH_NAME; + var isMfaAuthentication = sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) && authenticator == MFACacheAuthenticator.AuthName; var sessionOrCreateTokens = GetIdleSession(connStr, isMfaAuthentication ? 1 : int.MaxValue); if (sessionOrCreateTokens.Session != null) { @@ -165,7 +165,7 @@ private void ValidateMinPoolSizeWithPasscode(SFSessionProperties sessionProperti (sessionProperties.TryGetValue(SFSessionProperty.PASSCODEINPASSWORD, out var passcodeInPasswordValue) && bool.TryParse(passcodeInPasswordValue, out var isPasscodeinPassword) && isPasscodeinPassword)); var isMfaAuthenticator = sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) && - authenticator == MFACacheAuthenticator.AUTH_NAME; + authenticator == MFACacheAuthenticator.AuthName; if(isUsingPasscode && !isMfaAuthenticator) { const string ErrorMessage = "Passcode with MinPoolSize feature of connection pool allowed only for username_password_mfa authentication"; @@ -181,7 +181,7 @@ internal async Task GetSessionAsync(string connStr, SecureString pass ValidateMinPoolSizeWithPasscode(sessionProperties, passcode); if (!GetPooling()) return await NewNonPoolingSessionAsync(connStr, password, passcode, cancellationToken).ConfigureAwait(false); - var isMfaAuthentication = sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) && authenticator == MFACacheAuthenticator.AUTH_NAME; + var isMfaAuthentication = sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) && authenticator == MFACacheAuthenticator.AuthName; var sessionOrCreateTokens = GetIdleSession(connStr, isMfaAuthentication ? 1 : int.MaxValue); WarnAboutOverridenConfig(); diff --git a/Snowflake.Data/Core/TomlConnectionBuilder.cs b/Snowflake.Data/Core/TomlConnectionBuilder.cs index a8c2396b1..6206c856b 100644 --- a/Snowflake.Data/Core/TomlConnectionBuilder.cs +++ b/Snowflake.Data/Core/TomlConnectionBuilder.cs @@ -116,7 +116,7 @@ private string LoadTokenFromFile(string tokenFilePathValue) tokenFile = tokenFilePathValue; } s_logger.Info($"Read token from file path: {tokenFile}"); - return _fileOperations.Exists(tokenFile) ? _fileOperations.ReadAllText(tokenFile, ValidateFilePermissions) : null; + return _fileOperations.Exists(tokenFile) ? _fileOperations.ReadAllText(tokenFile, UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner) : null; } private TomlTable GetTomlTableFromConfig(string tomlPath, string connectionName) @@ -126,7 +126,7 @@ private TomlTable GetTomlTableFromConfig(string tomlPath, string connectionName) return null; } - var tomlContent = _fileOperations.ReadAllText(tomlPath, ValidateFilePermissions) ?? string.Empty; + var tomlContent = _fileOperations.ReadAllText(tomlPath, UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner) ?? string.Empty; var toml = Toml.ToModel(tomlContent); if (string.IsNullOrEmpty(connectionName)) { @@ -152,20 +152,5 @@ private string ResolveConnectionTomlFile() var tomlPath = Path.Combine(tomlFolder, "connections.toml"); return tomlPath; } - - internal static void ValidateFilePermissions(UnixStream stream) - { - var allowedPermissions = new FileAccessPermissions[] - { - FileAccessPermissions.UserRead | FileAccessPermissions.UserWrite, - FileAccessPermissions.UserRead - }; - if (stream.OwnerUser.UserId != Syscall.geteuid()) - throw new SecurityException("Attempting to read a file not owned by the effective user of the current process"); - if (stream.OwnerGroup.GroupId != Syscall.getegid()) - throw new SecurityException("Attempting to read a file not owned by the effective group of the current process"); - if (!(allowedPermissions.Any(a => stream.FileAccessPermissions == a))) - throw new SecurityException("Attempting to read a file with too broad permissions assigned"); - } } } diff --git a/Snowflake.Data/Core/Tools/FileOperations.cs b/Snowflake.Data/Core/Tools/FileOperations.cs index a03e1a22b..5d6f357ea 100644 --- a/Snowflake.Data/Core/Tools/FileOperations.cs +++ b/Snowflake.Data/Core/Tools/FileOperations.cs @@ -20,7 +20,17 @@ public virtual bool Exists(string path) return File.Exists(path); } - public virtual void Write(string path, string content) => File.WriteAllText(path, content); + public virtual void Write(string path, string content, Action validator = null) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) || validator == null) + { + File.WriteAllText(path, content); + } + else + { + _unixOperations.WriteAllText(path, content, validator); + } + } public virtual string ReadAllText(string path) { diff --git a/Snowflake.Data/Core/Tools/UnixOperations.cs b/Snowflake.Data/Core/Tools/UnixOperations.cs index f0d41a312..b09d02973 100644 --- a/Snowflake.Data/Core/Tools/UnixOperations.cs +++ b/Snowflake.Data/Core/Tools/UnixOperations.cs @@ -4,6 +4,7 @@ using System; using System.IO; +using System.Linq; using System.Security; using System.Text; using Mono.Unix; @@ -44,7 +45,7 @@ public virtual bool CheckFileHasAnyOfPermissions(string path, FileAccessPermissi return (permissions & fileInfo.FileAccessPermissions) != 0; } - public string ReadAllText(string path, Action validator) + public string ReadAllText(string path, Action validator) { var fileInfo = new UnixFileInfo(path: path); @@ -57,5 +58,45 @@ public string ReadAllText(string path, Action validator) } } } + + public void WriteAllText(string path, string content, Action validator) + { + var fileInfo = new UnixFileInfo(path: path); + + using (var handle = fileInfo.OpenRead()) + { + validator?.Invoke(handle); + } + File.WriteAllText(path, content); + } + + internal static void ValidateFileWhenReadIsAccessedOnlyByItsOwner(UnixStream stream) + { + var allowedPermissions = new[] + { + FileAccessPermissions.UserRead | FileAccessPermissions.UserWrite, + FileAccessPermissions.UserRead + }; + if (stream.OwnerUser.UserId != Syscall.geteuid()) + throw new SecurityException("Attempting to read a file not owned by the effective user of the current process"); + if (stream.OwnerGroup.GroupId != Syscall.getegid()) + throw new SecurityException("Attempting to read a file not owned by the effective group of the current process"); + if (!(allowedPermissions.Any(a => stream.FileAccessPermissions == a))) + throw new SecurityException("Attempting to read a file with too broad permissions assigned"); + } + + internal static void ValidateFileWhenWriteIsAccessedOnlyByItsOwner(UnixStream stream) + { + var allowedPermissions = new[] + { + FileAccessPermissions.UserRead | FileAccessPermissions.UserWrite + }; + if (stream.OwnerUser.UserId != Syscall.geteuid()) + throw new SecurityException("Attempting to write a file not owned by the effective user of the current process"); + if (stream.OwnerGroup.GroupId != Syscall.getegid()) + throw new SecurityException("Attempting to write a file not owned by the effective group of the current process"); + if (!(allowedPermissions.Any(a => stream.FileAccessPermissions == a))) + throw new SecurityException("Attempting to write a file with too broad permissions assigned"); + } } }