Skip to content

Commit

Permalink
Fixed server certificate validation for encrypt=strict (#2174)
Browse files Browse the repository at this point in the history
  • Loading branch information
lilgreenbird authored Jul 26, 2023
1 parent 46a0fb0 commit c394a1b
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 65 deletions.
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
xSQLv12 - - - - - - For tests not compatible with SQL Server 2008 R2 - 2014
xSQLv14 - - - - - - For tests not compatible with SQL Server 2016 - 2017
xSQLv15 - - - - - - For tests not compatible with SQL Server 2019 - - - -
xSQLv16 - - - - - - For tests not compatible with SQL Server 2022 - - - -
xAzureSQLDB - - - - For tests not compatible with Azure SQL Database - -
xAzureSQLDW - - - - For tests not compatible with Azure Data Warehouse -
xAzureSQLMI - - - - For tests not compatible with Azure SQL Managed Instance
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -1693,7 +1693,7 @@ else if (con.getTrustManagerClass() != null) {
// Otherwise, we'll validate the certificate using a real TrustManager obtained
// from the a security provider that is capable of validating X.509 certificates.
else {
if (isTDS8) {
if (isTDS8 && serverCert != null) {
if (logger.isLoggable(Level.FINEST))
logger.finest(toString() + " Verify server certificate for TDS 8");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,8 @@ static void validateServerNameInCertificate(X509Certificate cert, String hostNam
*/
static void validateServerCerticate(X509Certificate cert, String certFile) throws CertificateException {
try (InputStream is = fileToStream(certFile)) {
if (!CertificateFactory.getInstance("X509").generateCertificate(is).getPublicKey()
.equals(cert.getPublicKey())) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_publicKeyMismatch"));
if (!CertificateFactory.getInstance("X509").generateCertificate(is).equals(cert)) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_serverCertError"));
Object[] msgArgs = {certFile};
throw new CertificateException(form.format(msgArgs));
}
Expand Down Expand Up @@ -353,8 +352,7 @@ static KeyManager[] readPKCS8Certificate(String certPath, String keyPath,
}

private static KeyManager[] readPKCS12Certificate(String certPath,
String keyPassword) throws NoSuchAlgorithmException, CertificateException, IOException,
UnrecoverableKeyException, KeyStoreException, SQLServerException {
String keyPassword) throws NoSuchAlgorithmException, CertificateException, IOException, UnrecoverableKeyException, KeyStoreException, SQLServerException {

KeyStore keyStore = loadPKCS12KeyStore(certPath, keyPassword);
KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(SUN_X_509);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,6 @@ protected Object[][] getContents() {
{"R_illegalArgumentTrustManager", "Interal error. Peer certificate chain or key exchange algorithem can not be null or empty."},
{"R_serverCertExpired", "Server Certificate has expired: {0}: {1}"},
{"R_serverCertNotYetValid", "Server Certificate is not yet valid: {0}: {1}"},
{"R_publicKeyMismatch", "Error validating Server Certificate: public key mismatch: {0}"},
{"R_serverCertError", "Error validating Server Certificate: {0}: \n{1}:\n{2}."},
{"R_SecureStringInitFailed", "Failed to initialize SecureStringUtil to store secure strings"},
{"R_ALPNFailed", "Failed to negotiate Application-Layer Protocol {0}. Server returned: {1}."},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public class MSITest extends AESetup {
@Tag(Constants.xSQLv12)
@Tag(Constants.xSQLv14)
@Tag(Constants.xSQLv15)
@Tag(Constants.xSQLv16)
@Test
public void testManagedIdentityAuth() throws SQLException {
String connStr = connectionString;
Expand Down Expand Up @@ -78,6 +79,7 @@ private void testSimpleConnect(String connStr) {
@Tag(Constants.xSQLv12)
@Tag(Constants.xSQLv14)
@Tag(Constants.xSQLv15)
@Tag(Constants.xSQLv16)
@Test
public void testManagedIdentityAuthWithManagedIdentityClientId() throws SQLException {
String connStr = connectionString;
Expand Down Expand Up @@ -116,6 +118,7 @@ public void testManagedIdentityAuthWithManagedIdentityClientId() throws SQLExcep
@Tag(Constants.xSQLv12)
@Tag(Constants.xSQLv14)
@Tag(Constants.xSQLv15)
@Tag(Constants.xSQLv16)
@Test
public void testDSManagedIdentityAuth() throws SQLException {
String connStr = connectionString;
Expand All @@ -141,6 +144,7 @@ public void testDSManagedIdentityAuth() throws SQLException {
@Tag(Constants.xSQLv12)
@Tag(Constants.xSQLv14)
@Tag(Constants.xSQLv15)
@Tag(Constants.xSQLv16)
@Test
public void testDSManagedIdentityAuthWithManagedIdentityClientId() throws SQLException {
String connStr = connectionString;
Expand Down Expand Up @@ -169,6 +173,7 @@ public void testDSManagedIdentityAuthWithManagedIdentityClientId() throws SQLExc
@Tag(Constants.xSQLv12)
@Tag(Constants.xSQLv14)
@Tag(Constants.xSQLv15)
@Tag(Constants.xSQLv16)
@Test
public void testActiveDirectoryDefaultAuth() throws SQLException {
String connStr = connectionString;
Expand All @@ -188,6 +193,7 @@ public void testActiveDirectoryDefaultAuth() throws SQLException {
@Tag(Constants.xSQLv12)
@Tag(Constants.xSQLv14)
@Tag(Constants.xSQLv15)
@Tag(Constants.xSQLv16)
@Test
public void testActiveDirectoryDefaultAuthDS() throws SQLException {
String connStr = connectionString;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,22 @@ public void testEncryptedConnection() throws SQLException {
try (Connection con = ds.getConnection()) {}
}

@Tag(Constants.xSQLv11)
@Tag(Constants.xSQLv12)
@Tag(Constants.xSQLv14)
@Tag(Constants.xSQLv15)
@Tag(Constants.xAzureSQLDW)
@Tag(Constants.xAzureSQLDB)
@Test
public void testEncryptedStrictConnection() throws SQLException {
SQLServerDataSource ds = new SQLServerDataSource();
ds.setURL(connectionString);
ds.setServerCertificate(serverCertificate);
ds.setEncrypt(Constants.STRICT);

try (Connection con = ds.getConnection()) {}
}

@Test
public void testJdbcDataSourceMethod() throws SQLFeatureNotSupportedException {
SQLServerDataSource fxds = new SQLServerDataSource();
Expand Down Expand Up @@ -936,65 +952,63 @@ public void run() {
assertTrue(status && future.isCancelled(), TestResource.getResource("R_threadInterruptNotSet"));
}

/**
* Test thread count when finding socket using threading.
*/
@Test
@Tag(Constants.xAzureSQLDB)
@Tag(Constants.xAzureSQLDW)
public void testThreadCountWhenFindingSocket() {
ExecutorService executor = null;
ManagementFactory.getThreadMXBean().resetPeakThreadCount();

// First, check to see if there is a reachable local host, or else test will fail.
try {
SQLServerDataSource ds = new SQLServerDataSource();
ds.setServerName("localhost");
Connection con = ds.getConnection();
} catch (SQLServerException e) {
// Assume this will be an error different than 'localhost is unreachable'. If it is 'localhost is
// unreachable' abort and skip the test.
Assume.assumeFalse(e.getMessage().startsWith(TestResource.getResource("R_tcpipConnectionToHost")));
}

try {
executor = Executors.newSingleThreadExecutor(r -> new Thread(r, ""));
executor.submit(() -> {
try {
SQLServerDataSource ds = new SQLServerDataSource();
ds.setServerName("localhost");
Thread.sleep(5000);
Connection conn2 = ds.getConnection();
} catch (Exception e) {
if (!(e instanceof SQLServerException)) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}
});
SQLServerDataSource ds = new SQLServerDataSource();
ds.setServerName("localhost");
Connection conn = ds.getConnection();
Thread.sleep(5000);
} catch (Exception e) {
if (!(e instanceof SQLServerException)) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
} finally {
if (executor != null) {
executor.shutdown();
}
}

// At this point, thread count has returned to normal. If the peak was more
// than 2 times the current, this is an issue and the test should fail.
if (ManagementFactory.getThreadMXBean().getPeakThreadCount() > 2
* ManagementFactory.getThreadMXBean().getThreadCount()) {
fail(TestResource.getResource("R_unexpectedThreadCount"));
}
}

/**
* Test thread count when finding socket using threading.
*/
@Test
@Tag(Constants.xAzureSQLDB)
@Tag(Constants.xAzureSQLDW)
public void testThreadCountWhenFindingSocket() {
ExecutorService executor = null;
ManagementFactory.getThreadMXBean().resetPeakThreadCount();

// First, check to see if there is a reachable local host, or else test will fail.
try {
SQLServerDataSource ds = new SQLServerDataSource();
ds.setServerName("localhost");
Connection con = ds.getConnection();
} catch (SQLServerException e) {
// Assume this will be an error different than 'localhost is unreachable'. If it is 'localhost is
// unreachable' abort and skip the test.
Assume.assumeFalse(e.getMessage().startsWith(TestResource.getResource("R_tcpipConnectionToHost")));
}

try {
executor = Executors.newSingleThreadExecutor(r -> new Thread(r, ""));
executor.submit(() -> {
try {
SQLServerDataSource ds = new SQLServerDataSource();
ds.setServerName("localhost");
Thread.sleep(5000);
Connection conn2 = ds.getConnection();
} catch (Exception e) {
if (!(e instanceof SQLServerException)) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}
});
SQLServerDataSource ds = new SQLServerDataSource();
ds.setServerName("localhost");
Connection conn = ds.getConnection();
Thread.sleep(5000);
} catch (Exception e) {
if (!(e instanceof SQLServerException)) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
} finally {
if (executor != null) {
executor.shutdown();
}
}

// At this point, thread count has returned to normal. If the peak was more
// than 2 times the current, this is an issue and the test should fail.
if (ManagementFactory.getThreadMXBean().getPeakThreadCount()
> 2 * ManagementFactory.getThreadMXBean().getThreadCount()) {
fail(TestResource.getResource("R_unexpectedThreadCount"));
}
}


/**
* Test calling method to get redirected server string.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ public void testGetColumns() throws SQLException {
@Tag(Constants.xSQLv12)
@Tag(Constants.xSQLv14)
@Tag(Constants.xSQLv15)
@Tag(Constants.xSQLv16)
@Tag(Constants.xAzureSQLDB)
@Tag(Constants.xAzureSQLMI)
public void testGetImportedKeysDW() throws SQLException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ public void testPooledConnectionLang() throws SQLException {
@Tag(Constants.xSQLv12)
@Tag(Constants.xSQLv14)
@Tag(Constants.xSQLv15)
@Tag(Constants.xSQLv16)
@Tag(Constants.xAzureSQLDW)
@Tag(Constants.reqExternalSetup)
@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ public void testDataClassificationMetadata() throws Exception {
@Tag(Constants.xAzureSQLDB)
@Tag(Constants.xAzureSQLDW)
@Tag(Constants.xSQLv15)
@Tag(Constants.xSQLv16)
@Test
public void testDataClassificationNotSupported() throws Exception {
try (Statement stmt = connection.createStatement();) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ public abstract class AbstractTest {
protected static String trustStore = "";
protected static String trustStorePassword = "";

protected static String serverCertificate = "";

protected static String encrypt = "";
protected static String trustServerCertificate = "";

Expand Down Expand Up @@ -202,6 +204,12 @@ public static void setup() throws Exception {
trustStorePassword);
}

serverCertificate = getConfiguredProperty("serverCertificate", "");
if (!serverCertificate.trim().isEmpty()) {
connectionString = TestUtils.addOrOverrideProperty(connectionString, "serverCertificate",
serverCertificate);
}

Map<String, SQLServerColumnEncryptionKeyStoreProvider> map = new HashMap<String, SQLServerColumnEncryptionKeyStoreProvider>();
if (null == jksProvider) {
jksProvider = new SQLServerColumnEncryptionJavaKeyStoreProvider(javaKeyPath,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ private Constants() {}
* xSQLv12 - - - - - - For tests not compatible with SQL Server 2008 R2 - 2014
* xSQLv14 - - - - - - For tests not compatible with SQL Server 2016 - 2017
* xSQLv15 - - - - - - For tests not compatible with SQL Server 2019
* xSQLv16 - - - - - - For tests not compatible with SQL Server 2022
* xAzureSQLDB - - - - For tests not compatible with Azure SQL Database
* xAzureSQLDW - - - - For tests not compatible with Azure Data Warehouse
* xAzureSQLMI - - - - For tests not compatible with Azure SQL Managed Instance
Expand All @@ -35,6 +36,7 @@ private Constants() {}
public static final String xSQLv12 = "xSQLv12";
public static final String xSQLv14 = "xSQLv14";
public static final String xSQLv15 = "xSQLv15";
public static final String xSQLv16 = "xSQLv16";
public static final String xAzureSQLDB = "xAzureSQLDB";
public static final String xAzureSQLDW = "xAzureSQLDW";
public static final String xAzureSQLMI = "xAzureSQLMI";
Expand Down

0 comments on commit c394a1b

Please sign in to comment.