From 67089fcb44797fdff3de81b4c9b49a379085b61e Mon Sep 17 00:00:00 2001 From: Juan Martinez Ramirez Date: Thu, 11 Jul 2024 15:39:31 -0600 Subject: [PATCH] Change validation process for session pool, if using passcode in connection string without username_password_authentication an exception will be thrown to indicate the user that the passcode should not be used if pooling is enabled or with a minimum pool size greater than 0. Additionally, if the passcode is provided by an argument and not part of the connection string, it will not be used for the session created by the session pool, and the push MFA mechanism will be triggered. --- .../MockLoginMFATokenCacheRestRequester.cs | 6 ++ .../UnitTests/ConnectionPoolManagerMFATest.cs | 78 +++++++++++++------ .../Session/SFSessionHttpClientProperties.cs | 4 - Snowflake.Data/Core/Session/SessionPool.cs | 52 ++++++------- 4 files changed, 85 insertions(+), 55 deletions(-) diff --git a/Snowflake.Data.Tests/Mock/MockLoginMFATokenCacheRestRequester.cs b/Snowflake.Data.Tests/Mock/MockLoginMFATokenCacheRestRequester.cs index d2e8d5319..163124b7d 100644 --- a/Snowflake.Data.Tests/Mock/MockLoginMFATokenCacheRestRequester.cs +++ b/Snowflake.Data.Tests/Mock/MockLoginMFATokenCacheRestRequester.cs @@ -79,5 +79,11 @@ public void setHttpClient(HttpClient httpClient) { // Nothing to do } + + public void Reset() + { + LoginRequests.Clear(); + LoginResponses.Clear(); + } } } diff --git a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs index c99be5a45..139c2a4f1 100644 --- a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs +++ b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs @@ -2,25 +2,28 @@ * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. */ -using System.Security; -using System.Threading; -using NUnit.Framework; -using Snowflake.Data.Core; -using Snowflake.Data.Core.Session; -using Snowflake.Data.Client; -using Snowflake.Data.Core.Tools; -using Snowflake.Data.Tests.Util; + namespace Snowflake.Data.Tests.UnitTests { using System; + using System.Linq; + using System.Security; + using System.Threading; using Mock; - - [TestFixture, NonParallelizable] + using NUnit.Framework; + using Snowflake.Data.Core; + using Snowflake.Data.Core.Session; + using Snowflake.Data.Client; + using Snowflake.Data.Core.Tools; + using Snowflake.Data.Tests.Util; + + [TestFixture] class ConnectionPoolManagerMFATest { private readonly ConnectionPoolManager _connectionPoolManager = new ConnectionPoolManager(); private const string ConnectionStringMFACache = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;authenticator=username_password_mfa"; + private const string ConnectionStringMFABasicWithoutPasscode = "db=D2;warehouse=W2;account=A2;user=U2;password=P2;role=R2;minPoolSize=3;"; private static PoolConfig s_poolConfig; private static MockLoginMFATokenCacheRestRequester s_restRequester; @@ -44,6 +47,7 @@ public static void AfterAllTests() public void BeforeEach() { _connectionPoolManager.ClearAllPools(); + s_restRequester.Reset(); } [Test] @@ -79,6 +83,35 @@ public void TestPoolManagerReturnsSessionPoolForGivenConnectionStringUsingMFA() Assert.AreEqual("passcode", loginRequest1.data.extAuthnDuoMethod); } + [Test] + public void TestPoolManagerShouldOnlyUsePasscodeAsArgumentForFirstSessionWhenNotUsingMFAAuthenticator() + { + // Arrange + const string TestPasscode = "123456"; + s_restRequester.LoginResponses.Enqueue(new LoginResponseData() + { + authResponseSessionInfo = new SessionInfo() + }); + s_restRequester.LoginResponses.Enqueue(new LoginResponseData() + { + authResponseSessionInfo = new SessionInfo() + }); + s_restRequester.LoginResponses.Enqueue(new LoginResponseData() + { + authResponseSessionInfo = new SessionInfo() + }); + // Act + var session = _connectionPoolManager.GetSession(ConnectionStringMFABasicWithoutPasscode, null, SecureStringHelper.Encode(TestPasscode)); + Thread.Sleep(3000); + + // Assert + + Assert.AreEqual(3, s_restRequester.LoginRequests.Count); + var request = s_restRequester.LoginRequests.ToList(); + Assert.AreEqual(1, request.Count(r => r.data.extAuthnDuoMethod == "passcode" && r.data.passcode == TestPasscode)); + Assert.AreEqual(2, request.Count(r => r.data.extAuthnDuoMethod == "push" && r.data.passcode == null)); + } + [Test] public void TestPoolManagerShouldThrowExceptionIfForcePoolingWithPasscodeNotUsingMFATokenCacheAuthenticator() { @@ -86,24 +119,25 @@ public void TestPoolManagerShouldThrowExceptionIfForcePoolingWithPasscodeNotUsin var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;POOLINGENABLED=true"; // Act and assert var thrown = Assert.Throws(() =>_connectionPoolManager.GetSession(connectionString, null)); - Assert.That(thrown.Message, Does.Contain("Could not get a pool because passcode was provided using a different authenticator than username_password_mfa")); + Assert.That(thrown.Message, Does.Contain("Could not use connection pool because passcode was provided using a different authenticator than username_password_mfa")); } [Test] - public void TestPoolManagerShouldDisablePoolingWhenPassingPasscodeNotUsingMFATokenCacheAuthenticator() + public void TestPoolManagerShouldNotThrowExceptionIfForcePoolingWithPasscodeNotUsingMFATokenCacheAuthenticator() { // Arrange - var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;"; - var pool = _connectionPoolManager.GetPool(connectionString); - // Act - var session = _connectionPoolManager.GetSession(connectionString, null); - - // Asssert - // TODO: Review pool config is not the same for session and session pool - // Assert.IsFalse(session.GetPooling()); - Assert.AreEqual(0, pool.GetCurrentPoolSize()); - Assert.IsFalse(pool.GetPooling()); + var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;POOLINGENABLED=false"; + // Act and assert + Assert.DoesNotThrow(() =>_connectionPoolManager.GetSession(connectionString, null)); + } + [Test] + public void TestPoolManagerShouldNotThrowExceptionIfMinPoolSizeZeroNotUsingMFATokenCacheAuthenticator() + { + // Arrange + var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=0;passcode=12345;POOLINGENABLED=true"; + // Act and assert + Assert.DoesNotThrow(() =>_connectionPoolManager.GetSession(connectionString, null)); } } diff --git a/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs b/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs index bd857eb45..2d818f8c8 100644 --- a/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs +++ b/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs @@ -60,10 +60,6 @@ public void DisablePoolingDefaultIfSecretsProvidedExternally(SFSessionProperties && !properties.IsNonEmptyValueProvided(SFSessionProperty.PRIVATE_KEY_PWD)) { DisablePoolingIfNotExplicitlyEnabled(properties, "key pair with private key in a file"); - } else if (!MFACacheAuthenticator.AUTH_NAME.Equals(authenticator) - && properties.IsNonEmptyValueProvided(SFSessionProperty.PASSCODE)) - { - DisablePoolingIfNotExplicitlyEnabled(properties, "mfa authentication without token cache"); } } diff --git a/Snowflake.Data/Core/Session/SessionPool.cs b/Snowflake.Data/Core/Session/SessionPool.cs index 66c9facb3..8514fc62c 100644 --- a/Snowflake.Data/Core/Session/SessionPool.cs +++ b/Snowflake.Data/Core/Session/SessionPool.cs @@ -145,8 +145,8 @@ internal SFSession GetSession(string connStr, SecureString password, SecureStrin { s_logger.Debug("SessionPool::GetSession" + PoolIdentification()); SFSession session = null; - var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password, passcode); - ValidatePoolingIfPasscodeProvided(passcode, sessionProperties); + var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password); + ValidatePoolingIfPasscodeProvided(sessionProperties); if (!GetPooling()) return NewNonPoolingSession(connStr, password, passcode); var sessionOrCreateTokens = GetIdleSession(connStr); @@ -156,42 +156,37 @@ internal SFSession GetSession(string connStr, SecureString password, SecureStrin { _sessionPoolEventHandler.OnSessionProvided(this); } - ScheduleNewIdleSessions(connStr, password, passcode, sessionOrCreateTokens.BackgroundSessionCreationTokens()); + ScheduleNewIdleSessions(connStr, password, RegisterSessionCreationsToEnsureMinPoolSize()); WarnAboutOverridenConfig(); return session ?? sessionOrCreateTokens.Session ?? NewSession(connStr, password, passcode, sessionOrCreateTokens.SessionCreationToken()); } - private void ValidatePoolingIfPasscodeProvided(SecureString passcode, SFSessionProperties sessionProperties) + private void ValidatePoolingIfPasscodeProvided(SFSessionProperties sessionProperties) { - if (!GetPooling()) return; - var isUsingPasscode = ((passcode != null && !SecureStringHelper.Decode(passcode).IsNullOrEmpty()) || - sessionProperties.IsNonEmptyValueProvided(SFSessionProperty.PASSCODE) || + if (!GetPooling() || _poolConfig.MinPoolSize == 0) return; + var isUsingPasscode = (sessionProperties.IsNonEmptyValueProvided(SFSessionProperty.PASSCODE) || (sessionProperties.TryGetValue(SFSessionProperty.PASSCODEINPASSWORD, out var passcodeInPasswordValue) && bool.TryParse(passcodeInPasswordValue, out var isPasscodeinPassword) && isPasscodeinPassword)); - if(!isUsingPasscode) return; var isMfaAuthenticator = sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) && authenticator == MFACacheAuthenticator.AUTH_NAME; - - if (isMfaAuthenticator) return; - if (sessionProperties.IsPoolingEnabledValueProvided) + if(isUsingPasscode && !isMfaAuthenticator) { - const string ErrorMessage = "Could not get a pool because passcode was provided using a different authenticator than username_password_mfa"; + const string ErrorMessage = "Could not use connection pool because passcode was provided using a different authenticator than username_password_mfa"; s_logger.Error(ErrorMessage + PoolIdentification()); throw new Exception(ErrorMessage); } - s_logger.Warn("Pooling is disabled because passcode was provided using a different authenticator than username_password_mfa" + PoolIdentification()); - _poolConfig.PoolingEnabled = false; } internal async Task GetSessionAsync(string connStr, SecureString password, SecureString passcode, CancellationToken cancellationToken) { s_logger.Debug("SessionPool::GetSessionAsync" + PoolIdentification()); SFSession session = null; - var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password, passcode); - ValidatePoolingIfPasscodeProvided(passcode, sessionProperties); + var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password); + ValidatePoolingIfPasscodeProvided(sessionProperties); if (!GetPooling()) return await NewNonPoolingSessionAsync(connStr, password, passcode, cancellationToken).ConfigureAwait(false); var sessionOrCreateTokens = GetIdleSession(connStr); + WarnAboutOverridenConfig(); if (sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) && authenticator == MFACacheAuthenticator.AUTH_NAME) session = sessionOrCreateTokens.Session ?? @@ -201,21 +196,20 @@ await NewSessionAsync(connStr, password, passcode, sessionOrCreateTokens.Session { _sessionPoolEventHandler.OnSessionProvided(this); } - ScheduleNewIdleSessions(connStr, password, passcode, sessionOrCreateTokens.BackgroundSessionCreationTokens()); - WarnAboutOverridenConfig(); + ScheduleNewIdleSessions(connStr, password, RegisterSessionCreationsToEnsureMinPoolSize()); return session ?? sessionOrCreateTokens.Session ?? await NewSessionAsync(connStr, password, passcode, sessionOrCreateTokens.SessionCreationToken(), cancellationToken).ConfigureAwait(false); } - private void ScheduleNewIdleSessions(string connStr, SecureString password, SecureString passcode, List tokens) + private void ScheduleNewIdleSessions(string connStr, SecureString password, List tokens) { - tokens.ForEach(token => ScheduleNewIdleSession(connStr, password, passcode, token)); + tokens.ForEach(token => ScheduleNewIdleSession(connStr, password, token)); } - private void ScheduleNewIdleSession(string connStr, SecureString password, SecureString passcode, SessionCreationToken token) + private void ScheduleNewIdleSession(string connStr, SecureString password, SessionCreationToken token) { Task.Run(() => { - var session = NewSession(connStr, password, passcode, token); + var session = NewSession(connStr, password, null, token); AddSession(session, false); // we don't want to ensure min pool size here because we could get into infinite recursion if expirationTimeout would be very low }); } @@ -258,7 +252,7 @@ private SessionOrCreationTokens GetIdleSession(string connStr) return new SessionOrCreationTokens(session); } s_logger.Debug("SessionPool::GetIdleSession - no thread was waiting for a session, but could not find any idle session available in the pool" + PoolIdentification()); - var sessionsCount = AllowedNumberOfNewSessionCreations(1); + var sessionsCount = Math.Min(1, AllowedNumberOfNewSessionCreations(1)); if (sessionsCount > 0) { // there is no need to wait for a session since we can create new ones @@ -269,7 +263,7 @@ private SessionOrCreationTokens GetIdleSession(string connStr) return new SessionOrCreationTokens(WaitForSession(connStr)); } - private List RegisterSessionCreationsWhenReturningSessionToPool() + private List RegisterSessionCreationsToEnsureMinPoolSize() { var count = AllowedNumberOfNewSessionCreations(0); return RegisterSessionCreations(count); @@ -501,7 +495,7 @@ internal bool AddSession(SFSession session, bool ensureMinPoolSize) ReleaseBusySession(session); if (ensureMinPoolSize) { - ScheduleNewIdleSessions(ConnectionString, Password, session.Passcode, RegisterSessionCreationsWhenReturningSessionToPool()); // passcode is probably not fresh - it could be improved + ScheduleNewIdleSessions(ConnectionString, Password, RegisterSessionCreationsToEnsureMinPoolSize()); // passcode is probably not fresh - it could be improved } return false; } @@ -509,7 +503,7 @@ internal bool AddSession(SFSession session, bool ensureMinPoolSize) var result = ReturnSessionToPool(session, ensureMinPoolSize); var wasSessionReturnedToPool = result.Item1; var sessionCreationTokens = result.Item2; - ScheduleNewIdleSessions(ConnectionString, Password, session.Passcode, sessionCreationTokens); // passcode is probably not fresh - it could be improved + ScheduleNewIdleSessions(ConnectionString, Password, sessionCreationTokens); return wasSessionReturnedToPool; } @@ -522,7 +516,7 @@ private Tuple> ReturnSessionToPool(SFSession se { _busySessionsCounter.Decrease(); var sessionCreationTokens = ensureMinPoolSize - ? RegisterSessionCreationsWhenReturningSessionToPool() + ? RegisterSessionCreationsToEnsureMinPoolSize() : SessionOrCreationTokens.s_emptySessionCreationTokenList; var poolState = GetCurrentState(); s_logger.Debug($"Could not return session to pool {poolState}" + PoolIdentification()); @@ -537,7 +531,7 @@ private Tuple> ReturnSessionToPool(SFSession se if (session.IsExpired(_poolConfig.ExpirationTimeout, DateTimeOffset.UtcNow.ToUnixTimeMilliseconds())) // checking again because we could have spent some time waiting for a lock { var sessionCreationTokens = ensureMinPoolSize - ? RegisterSessionCreationsWhenReturningSessionToPool() + ? RegisterSessionCreationsToEnsureMinPoolSize() : SessionOrCreationTokens.s_emptySessionCreationTokenList; var poolState = GetCurrentState(); s_logger.Debug($"Could not return session to pool {poolState}" + PoolIdentification()); @@ -552,7 +546,7 @@ private Tuple> ReturnSessionToPool(SFSession se _idleSessions.Add(session); _waitingForIdleSessionQueue.OnResourceIncrease(); var sessionCreationTokensAfterReturningToPool = ensureMinPoolSize - ? RegisterSessionCreationsWhenReturningSessionToPool() + ? RegisterSessionCreationsToEnsureMinPoolSize() : SessionOrCreationTokens.s_emptySessionCreationTokenList; var poolStateAfterReturningToPool = GetCurrentState(); s_logger.Debug($"returned session with sid {session.sessionId} to pool {poolStateAfterReturningToPool}" + PoolIdentification());