Skip to content

Commit

Permalink
Added lock mechanism when applying credential manager operations.
Browse files Browse the repository at this point in the history
Added write file validator for file permissions.
Additional PR suggestions
  • Loading branch information
sfc-gh-jmartinezramirez committed Nov 20, 2024
1 parent b74079b commit 7df5cd8
Show file tree
Hide file tree
Showing 12 changed files with 254 additions and 91 deletions.
6 changes: 3 additions & 3 deletions Snowflake.Data.Tests/UnitTests/Tools/FileOperationsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public void TestReadAllTextOnWindows()
var filePath = CreateConfigTempFile(s_workingDirectory, content);

// act
var result = s_fileOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions);
var result = s_fileOperations.ReadAllText(filePath, UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner);

// assert
Assert.AreEqual(content, result);
Expand All @@ -73,7 +73,7 @@ public void TestReadAllTextCheckingPermissionsUsingTomlConfigurationFileValidati
Syscall.chmod(filePath, (FilePermissions)filePermissions);

// act
var result = s_fileOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions);
var result = s_fileOperations.ReadAllText(filePath, UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner);

// assert
Assert.AreEqual(content, result);
Expand All @@ -96,7 +96,7 @@ public void TestShouldThrowExceptionIfOtherPermissionsIsSetWhenReadConfiguration
Syscall.chmod(filePath, (FilePermissions)filePermissions);

// act and assert
Assert.Throws<SecurityException>(() => s_fileOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions),
Assert.Throws<SecurityException>(() => s_fileOperations.ReadAllText(filePath, UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner),
"Attempting to read a file with too broad permissions assigned");
}

Expand Down
51 changes: 48 additions & 3 deletions Snowflake.Data.Tests/UnitTests/Tools/UnixOperationsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,30 @@ public void TestReadAllTextCheckingPermissionsUsingTomlConfigurationFileValidati
Syscall.chmod(filePath, userAllowedPermissions);

// act
var result = s_unixOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions);
var result = s_unixOperations.ReadAllText(filePath, UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner);

// assert
Assert.AreEqual(content, result);
}

[Test]
public void TestFailIfGroupOrOthersHavePermissionsToFileWithTomlConfigurationValidations([ValueSource(nameof(UserReadWritePermissions))] FilePermissions userPermissions,
public void TestWriteAllTextCheckingPermissionsUsingTomlConfigurationFileValidations(
[ValueSource(nameof(UserAllowedWritePermissions))] FilePermissions userAllowedPermissions)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
Assert.Ignore("skip test on Windows");
}
var content = "random text";
var filePath = CreateConfigTempFile(s_workingDirectory, content);
Syscall.chmod(filePath, userAllowedPermissions);

// act and assert
Assert.DoesNotThrow(() => s_unixOperations.WriteAllText(filePath,"test", UnixOperations.ValidateFileWhenWriteIsAccessedOnlyByItsOwner));
}

[Test]
public void TestFailIfGroupOrOthersHavePermissionsToFileWhileReadingWithUnixValidations([ValueSource(nameof(UserReadWritePermissions))] FilePermissions userPermissions,
[ValueSource(nameof(GroupPermissions))] FilePermissions groupPermissions,
[ValueSource(nameof(OthersPermissions))] FilePermissions othersPermissions)
{
Expand All @@ -123,7 +139,31 @@ public void TestFailIfGroupOrOthersHavePermissionsToFileWithTomlConfigurationVal
Syscall.chmod(filePath, filePermissions);

// act and assert
Assert.Throws<SecurityException>(() => s_unixOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions), "Attempting to read a file with too broad permissions assigned");
Assert.Throws<SecurityException>(() => s_unixOperations.ReadAllText(filePath, UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner), "Attempting to read a file with too broad permissions assigned");
}

[Test]
public void TestFailIfGroupOrOthersHavePermissionsToFileWhileWritingWithUnixValidations([ValueSource(nameof(UserReadWritePermissions))] FilePermissions userPermissions,
[ValueSource(nameof(GroupPermissions))] FilePermissions groupPermissions,
[ValueSource(nameof(OthersPermissions))] FilePermissions othersPermissions)
{
if(groupPermissions == 0 && othersPermissions == 0)
{
Assert.Ignore("Skip test when group and others have no permissions");
}

if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
Assert.Ignore("skip test on Windows");
}
var content = "random text";
var filePath = CreateConfigTempFile(s_workingDirectory, content);

var filePermissions = userPermissions | groupPermissions | othersPermissions;
Syscall.chmod(filePath, filePermissions);

// act and assert
Assert.Throws<SecurityException>(() => s_unixOperations.WriteAllText(filePath, "test", UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner), "Attempting to write a file with too broad permissions assigned");
}

public static IEnumerable<FilePermissions> UserPermissions()
Expand Down Expand Up @@ -186,6 +226,11 @@ public static IEnumerable<FilePermissions> UserAllowedPermissions()
yield return FilePermissions.S_IRUSR | FilePermissions.S_IWUSR;
}

public static IEnumerable<FilePermissions> UserAllowedWritePermissions()
{
yield return FilePermissions.S_IRUSR | FilePermissions.S_IWUSR;
}

public static IEnumerable<FilePermissions> GroupOrOthersReadablePermissions()
{
yield return 0;
Expand Down
2 changes: 1 addition & 1 deletion Snowflake.Data/Core/Authenticator/IAuthenticator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ internal static IAuthenticator GetAuthenticator(SFSession session)

return new OAuthAuthenticator(session);
}
else if (type.Equals(MFACacheAuthenticator.AUTH_NAME, StringComparison.InvariantCultureIgnoreCase))
else if (type.Equals(MFACacheAuthenticator.AuthName, StringComparison.InvariantCultureIgnoreCase))
{
return new MFACacheAuthenticator(session);
}
Expand Down
8 changes: 4 additions & 4 deletions Snowflake.Data/Core/Authenticator/MFACacheAuthenticator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ namespace Snowflake.Data.Core.Authenticator
{
class MFACacheAuthenticator : BaseAuthenticator, IAuthenticator
{
public const string AUTH_NAME = "username_password_mfa";
private const int _MFA_LOGIN_HTTP_TIMEOUT = 60;
public const string AuthName = "username_password_mfa";
private const int MfaLoginHttpTimeout = 60;

internal MFACacheAuthenticator(SFSession session) : base(session, AUTH_NAME)
internal MFACacheAuthenticator(SFSession session) : base(session, AuthName)
{
}

Expand All @@ -36,7 +36,7 @@ protected override void SetSpecializedAuthenticatorData(ref LoginRequestData dat
// Only need to add the password to Data for basic authentication
data.password = session.properties[SFSessionProperty.PASSWORD];
data.SessionParameters[SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN] = true;
data.HttpTimeout = TimeSpan.FromSeconds(_MFA_LOGIN_HTTP_TIMEOUT);
data.HttpTimeout = TimeSpan.FromSeconds(MfaLoginHttpTimeout);
if (!string.IsNullOrEmpty(session._mfaToken?.ToString()))
{
data.Token = SecureStringHelper.Decode(session._mfaToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System;
using System.IO;
using System.Runtime.InteropServices;
using System.Threading;
using KeyTokenDict = System.Collections.Generic.Dictionary<string, string>;

namespace Snowflake.Data.Core.CredentialManager.Infrastructure
Expand All @@ -25,6 +26,8 @@ internal class SFCredentialManagerFileImpl : ISnowflakeCredentialManager

private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger<SFCredentialManagerFileImpl>();

private readonly ReaderWriterLockSlim _lock = new ReaderWriterLockSlim();

private readonly string _jsonCacheDirectory;

private readonly string _jsonCacheFilePath;
Expand Down Expand Up @@ -88,7 +91,7 @@ internal void WriteToJsonFile(string content)
}
else
{
_fileOperations.Write(_jsonCacheFilePath, content);
_fileOperations.Write(_jsonCacheFilePath, content, UnixOperations.ValidateFileWhenWriteIsAccessedOnlyByItsOwner);
}

var jsonPermissions = _unixOperations.GetFilePermissions(_jsonCacheFilePath);
Expand All @@ -103,45 +106,69 @@ internal void WriteToJsonFile(string content)

internal KeyTokenDict ReadJsonFile()
{
var contentFile = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? File.ReadAllText(_jsonCacheFilePath) : _fileOperations.ReadAllText(_jsonCacheFilePath, TomlConnectionBuilder.ValidateFilePermissions);
var contentFile = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? File.ReadAllText(_jsonCacheFilePath) : _fileOperations.ReadAllText(_jsonCacheFilePath, UnixOperations.ValidateFileWhenReadIsAccessedOnlyByItsOwner);
return JsonConvert.DeserializeObject<KeyTokenDict>(contentFile);
}

public string GetCredentials(string key)
{
s_logger.Debug($"Getting credentials from json file in {_jsonCacheFilePath} for key: {key}");
if (_fileOperations.Exists(_jsonCacheFilePath))
try
{
var keyTokenPairs = ReadJsonFile();
if (keyTokenPairs.TryGetValue(key, out string token))
_lock.EnterReadLock();
s_logger.Debug($"Getting credentials from json file in {_jsonCacheFilePath} for key: {key}");
if (_fileOperations.Exists(_jsonCacheFilePath))
{
return token;
var keyTokenPairs = ReadJsonFile();
if (keyTokenPairs.TryGetValue(key, out string token))
{
return token;
}
}
}

s_logger.Info("Unable to get credentials for the specified key");
return "";
s_logger.Info("Unable to get credentials for the specified key");
return "";
}
finally
{
_lock.ExitReadLock();
}
}

public void RemoveCredentials(string key)
{
s_logger.Debug($"Removing credentials from json file in {_jsonCacheFilePath} for key: {key}");
if (_fileOperations.Exists(_jsonCacheFilePath))
try
{
var keyTokenPairs = ReadJsonFile();
keyTokenPairs.Remove(key);
WriteToJsonFile(JsonConvert.SerializeObject(keyTokenPairs));
_lock.EnterWriteLock();
s_logger.Debug($"Removing credentials from json file in {_jsonCacheFilePath} for key: {key}");
if (_fileOperations.Exists(_jsonCacheFilePath))
{
var keyTokenPairs = ReadJsonFile();
keyTokenPairs.Remove(key);
WriteToJsonFile(JsonConvert.SerializeObject(keyTokenPairs));
}
}
finally
{
_lock.ExitWriteLock();
}
}

public void SaveCredentials(string key, string token)
{
s_logger.Debug($"Saving credentials to json file in {_jsonCacheFilePath} for key: {key}");
KeyTokenDict keyTokenPairs = _fileOperations.Exists(_jsonCacheFilePath) ? ReadJsonFile() : new KeyTokenDict();
keyTokenPairs[key] = token;
try
{
_lock.EnterWriteLock();
s_logger.Debug($"Saving credentials to json file in {_jsonCacheFilePath} for key: {key}");
KeyTokenDict keyTokenPairs = _fileOperations.Exists(_jsonCacheFilePath) ? ReadJsonFile() : new KeyTokenDict();
keyTokenPairs[key] = token;

string jsonString = JsonConvert.SerializeObject(keyTokenPairs);
WriteToJsonFile(jsonString);
string jsonString = JsonConvert.SerializeObject(keyTokenPairs);
WriteToJsonFile(jsonString);
}
finally
{
_lock.ExitWriteLock();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
*/


using System;
using System.Collections.Generic;
using System.Security;
using System.Threading;
using Snowflake.Data.Client;
using Snowflake.Data.Core.Tools;
using Snowflake.Data.Log;
Expand All @@ -15,34 +17,60 @@ internal class SFCredentialManagerInMemoryImpl : ISnowflakeCredentialManager
{
private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger<SFCredentialManagerInMemoryImpl>();

private readonly ReaderWriterLockSlim _lock = new ReaderWriterLockSlim();

private Dictionary<string, SecureString> s_credentials = new Dictionary<string, SecureString>();

public static readonly SFCredentialManagerInMemoryImpl Instance = new SFCredentialManagerInMemoryImpl();

public string GetCredentials(string key)
{
s_logger.Debug($"Getting credentials from memory for key: {key}");
if (s_credentials.TryGetValue(key, out var secureToken))
try
{
return SecureStringHelper.Decode(secureToken);
_lock.EnterReadLock();
s_logger.Debug($"Getting credentials from memory for key: {key}");
if (s_credentials.TryGetValue(key, out var secureToken))
{
return SecureStringHelper.Decode(secureToken);
}
else
{
s_logger.Info("Unable to get credentials for the specified key");
return "";
}
}
else
finally
{
s_logger.Info("Unable to get credentials for the specified key");
return "";
_lock.ExitReadLock();
}
}

public void RemoveCredentials(string key)
{
s_logger.Debug($"Removing credentials from memory for key: {key}");
s_credentials.Remove(key);
try
{
_lock.EnterWriteLock();
s_logger.Debug($"Removing credentials from memory for key: {key}");
s_credentials.Remove(key);
}
finally
{
_lock.ExitWriteLock();
}
}

public void SaveCredentials(string key, string token)
{
s_logger.Debug($"Saving credentials into memory for key: {key}");
s_credentials[key] = SecureStringHelper.Encode(token);
try
{
_lock.EnterWriteLock();
s_logger.Debug($"Saving credentials into memory for key: {key}");
s_credentials[key] = SecureStringHelper.Encode(token);
}
finally
{
_lock.ExitWriteLock();
}
}
}
}
Loading

0 comments on commit 7df5cd8

Please sign in to comment.