diff --git a/Snowflake.Data.Tests/Mock/MockLoginMFATokenCacheRestRequester.cs b/Snowflake.Data.Tests/Mock/MockLoginMFATokenCacheRestRequester.cs index d2e8d5319..163124b7d 100644 --- a/Snowflake.Data.Tests/Mock/MockLoginMFATokenCacheRestRequester.cs +++ b/Snowflake.Data.Tests/Mock/MockLoginMFATokenCacheRestRequester.cs @@ -79,5 +79,11 @@ public void setHttpClient(HttpClient httpClient) { // Nothing to do } + + public void Reset() + { + LoginRequests.Clear(); + LoginResponses.Clear(); + } } } diff --git a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs index c99be5a45..139c2a4f1 100644 --- a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs +++ b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs @@ -2,25 +2,28 @@ * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. */ -using System.Security; -using System.Threading; -using NUnit.Framework; -using Snowflake.Data.Core; -using Snowflake.Data.Core.Session; -using Snowflake.Data.Client; -using Snowflake.Data.Core.Tools; -using Snowflake.Data.Tests.Util; + namespace Snowflake.Data.Tests.UnitTests { using System; + using System.Linq; + using System.Security; + using System.Threading; using Mock; - - [TestFixture, NonParallelizable] + using NUnit.Framework; + using Snowflake.Data.Core; + using Snowflake.Data.Core.Session; + using Snowflake.Data.Client; + using Snowflake.Data.Core.Tools; + using Snowflake.Data.Tests.Util; + + [TestFixture] class ConnectionPoolManagerMFATest { private readonly ConnectionPoolManager _connectionPoolManager = new ConnectionPoolManager(); private const string ConnectionStringMFACache = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;authenticator=username_password_mfa"; + private const string ConnectionStringMFABasicWithoutPasscode = "db=D2;warehouse=W2;account=A2;user=U2;password=P2;role=R2;minPoolSize=3;"; private static PoolConfig s_poolConfig; private static MockLoginMFATokenCacheRestRequester s_restRequester; @@ -44,6 +47,7 @@ public static void AfterAllTests() public void BeforeEach() { _connectionPoolManager.ClearAllPools(); + s_restRequester.Reset(); } [Test] @@ -79,6 +83,35 @@ public void TestPoolManagerReturnsSessionPoolForGivenConnectionStringUsingMFA() Assert.AreEqual("passcode", loginRequest1.data.extAuthnDuoMethod); } + [Test] + public void TestPoolManagerShouldOnlyUsePasscodeAsArgumentForFirstSessionWhenNotUsingMFAAuthenticator() + { + // Arrange + const string TestPasscode = "123456"; + s_restRequester.LoginResponses.Enqueue(new LoginResponseData() + { + authResponseSessionInfo = new SessionInfo() + }); + s_restRequester.LoginResponses.Enqueue(new LoginResponseData() + { + authResponseSessionInfo = new SessionInfo() + }); + s_restRequester.LoginResponses.Enqueue(new LoginResponseData() + { + authResponseSessionInfo = new SessionInfo() + }); + // Act + var session = _connectionPoolManager.GetSession(ConnectionStringMFABasicWithoutPasscode, null, SecureStringHelper.Encode(TestPasscode)); + Thread.Sleep(3000); + + // Assert + + Assert.AreEqual(3, s_restRequester.LoginRequests.Count); + var request = s_restRequester.LoginRequests.ToList(); + Assert.AreEqual(1, request.Count(r => r.data.extAuthnDuoMethod == "passcode" && r.data.passcode == TestPasscode)); + Assert.AreEqual(2, request.Count(r => r.data.extAuthnDuoMethod == "push" && r.data.passcode == null)); + } + [Test] public void TestPoolManagerShouldThrowExceptionIfForcePoolingWithPasscodeNotUsingMFATokenCacheAuthenticator() { @@ -86,24 +119,25 @@ public void TestPoolManagerShouldThrowExceptionIfForcePoolingWithPasscodeNotUsin var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;POOLINGENABLED=true"; // Act and assert var thrown = Assert.Throws(() =>_connectionPoolManager.GetSession(connectionString, null)); - Assert.That(thrown.Message, Does.Contain("Could not get a pool because passcode was provided using a different authenticator than username_password_mfa")); + Assert.That(thrown.Message, Does.Contain("Could not use connection pool because passcode was provided using a different authenticator than username_password_mfa")); } [Test] - public void TestPoolManagerShouldDisablePoolingWhenPassingPasscodeNotUsingMFATokenCacheAuthenticator() + public void TestPoolManagerShouldNotThrowExceptionIfForcePoolingWithPasscodeNotUsingMFATokenCacheAuthenticator() { // Arrange - var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;"; - var pool = _connectionPoolManager.GetPool(connectionString); - // Act - var session = _connectionPoolManager.GetSession(connectionString, null); - - // Asssert - // TODO: Review pool config is not the same for session and session pool - // Assert.IsFalse(session.GetPooling()); - Assert.AreEqual(0, pool.GetCurrentPoolSize()); - Assert.IsFalse(pool.GetPooling()); + var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;POOLINGENABLED=false"; + // Act and assert + Assert.DoesNotThrow(() =>_connectionPoolManager.GetSession(connectionString, null)); + } + [Test] + public void TestPoolManagerShouldNotThrowExceptionIfMinPoolSizeZeroNotUsingMFATokenCacheAuthenticator() + { + // Arrange + var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=0;passcode=12345;POOLINGENABLED=true"; + // Act and assert + Assert.DoesNotThrow(() =>_connectionPoolManager.GetSession(connectionString, null)); } } diff --git a/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs b/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs index bd857eb45..2d818f8c8 100644 --- a/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs +++ b/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs @@ -60,10 +60,6 @@ public void DisablePoolingDefaultIfSecretsProvidedExternally(SFSessionProperties && !properties.IsNonEmptyValueProvided(SFSessionProperty.PRIVATE_KEY_PWD)) { DisablePoolingIfNotExplicitlyEnabled(properties, "key pair with private key in a file"); - } else if (!MFACacheAuthenticator.AUTH_NAME.Equals(authenticator) - && properties.IsNonEmptyValueProvided(SFSessionProperty.PASSCODE)) - { - DisablePoolingIfNotExplicitlyEnabled(properties, "mfa authentication without token cache"); } } diff --git a/Snowflake.Data/Core/Session/SessionPool.cs b/Snowflake.Data/Core/Session/SessionPool.cs index 66c9facb3..8514fc62c 100644 --- a/Snowflake.Data/Core/Session/SessionPool.cs +++ b/Snowflake.Data/Core/Session/SessionPool.cs @@ -145,8 +145,8 @@ internal SFSession GetSession(string connStr, SecureString password, SecureStrin { s_logger.Debug("SessionPool::GetSession" + PoolIdentification()); SFSession session = null; - var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password, passcode); - ValidatePoolingIfPasscodeProvided(passcode, sessionProperties); + var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password); + ValidatePoolingIfPasscodeProvided(sessionProperties); if (!GetPooling()) return NewNonPoolingSession(connStr, password, passcode); var sessionOrCreateTokens = GetIdleSession(connStr); @@ -156,42 +156,37 @@ internal SFSession GetSession(string connStr, SecureString password, SecureStrin { _sessionPoolEventHandler.OnSessionProvided(this); } - ScheduleNewIdleSessions(connStr, password, passcode, sessionOrCreateTokens.BackgroundSessionCreationTokens()); + ScheduleNewIdleSessions(connStr, password, RegisterSessionCreationsToEnsureMinPoolSize()); WarnAboutOverridenConfig(); return session ?? sessionOrCreateTokens.Session ?? NewSession(connStr, password, passcode, sessionOrCreateTokens.SessionCreationToken()); } - private void ValidatePoolingIfPasscodeProvided(SecureString passcode, SFSessionProperties sessionProperties) + private void ValidatePoolingIfPasscodeProvided(SFSessionProperties sessionProperties) { - if (!GetPooling()) return; - var isUsingPasscode = ((passcode != null && !SecureStringHelper.Decode(passcode).IsNullOrEmpty()) || - sessionProperties.IsNonEmptyValueProvided(SFSessionProperty.PASSCODE) || + if (!GetPooling() || _poolConfig.MinPoolSize == 0) return; + var isUsingPasscode = (sessionProperties.IsNonEmptyValueProvided(SFSessionProperty.PASSCODE) || (sessionProperties.TryGetValue(SFSessionProperty.PASSCODEINPASSWORD, out var passcodeInPasswordValue) && bool.TryParse(passcodeInPasswordValue, out var isPasscodeinPassword) && isPasscodeinPassword)); - if(!isUsingPasscode) return; var isMfaAuthenticator = sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) && authenticator == MFACacheAuthenticator.AUTH_NAME; - - if (isMfaAuthenticator) return; - if (sessionProperties.IsPoolingEnabledValueProvided) + if(isUsingPasscode && !isMfaAuthenticator) { - const string ErrorMessage = "Could not get a pool because passcode was provided using a different authenticator than username_password_mfa"; + const string ErrorMessage = "Could not use connection pool because passcode was provided using a different authenticator than username_password_mfa"; s_logger.Error(ErrorMessage + PoolIdentification()); throw new Exception(ErrorMessage); } - s_logger.Warn("Pooling is disabled because passcode was provided using a different authenticator than username_password_mfa" + PoolIdentification()); - _poolConfig.PoolingEnabled = false; } internal async Task GetSessionAsync(string connStr, SecureString password, SecureString passcode, CancellationToken cancellationToken) { s_logger.Debug("SessionPool::GetSessionAsync" + PoolIdentification()); SFSession session = null; - var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password, passcode); - ValidatePoolingIfPasscodeProvided(passcode, sessionProperties); + var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password); + ValidatePoolingIfPasscodeProvided(sessionProperties); if (!GetPooling()) return await NewNonPoolingSessionAsync(connStr, password, passcode, cancellationToken).ConfigureAwait(false); var sessionOrCreateTokens = GetIdleSession(connStr); + WarnAboutOverridenConfig(); if (sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) && authenticator == MFACacheAuthenticator.AUTH_NAME) session = sessionOrCreateTokens.Session ?? @@ -201,21 +196,20 @@ await NewSessionAsync(connStr, password, passcode, sessionOrCreateTokens.Session { _sessionPoolEventHandler.OnSessionProvided(this); } - ScheduleNewIdleSessions(connStr, password, passcode, sessionOrCreateTokens.BackgroundSessionCreationTokens()); - WarnAboutOverridenConfig(); + ScheduleNewIdleSessions(connStr, password, RegisterSessionCreationsToEnsureMinPoolSize()); return session ?? sessionOrCreateTokens.Session ?? await NewSessionAsync(connStr, password, passcode, sessionOrCreateTokens.SessionCreationToken(), cancellationToken).ConfigureAwait(false); } - private void ScheduleNewIdleSessions(string connStr, SecureString password, SecureString passcode, List tokens) + private void ScheduleNewIdleSessions(string connStr, SecureString password, List tokens) { - tokens.ForEach(token => ScheduleNewIdleSession(connStr, password, passcode, token)); + tokens.ForEach(token => ScheduleNewIdleSession(connStr, password, token)); } - private void ScheduleNewIdleSession(string connStr, SecureString password, SecureString passcode, SessionCreationToken token) + private void ScheduleNewIdleSession(string connStr, SecureString password, SessionCreationToken token) { Task.Run(() => { - var session = NewSession(connStr, password, passcode, token); + var session = NewSession(connStr, password, null, token); AddSession(session, false); // we don't want to ensure min pool size here because we could get into infinite recursion if expirationTimeout would be very low }); } @@ -258,7 +252,7 @@ private SessionOrCreationTokens GetIdleSession(string connStr) return new SessionOrCreationTokens(session); } s_logger.Debug("SessionPool::GetIdleSession - no thread was waiting for a session, but could not find any idle session available in the pool" + PoolIdentification()); - var sessionsCount = AllowedNumberOfNewSessionCreations(1); + var sessionsCount = Math.Min(1, AllowedNumberOfNewSessionCreations(1)); if (sessionsCount > 0) { // there is no need to wait for a session since we can create new ones @@ -269,7 +263,7 @@ private SessionOrCreationTokens GetIdleSession(string connStr) return new SessionOrCreationTokens(WaitForSession(connStr)); } - private List RegisterSessionCreationsWhenReturningSessionToPool() + private List RegisterSessionCreationsToEnsureMinPoolSize() { var count = AllowedNumberOfNewSessionCreations(0); return RegisterSessionCreations(count); @@ -501,7 +495,7 @@ internal bool AddSession(SFSession session, bool ensureMinPoolSize) ReleaseBusySession(session); if (ensureMinPoolSize) { - ScheduleNewIdleSessions(ConnectionString, Password, session.Passcode, RegisterSessionCreationsWhenReturningSessionToPool()); // passcode is probably not fresh - it could be improved + ScheduleNewIdleSessions(ConnectionString, Password, RegisterSessionCreationsToEnsureMinPoolSize()); // passcode is probably not fresh - it could be improved } return false; } @@ -509,7 +503,7 @@ internal bool AddSession(SFSession session, bool ensureMinPoolSize) var result = ReturnSessionToPool(session, ensureMinPoolSize); var wasSessionReturnedToPool = result.Item1; var sessionCreationTokens = result.Item2; - ScheduleNewIdleSessions(ConnectionString, Password, session.Passcode, sessionCreationTokens); // passcode is probably not fresh - it could be improved + ScheduleNewIdleSessions(ConnectionString, Password, sessionCreationTokens); return wasSessionReturnedToPool; } @@ -522,7 +516,7 @@ private Tuple> ReturnSessionToPool(SFSession se { _busySessionsCounter.Decrease(); var sessionCreationTokens = ensureMinPoolSize - ? RegisterSessionCreationsWhenReturningSessionToPool() + ? RegisterSessionCreationsToEnsureMinPoolSize() : SessionOrCreationTokens.s_emptySessionCreationTokenList; var poolState = GetCurrentState(); s_logger.Debug($"Could not return session to pool {poolState}" + PoolIdentification()); @@ -537,7 +531,7 @@ private Tuple> ReturnSessionToPool(SFSession se if (session.IsExpired(_poolConfig.ExpirationTimeout, DateTimeOffset.UtcNow.ToUnixTimeMilliseconds())) // checking again because we could have spent some time waiting for a lock { var sessionCreationTokens = ensureMinPoolSize - ? RegisterSessionCreationsWhenReturningSessionToPool() + ? RegisterSessionCreationsToEnsureMinPoolSize() : SessionOrCreationTokens.s_emptySessionCreationTokenList; var poolState = GetCurrentState(); s_logger.Debug($"Could not return session to pool {poolState}" + PoolIdentification()); @@ -552,7 +546,7 @@ private Tuple> ReturnSessionToPool(SFSession se _idleSessions.Add(session); _waitingForIdleSessionQueue.OnResourceIncrease(); var sessionCreationTokensAfterReturningToPool = ensureMinPoolSize - ? RegisterSessionCreationsWhenReturningSessionToPool() + ? RegisterSessionCreationsToEnsureMinPoolSize() : SessionOrCreationTokens.s_emptySessionCreationTokenList; var poolStateAfterReturningToPool = GetCurrentState(); s_logger.Debug($"returned session with sid {session.sessionId} to pool {poolStateAfterReturningToPool}" + PoolIdentification());