Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade BouncyCastle #4

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions README.MD
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ Scala wrappers for JCA/BouncyCastle classes
Add to `build.sbt`:
```scala
libraryDependencies ++= Seq(
"org.bouncycastle" % "bcprov-jdk15on" % "1.58",
"org.bouncycastle" % "bcpkix-jdk15on" % "1.58",
"com.github.karasiq" %% "cryptoutils" % "1.4.3"
"org.bouncycastle" % "bcprov-jdk15on" % "1.67",
"org.bouncycastle" % "bcpkix-jdk15on" % "1.67",
"org.bouncycastle" % "bctls-jdk15on" % "1.67",
"com.github.karasiq" %% "cryptoutils" % "2.0.0"
)
```

Expand Down Expand Up @@ -105,4 +106,4 @@ serverSocket.bind(new InetSocketAddress("0.0.0.0", 443))
val socket = serverWrapper(serverSocket.accept())
// ... Do read/write, etc ...
socket.close()
```
```
7 changes: 4 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name := "cryptoutils"

organization := "com.github.karasiq"

version := "1.4.3"
version := "2.0.0-SNAPSHOT"

isSnapshot := version.value.endsWith("SNAPSHOT")

Expand All @@ -14,8 +14,9 @@ resolvers += "softprops-maven" at "http://dl.bintray.com/content/softprops/maven

libraryDependencies ++= Seq(
"commons-io" % "commons-io" % "2.5",
"org.bouncycastle" % "bcprov-jdk15on" % "1.58" % "provided",
"org.bouncycastle" % "bcpkix-jdk15on" % "1.58" % "provided",
"org.bouncycastle" % "bcprov-jdk15on" % "1.67" % "provided",
"org.bouncycastle" % "bcpkix-jdk15on" % "1.67" % "provided",
"org.bouncycastle" % "bctls-jdk15on" % "1.67" % "provided",
"com.typesafe" % "config" % "1.3.1",
"org.scalatest" %% "scalatest" % "3.0.4" % "test"
)
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/com/karasiq/tls/TLS.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package com.karasiq.tls

object TLS {
type CertificateChain = org.bouncycastle.crypto.tls.Certificate
type CertificateChain = org.bouncycastle.tls.Certificate
type Certificate = org.bouncycastle.asn1.x509.Certificate
type CertificateKeyPair = org.bouncycastle.crypto.AsymmetricCipherKeyPair

case class CertificateKey(certificateChain: CertificateChain, key: CertificateKeyPair) {
def certificate: TLS.Certificate = {
import com.karasiq.tls.internal.BCConversions._
certificateChain.toTlsCertificate
certificateChain.toCertificate
}
}

Expand Down
29 changes: 20 additions & 9 deletions src/main/scala/com/karasiq/tls/TLSClientWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import java.net.InetSocketAddress
import java.nio.channels.SocketChannel
import java.security.SecureRandom

import com.karasiq.tls.internal.BCConversions.CipherSuiteId
import com.karasiq.tls.internal.BCConversions._
import com.karasiq.tls.internal.{SocketChannelWrapper, TLSUtils}
import com.karasiq.tls.x509.CertificateVerifier
import org.bouncycastle.crypto.tls._
import org.bouncycastle.tls._
import org.bouncycastle.tls.crypto.TlsCryptoParameters
import org.bouncycastle.tls.crypto.impl.bc.{BcDefaultTlsCredentialedSigner, BcTlsCrypto}

import scala.concurrent.Await
import scala.concurrent.duration._
Expand All @@ -20,30 +22,39 @@ class TLSClientWrapper(verifier: CertificateVerifier, address: InetSocketAddress
}

override def apply(connection: SocketChannel): SocketChannel = {
val protocol = new TlsClientProtocol(SocketChannelWrapper.inputStream(connection), SocketChannelWrapper.outputStream(connection), SecureRandom.getInstanceStrong)
val client = new DefaultTlsClient() {
override def getMinimumVersion: ProtocolVersion = {
TLSUtils.minVersion()
val protocol = new TlsClientProtocol(SocketChannelWrapper.inputStream(connection), SocketChannelWrapper.outputStream(connection))
val crypto = new BcTlsCrypto(SecureRandom.getInstanceStrong)
val client = new DefaultTlsClient(crypto) {
@volatile
protected var selectedCipherSuite = 0

override def getSupportedVersions: Array[ProtocolVersion] = {
TLSUtils.maxVersion().downTo(TLSUtils.minVersion())
}

override def getCipherSuites: Array[Int] = {
TLSUtils.defaultCipherSuites()
}

override def notifySelectedCipherSuite(selectedCipherSuite: Int): Unit = {
this.selectedCipherSuite = selectedCipherSuite
}

override def notifyHandshakeComplete(): Unit = {
handshake.trySuccess(true)
this.cipherSuites
onInfo(s"Selected cipher suite: ${CipherSuiteId.asString(selectedCipherSuite)}")
}

override def getAuthentication: TlsAuthentication = new TlsAuthentication {
override def getClientCredentials(certificateRequest: CertificateRequest): TlsCredentials = wrapException("Could not provide client credentials") {
getClientCertificate(certificateRequest)
.map(ck ⇒ new DefaultTlsSignerCredentials(context, ck.certificateChain, ck.key.getPrivate, TLSUtils.signatureAlgorithm(ck.key.getPrivate))) // Ignores certificateRequest data
.map(ck ⇒ new BcDefaultTlsCredentialedSigner(new TlsCryptoParameters(context), crypto, ck.key.getPrivate, ck.certificateChain, TLSUtils.signatureAlgorithm(ck.key.getPrivate))) // Ignores certificateRequest data
.orNull
}

override def notifyServerCertificate(serverCertificate: TLS.CertificateChain): Unit = wrapException("Server certificate error") {
val chain: List[TLS.Certificate] = serverCertificate.getCertificateList.toList
override def notifyServerCertificate(serverCertificate: TlsServerCertificate): Unit = wrapException("Server certificate error") {
val chain: List[TLS.Certificate] = serverCertificate.getCertificate.getCertificateList.toList.map(_.toCertificate)

if (chain.nonEmpty) {
onInfo(s"Server certificate chain: ${chain.map(_.getSubject).mkString("; ")}")
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/com/karasiq/tls/TLSConnectionWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.karasiq.tls

import java.nio.channels.SocketChannel

import org.bouncycastle.crypto.tls.{AlertDescription, TlsFatalAlert}
import org.bouncycastle.tls.{AlertDescription, TlsFatalAlert}

import scala.concurrent.Promise
import scala.util.control.Exception
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/com/karasiq/tls/TLSKeyStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class TLSKeyStore(val keyStore: KeyStore = TLSKeyStore.defaultKeyStore(), val pa
}

def getCertificate(alias: String): TLS.Certificate = {
keyStore.getCertificate(alias).toTlsCertificate
keyStore.getCertificate(alias).toCertificate
}

def getKeySet(alias: String, password: String = password): TLS.KeySet = {
Expand Down
37 changes: 18 additions & 19 deletions src/main/scala/com/karasiq/tls/TLSServerWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import java.nio.channels.SocketChannel
import java.security.SecureRandom

import com.karasiq.tls.TLS.CertificateChain
import com.karasiq.tls.internal.BCConversions.CipherSuiteId
import com.karasiq.tls.internal.BCConversions._
import com.karasiq.tls.internal.{SocketChannelWrapper, TLSUtils}
import com.karasiq.tls.x509.{CertificateVerifier, X509Utils}
import org.bouncycastle.asn1.x509.KeyUsage
import org.bouncycastle.crypto.tls._
import org.bouncycastle.tls._
import org.bouncycastle.tls.crypto.TlsCryptoParameters
import org.bouncycastle.tls.crypto.impl.bc.{BcDefaultTlsCredentialedDecryptor, BcDefaultTlsCredentialedSigner, BcTlsCrypto}

import scala.concurrent.Await
import scala.concurrent.duration._
Expand All @@ -19,7 +21,7 @@ class TLSServerWrapper(keySet: TLS.KeySet, clientAuth: Boolean = false, verifier

@throws(classOf[TlsFatalAlert])
protected def onClientAuth(clientCertificate: CertificateChain): Unit = {
val chain: List[TLS.Certificate] = clientCertificate.getCertificateList.toList
val chain: List[TLS.Certificate] = clientCertificate.getCertificateList.toList.map(_.toCertificate)
if (chain.nonEmpty) {
onInfo(s"Client certificate chain: ${chain.map(_.getSubject).mkString("; ")}")
}
Expand All @@ -33,14 +35,11 @@ class TLSServerWrapper(keySet: TLS.KeySet, clientAuth: Boolean = false, verifier
}

def apply(connection: SocketChannel): SocketChannel = {
val protocol = new TlsServerProtocol(SocketChannelWrapper.inputStream(connection), SocketChannelWrapper.outputStream(connection), SecureRandom.getInstanceStrong)
val server = new DefaultTlsServer() {
override def getMinimumVersion: ProtocolVersion = {
TLSUtils.minVersion()
}

override def getMaximumVersion: ProtocolVersion = {
TLSUtils.maxVersion()
val protocol = new TlsServerProtocol(SocketChannelWrapper.inputStream(connection), SocketChannelWrapper.outputStream(connection))
val crypto = new BcTlsCrypto(SecureRandom.getInstanceStrong)
val server = new DefaultTlsServer(crypto) {
override def getSupportedVersions: Array[ProtocolVersion] = {
TLSUtils.maxVersion().downTo(TLSUtils.minVersion())
}

override def getCipherSuites: Array[Int] = {
Expand All @@ -52,33 +51,33 @@ class TLSServerWrapper(keySet: TLS.KeySet, clientAuth: Boolean = false, verifier
onInfo(s"Selected cipher suite: ${CipherSuiteId.asString(selectedCipherSuite)}")
}

private def signerCredentials(certOption: Option[TLS.CertificateKey]): TlsSignerCredentials = {
private def signerCredentials(certOption: Option[TLS.CertificateKey]): TlsCredentialedSigner = {
certOption.filter(c ⇒ X509Utils.isKeyUsageAllowed(c.certificate, KeyUsage.digitalSignature)).fold(throw new TLSException("No suitable signer credentials found")) { cert ⇒
new DefaultTlsSignerCredentials(context, cert.certificateChain, cert.key.getPrivate, TLSUtils.signatureAlgorithm(cert.key.getPrivate))
new BcDefaultTlsCredentialedSigner(new TlsCryptoParameters(context), crypto, cert.key.getPrivate, cert.certificateChain, TLSUtils.signatureAlgorithm(cert.key.getPrivate))
}
}

override def getRSASignerCredentials: TlsSignerCredentials = wrapException("Could not provide server RSA credentials") {
override def getRSASignerCredentials: TlsCredentialedSigner = wrapException("Could not provide server RSA credentials") {
signerCredentials(keySet.rsa)
}

override def getECDSASignerCredentials: TlsSignerCredentials = wrapException("Could not provide server ECDSA credentials") {
override def getECDSASignerCredentials: TlsCredentialedSigner = wrapException("Could not provide server ECDSA credentials") {
signerCredentials(keySet.ecdsa)
}

override def getDSASignerCredentials: TlsSignerCredentials = wrapException("Could not provide server DSA credentials") {
override def getDSASignerCredentials: TlsCredentialedSigner = wrapException("Could not provide server DSA credentials") {
signerCredentials(keySet.dsa)
}

override def getRSAEncryptionCredentials: TlsEncryptionCredentials = wrapException("Could not provide server RSA encryption credentials") {
override def getRSAEncryptionCredentials: TlsCredentialedDecryptor = wrapException("Could not provide server RSA encryption credentials") {
keySet.rsa.filter(c ⇒ X509Utils.isKeyUsageAllowed(c.certificate, KeyUsage.keyEncipherment)).fold(super.getRSAEncryptionCredentials) { cert ⇒
new DefaultTlsEncryptionCredentials(context, cert.certificateChain, cert.key.getPrivate)
new BcDefaultTlsCredentialedDecryptor(crypto, cert.certificateChain, cert.key.getPrivate)
}
}

override def getCertificateRequest: CertificateRequest = {
if (clientAuth) {
TLSUtils.certificateRequest(this.getServerVersion, verifier)
TLSUtils.certificateRequest(this.getServerVersion, verifier, context)
} else {
null
}
Expand Down
49 changes: 43 additions & 6 deletions src/main/scala/com/karasiq/tls/internal/BCConversions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,29 @@ package com.karasiq.tls.internal
import java.io.ByteArrayInputStream
import java.security.cert.CertificateFactory
import java.security.spec.{PKCS8EncodedKeySpec, X509EncodedKeySpec}
import java.security.{KeyFactory, PrivateKey, PublicKey}
import java.security.{KeyFactory, PrivateKey, PublicKey, SecureRandom}

import com.karasiq.tls.TLS
import org.apache.commons.io.IOUtils
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo
import org.bouncycastle.asn1.x509.{AlgorithmIdentifier, SubjectPublicKeyInfo}
import org.bouncycastle.crypto.AsymmetricCipherKeyPair
import org.bouncycastle.crypto.params.{AsymmetricKeyParameter, DSAKeyParameters, ECKeyParameters, RSAKeyParameters}
import org.bouncycastle.crypto.tls.CipherSuite
import org.bouncycastle.tls.CipherSuite
import org.bouncycastle.crypto.util.{PrivateKeyFactory, PrivateKeyInfoFactory, PublicKeyFactory, SubjectPublicKeyInfoFactory}
import org.bouncycastle.operator.DefaultDigestAlgorithmIdentifierFinder
import org.bouncycastle.tls.crypto.TlsCertificate
import org.bouncycastle.tls.crypto.impl.bc.{BcTlsCertificate, BcTlsCrypto}

import scala.util.Try

/**
* Provides conversions between JCA and BouncyCastle classes
*/
object BCConversions {

private val crypto = new BcTlsCrypto(SecureRandom.getInstanceStrong)

implicit class JavaKeyOps(private val key: java.security.Key) extends AnyVal {
private def convertPKCS8Key(data: Array[Byte], public: SubjectPublicKeyInfo): AsymmetricCipherKeyPair = {
new AsymmetricCipherKeyPair(PublicKeyFactory.createKey(public), PrivateKeyFactory.createKey(data))
Expand Down Expand Up @@ -123,16 +128,24 @@ object BCConversions {
}

implicit class JavaCertificateOps(private val cert: java.security.cert.Certificate) extends AnyVal {
def toTlsCertificate: TLS.Certificate = {
def toCertificate: TLS.Certificate = {
org.bouncycastle.asn1.x509.Certificate.getInstance(cert.getEncoded)
}

def toTlsCertificate: TlsCertificate = {
new BcTlsCertificate(crypto, cert.getEncoded)
}

def toTlsCertificateChain: TLS.CertificateChain = {
toTlsCertificate.toTlsCertificateChain
toCertificate.toTlsCertificateChain
}
}

implicit class CertificateOps(private val cert: TLS.Certificate) extends AnyVal {
implicit class TlsCertificateOps(private val cert: TlsCertificate) extends AnyVal {
def toCertificate: TLS.Certificate = {
org.bouncycastle.asn1.x509.Certificate.getInstance(cert.getEncoded)
}

def toTlsCertificateChain: TLS.CertificateChain = {
new TLS.CertificateChain(Array(cert))
}
Expand All @@ -148,8 +161,32 @@ object BCConversions {
}
}

implicit class CertificateOps(private val cert: TLS.Certificate) extends AnyVal {
def toTlsCertificate: TlsCertificate = {
new BcTlsCertificate(crypto, cert.getEncoded)
}

def toTlsCertificateChain: TLS.CertificateChain = {
new TLS.CertificateChain(Array(cert.toTlsCertificate))
}

def toJavaCertificate: java.security.cert.Certificate = {
val certificateFactory = CertificateFactory.getInstance("X.509")
val inputStream = new ByteArrayInputStream(cert.getEncoded)
try {
certificateFactory.generateCertificate(inputStream)
} finally {
IOUtils.closeQuietly(inputStream)
}
}
}

implicit class CertificateChainOps(private val chain: TLS.CertificateChain) extends AnyVal {
def toTlsCertificate: TLS.Certificate = {
def toCertificate: TLS.Certificate = {
toTlsCertificate.toCertificate
}

def toTlsCertificate: TlsCertificate = {
chain.getCertificateList.headOption
.getOrElse(throw new NoSuchElementException("Empty certificate chain"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import java.nio.ByteBuffer
import java.nio.channels.SocketChannel
import java.util

import org.bouncycastle.crypto.tls.TlsProtocol
import org.bouncycastle.tls.TlsProtocol
import sun.nio.ch.{SelChImpl, SelectionKeyImpl}

private[tls] object SocketChannelWrapper {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import java.io.{InputStream, OutputStream}
import java.net.{InetAddress, Socket, SocketAddress}
import java.nio.channels.SocketChannel

import org.bouncycastle.crypto.tls.TlsProtocol
import org.bouncycastle.tls.TlsProtocol

final private[tls] class SocketWrapper(connection: Socket, protocol: TlsProtocol) extends Socket {
override def shutdownInput(): Unit = connection.shutdownInput()
Expand Down
15 changes: 8 additions & 7 deletions src/main/scala/com/karasiq/tls/internal/TLSUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package com.karasiq.tls.internal
import java.security.Provider

import com.karasiq.tls.TLS
import com.karasiq.tls.internal.BCConversions.CipherSuiteId
import com.karasiq.tls.internal.BCConversions._
import com.karasiq.tls.x509.CertificateVerifier
import com.typesafe.config.ConfigFactory
import org.bouncycastle.crypto.params._
import org.bouncycastle.crypto.tls._
import org.bouncycastle.tls._
import org.bouncycastle.jce.ECNamedCurveTable
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.jce.spec.ECParameterSpec
Expand Down Expand Up @@ -55,9 +55,9 @@ object TLSUtils {
asJavaVector(trustStore.trustedRootCertificates.map(_.getSubject))
}

def certificateRequest(protocolVersion: ProtocolVersion, verifier: CertificateVerifier): CertificateRequest = {
def certificateRequest(protocolVersion: ProtocolVersion, verifier: CertificateVerifier, context: TlsContext): CertificateRequest = {
val certificateTypes = Array(ClientCertificateType.rsa_sign, ClientCertificateType.ecdsa_sign, ClientCertificateType.dss_sign)
new CertificateRequest(certificateTypes, defaultSignatureAlgorithms(protocolVersion), authoritiesOf(verifier))
new CertificateRequest(certificateTypes, defaultSignatureAlgorithms(protocolVersion, context), authoritiesOf(verifier))
}

def certificateFor(keySet: TLS.KeySet, certificateRequest: CertificateRequest): Option[TLS.CertificateKey] = {
Expand All @@ -76,7 +76,8 @@ object TLSUtils {
}

def isInAuthorities(chain: TLS.CertificateChain, certificateRequest: CertificateRequest): Boolean = {
chain.getCertificateList.exists { cert ⇒
chain.getCertificateList.exists { tlsCert ⇒
val cert = tlsCert.toCertificate
certificateRequest.getCertificateAuthorities.contains(cert.getSubject) || certificateRequest.getCertificateAuthorities.contains(cert.getIssuer)
}
}
Expand Down Expand Up @@ -119,9 +120,9 @@ object TLSUtils {
config.getString("hash-algorithm")
}

def defaultSignatureAlgorithms(protocolVersion: ProtocolVersion): java.util.Vector[_] = {
def defaultSignatureAlgorithms(protocolVersion: ProtocolVersion, context: TlsContext): java.util.Vector[_] = {
if (TlsUtils.isSignatureAlgorithmsExtensionAllowed(protocolVersion)) {
TlsUtils.getDefaultSupportedSignatureAlgorithms
TlsUtils.getDefaultSupportedSignatureAlgorithms(context)
} else {
null
}
Expand Down
Loading